package srpc import ( "bytes" "io" "net" "sync" "time" ) type RPCMsgType uint8 const ( RPC_REQUEST RPCMsgType = 0 RPC_RESPONSE RPCMsgType = 1 RPC_HEARTBEAT RPCMsgType = 2 ) const RPCHeaderSize = 10 type RPCHeader struct { RPCType RPCMsgType Size uint64 } type RPCStatus uint8 const ( RPCOK RPCStatus = 0 RPCERR RPCStatus = 1 ) type RPCHeartbeat struct { } type RPCRequest struct { FuncName string Payload []byte } type RPCResponse struct { Status RPCStatus Error string Payload []byte } type IRPCConn interface { ReceiveRequest() (*RPCRequest, error) ReceiveResponse() (*RPCResponse, error) SendRequest(request *RPCRequest) error SendResponse(response *RPCResponse) error Close() } type TCPConn struct { sync.Mutex conn net.Conn ed IEncoderDecoder // interface IRPCConn } func (tc *TCPConn) send(payload interface{}) (err error) { var header RPCHeader var hb, b bytes.Buffer if _, ok := payload.(*RPCHeartbeat); ok { header.RPCType = RPC_HEARTBEAT } else if _, ok := payload.(*RPCRequest); ok { header.RPCType = RPC_REQUEST } else if _, ok := payload.(*RPCResponse); ok { header.RPCType = RPC_RESPONSE } else { return errors.New("srpc - Invalid RPC message type") } enc := tc.ed.NewEncoder(&b) if err = enc.Encode(payload); err != nil { return errors.New("srpc - Error: '" + err.Error() + "'") } header.Size = uint64(len(b.Bytes())) enc := tc.ed.NewEncoder(&hb) if err = enc.Encode(header); err != nil { return errors.New("srpc - Error: '" + err.Error() + "'") } data := append([]byte{}, hb.Bytes()...) data = append(data, b.Bytes()...) if n, err = tc.conn.Write(data); err != nil || n != len(data) { return errors.New("srpc - Error writing message") } } func (tc *TCPConn) receive() (payload interface{}, err error) { var header RPCHeader hb := make([]byte, RPCHeaderSize) if n, err = tc.conn.Read(hb); err != nil || n != RPCHeaderSize { return nil, errors.New("srpc - Error receiving message") } dec := tc.ed.NewDecoder(bytes.NewReader(hb)) if err = dec.Decode(&header); err != nil { return nil, errors.New("srpc - Error: '" + err.Error() + "'") } switch header.RPCType { case RPC_HEARTBEAT: payload = new(RPCHeartbeat) break case RPC_REQUEST: payload = new(RPCRequest) break case RPC_RESPONSE: payload = new(RPCResponse) break default: return nil, errors.New("srpc - Invalid RPC message type") } b := make([]byte, header.Size) if n, err = tc.conn.Read(b); err != nil || n != int(header.Size) { return nil, errors.New("srpc - Error receiving message") } dec := tc.ed.NewDecoder(bytes.NewReader(b)) if err = dec.Decode(payload); err != nil { return errors.New("srpc - Error: '" + err.Error() + "'") } return payload, nil } func (tc *TCPConn) ReceiveRequest() (ret *RPCRequest, err error) { ret = new(RPCRequest) dec := tc.ed.NewDecoder(tc.conn) if err = dec.Decode(ret); err != nil { return nil, err } return ret, err } func (tc *TCPConn) ReceiveResponse() (ret *RPCResponse, err error) { ret = new(RPCResponse) dec := tc.ed.NewDecoder(tc.conn) if err = dec.Decode(ret); err != nil { return nil, err } return ret, err } func (tc *TCPConn) SendRequest(request *RPCRequest) (err error) { enc := tc.ed.NewEncoder(tc.conn) err = enc.Encode(request) return err } func (tc *TCPConn) SendResponse(response *RPCResponse) (err error) { enc := tc.ed.NewEncoder(tc.conn) err = enc.Encode(response) return err } func (tc *TCPConn) Close() { tc.conn.Close() } func NewTCPConn(conn net.Conn, ed IEncoderDecoder) *TCPConn { ret := &TCPConn{} ret.conn = conn ret.ed = ed return ret } type NetHeartbeatConn struct { sync.Mutex conn io.ReadWriteCloser dec IDecoder enc IEncoder // interfaces IRPCConn } func (nhc *NetHeartbeatConn) ReceiveRequest() (ret *RPCRequest, err error) { nhc.Lock() defer nhc.Unlock() ret = new(RPCRequest) if err = nhc.dec.Decode(ret); err != nil { return nil, err } return ret, err } func (nhc *NetHeartbeatConn) ReceiveResponse() (ret *RPCResponse, err error) { nhc.Lock() defer nhc.Unlock() ret = new(RPCResponse) if err = nhc.dec.Decode(ret); err != nil { return nil, err } return ret, err } func (nhc *NetHeartbeatConn) SendRequest(request *RPCRequest) (err error) { nhc.Lock() defer nhc.Unlock() err = nhc.enc.Encode(request) return err } func (nhc *NetHeartbeatConn) SendResponse(response *RPCResponse) (err error) { nhc.Lock() defer nhc.Unlock() err = nhc.enc.Encode(response) return err } func (nhc *NetHeartbeatConn) Close() { nhc.Lock() defer nhc.Unlock() nhc.conn.Close() }