diff --git a/client.go b/client.go index 3cf641f..c24a080 100644 --- a/client.go +++ b/client.go @@ -3,18 +3,21 @@ package srpc import ( "bytes" "errors" + "fmt" "io" "net" + "sync" + "time" ) type Client struct { - ed IEncoderDecoder - conn IRPCConn + sync.RWMutex + ed IEncoderDecoder + conn IRPCConn + running bool } func (client *Client) Call(funcName string, args ...interface{}) (ret []byte, err error) { - defer client.conn.Close() - var b bytes.Buffer enc := client.ed.NewEncoder(&b) @@ -25,13 +28,22 @@ func (client *Client) Call(funcName string, args ...interface{}) (ret []byte, er payload := b.Bytes() req := RPCRequest{funcName, payload} - if err = client.conn.SendRequest(&req); err != nil { + client.RLock() + defer client.RUnlock() + + if err = client.conn.Send(&req); err != nil { + return nil, err + } + + var call *RPCCall + if call, err = client.conn.Receive(); err != nil { return nil, err } var response *RPCResponse - if response, err = client.conn.ReceiveResponse(); err != nil { - return nil, err + var ok bool + if response, ok = call.Payload.(*RPCResponse); !ok { + return nil, errors.New("srpc - Expected response") } if response.Status != RPCOK { @@ -41,10 +53,40 @@ func (client *Client) Call(funcName string, args ...interface{}) (ret []byte, er return response.Payload, err } +func (client *Client) Close() { + client.Lock() + defer client.Unlock() + + client.conn.Send(new(RPCClose)) + client.running = false +} + func (client *Client) NewDecoder(r io.Reader) IDecoder { return client.ed.NewDecoder(r) } func NewClient(conn net.Conn) *Client { - return &Client{NewEncoderDecoder(), NewTCPConn(conn, NewEncoderDecoder())} + ret := &Client{sync.RWMutex{}, NewEncoderDecoder(), NewNetConn(conn, NewEncoderDecoder()), true} + go func() { + for { + fmt.Println("heartbeat") + ret.Lock() + + if ret.running == false { + ret.Unlock() + return + } + ret.conn.Send(new(RPCHeartbeat)) + if _, err := ret.conn.Receive(); err != nil { + ret.conn.Send(new(RPCClose)) + ret.Unlock() + return + } + fmt.Println("got heartbeat") + ret.Unlock() + time.Sleep(15 * time.Second) + } + }() + + return ret } diff --git a/connection.go b/connection.go index 98d13b6..ab37a7d 100644 --- a/connection.go +++ b/connection.go @@ -2,10 +2,10 @@ package srpc import ( "bytes" + "errors" "io" "net" "sync" - "time" ) type RPCMsgType uint8 @@ -14,9 +14,10 @@ const ( RPC_REQUEST RPCMsgType = 0 RPC_RESPONSE RPCMsgType = 1 RPC_HEARTBEAT RPCMsgType = 2 + RPC_CLOSE RPCMsgType = 3 ) -const RPCHeaderSize = 10 +const RPCHeaderSize = 13 type RPCHeader struct { RPCType RPCMsgType @@ -31,6 +32,11 @@ const ( ) type RPCHeartbeat struct { + OK bool +} + +type RPCClose struct { + OK bool } type RPCRequest struct { @@ -44,16 +50,19 @@ type RPCResponse struct { Payload []byte } +type RPCCall struct { + Header RPCHeader + Payload interface{} +} + type IRPCConn interface { - ReceiveRequest() (*RPCRequest, error) - ReceiveResponse() (*RPCResponse, error) - SendRequest(request *RPCRequest) error - SendResponse(response *RPCResponse) error + Send(payload interface{}) error + Receive() (*RPCCall, error) Close() } -type TCPConn struct { - sync.Mutex +type NetConn struct { + sync.RWMutex conn net.Conn ed IEncoderDecoder @@ -61,12 +70,14 @@ type TCPConn struct { IRPCConn } -func (tc *TCPConn) send(payload interface{}) (err error) { +func (tc *NetConn) Send(payload interface{}) (err error) { var header RPCHeader var hb, b bytes.Buffer if _, ok := payload.(*RPCHeartbeat); ok { header.RPCType = RPC_HEARTBEAT + } else if _, ok := payload.(*RPCClose); ok { + header.RPCType = RPC_CLOSE } else if _, ok := payload.(*RPCRequest); ok { header.RPCType = RPC_REQUEST } else if _, ok := payload.(*RPCResponse); ok { @@ -81,7 +92,7 @@ func (tc *TCPConn) send(payload interface{}) (err error) { } header.Size = uint64(len(b.Bytes())) - enc := tc.ed.NewEncoder(&hb) + enc = tc.ed.NewEncoder(&hb) if err = enc.Encode(header); err != nil { return errors.New("srpc - Error: '" + err.Error() + "'") } @@ -89,16 +100,32 @@ func (tc *TCPConn) send(payload interface{}) (err error) { data := append([]byte{}, hb.Bytes()...) data = append(data, b.Bytes()...) - if n, err = tc.conn.Write(data); err != nil || n != len(data) { + tc.Lock() + defer tc.Unlock() + + if n, err := tc.conn.Write(data); err != nil || n != len(data) { return errors.New("srpc - Error writing message") } + + return nil } -func (tc *TCPConn) receive() (payload interface{}, err error) { +func (tc *NetConn) Receive() (ret *RPCCall, err error) { var header RPCHeader + var payload interface{} hb := make([]byte, RPCHeaderSize) - if n, err = tc.conn.Read(hb); err != nil || n != RPCHeaderSize { + tc.Lock() + defer tc.Unlock() + + if n, err := tc.conn.Read(hb); err != nil || n != RPCHeaderSize { + if err != nil { + if err == io.EOF { + header.RPCType = RPC_CLOSE + ret = &RPCCall{header, new(RPCClose)} + return ret, nil + } + } return nil, errors.New("srpc - Error receiving message") } @@ -111,6 +138,9 @@ func (tc *TCPConn) receive() (payload interface{}, err error) { case RPC_HEARTBEAT: payload = new(RPCHeartbeat) break + case RPC_CLOSE: + payload = new(RPCClose) + break case RPC_REQUEST: payload = new(RPCRequest) break @@ -122,119 +152,30 @@ func (tc *TCPConn) receive() (payload interface{}, err error) { } b := make([]byte, header.Size) - if n, err = tc.conn.Read(b); err != nil || n != int(header.Size) { + if n, err := tc.conn.Read(b); err != nil || n != int(header.Size) { return nil, errors.New("srpc - Error receiving message") } - dec := tc.ed.NewDecoder(bytes.NewReader(b)) + dec = tc.ed.NewDecoder(bytes.NewReader(b)) if err = dec.Decode(payload); err != nil { - return errors.New("srpc - Error: '" + err.Error() + "'") + return nil, errors.New("srpc - Error: '" + err.Error() + "'") } - return payload, nil + ret = new(RPCCall) + ret.Header = header + ret.Payload = payload + + return ret, nil } -func (tc *TCPConn) ReceiveRequest() (ret *RPCRequest, err error) { - ret = new(RPCRequest) - - dec := tc.ed.NewDecoder(tc.conn) - if err = dec.Decode(ret); err != nil { - return nil, err - } - - return ret, err -} - -func (tc *TCPConn) ReceiveResponse() (ret *RPCResponse, err error) { - ret = new(RPCResponse) - - dec := tc.ed.NewDecoder(tc.conn) - if err = dec.Decode(ret); err != nil { - return nil, err - } - - return ret, err -} - -func (tc *TCPConn) SendRequest(request *RPCRequest) (err error) { - enc := tc.ed.NewEncoder(tc.conn) - err = enc.Encode(request) - - return err -} - -func (tc *TCPConn) SendResponse(response *RPCResponse) (err error) { - enc := tc.ed.NewEncoder(tc.conn) - err = enc.Encode(response) - - return err -} - -func (tc *TCPConn) Close() { +func (tc *NetConn) Close() { tc.conn.Close() } -func NewTCPConn(conn net.Conn, ed IEncoderDecoder) *TCPConn { - ret := &TCPConn{} +func NewNetConn(conn net.Conn, ed IEncoderDecoder) *NetConn { + ret := &NetConn{} ret.conn = conn ret.ed = ed return ret } - -type NetHeartbeatConn struct { - sync.Mutex - conn io.ReadWriteCloser - dec IDecoder - enc IEncoder - - // interfaces - IRPCConn -} - -func (nhc *NetHeartbeatConn) ReceiveRequest() (ret *RPCRequest, err error) { - nhc.Lock() - defer nhc.Unlock() - - ret = new(RPCRequest) - if err = nhc.dec.Decode(ret); err != nil { - return nil, err - } - - return ret, err -} - -func (nhc *NetHeartbeatConn) ReceiveResponse() (ret *RPCResponse, err error) { - nhc.Lock() - defer nhc.Unlock() - - ret = new(RPCResponse) - if err = nhc.dec.Decode(ret); err != nil { - return nil, err - } - - return ret, err -} - -func (nhc *NetHeartbeatConn) SendRequest(request *RPCRequest) (err error) { - nhc.Lock() - defer nhc.Unlock() - - err = nhc.enc.Encode(request) - return err -} - -func (nhc *NetHeartbeatConn) SendResponse(response *RPCResponse) (err error) { - nhc.Lock() - defer nhc.Unlock() - - err = nhc.enc.Encode(response) - return err -} - -func (nhc *NetHeartbeatConn) Close() { - nhc.Lock() - defer nhc.Unlock() - - nhc.conn.Close() -} diff --git a/server.go b/server.go index e4b9028..797e53f 100644 --- a/server.go +++ b/server.go @@ -8,6 +8,7 @@ import ( "olznet.de/slog" "reflect" "sync" + "time" ) type parseInputsFunc func(in []byte) (ret []reflect.Value, err error) @@ -24,50 +25,105 @@ type Server struct { serviceMap sync.Map } -func (server *Server) ServeConn(conn net.Conn) { +func (server *Server) ServeConn(conn IRPCConn) { var fs service - var request *RPCRequest + var call *RPCCall var err error + received := make(chan bool, 1) - tcpConn := NewTCPConn(conn, server.ed) - defer tcpConn.Close() + defer conn.Close() - if request, err = tcpConn.ReceiveRequest(); err != nil { - slog.LOG_ERROR("srpc - Malformed request received: '%s'\n", err.Error()) - return - } + for { + go func() { + if call, err = conn.Receive(); err != nil { + slog.LOG_ERROR("srpc - Malformed request received: '%s'\n", err.Error()) + return + } + received <- true + }() - var response RPCResponse - response.Status = RPCOK - response.Error = "" + select { + case <-received: + switch call.Header.RPCType { + case RPC_HEARTBEAT: + slog.LOG_INFO("srpc - Client sent hearbeat\n") + if err = conn.Send(new(RPCHeartbeat)); err != nil { + slog.LOG_ERROR("srpc - Error sending heartbeat: '%s'\n", err.Error()) + } + break + case RPC_CLOSE: + slog.LOG_INFO("srpc - Client closed connection.\n") + if err = conn.Send(new(RPCClose)); err != nil { + slog.LOG_ERROR("srpc - Error sending close: '%s'\n", err.Error()) + } + return + case RPC_RESPONSE: + slog.LOG_ERROR("srpc - Got response WTF?!\n") + break + case RPC_REQUEST: + request, ok := call.Payload.(*RPCRequest) + if !ok { + slog.LOG_ERROR("srpc - Expected request, but got: %v\n", call.Header.RPCType) + break + } - if fsRaw, ok := server.serviceMap.Load(request.FuncName); !ok { - slog.LOG_ERROR("srpc - Call to unknown method: '%s'\n", request.FuncName) - err := fmt.Sprintf("Unknown method: '%s'", request.FuncName) - response.Status = RPCERR - response.Error = err - } else { - fs = fsRaw.(service) - if inputs, err := fs.fin(request.Payload); err != nil { - slog.LOG_ERROR("srpc - Error parsing inputs '%s'\n", err.Error()) - response.Status = RPCERR - response.Error = err.Error() - } else { - outputs := fs.f.Call(inputs) - if payload, err := fs.fout(outputs); err != nil { - slog.LOG_ERROR("srpc - Error parsing outputs '%s'\n", err.Error()) - response.Status = RPCERR - response.Error = err.Error() - } else { + var response RPCResponse + response.Status = RPCOK + response.Error = "" + + fsRaw, ok := server.serviceMap.Load(request.FuncName) + if !ok { + slog.LOG_ERROR("srpc - Call to unknown method: '%s'\n", request.FuncName) + err := fmt.Sprintf("Unknown method: '%s'", request.FuncName) + response.Status = RPCERR + response.Error = err + if err := conn.Send(&response); err != nil { + slog.LOG_ERROR("srpc - Error sending response: '%s'\n", err.Error()) + } + break + } + + fs = fsRaw.(service) + inputs, err := fs.fin(request.Payload) + if err != nil { + slog.LOG_ERROR("srpc - Error parsing inputs '%s'\n", err.Error()) + response.Status = RPCERR + response.Error = err.Error() + if err := conn.Send(&response); err != nil { + slog.LOG_ERROR("srpc - Error sending response: '%s'\n", err.Error()) + } + break + } + + outputs := fs.f.Call(inputs) + payload, err := fs.fout(outputs) + if err != nil { + slog.LOG_ERROR("srpc - Error parsing outputs '%s'\n", err.Error()) + response.Status = RPCERR + response.Error = err.Error() + if err := conn.Send(&response); err != nil { + slog.LOG_ERROR("srpc - Error sending response: '%s'\n", err.Error()) + } + break + } slog.LOG_DEBUG("%v\n", outputs) response.Payload = append([]byte{}, payload...) + if err := conn.Send(&response); err != nil { + slog.LOG_ERROR("srpc - Error sending response: '%s'\n", err.Error()) + } + break + default: + slog.LOG_ERROR("srpc - Unknown rpc call received\n") + break } + case <-time.After(30 * time.Second): + slog.LOG_INFO("srpc - Client gone. Closing connection\n") + if err := conn.Send(new(RPCClose)); err != nil { + slog.LOG_ERROR("srpc - Error sending close: '%s'\n", err.Error()) + } + return } } - - if err = tcpConn.SendResponse(&response); err != nil { - slog.LOG_ERROR("srpc - Error sending response '%s'\n", err.Error()) - } } func (server *Server) Accept(ln net.Listener) { @@ -77,7 +133,7 @@ func (server *Server) Accept(ln net.Listener) { slog.LOG_INFO("srpc - Accept: '%s'\n", err.Error()) return } - go server.ServeConn(conn) + go server.ServeConn(NewNetConn(conn, server.ed)) } }