diff --git a/client.go b/client.go index 6303427..3cf641f 100644 --- a/client.go +++ b/client.go @@ -2,30 +2,20 @@ package srpc import ( "bytes" - "fmt" + "errors" "io" "net" ) -var REQUEST_HEADER_SIZE = int32(8) - -type requestHeader struct { - Size int64 -} - -type request struct { - FuncName string - Payload []byte -} - type Client struct { ed IEncoderDecoder - conn net.Conn + conn IRPCConn } -func (client *Client) Call(funcName string, args ...interface{}) (ret []byte, ok bool) { +func (client *Client) Call(funcName string, args ...interface{}) (ret []byte, err error) { + defer client.conn.Close() + var b bytes.Buffer - ok = true enc := client.ed.NewEncoder(&b) for _, a := range args { @@ -33,26 +23,22 @@ func (client *Client) Call(funcName string, args ...interface{}) (ret []byte, ok } payload := b.Bytes() - req := request{funcName, payload} - fmt.Println(req) + req := RPCRequest{funcName, payload} - enc = client.ed.NewEncoder(client.conn) - enc.Encode(req) - - var header responseHeader - dec := client.ed.NewDecoder(client.conn) - dec.Decode(&header) - fmt.Println(header.Size) - fmt.Println(header.Status) - - respData := make([]byte, header.Size) - dec.Decode(&respData) - - if header.Status != OK { - ok = false + if err = client.conn.SendRequest(&req); err != nil { + return nil, err } - return respData, ok + var response *RPCResponse + if response, err = client.conn.ReceiveResponse(); err != nil { + return nil, err + } + + if response.Status != RPCOK { + err = errors.New("srpc - Response contained error: '" + response.Error + "'") + } + + return response.Payload, err } func (client *Client) NewDecoder(r io.Reader) IDecoder { @@ -60,5 +46,5 @@ func (client *Client) NewDecoder(r io.Reader) IDecoder { } func NewClient(conn net.Conn) *Client { - return &Client{NewEncoderDecoder(), conn} + return &Client{NewEncoderDecoder(), NewTCPConn(conn, NewEncoderDecoder())} } diff --git a/connection.go b/connection.go new file mode 100644 index 0000000..6026ffc --- /dev/null +++ b/connection.go @@ -0,0 +1,93 @@ +package srpc + +import ( + "net" +) + +type RPCStatus uint8 + +const ( + RPCOK RPCStatus = 0 + RPCERR RPCStatus = 1 +) + +type RPCRequest struct { + FuncName string + Payload []byte +} + +type RPCResponse struct { + Status RPCStatus + Error string + Payload []byte +} + +type IRPCConn interface { + ReceiveRequest() (*RPCRequest, error) + ReceiveResponse() (*RPCResponse, error) + SendRequest(request *RPCRequest) error + SendResponse(response *RPCResponse) error + Close() +} + +type TCPConn struct { + conn net.Conn + ed IEncoderDecoder + + // interface + IRPCConn +} + +func (tc *TCPConn) ReceiveRequest() (ret *RPCRequest, err error) { + ret = new(RPCRequest) + err = nil + + dec := tc.ed.NewDecoder(tc.conn) + if err = dec.Decode(ret); err != nil { + return nil, err + } + + return ret, err +} + +func (tc *TCPConn) ReceiveResponse() (ret *RPCResponse, err error) { + ret = new(RPCResponse) + err = nil + + dec := tc.ed.NewDecoder(tc.conn) + if err = dec.Decode(ret); err != nil { + return nil, err + } + + return ret, err +} + +func (tc *TCPConn) SendRequest(request *RPCRequest) (err error) { + err = nil + + enc := tc.ed.NewEncoder(tc.conn) + err = enc.Encode(request) + + return err +} + +func (tc *TCPConn) SendResponse(response *RPCResponse) (err error) { + err = nil + + enc := tc.ed.NewEncoder(tc.conn) + err = enc.Encode(response) + + return err +} + +func (tc *TCPConn) Close() { + tc.conn.Close() +} + +func NewTCPConn(conn net.Conn, ed IEncoderDecoder) *TCPConn { + ret := &TCPConn{} + ret.conn = conn + ret.ed = ed + + return ret +} diff --git a/server.go b/server.go index f2e27d4..e4b9028 100644 --- a/server.go +++ b/server.go @@ -10,24 +10,8 @@ import ( "sync" ) -type argumentType uint8 - -type responseStatus uint8 - -const ( - OK responseStatus = 0 - ERR responseStatus = 1 -) - -var RESPONSE_HEADER_SIZE = int32(9) - -type responseHeader struct { - Size int64 - Status responseStatus -} - -type parseInputsFunc func(in []byte) (ret []reflect.Value) -type parseOutputsFunc func(in []reflect.Value) (ret []byte) +type parseInputsFunc func(in []byte) (ret []reflect.Value, err error) +type parseOutputsFunc func(in []reflect.Value) (ret []byte, err error) type service struct { f reflect.Value @@ -42,33 +26,48 @@ type Server struct { func (server *Server) ServeConn(conn net.Conn) { var fs service + var request *RPCRequest + var err error - dec := server.ed.NewDecoder(conn) + tcpConn := NewTCPConn(conn, server.ed) + defer tcpConn.Close() - var req request - dec.Decode(&req) + if request, err = tcpConn.ReceiveRequest(); err != nil { + slog.LOG_ERROR("srpc - Malformed request received: '%s'\n", err.Error()) + return + } - header := responseHeader{} - var respData []byte + var response RPCResponse + response.Status = RPCOK + response.Error = "" - if fsRaw, ok := server.serviceMap.Load(req.FuncName); !ok { - slog.LOG_ERROR("srpc - Call to unknown method: '%s'\n", req.FuncName) - err := fmt.Sprintf("Unknown method: '%s'", req.FuncName) - header.Status = ERR - respData = append([]byte{}, []byte(err)...) + if fsRaw, ok := server.serviceMap.Load(request.FuncName); !ok { + slog.LOG_ERROR("srpc - Call to unknown method: '%s'\n", request.FuncName) + err := fmt.Sprintf("Unknown method: '%s'", request.FuncName) + response.Status = RPCERR + response.Error = err } else { fs = fsRaw.(service) - inputs := fs.fin(req.Payload) - outputs := fs.f.Call(inputs) - slog.LOG_DEBUG("%v\n", outputs) - header.Status = OK - respData = fs.fout(outputs) + if inputs, err := fs.fin(request.Payload); err != nil { + slog.LOG_ERROR("srpc - Error parsing inputs '%s'\n", err.Error()) + response.Status = RPCERR + response.Error = err.Error() + } else { + outputs := fs.f.Call(inputs) + if payload, err := fs.fout(outputs); err != nil { + slog.LOG_ERROR("srpc - Error parsing outputs '%s'\n", err.Error()) + response.Status = RPCERR + response.Error = err.Error() + } else { + slog.LOG_DEBUG("%v\n", outputs) + response.Payload = append([]byte{}, payload...) + } + } } - header.Size = int64(len(respData)) - enc := server.ed.NewEncoder(conn) - enc.Encode(header) - enc.Encode(respData) + if err = tcpConn.SendResponse(&response); err != nil { + slog.LOG_ERROR("srpc - Error sending response '%s'\n", err.Error()) + } } func (server *Server) Accept(ln net.Listener) { @@ -93,52 +92,49 @@ func (server *Server) RegisterName(name string, rcvr interface{}) (err error) { fs := service{} - fs.fin = func(in []byte) (ret []reflect.Value) { + fs.fin = func(in []byte) (ret []reflect.Value, err error) { var b bytes.Buffer if _, err = b.Write(in); err != nil { - return nil + return nil, err } decoder := server.ed.NewDecoder(&b) ret = make([]reflect.Value, nIn) for i := 0; i < nIn; i++ { - in := reflect.New(ft.In(i)) - if err = decoder.Decode(in.Interface()); err != nil { - fmt.Println(err) - return nil + arg := reflect.New(ft.In(i)) + if err = decoder.Decode(arg.Interface()); err != nil { + return nil, err } - ret[i] = reflect.Indirect(in) + ret[i] = reflect.Indirect(arg) } - return ret + return ret, nil } - fs.fout = func(in []reflect.Value) (ret []byte) { + fs.fout = func(in []reflect.Value) (ret []byte, err error) { var b bytes.Buffer encoder := server.ed.NewEncoder(&b) for _, v := range in { if v.Type() == reflect.TypeOf((*error)(nil)).Elem() { if v.IsNil() { - if err := encoder.Encode(string("")); err != nil { - fmt.Println(err) - return nil + if err = encoder.Encode(string("")); err != nil { + return nil, err } } else { - if err := encoder.Encode(v.Interface().(error).Error()); err != nil { + if err = encoder.Encode(v.Interface().(error).Error()); err != nil { fmt.Println(err) - return nil + return nil, err } } } else { - if err := encoder.Encode(v.Interface()); err != nil { - fmt.Println(err) - return nil + if err = encoder.Encode(v.Interface()); err != nil { + return nil, err } } } - return b.Bytes() + return b.Bytes(), nil } fs.f = fv @@ -163,9 +159,12 @@ func (server *Server) CallName(name string, args ...interface{}) (ret []byte, er return nil, errors.New("srpc - Error: '" + err.Error() + "'") } } - inputs := fs.fin(b.Bytes()) - outputs := fs.f.Call(inputs) - return fs.fout(outputs), nil + if inputs, err := fs.fin(b.Bytes()); err != nil { + return nil, err + } else { + outputs := fs.f.Call(inputs) + return fs.fout(outputs) + } } func NewServer(ed IEncoderDecoder) *Server {