diff --git a/client.go b/client.go index c24a080..a8998a5 100644 --- a/client.go +++ b/client.go @@ -90,3 +90,13 @@ func NewClient(conn net.Conn) *Client { return ret } + +type DialFunc func(addr string) (conn io.ReadWriteCloser, err error) + +func unixDial(addr string) (conn io.ReadWriteCloser, err error) { + if conn, err = net.Dial("unix", addr); err != nil { + return nil, err + } + + return conn, nil +} diff --git a/configuration.go b/configuration.go new file mode 100644 index 0000000..bc7bc65 --- /dev/null +++ b/configuration.go @@ -0,0 +1,13 @@ +package srpc + +import ( + "log" +) + +type LogErrorFunc func(format string, args ...interface{}) + +var logError = LogErrorFunc(log.Printf) + +func SetLogError(f LogErrorFunc) { + logError = f +} diff --git a/connection.go b/connection.go index ab37a7d..e3dcb4c 100644 --- a/connection.go +++ b/connection.go @@ -74,15 +74,20 @@ func (tc *NetConn) Send(payload interface{}) (err error) { var header RPCHeader var hb, b bytes.Buffer - if _, ok := payload.(*RPCHeartbeat); ok { + switch payload.(type) { + case *RPCHeartbeat: header.RPCType = RPC_HEARTBEAT - } else if _, ok := payload.(*RPCClose); ok { + break + case *RPCClose: header.RPCType = RPC_CLOSE - } else if _, ok := payload.(*RPCRequest); ok { + break + case *RPCRequest: header.RPCType = RPC_REQUEST - } else if _, ok := payload.(*RPCResponse); ok { + break + case *RPCResponse: header.RPCType = RPC_RESPONSE - } else { + break + default: return errors.New("srpc - Invalid RPC message type") } diff --git a/server.go b/server.go index 797e53f..c08017b 100644 --- a/server.go +++ b/server.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "fmt" + "io" "net" "olznet.de/slog" "reflect" @@ -235,3 +236,40 @@ func NewDefaultServer() *Server { } var DefaultServer = NewDefaultServer() + +type IListener interface { + Init(addr string) error + Accept() (conn io.ReadWriteCloser, clientAddr string, err error) + Close() error + Addr() net.Addr +} + +type netListener struct { + F func(addr string) (net.Listener, error) + L net.Listener +} + +func (nl *netListener) Init(addr string) (err error) { + nl.L, err = nl.F(addr) + return err +} + +func (nl *netListener) Accept() (conn io.ReadWriteCloser, clientAddr string, err error) { + var c net.Conn + + if c, err = nl.L.Accept(); err != nil { + return nil, "", err + } + return c, c.RemoteAddr().String(), nil +} + +func (nl *netListener) Close() error { + return nl.L.Close() +} + +func (nl *netListener) Addr() net.Addr { + if nl.L == nil { + return nil + } + return nl.L.Addr() +}