diff --git a/client.go b/client.go index b3d80c5..4971906 100644 --- a/client.go +++ b/client.go @@ -1,9 +1,9 @@ package srpc import ( - //"bytes" + "bytes" "errors" - //"fmt" + "fmt" "io" "net" "sync" @@ -93,9 +93,20 @@ type clientMessage struct { type call struct { Response serverMessage + Request clientMessage + Error error + state messageState t time.Time - done chan struct{} - request clientMessage + done chan bool + sync.RWMutex +} + +func newCall() (ret *call) { + ret = new(call) + ret.state = FREE + ret.done = make(chan bool) + + return ret } type ResponseDataHandlerFunc func() (response interface{}) @@ -112,16 +123,31 @@ func unixDial(addr string) (conn io.ReadWriteCloser, err error) { type RPCClient struct { Addr string + MaxRequests int LogError LogErrorFunc ConnectHander ConnectHandlerFunc ResponseDataHandler ResponseDataHandlerFunc DialHandler DialHandlerFunc Ed IEncoderDecoder + calls []*call + callChan chan *call + nextID uint64 stopChan chan struct{} stopWg sync.WaitGroup } func (c *RPCClient) Start() (err error) { + if c.MaxRequests <= 0 { + c.MaxRequests = DefaultMaxClientRequests + } + + c.calls = make([]*call, c.MaxRequests) + for i := 0; i < c.MaxRequests; i++ { + c.calls[i] = newCall() + } + c.nextID = 0 + c.callChan = make(chan *call) + if c.LogError == nil { c.LogError = logError } @@ -140,6 +166,7 @@ func (c *RPCClient) Start() (err error) { } c.stopWg.Add(1) + go clientHandler(c) return nil } @@ -152,6 +179,39 @@ func (c *RPCClient) Stop() { c.stopChan = nil } +func (c *RPCClient) Call(request interface{}) (response interface{}, err error) { + var requestCall *call + + for _, e := range c.calls { + e.Lock() + if e.state == FREE { + requestCall = e + requestCall.state = PENDING + requestCall.Request.ID = c.nextID + requestCall.Request.ClientAddr = c.Addr + requestCall.Request.Data = request + c.nextID++ + e.Unlock() + break + } + e.Unlock() + } + if requestCall == nil { + return nil, errors.New("srpc - Client requests are full") + } + + c.callChan <- requestCall + + <-requestCall.done + requestCall.Lock() + response = requestCall.Response.Data + err = requestCall.Error + requestCall.state = FREE + requestCall.Unlock() + + return response, err +} + func clientHandler(c *RPCClient) { defer c.stopWg.Done() @@ -187,8 +247,12 @@ func clientHandler(c *RPCClient) { continue } - c.stopWg.Add(1) - go clientHandleConnection(c, conn) + clientHandleConnection(c, conn) + select { + case <-c.stopChan: + return + default: + } } } @@ -202,4 +266,99 @@ func clientHandleConnection(c *RPCClient, conn io.ReadWriteCloser) { conn = newConn } } + + stopChan := make(chan struct{}) + + writerDone := make(chan error, 1) + go clientWriter(c, conn, stopChan, writerDone) + + readerDone := make(chan error, 1) + go clientReader(c, conn, readerDone) + + var err error + select { + case err = <-writerDone: + close(stopChan) + conn.Close() + <-readerDone + case err = <-readerDone: + close(stopChan) + conn.Close() + <-writerDone + case <-c.stopChan: + close(stopChan) + conn.Close() + <-readerDone + <-writerDone + } + + if err != nil { + c.LogError("srpc - '%s'\n", err) + } + for _, e := range c.calls { + e.Lock() + if e.state == PENDING { + e.Error = err + e.done <- true + } + e.Unlock() + } +} + +func clientWriter(c *RPCClient, conn io.Writer, stopChan <-chan struct{}, done chan<- error) { + var err error + defer func() { done <- err }() + + enc := c.Ed.NewEncoder(conn) + b := bytes.Buffer{} + enc2 := c.Ed.NewEncoder(&b) + + for { + err = nil + var requestCall *call + + select { + case requestCall = <-c.callChan: + requestCall.Lock() + enc2.Encode(requestCall.Request) + fmt.Println(b.Bytes()) + if err = enc.Encode(requestCall.Request); err != nil { + if !serverDisconnect(err) && !clientStop(c.stopChan) { + requestCall.Error = errors.New("srpc - '%s'=>'%s': Cannot encode request: '%s'\n") + requestCall.Unlock() + requestCall.done <- true + } else { + requestCall.Unlock() + return + } + } + case <-stopChan: + return + } + } +} + +func clientReader(c *RPCClient, conn io.Reader, done <-chan error) { +} + +func serverDisconnect(err error) bool { + return err == io.ErrUnexpectedEOF || err == io.EOF +} + +func clientStop(stopChan <-chan struct{}) bool { + select { + case <-stopChan: + return true + default: + return false + } +} + +func NewUnixClient(addr string, handler ResponseDataHandlerFunc) *RPCClient { + return &RPCClient{ + Addr: addr, + ResponseDataHandler: handler, + DialHandler: unixDial, + Ed: NewEncoderDecoder(), + } } diff --git a/configuration.go b/configuration.go index bc7bc65..29ad877 100644 --- a/configuration.go +++ b/configuration.go @@ -11,3 +11,14 @@ var logError = LogErrorFunc(log.Printf) func SetLogError(f LogErrorFunc) { logError = f } + +type messageState uint8 + +const ( + FREE messageState = 0 + PENDING messageState = 1 +) + +const ( + DefaultMaxClientRequests = int(128) +) diff --git a/rpc.go b/rpc.go index 910143b..b1007ec 100644 --- a/rpc.go +++ b/rpc.go @@ -145,8 +145,8 @@ func RequestDataHandler() (request interface{}) { return new(RPCRequest) } -func ResponseDataHandler() (request interface{}) { - return new(RPCRequest) +func ResponseDataHandler() (response interface{}) { + return new(RPCResponse) } func NewRPC(ed IEncoderDecoder) *RPC { diff --git a/server.go b/server.go index 31544de..68f76f1 100644 --- a/server.go +++ b/server.go @@ -1,6 +1,7 @@ package srpc import ( + "bytes" "errors" "fmt" "io" @@ -177,16 +178,23 @@ func serverHandleConnection(s *RPCServer, conn io.ReadWriteCloser, clientAddr st } var err error - var msg serverMessage + var msg clientMessage msg.Data = s.RequestDataHandler() + fmt.Println(msg) + p := make([]byte, 62) - dec := s.Ed.NewDecoder(conn) + conn.Read(p) + fmt.Println(p) + + dec := s.Ed.NewDecoder(bytes.NewBuffer(p)) + //dec := s.Ed.NewDecoder(conn) enc := s.Ed.NewEncoder(conn) for { if err = dec.Decode(&msg); err != nil { if !clientDisconnect(err) && !serverStop(s.stopChan) { s.LogError("srpc - '%s'=>'%s': Cannot decode request: '%s'\n", clientAddr, s.Addr, err) + fmt.Println(msg) } conn.Close() return