diff --git a/client.go b/client.go index a8998a5..c89c177 100644 --- a/client.go +++ b/client.go @@ -1,97 +1,92 @@ package srpc import ( - "bytes" - "errors" - "fmt" + //"bytes" + //"errors" + //"fmt" "io" "net" "sync" - "time" + //"time" ) -type Client struct { - sync.RWMutex - ed IEncoderDecoder - conn IRPCConn - running bool -} +//func (client *Client) Call(funcName string, args ...interface{}) (ret []byte, err error) { +//var b bytes.Buffer -func (client *Client) Call(funcName string, args ...interface{}) (ret []byte, err error) { - var b bytes.Buffer +//enc := client.ed.NewEncoder(&b) +//for _, a := range args { +//enc.Encode(a) +//} - enc := client.ed.NewEncoder(&b) - for _, a := range args { - enc.Encode(a) - } +//payload := b.Bytes() +//req := RPCRequest{funcName, payload} - payload := b.Bytes() - req := RPCRequest{funcName, payload} +//client.RLock() +//defer client.RUnlock() - client.RLock() - defer client.RUnlock() +//if err = client.conn.Send(&req); err != nil { +//return nil, err +//} - if err = client.conn.Send(&req); err != nil { - return nil, err - } +//var call *RPCCall +//if call, err = client.conn.Receive(); err != nil { +//return nil, err +//} - var call *RPCCall - if call, err = client.conn.Receive(); err != nil { - return nil, err - } +//var response *RPCResponse +//var ok bool +//if response, ok = call.Payload.(*RPCResponse); !ok { +//return nil, errors.New("srpc - Expected response") +//} - var response *RPCResponse - var ok bool - if response, ok = call.Payload.(*RPCResponse); !ok { - return nil, errors.New("srpc - Expected response") - } +//if response.Status != RPCOK { +//err = errors.New("srpc - Response contained error: '" + response.Error + "'") +//} - if response.Status != RPCOK { - err = errors.New("srpc - Response contained error: '" + response.Error + "'") - } +//return response.Payload, err +//} - return response.Payload, err -} +//func (client *Client) Close() { +//client.Lock() +//defer client.Unlock() -func (client *Client) Close() { - client.Lock() - defer client.Unlock() +//client.conn.Send(new(RPCClose)) +//client.running = false +//} - client.conn.Send(new(RPCClose)) - client.running = false -} +//func (client *Client) NewDecoder(r io.Reader) IDecoder { +//return client.ed.NewDecoder(r) +//} -func (client *Client) NewDecoder(r io.Reader) IDecoder { - return client.ed.NewDecoder(r) -} +//func NewClient(conn net.Conn) *Client { +//ret := &Client{sync.RWMutex{}, NewEncoderDecoder(), NewNetConn(conn, NewEncoderDecoder()), true} +//go func() { +//for { +//fmt.Println("heartbeat") +//ret.Lock() -func NewClient(conn net.Conn) *Client { - ret := &Client{sync.RWMutex{}, NewEncoderDecoder(), NewNetConn(conn, NewEncoderDecoder()), true} - go func() { - for { - fmt.Println("heartbeat") - ret.Lock() +//if ret.running == false { +//ret.Unlock() +//return +//} +//ret.conn.Send(new(RPCHeartbeat)) +//if _, err := ret.conn.Receive(); err != nil { +//ret.conn.Send(new(RPCClose)) +//ret.Unlock() +//return +//} +//fmt.Println("got heartbeat") +//ret.Unlock() +//time.Sleep(15 * time.Second) +//} +//}() - if ret.running == false { - ret.Unlock() - return - } - ret.conn.Send(new(RPCHeartbeat)) - if _, err := ret.conn.Receive(); err != nil { - ret.conn.Send(new(RPCClose)) - ret.Unlock() - return - } - fmt.Println("got heartbeat") - ret.Unlock() - time.Sleep(15 * time.Second) - } - }() +//return ret +//} - return ret -} +type ResponseDataHandlerFunc func() (response interface{}) -type DialFunc func(addr string) (conn io.ReadWriteCloser, err error) +type DialHandlerFunc func(addr string) (conn io.ReadWriteCloser, err error) func unixDial(addr string) (conn io.ReadWriteCloser, err error) { if conn, err = net.Dial("unix", addr); err != nil { @@ -100,3 +95,45 @@ func unixDial(addr string) (conn io.ReadWriteCloser, err error) { return conn, nil } + +type RPCClient struct { + Addr string + LogError LogErrorFunc + ConnectHander ConnectHandlerFunc + ResponseDataHandler ResponseDataHandlerFunc + DialHandler DialHandlerFunc + Ed IEncoderDecoder + stopChan chan struct{} + stopWg sync.WaitGroup +} + +func (c *RPCClient) Start() (err error) { + if c.LogError == nil { + c.LogError = logError + } + + if c.stopChan != nil { + return errors.New("srpc - Client already running") + } + c.stopChan = make(chan struct{}) + + if c.ResponseDataHandler == nil { + return errors.New("srpc - Client needs a ResponseDataHandler") + } + + if c.DialHandler == nil { + return errors.New("srpc - Client needs a DialHandler") + } + + c.stopWg.Add(1) + return nil +} + +func (c *RPCClient) Stop() { + if c.stopChan == nil { + return + } + close(c.stopChan) + c.stopWg.Wait() + c.stopChan = nil +} diff --git a/connection.go b/connection.go index e3dcb4c..a9c4919 100644 --- a/connection.go +++ b/connection.go @@ -24,13 +24,6 @@ type RPCHeader struct { Size uint64 } -type RPCStatus uint8 - -const ( - RPCOK RPCStatus = 0 - RPCERR RPCStatus = 1 -) - type RPCHeartbeat struct { OK bool } @@ -39,17 +32,6 @@ type RPCClose struct { OK bool } -type RPCRequest struct { - FuncName string - Payload []byte -} - -type RPCResponse struct { - Status RPCStatus - Error string - Payload []byte -} - type RPCCall struct { Header RPCHeader Payload interface{} diff --git a/rpc.go b/rpc.go index 22857ec..910143b 100644 --- a/rpc.go +++ b/rpc.go @@ -8,6 +8,24 @@ import ( "sync" ) +type RPCStatus uint8 + +const ( + RPCOK RPCStatus = 0 + RPCERR RPCStatus = 1 +) + +type RPCRequest struct { + FuncName string + Payload []byte +} + +type RPCResponse struct { + Status RPCStatus + Error string + Payload []byte +} + type parseInputsFunc func(in []byte) (ret []reflect.Value, err error) type parseOutputsFunc func(in []reflect.Value) (ret []byte, err error) @@ -84,42 +102,51 @@ func (r *RPC) RegisterName(name string, rcvr interface{}) (err error) { return nil } -func (r *RPC) RequestHandler(clientAddr string, request []byte) (response []byte, err error) { +func (r *RPC) RequestHandler(clientAddr string, request interface{}) (response interface{}) { var fs service var ok bool - var req RPCRequest + var req *RPCRequest var res RPCResponse - in := bytes.NewBuffer(request) - dec := r.ed.NewDecoder(in) - if err = dec.Decode(&req); err != nil { - return nil, errors.New("srpc - Invalid request: '" + err.Error() + "'") + if req, ok = request.(*RPCRequest); !ok { + res.Error = "srpc - Invalid request type" + res.Status = RPCERR + return res } fsRaw, ok := r.serviceMap.Load(req.FuncName) if !ok { - return nil, errors.New("Unknown method: '" + req.FuncName + "'") + res.Error = "Unknown method: '" + req.FuncName + "'" + res.Status = RPCERR + return res } fs = fsRaw.(service) inputs, err := fs.fin(req.Payload) if err != nil { - return nil, errors.New("srpc - Error parsing inputs '" + err.Error() + "'") + res.Error = "srpc - Error parsing inputs '" + err.Error() + "'" + res.Status = RPCERR + return res } outputs := fs.f.Call(inputs) payload, err := fs.fout(outputs) if err != nil { - return nil, errors.New("srpc - Error parsing outputs '" + err.Error() + "'") + res.Error = "srpc - Error parsing outputs '" + err.Error() + "'" + res.Status = RPCERR + return res } + res.Status = RPCOK res.Payload = append([]byte{}, payload...) + return res +} - var out bytes.Buffer - enc := r.ed.NewEncoder(&out) - if err = enc.Encode(res); err != nil { - return nil, errors.New("srpc - Error encoding response '" + err.Error() + "'") - } - return out.Bytes(), nil +func RequestDataHandler() (request interface{}) { + return new(RPCRequest) +} + +func ResponseDataHandler() (request interface{}) { + return new(RPCRequest) } func NewRPC(ed IEncoderDecoder) *RPC { diff --git a/server.go b/server.go index 01f66a6..31544de 100644 --- a/server.go +++ b/server.go @@ -12,13 +12,14 @@ import ( type serverMessage struct { ID uint64 - Request []byte - Response []byte + Data interface{} Error string ClientAddr string } -type RequestHandlerFunc func(clientAddr string, request []byte) (response []byte, err error) +type RequestDataHandlerFunc func() (request interface{}) + +type RequestHandlerFunc func(clientAddr string, request interface{}) (response interface{}) type ConnectHandlerFunc func(remoteAddr string, rwc io.ReadWriteCloser) (io.ReadWriteCloser, error) @@ -60,14 +61,15 @@ func (nl *netListener) Addr() net.Addr { } type RPCServer struct { - Addr string - Listener IListener - LogError LogErrorFunc - ConnectHandler ConnectHandlerFunc - RequestHandler RequestHandlerFunc - Ed IEncoderDecoder - stopChan chan struct{} - stopWg sync.WaitGroup + Addr string + Listener IListener + LogError LogErrorFunc + ConnectHandler ConnectHandlerFunc + RequestDataHandler RequestDataHandlerFunc + RequestHandler RequestHandlerFunc + Ed IEncoderDecoder + stopChan chan struct{} + stopWg sync.WaitGroup } func (s *RPCServer) Start() (err error) { @@ -80,6 +82,10 @@ func (s *RPCServer) Start() (err error) { } s.stopChan = make(chan struct{}) + if s.RequestDataHandler == nil { + return errors.New("srpc - Server needs a RequestHandlerData") + } + if s.RequestHandler == nil { return errors.New("srpc - Server needs a RequestHandler") } @@ -170,8 +176,9 @@ func serverHandleConnection(s *RPCServer, conn io.ReadWriteCloser, clientAddr st } } - var msg serverMessage var err error + var msg serverMessage + msg.Data = s.RequestDataHandler() dec := s.Ed.NewDecoder(conn) enc := s.Ed.NewEncoder(conn) @@ -188,13 +195,8 @@ func serverHandleConnection(s *RPCServer, conn io.ReadWriteCloser, clientAddr st var response serverMessage response.ID = msg.ID response.ClientAddr = msg.ClientAddr - response.Request = msg.Request - if response.Response, err = s.RequestHandler(msg.ClientAddr, msg.Request); err != nil { - s.LogError("srpc - '%s'=>'%s': Error handling request: '%s'\n", clientAddr, s.Addr, err) - response.Response = []byte{} - response.Error = err.Error() - } + response.Data = s.RequestHandler(msg.ClientAddr, msg.Data) if err = enc.Encode(response); err != nil { if !clientDisconnect(err) && !serverStop(s.stopChan) { @@ -219,10 +221,11 @@ func serverStop(stopChan <-chan struct{}) bool { } } -func NewUnixServer(addr string, handler RequestHandlerFunc) *RPCServer { +func NewUnixServer(addr string, dhandler RequestDataHandlerFunc, handler RequestHandlerFunc) *RPCServer { return &RPCServer{ - Addr: addr, - RequestHandler: handler, + Addr: addr, + RequestDataHandler: dhandler, + RequestHandler: handler, Listener: &netListener{ F: func(addr string) (net.Listener, error) { return net.Listen("unix", addr)