Working on rewrite

This commit is contained in:
Matthias Fulz 2020-02-24 17:03:24 +01:00
parent 211d067467
commit ec9289224b
4 changed files with 173 additions and 124 deletions

177
client.go
View File

@ -1,97 +1,92 @@
package srpc package srpc
import ( import (
"bytes" //"bytes"
"errors" //"errors"
"fmt" //"fmt"
"io" "io"
"net" "net"
"sync" "sync"
"time" //"time"
) )
type Client struct { //func (client *Client) Call(funcName string, args ...interface{}) (ret []byte, err error) {
sync.RWMutex //var b bytes.Buffer
ed IEncoderDecoder
conn IRPCConn
running bool
}
func (client *Client) Call(funcName string, args ...interface{}) (ret []byte, err error) { //enc := client.ed.NewEncoder(&b)
var b bytes.Buffer //for _, a := range args {
//enc.Encode(a)
//}
enc := client.ed.NewEncoder(&b) //payload := b.Bytes()
for _, a := range args { //req := RPCRequest{funcName, payload}
enc.Encode(a)
}
payload := b.Bytes() //client.RLock()
req := RPCRequest{funcName, payload} //defer client.RUnlock()
client.RLock() //if err = client.conn.Send(&req); err != nil {
defer client.RUnlock() //return nil, err
//}
if err = client.conn.Send(&req); err != nil { //var call *RPCCall
return nil, err //if call, err = client.conn.Receive(); err != nil {
} //return nil, err
//}
var call *RPCCall //var response *RPCResponse
if call, err = client.conn.Receive(); err != nil { //var ok bool
return nil, err //if response, ok = call.Payload.(*RPCResponse); !ok {
} //return nil, errors.New("srpc - Expected response")
//}
var response *RPCResponse //if response.Status != RPCOK {
var ok bool //err = errors.New("srpc - Response contained error: '" + response.Error + "'")
if response, ok = call.Payload.(*RPCResponse); !ok { //}
return nil, errors.New("srpc - Expected response")
}
if response.Status != RPCOK { //return response.Payload, err
err = errors.New("srpc - Response contained error: '" + response.Error + "'") //}
}
return response.Payload, err //func (client *Client) Close() {
} //client.Lock()
//defer client.Unlock()
func (client *Client) Close() { //client.conn.Send(new(RPCClose))
client.Lock() //client.running = false
defer client.Unlock() //}
client.conn.Send(new(RPCClose)) //func (client *Client) NewDecoder(r io.Reader) IDecoder {
client.running = false //return client.ed.NewDecoder(r)
} //}
func (client *Client) NewDecoder(r io.Reader) IDecoder { //func NewClient(conn net.Conn) *Client {
return client.ed.NewDecoder(r) //ret := &Client{sync.RWMutex{}, NewEncoderDecoder(), NewNetConn(conn, NewEncoderDecoder()), true}
} //go func() {
//for {
//fmt.Println("heartbeat")
//ret.Lock()
func NewClient(conn net.Conn) *Client { //if ret.running == false {
ret := &Client{sync.RWMutex{}, NewEncoderDecoder(), NewNetConn(conn, NewEncoderDecoder()), true} //ret.Unlock()
go func() { //return
for { //}
fmt.Println("heartbeat") //ret.conn.Send(new(RPCHeartbeat))
ret.Lock() //if _, err := ret.conn.Receive(); err != nil {
//ret.conn.Send(new(RPCClose))
//ret.Unlock()
//return
//}
//fmt.Println("got heartbeat")
//ret.Unlock()
//time.Sleep(15 * time.Second)
//}
//}()
if ret.running == false { //return ret
ret.Unlock() //}
return
}
ret.conn.Send(new(RPCHeartbeat))
if _, err := ret.conn.Receive(); err != nil {
ret.conn.Send(new(RPCClose))
ret.Unlock()
return
}
fmt.Println("got heartbeat")
ret.Unlock()
time.Sleep(15 * time.Second)
}
}()
return ret type ResponseDataHandlerFunc func() (response interface{})
}
type DialFunc func(addr string) (conn io.ReadWriteCloser, err error) type DialHandlerFunc func(addr string) (conn io.ReadWriteCloser, err error)
func unixDial(addr string) (conn io.ReadWriteCloser, err error) { func unixDial(addr string) (conn io.ReadWriteCloser, err error) {
if conn, err = net.Dial("unix", addr); err != nil { if conn, err = net.Dial("unix", addr); err != nil {
@ -100,3 +95,45 @@ func unixDial(addr string) (conn io.ReadWriteCloser, err error) {
return conn, nil return conn, nil
} }
type RPCClient struct {
Addr string
LogError LogErrorFunc
ConnectHander ConnectHandlerFunc
ResponseDataHandler ResponseDataHandlerFunc
DialHandler DialHandlerFunc
Ed IEncoderDecoder
stopChan chan struct{}
stopWg sync.WaitGroup
}
func (c *RPCClient) Start() (err error) {
if c.LogError == nil {
c.LogError = logError
}
if c.stopChan != nil {
return errors.New("srpc - Client already running")
}
c.stopChan = make(chan struct{})
if c.ResponseDataHandler == nil {
return errors.New("srpc - Client needs a ResponseDataHandler")
}
if c.DialHandler == nil {
return errors.New("srpc - Client needs a DialHandler")
}
c.stopWg.Add(1)
return nil
}
func (c *RPCClient) Stop() {
if c.stopChan == nil {
return
}
close(c.stopChan)
c.stopWg.Wait()
c.stopChan = nil
}

View File

@ -24,13 +24,6 @@ type RPCHeader struct {
Size uint64 Size uint64
} }
type RPCStatus uint8
const (
RPCOK RPCStatus = 0
RPCERR RPCStatus = 1
)
type RPCHeartbeat struct { type RPCHeartbeat struct {
OK bool OK bool
} }
@ -39,17 +32,6 @@ type RPCClose struct {
OK bool OK bool
} }
type RPCRequest struct {
FuncName string
Payload []byte
}
type RPCResponse struct {
Status RPCStatus
Error string
Payload []byte
}
type RPCCall struct { type RPCCall struct {
Header RPCHeader Header RPCHeader
Payload interface{} Payload interface{}

57
rpc.go
View File

@ -8,6 +8,24 @@ import (
"sync" "sync"
) )
type RPCStatus uint8
const (
RPCOK RPCStatus = 0
RPCERR RPCStatus = 1
)
type RPCRequest struct {
FuncName string
Payload []byte
}
type RPCResponse struct {
Status RPCStatus
Error string
Payload []byte
}
type parseInputsFunc func(in []byte) (ret []reflect.Value, err error) type parseInputsFunc func(in []byte) (ret []reflect.Value, err error)
type parseOutputsFunc func(in []reflect.Value) (ret []byte, err error) type parseOutputsFunc func(in []reflect.Value) (ret []byte, err error)
@ -84,42 +102,51 @@ func (r *RPC) RegisterName(name string, rcvr interface{}) (err error) {
return nil return nil
} }
func (r *RPC) RequestHandler(clientAddr string, request []byte) (response []byte, err error) { func (r *RPC) RequestHandler(clientAddr string, request interface{}) (response interface{}) {
var fs service var fs service
var ok bool var ok bool
var req RPCRequest var req *RPCRequest
var res RPCResponse var res RPCResponse
in := bytes.NewBuffer(request) if req, ok = request.(*RPCRequest); !ok {
dec := r.ed.NewDecoder(in) res.Error = "srpc - Invalid request type"
if err = dec.Decode(&req); err != nil { res.Status = RPCERR
return nil, errors.New("srpc - Invalid request: '" + err.Error() + "'") return res
} }
fsRaw, ok := r.serviceMap.Load(req.FuncName) fsRaw, ok := r.serviceMap.Load(req.FuncName)
if !ok { if !ok {
return nil, errors.New("Unknown method: '" + req.FuncName + "'") res.Error = "Unknown method: '" + req.FuncName + "'"
res.Status = RPCERR
return res
} }
fs = fsRaw.(service) fs = fsRaw.(service)
inputs, err := fs.fin(req.Payload) inputs, err := fs.fin(req.Payload)
if err != nil { if err != nil {
return nil, errors.New("srpc - Error parsing inputs '" + err.Error() + "'") res.Error = "srpc - Error parsing inputs '" + err.Error() + "'"
res.Status = RPCERR
return res
} }
outputs := fs.f.Call(inputs) outputs := fs.f.Call(inputs)
payload, err := fs.fout(outputs) payload, err := fs.fout(outputs)
if err != nil { if err != nil {
return nil, errors.New("srpc - Error parsing outputs '" + err.Error() + "'") res.Error = "srpc - Error parsing outputs '" + err.Error() + "'"
res.Status = RPCERR
return res
} }
res.Status = RPCOK
res.Payload = append([]byte{}, payload...) res.Payload = append([]byte{}, payload...)
return res
}
var out bytes.Buffer func RequestDataHandler() (request interface{}) {
enc := r.ed.NewEncoder(&out) return new(RPCRequest)
if err = enc.Encode(res); err != nil { }
return nil, errors.New("srpc - Error encoding response '" + err.Error() + "'")
} func ResponseDataHandler() (request interface{}) {
return out.Bytes(), nil return new(RPCRequest)
} }
func NewRPC(ed IEncoderDecoder) *RPC { func NewRPC(ed IEncoderDecoder) *RPC {

View File

@ -12,13 +12,14 @@ import (
type serverMessage struct { type serverMessage struct {
ID uint64 ID uint64
Request []byte Data interface{}
Response []byte
Error string Error string
ClientAddr string ClientAddr string
} }
type RequestHandlerFunc func(clientAddr string, request []byte) (response []byte, err error) type RequestDataHandlerFunc func() (request interface{})
type RequestHandlerFunc func(clientAddr string, request interface{}) (response interface{})
type ConnectHandlerFunc func(remoteAddr string, rwc io.ReadWriteCloser) (io.ReadWriteCloser, error) type ConnectHandlerFunc func(remoteAddr string, rwc io.ReadWriteCloser) (io.ReadWriteCloser, error)
@ -60,14 +61,15 @@ func (nl *netListener) Addr() net.Addr {
} }
type RPCServer struct { type RPCServer struct {
Addr string Addr string
Listener IListener Listener IListener
LogError LogErrorFunc LogError LogErrorFunc
ConnectHandler ConnectHandlerFunc ConnectHandler ConnectHandlerFunc
RequestHandler RequestHandlerFunc RequestDataHandler RequestDataHandlerFunc
Ed IEncoderDecoder RequestHandler RequestHandlerFunc
stopChan chan struct{} Ed IEncoderDecoder
stopWg sync.WaitGroup stopChan chan struct{}
stopWg sync.WaitGroup
} }
func (s *RPCServer) Start() (err error) { func (s *RPCServer) Start() (err error) {
@ -80,6 +82,10 @@ func (s *RPCServer) Start() (err error) {
} }
s.stopChan = make(chan struct{}) s.stopChan = make(chan struct{})
if s.RequestDataHandler == nil {
return errors.New("srpc - Server needs a RequestHandlerData")
}
if s.RequestHandler == nil { if s.RequestHandler == nil {
return errors.New("srpc - Server needs a RequestHandler") return errors.New("srpc - Server needs a RequestHandler")
} }
@ -170,8 +176,9 @@ func serverHandleConnection(s *RPCServer, conn io.ReadWriteCloser, clientAddr st
} }
} }
var msg serverMessage
var err error var err error
var msg serverMessage
msg.Data = s.RequestDataHandler()
dec := s.Ed.NewDecoder(conn) dec := s.Ed.NewDecoder(conn)
enc := s.Ed.NewEncoder(conn) enc := s.Ed.NewEncoder(conn)
@ -188,13 +195,8 @@ func serverHandleConnection(s *RPCServer, conn io.ReadWriteCloser, clientAddr st
var response serverMessage var response serverMessage
response.ID = msg.ID response.ID = msg.ID
response.ClientAddr = msg.ClientAddr response.ClientAddr = msg.ClientAddr
response.Request = msg.Request
if response.Response, err = s.RequestHandler(msg.ClientAddr, msg.Request); err != nil { response.Data = s.RequestHandler(msg.ClientAddr, msg.Data)
s.LogError("srpc - '%s'=>'%s': Error handling request: '%s'\n", clientAddr, s.Addr, err)
response.Response = []byte{}
response.Error = err.Error()
}
if err = enc.Encode(response); err != nil { if err = enc.Encode(response); err != nil {
if !clientDisconnect(err) && !serverStop(s.stopChan) { if !clientDisconnect(err) && !serverStop(s.stopChan) {
@ -219,10 +221,11 @@ func serverStop(stopChan <-chan struct{}) bool {
} }
} }
func NewUnixServer(addr string, handler RequestHandlerFunc) *RPCServer { func NewUnixServer(addr string, dhandler RequestDataHandlerFunc, handler RequestHandlerFunc) *RPCServer {
return &RPCServer{ return &RPCServer{
Addr: addr, Addr: addr,
RequestHandler: handler, RequestDataHandler: dhandler,
RequestHandler: handler,
Listener: &netListener{ Listener: &netListener{
F: func(addr string) (net.Listener, error) { F: func(addr string) (net.Listener, error) {
return net.Listen("unix", addr) return net.Listen("unix", addr)