package srpc import ( "bytes" "errors" "io" "net" "sync" "sync/atomic" "time" ) type clientMessage struct { ID uint64 Data interface{} ClientAddr string } type call struct { Response serverMessage Request clientMessage Error error state messageState done chan bool sync.RWMutex } func newCall() (ret *call) { ret = new(call) ret.state = FREE ret.done = make(chan bool) return ret } type ResponseDataHandlerFunc func() (response interface{}) type DialHandlerFunc func(addr string) (conn io.ReadWriteCloser, err error) func unixDial(addr string) (conn io.ReadWriteCloser, err error) { if conn, err = net.Dial("unix", addr); err != nil { return nil, err } return conn, nil } type RPCClient struct { Addr string MaxRequests int LogError LogErrorFunc ConnectHander ConnectHandlerFunc ResponseDataHandler ResponseDataHandlerFunc DialHandler DialHandlerFunc Ed IEncoderDecoder RequestTimeout time.Duration calls []*call callChan chan *call nextID uint64 stopChan chan struct{} stopWg sync.WaitGroup } func (c *RPCClient) Start() (err error) { if c.MaxRequests <= 0 { c.MaxRequests = DefaultMaxClientRequests } c.calls = make([]*call, c.MaxRequests) for i := 0; i < c.MaxRequests; i++ { c.calls[i] = newCall() } c.nextID = 0 c.callChan = make(chan *call) if c.LogError == nil { c.LogError = logError } if c.stopChan != nil { return errors.New("srpc - Client already running") } c.stopChan = make(chan struct{}) if c.ResponseDataHandler == nil { return errors.New("srpc - Client needs a ResponseDataHandler") } if c.DialHandler == nil { return errors.New("srpc - Client needs a DialHandler") } if c.RequestTimeout == 0 { c.RequestTimeout = DefaultRequestTimeout } c.stopWg.Add(1) go clientHandler(c) return nil } func (c *RPCClient) Stop() { if c.stopChan == nil { return } close(c.stopChan) c.stopWg.Wait() c.stopChan = nil } func (c *RPCClient) MakeRequest(fctName string, params ...interface{}) (request RPCRequest, err error) { data := bytes.Buffer{} enc := c.Ed.NewEncoder(&data) for _, param := range params { if err = enc.Encode(param); err != nil { return request, err } } request.FuncName = fctName request.Payload = data.Bytes() return request, nil } func (c *RPCClient) Call(fctName string, params ...interface{}) (response interface{}, err error) { var requestCall *call var request RPCRequest if request, err = c.MakeRequest(fctName, params...); err != nil { return nil, errors.New("srpc - Client request encode failed: '" + err.Error() + "'") } for _, e := range c.calls { e.Lock() if e.state == FREE { requestCall = e requestCall.state = PENDING requestCall.Request.ID = c.nextID requestCall.Request.ClientAddr = c.Addr requestCall.Request.Data = request c.nextID++ e.Unlock() break } e.Unlock() } if requestCall == nil { return nil, errors.New("srpc - Client requests are full") } c.callChan <- requestCall select { case <-requestCall.done: requestCall.Lock() response = requestCall.Response.Data err = requestCall.Error requestCall.state = FREE requestCall.Unlock() case <-time.After(c.RequestTimeout): requestCall.Lock() response = nil err = errors.New("srpc - Request timed out") requestCall.state = FREE requestCall.Unlock() } return response, err } func clientHandler(c *RPCClient) { defer c.stopWg.Done() var conn io.ReadWriteCloser var err error var stopping atomic.Value for { dialChan := make(chan struct{}) go func() { if conn, err = c.DialHandler(c.Addr); err != nil { if stopping.Load() == nil { c.LogError("srpc - '%s' cannot establish rpc connection: '%s'\n", c.Addr, err) } } close(dialChan) }() select { case <-c.stopChan: stopping.Store(true) <-dialChan return case <-dialChan: } if err != nil { select { case <-c.stopChan: return case <-time.After(time.Second): } continue } clientHandleConnection(c, conn) select { case <-c.stopChan: return default: } } } func clientHandleConnection(c *RPCClient, conn io.ReadWriteCloser) { if c.ConnectHander != nil { if newConn, err := c.ConnectHander(c.Addr, conn); err != nil { c.LogError("srpc - Client: [%s] - Connect error: '%s'\n", c.Addr, err) conn.Close() return } else { conn = newConn } } stopChan := make(chan struct{}) writerDone := make(chan error, 1) go clientWriter(c, conn, stopChan, writerDone) readerDone := make(chan error, 1) go clientReader(c, conn, readerDone) var err error select { case err = <-writerDone: close(stopChan) conn.Close() <-readerDone case err = <-readerDone: close(stopChan) conn.Close() <-writerDone case <-c.stopChan: close(stopChan) conn.Close() <-readerDone <-writerDone } if err != nil { c.LogError("srpc - '%s'\n", err) } for _, e := range c.calls { e.Lock() if e.state == PENDING { e.Error = err e.done <- true } e.Unlock() } } func clientWriter(c *RPCClient, conn io.Writer, stopChan <-chan struct{}, done chan<- error) { var err error defer func() { done <- err }() enc := c.Ed.NewEncoder(conn) for { err = nil var requestCall *call select { case requestCall = <-c.callChan: requestCall.Lock() if err = enc.Encode(requestCall.Request); err != nil { if !serverDisconnect(err) && !clientStop(c.stopChan) { requestCall.Error = errors.New("srpc - '%s'=>'%s': Cannot encode request: '%s'\n") requestCall.Unlock() requestCall.done <- true } else { requestCall.Unlock() return } } requestCall.Unlock() case <-stopChan: return } } } func clientReader(c *RPCClient, conn io.Reader, done chan<- error) { var err error defer func() { done <- err }() dec := c.Ed.NewDecoder(conn) for { var response serverMessage response.Data = c.ResponseDataHandler() if err = dec.Decode(&response); err != nil { if serverDisconnect(err) || clientStop(c.stopChan) { return } } ok := false for _, e := range c.calls { e.Lock() if e.Request.ID == response.ID { e.Response = response if e.Response.Error != "" { e.Error = errors.New(e.Response.Error) } ok = true e.done <- true e.Unlock() break } e.Unlock() } if !ok { c.LogError("srpc - Client response for unknown request ID '%d'\n", response.ID) } } } func serverDisconnect(err error) bool { return err == io.ErrUnexpectedEOF || err == io.EOF } func clientStop(stopChan <-chan struct{}) bool { select { case <-stopChan: return true default: return false } } func NewUnixClient(addr string, handler ResponseDataHandlerFunc, timeout time.Duration) *RPCClient { return &RPCClient{ Addr: addr, ResponseDataHandler: handler, DialHandler: unixDial, Ed: NewEncoderDecoder(), RequestTimeout: timeout, } }