diff --git a/connection.go b/connection.go index 6026ffc..d6f3436 100644 --- a/connection.go +++ b/connection.go @@ -1,7 +1,9 @@ package srpc import ( + "io" "net" + "sync" ) type RPCStatus uint8 @@ -40,7 +42,6 @@ type TCPConn struct { func (tc *TCPConn) ReceiveRequest() (ret *RPCRequest, err error) { ret = new(RPCRequest) - err = nil dec := tc.ed.NewDecoder(tc.conn) if err = dec.Decode(ret); err != nil { @@ -52,7 +53,6 @@ func (tc *TCPConn) ReceiveRequest() (ret *RPCRequest, err error) { func (tc *TCPConn) ReceiveResponse() (ret *RPCResponse, err error) { ret = new(RPCResponse) - err = nil dec := tc.ed.NewDecoder(tc.conn) if err = dec.Decode(ret); err != nil { @@ -63,8 +63,6 @@ func (tc *TCPConn) ReceiveResponse() (ret *RPCResponse, err error) { } func (tc *TCPConn) SendRequest(request *RPCRequest) (err error) { - err = nil - enc := tc.ed.NewEncoder(tc.conn) err = enc.Encode(request) @@ -72,8 +70,6 @@ func (tc *TCPConn) SendRequest(request *RPCRequest) (err error) { } func (tc *TCPConn) SendResponse(response *RPCResponse) (err error) { - err = nil - enc := tc.ed.NewEncoder(tc.conn) err = enc.Encode(response) @@ -91,3 +87,60 @@ func NewTCPConn(conn net.Conn, ed IEncoderDecoder) *TCPConn { return ret } + +type NetHeartbeatConn struct { + sync.Mutex + conn io.ReadWriteCloser + dec IDecoder + enc IEncoder + + // interfaces + IRPCConn +} + +func (nhc *NetHeartbeatConn) ReceiveRequest() (ret *RPCRequest, err error) { + nhc.Lock() + defer nhc.Unlock() + + ret = new(RPCRequest) + if err = nhc.dec.Decode(ret); err != nil { + return nil, err + } + + return ret, err +} + +func (nhc *NetHeartbeatConn) ReceiveResponse() (ret *RPCResponse, err error) { + nhc.Lock() + defer nhc.Unlock() + + ret = new(RPCResponse) + if err = nhc.dec.Decode(ret); err != nil { + return nil, err + } + + return ret, err +} + +func (nhc *NetHeartbeatConn) SendRequest(request *RPCRequest) (err error) { + nhc.Lock() + defer nhc.Unlock() + + err = nhc.enc.Encode(request) + return err +} + +func (nhc *NetHeartbeatConn) SendResponse(response *RPCResponse) (err error) { + nhc.Lock() + defer nhc.Unlock() + + err = nhc.enc.Encode(response) + return err +} + +func (nhc *NetHeartbeatConn) Close() { + nhc.Lock() + defer nhc.Unlock() + + nhc.conn.Close() +}