package srpc import ( "bytes" "errors" "fmt" "io" "net" "sync" "sync/atomic" "time" ) //func (client *Client) Call(funcName string, args ...interface{}) (ret []byte, err error) { //var b bytes.Buffer //enc := client.ed.NewEncoder(&b) //for _, a := range args { //enc.Encode(a) //} //payload := b.Bytes() //req := RPCRequest{funcName, payload} //client.RLock() //defer client.RUnlock() //if err = client.conn.Send(&req); err != nil { //return nil, err //} //var call *RPCCall //if call, err = client.conn.Receive(); err != nil { //return nil, err //} //var response *RPCResponse //var ok bool //if response, ok = call.Payload.(*RPCResponse); !ok { //return nil, errors.New("srpc - Expected response") //} //if response.Status != RPCOK { //err = errors.New("srpc - Response contained error: '" + response.Error + "'") //} //return response.Payload, err //} //func (client *Client) Close() { //client.Lock() //defer client.Unlock() //client.conn.Send(new(RPCClose)) //client.running = false //} //func (client *Client) NewDecoder(r io.Reader) IDecoder { //return client.ed.NewDecoder(r) //} //func NewClient(conn net.Conn) *Client { //ret := &Client{sync.RWMutex{}, NewEncoderDecoder(), NewNetConn(conn, NewEncoderDecoder()), true} //go func() { //for { //fmt.Println("heartbeat") //ret.Lock() //if ret.running == false { //ret.Unlock() //return //} //ret.conn.Send(new(RPCHeartbeat)) //if _, err := ret.conn.Receive(); err != nil { //ret.conn.Send(new(RPCClose)) //ret.Unlock() //return //} //fmt.Println("got heartbeat") //ret.Unlock() //time.Sleep(15 * time.Second) //} //}() //return ret //} type clientMessage struct { ID uint64 Data interface{} ClientAddr string } type call struct { Response serverMessage Request clientMessage Error error state messageState t time.Time done chan bool sync.RWMutex } func newCall() (ret *call) { ret = new(call) ret.state = FREE ret.done = make(chan bool) return ret } type ResponseDataHandlerFunc func() (response interface{}) type DialHandlerFunc func(addr string) (conn io.ReadWriteCloser, err error) func unixDial(addr string) (conn io.ReadWriteCloser, err error) { if conn, err = net.Dial("unix", addr); err != nil { return nil, err } return conn, nil } type RPCClient struct { Addr string MaxRequests int LogError LogErrorFunc ConnectHander ConnectHandlerFunc ResponseDataHandler ResponseDataHandlerFunc DialHandler DialHandlerFunc Ed IEncoderDecoder calls []*call callChan chan *call nextID uint64 stopChan chan struct{} stopWg sync.WaitGroup } func (c *RPCClient) Start() (err error) { if c.MaxRequests <= 0 { c.MaxRequests = DefaultMaxClientRequests } c.calls = make([]*call, c.MaxRequests) for i := 0; i < c.MaxRequests; i++ { c.calls[i] = newCall() } c.nextID = 0 c.callChan = make(chan *call) if c.LogError == nil { c.LogError = logError } if c.stopChan != nil { return errors.New("srpc - Client already running") } c.stopChan = make(chan struct{}) if c.ResponseDataHandler == nil { return errors.New("srpc - Client needs a ResponseDataHandler") } if c.DialHandler == nil { return errors.New("srpc - Client needs a DialHandler") } c.stopWg.Add(1) go clientHandler(c) return nil } func (c *RPCClient) Stop() { if c.stopChan == nil { return } close(c.stopChan) c.stopWg.Wait() c.stopChan = nil } func (c *RPCClient) Call(request interface{}) (response interface{}, err error) { var requestCall *call for _, e := range c.calls { e.Lock() if e.state == FREE { requestCall = e requestCall.state = PENDING requestCall.Request.ID = c.nextID requestCall.Request.ClientAddr = c.Addr requestCall.Request.Data = request c.nextID++ e.Unlock() break } e.Unlock() } if requestCall == nil { return nil, errors.New("srpc - Client requests are full") } c.callChan <- requestCall <-requestCall.done requestCall.Lock() response = requestCall.Response.Data err = requestCall.Error requestCall.state = FREE requestCall.Unlock() return response, err } func clientHandler(c *RPCClient) { defer c.stopWg.Done() var conn io.ReadWriteCloser var err error var stopping atomic.Value for { dialChan := make(chan struct{}) go func() { if conn, err = c.DialHandler(c.Addr); err != nil { if stopping.Load() == nil { c.LogError("srpc - '%s' cannot estable rpc connection: '%s'\n", c.Addr, err) } } close(dialChan) }() select { case <-c.stopChan: stopping.Store(true) <-dialChan return case <-dialChan: } if err != nil { select { case <-c.stopChan: return case <-time.After(time.Second): } continue } clientHandleConnection(c, conn) select { case <-c.stopChan: return default: } } } func clientHandleConnection(c *RPCClient, conn io.ReadWriteCloser) { if c.ConnectHander != nil { if newConn, err := c.ConnectHander(c.Addr, conn); err != nil { c.LogError("srpc - Client: [%s] - Connect error: '%s'\n", c.Addr, err) conn.Close() return } else { conn = newConn } } stopChan := make(chan struct{}) writerDone := make(chan error, 1) go clientWriter(c, conn, stopChan, writerDone) readerDone := make(chan error, 1) go clientReader(c, conn, readerDone) var err error select { case err = <-writerDone: close(stopChan) conn.Close() <-readerDone case err = <-readerDone: close(stopChan) conn.Close() <-writerDone case <-c.stopChan: close(stopChan) conn.Close() <-readerDone <-writerDone } if err != nil { c.LogError("srpc - '%s'\n", err) } for _, e := range c.calls { e.Lock() if e.state == PENDING { e.Error = err e.done <- true } e.Unlock() } } func clientWriter(c *RPCClient, conn io.Writer, stopChan <-chan struct{}, done chan<- error) { var err error defer func() { done <- err }() enc := c.Ed.NewEncoder(conn) b := bytes.Buffer{} enc2 := c.Ed.NewEncoder(&b) for { err = nil var requestCall *call select { case requestCall = <-c.callChan: requestCall.Lock() enc2.Encode(requestCall.Request) fmt.Println(b.Bytes()) if err = enc.Encode(requestCall.Request); err != nil { if !serverDisconnect(err) && !clientStop(c.stopChan) { requestCall.Error = errors.New("srpc - '%s'=>'%s': Cannot encode request: '%s'\n") requestCall.Unlock() requestCall.done <- true } else { requestCall.Unlock() return } } case <-stopChan: return } } } func clientReader(c *RPCClient, conn io.Reader, done <-chan error) { } func serverDisconnect(err error) bool { return err == io.ErrUnexpectedEOF || err == io.EOF } func clientStop(stopChan <-chan struct{}) bool { select { case <-stopChan: return true default: return false } } func NewUnixClient(addr string, handler ResponseDataHandlerFunc) *RPCClient { return &RPCClient{ Addr: addr, ResponseDataHandler: handler, DialHandler: unixDial, Ed: NewEncoderDecoder(), } }