package srpc import ( "bytes" "errors" "io" "net" "sync" ) type RPCMsgType uint8 const ( RPC_REQUEST RPCMsgType = 0 RPC_RESPONSE RPCMsgType = 1 RPC_HEARTBEAT RPCMsgType = 2 RPC_CLOSE RPCMsgType = 3 ) const RPCHeaderSize = 13 type RPCHeader struct { RPCType RPCMsgType Size uint64 } type RPCHeartbeat struct { OK bool } type RPCClose struct { OK bool } type RPCCall struct { Header RPCHeader Payload interface{} } type IRPCConn interface { Send(payload interface{}) error Receive() (*RPCCall, error) Close() } type NetConn struct { sync.RWMutex conn net.Conn ed IEncoderDecoder // interface IRPCConn } func (tc *NetConn) Send(payload interface{}) (err error) { var header RPCHeader var hb, b bytes.Buffer switch payload.(type) { case *RPCHeartbeat: header.RPCType = RPC_HEARTBEAT break case *RPCClose: header.RPCType = RPC_CLOSE break case *RPCRequest: header.RPCType = RPC_REQUEST break case *RPCResponse: header.RPCType = RPC_RESPONSE break default: return errors.New("srpc - Invalid RPC message type") } enc := tc.ed.NewEncoder(&b) if err = enc.Encode(payload); err != nil { return errors.New("srpc - Error: '" + err.Error() + "'") } header.Size = uint64(len(b.Bytes())) enc = tc.ed.NewEncoder(&hb) if err = enc.Encode(header); err != nil { return errors.New("srpc - Error: '" + err.Error() + "'") } data := append([]byte{}, hb.Bytes()...) data = append(data, b.Bytes()...) tc.Lock() defer tc.Unlock() if n, err := tc.conn.Write(data); err != nil || n != len(data) { return errors.New("srpc - Error writing message") } return nil } func (tc *NetConn) Receive() (ret *RPCCall, err error) { var header RPCHeader var payload interface{} hb := make([]byte, RPCHeaderSize) tc.Lock() defer tc.Unlock() if n, err := tc.conn.Read(hb); err != nil || n != RPCHeaderSize { if err != nil { if err == io.EOF { header.RPCType = RPC_CLOSE ret = &RPCCall{header, new(RPCClose)} return ret, nil } } return nil, errors.New("srpc - Error receiving message") } dec := tc.ed.NewDecoder(bytes.NewReader(hb)) if err = dec.Decode(&header); err != nil { return nil, errors.New("srpc - Error: '" + err.Error() + "'") } switch header.RPCType { case RPC_HEARTBEAT: payload = new(RPCHeartbeat) break case RPC_CLOSE: payload = new(RPCClose) break case RPC_REQUEST: payload = new(RPCRequest) break case RPC_RESPONSE: payload = new(RPCResponse) break default: return nil, errors.New("srpc - Invalid RPC message type") } b := make([]byte, header.Size) if n, err := tc.conn.Read(b); err != nil || n != int(header.Size) { return nil, errors.New("srpc - Error receiving message") } dec = tc.ed.NewDecoder(bytes.NewReader(b)) if err = dec.Decode(payload); err != nil { return nil, errors.New("srpc - Error: '" + err.Error() + "'") } ret = new(RPCCall) ret.Header = header ret.Payload = payload return ret, nil } func (tc *NetConn) Close() { tc.conn.Close() } func NewNetConn(conn net.Conn, ed IEncoderDecoder) *NetConn { ret := &NetConn{} ret.conn = conn ret.ed = ed return ret }