diff --git a/client.go b/client.go index e6017bc..f21ebfd 100644 --- a/client.go +++ b/client.go @@ -10,80 +10,6 @@ import ( "time" ) -//func (client *Client) Call(funcName string, args ...interface{}) (ret []byte, err error) { -//var b bytes.Buffer - -//enc := client.ed.NewEncoder(&b) -//for _, a := range args { -//enc.Encode(a) -//} - -//payload := b.Bytes() -//req := RPCRequest{funcName, payload} - -//client.RLock() -//defer client.RUnlock() - -//if err = client.conn.Send(&req); err != nil { -//return nil, err -//} - -//var call *RPCCall -//if call, err = client.conn.Receive(); err != nil { -//return nil, err -//} - -//var response *RPCResponse -//var ok bool -//if response, ok = call.Payload.(*RPCResponse); !ok { -//return nil, errors.New("srpc - Expected response") -//} - -//if response.Status != RPCOK { -//err = errors.New("srpc - Response contained error: '" + response.Error + "'") -//} - -//return response.Payload, err -//} - -//func (client *Client) Close() { -//client.Lock() -//defer client.Unlock() - -//client.conn.Send(new(RPCClose)) -//client.running = false -//} - -//func (client *Client) NewDecoder(r io.Reader) IDecoder { -//return client.ed.NewDecoder(r) -//} - -//func NewClient(conn net.Conn) *Client { -//ret := &Client{sync.RWMutex{}, NewEncoderDecoder(), NewNetConn(conn, NewEncoderDecoder()), true} -//go func() { -//for { -//fmt.Println("heartbeat") -//ret.Lock() - -//if ret.running == false { -//ret.Unlock() -//return -//} -//ret.conn.Send(new(RPCHeartbeat)) -//if _, err := ret.conn.Receive(); err != nil { -//ret.conn.Send(new(RPCClose)) -//ret.Unlock() -//return -//} -//fmt.Println("got heartbeat") -//ret.Unlock() -//time.Sleep(15 * time.Second) -//} -//}() - -//return ret -//} - type clientMessage struct { ID uint64 Data interface{} @@ -95,7 +21,6 @@ type call struct { Request clientMessage Error error state messageState - t time.Time done chan bool sync.RWMutex } @@ -128,6 +53,7 @@ type RPCClient struct { ResponseDataHandler ResponseDataHandlerFunc DialHandler DialHandlerFunc Ed IEncoderDecoder + RequestTimeout time.Duration calls []*call callChan chan *call nextID uint64 @@ -164,6 +90,10 @@ func (c *RPCClient) Start() (err error) { return errors.New("srpc - Client needs a DialHandler") } + if c.RequestTimeout == 0 { + c.RequestTimeout = DefaultRequestTimeout + } + c.stopWg.Add(1) go clientHandler(c) return nil @@ -199,16 +129,25 @@ func (c *RPCClient) Call(request interface{}) (response interface{}, err error) return nil, errors.New("srpc - Client requests are full") } + fmt.Printf("Sending request with ID %d\n", requestCall.Request.ID) + fmt.Printf("Request %v\n", requestCall.Request.Data) c.callChan <- requestCall - <-requestCall.done - requestCall.Lock() - response = requestCall.Response.Data - err = requestCall.Error - requestCall.state = FREE - requestCall.Unlock() - fmt.Println(response) - fmt.Println(err) + select { + case <-requestCall.done: + fmt.Printf("Finished request with ID %d\n", requestCall.Request.ID) + requestCall.Lock() + response = requestCall.Response.Data + err = requestCall.Error + requestCall.state = FREE + requestCall.Unlock() + case <-time.After(c.RequestTimeout): + requestCall.Lock() + response = nil + err = errors.New("srpc - Request timed out") + requestCall.state = FREE + requestCall.Unlock() + } return response, err } @@ -352,12 +291,9 @@ func clientReader(c *RPCClient, conn io.Reader, done chan<- error) { } } - fmt.Println(response) ok := false for _, e := range c.calls { - fmt.Println(e.Request.ID) e.Lock() - fmt.Println(e.Request.ID) if e.Request.ID == response.ID { e.Response = response if e.Response.Error != "" { @@ -390,11 +326,12 @@ func clientStop(stopChan <-chan struct{}) bool { } } -func NewUnixClient(addr string, handler ResponseDataHandlerFunc) *RPCClient { +func NewUnixClient(addr string, handler ResponseDataHandlerFunc, timeout time.Duration) *RPCClient { return &RPCClient{ Addr: addr, ResponseDataHandler: handler, DialHandler: unixDial, Ed: NewEncoderDecoder(), + RequestTimeout: timeout, } } diff --git a/configuration.go b/configuration.go index 29ad877..c330b25 100644 --- a/configuration.go +++ b/configuration.go @@ -2,6 +2,7 @@ package srpc import ( "log" + "time" ) type LogErrorFunc func(format string, args ...interface{}) @@ -21,4 +22,5 @@ const ( const ( DefaultMaxClientRequests = int(128) + DefaultRequestTimeout = 30 * time.Second ) diff --git a/server.go b/server.go index 0faef2f..5376c86 100644 --- a/server.go +++ b/server.go @@ -176,13 +176,15 @@ func serverHandleConnection(s *RPCServer, conn io.ReadWriteCloser, clientAddr st } var err error - var msg clientMessage - msg.Data = s.RequestDataHandler() dec := s.Ed.NewDecoder(conn) enc := s.Ed.NewEncoder(conn) for { + //TODO: check for multiple requests per client + var msg clientMessage + msg.Data = s.RequestDataHandler() + if err = dec.Decode(&msg); err != nil { if !clientDisconnect(err) && !serverStop(s.stopChan) { s.LogError("srpc - '%s'=>'%s': Cannot decode request: '%s'\n", clientAddr, s.Addr, err) @@ -191,20 +193,21 @@ func serverHandleConnection(s *RPCServer, conn io.ReadWriteCloser, clientAddr st return } - var response serverMessage - response.ID = msg.ID - response.ClientAddr = msg.ClientAddr + fmt.Println(msg.Data) + go func() { + var response serverMessage + response.ID = msg.ID + response.ClientAddr = msg.ClientAddr - response.Data = s.RequestHandler(msg.ClientAddr, msg.Data) - fmt.Println(response.Data) - - if err = enc.Encode(response); err != nil { - if !clientDisconnect(err) && !serverStop(s.stopChan) { - s.LogError("srpc - '%s'=>'%s': Cannot encode response: '%s'\n", clientAddr, s.Addr, err) + response.Data = s.RequestHandler(msg.ClientAddr, msg.Data) + if err = enc.Encode(response); err != nil { + if !clientDisconnect(err) && !serverStop(s.stopChan) { + s.LogError("srpc - '%s'=>'%s': Cannot encode response: '%s'\n", clientAddr, s.Addr, err) + } + conn.Close() + return } - conn.Close() - return - } + }() } }