Working on rewrite

This commit is contained in:
Matthias Fulz 2020-02-24 17:03:24 +01:00
parent 211d067467
commit ec9289224b
4 changed files with 173 additions and 124 deletions

177
client.go
View File

@ -1,97 +1,92 @@
package srpc
import (
"bytes"
"errors"
"fmt"
//"bytes"
//"errors"
//"fmt"
"io"
"net"
"sync"
"time"
//"time"
)
type Client struct {
sync.RWMutex
ed IEncoderDecoder
conn IRPCConn
running bool
}
//func (client *Client) Call(funcName string, args ...interface{}) (ret []byte, err error) {
//var b bytes.Buffer
func (client *Client) Call(funcName string, args ...interface{}) (ret []byte, err error) {
var b bytes.Buffer
//enc := client.ed.NewEncoder(&b)
//for _, a := range args {
//enc.Encode(a)
//}
enc := client.ed.NewEncoder(&b)
for _, a := range args {
enc.Encode(a)
}
//payload := b.Bytes()
//req := RPCRequest{funcName, payload}
payload := b.Bytes()
req := RPCRequest{funcName, payload}
//client.RLock()
//defer client.RUnlock()
client.RLock()
defer client.RUnlock()
//if err = client.conn.Send(&req); err != nil {
//return nil, err
//}
if err = client.conn.Send(&req); err != nil {
return nil, err
}
//var call *RPCCall
//if call, err = client.conn.Receive(); err != nil {
//return nil, err
//}
var call *RPCCall
if call, err = client.conn.Receive(); err != nil {
return nil, err
}
//var response *RPCResponse
//var ok bool
//if response, ok = call.Payload.(*RPCResponse); !ok {
//return nil, errors.New("srpc - Expected response")
//}
var response *RPCResponse
var ok bool
if response, ok = call.Payload.(*RPCResponse); !ok {
return nil, errors.New("srpc - Expected response")
}
//if response.Status != RPCOK {
//err = errors.New("srpc - Response contained error: '" + response.Error + "'")
//}
if response.Status != RPCOK {
err = errors.New("srpc - Response contained error: '" + response.Error + "'")
}
//return response.Payload, err
//}
return response.Payload, err
}
//func (client *Client) Close() {
//client.Lock()
//defer client.Unlock()
func (client *Client) Close() {
client.Lock()
defer client.Unlock()
//client.conn.Send(new(RPCClose))
//client.running = false
//}
client.conn.Send(new(RPCClose))
client.running = false
}
//func (client *Client) NewDecoder(r io.Reader) IDecoder {
//return client.ed.NewDecoder(r)
//}
func (client *Client) NewDecoder(r io.Reader) IDecoder {
return client.ed.NewDecoder(r)
}
//func NewClient(conn net.Conn) *Client {
//ret := &Client{sync.RWMutex{}, NewEncoderDecoder(), NewNetConn(conn, NewEncoderDecoder()), true}
//go func() {
//for {
//fmt.Println("heartbeat")
//ret.Lock()
func NewClient(conn net.Conn) *Client {
ret := &Client{sync.RWMutex{}, NewEncoderDecoder(), NewNetConn(conn, NewEncoderDecoder()), true}
go func() {
for {
fmt.Println("heartbeat")
ret.Lock()
//if ret.running == false {
//ret.Unlock()
//return
//}
//ret.conn.Send(new(RPCHeartbeat))
//if _, err := ret.conn.Receive(); err != nil {
//ret.conn.Send(new(RPCClose))
//ret.Unlock()
//return
//}
//fmt.Println("got heartbeat")
//ret.Unlock()
//time.Sleep(15 * time.Second)
//}
//}()
if ret.running == false {
ret.Unlock()
return
}
ret.conn.Send(new(RPCHeartbeat))
if _, err := ret.conn.Receive(); err != nil {
ret.conn.Send(new(RPCClose))
ret.Unlock()
return
}
fmt.Println("got heartbeat")
ret.Unlock()
time.Sleep(15 * time.Second)
}
}()
//return ret
//}
return ret
}
type ResponseDataHandlerFunc func() (response interface{})
type DialFunc func(addr string) (conn io.ReadWriteCloser, err error)
type DialHandlerFunc func(addr string) (conn io.ReadWriteCloser, err error)
func unixDial(addr string) (conn io.ReadWriteCloser, err error) {
if conn, err = net.Dial("unix", addr); err != nil {
@ -100,3 +95,45 @@ func unixDial(addr string) (conn io.ReadWriteCloser, err error) {
return conn, nil
}
type RPCClient struct {
Addr string
LogError LogErrorFunc
ConnectHander ConnectHandlerFunc
ResponseDataHandler ResponseDataHandlerFunc
DialHandler DialHandlerFunc
Ed IEncoderDecoder
stopChan chan struct{}
stopWg sync.WaitGroup
}
func (c *RPCClient) Start() (err error) {
if c.LogError == nil {
c.LogError = logError
}
if c.stopChan != nil {
return errors.New("srpc - Client already running")
}
c.stopChan = make(chan struct{})
if c.ResponseDataHandler == nil {
return errors.New("srpc - Client needs a ResponseDataHandler")
}
if c.DialHandler == nil {
return errors.New("srpc - Client needs a DialHandler")
}
c.stopWg.Add(1)
return nil
}
func (c *RPCClient) Stop() {
if c.stopChan == nil {
return
}
close(c.stopChan)
c.stopWg.Wait()
c.stopChan = nil
}

View File

@ -24,13 +24,6 @@ type RPCHeader struct {
Size uint64
}
type RPCStatus uint8
const (
RPCOK RPCStatus = 0
RPCERR RPCStatus = 1
)
type RPCHeartbeat struct {
OK bool
}
@ -39,17 +32,6 @@ type RPCClose struct {
OK bool
}
type RPCRequest struct {
FuncName string
Payload []byte
}
type RPCResponse struct {
Status RPCStatus
Error string
Payload []byte
}
type RPCCall struct {
Header RPCHeader
Payload interface{}

57
rpc.go
View File

@ -8,6 +8,24 @@ import (
"sync"
)
type RPCStatus uint8
const (
RPCOK RPCStatus = 0
RPCERR RPCStatus = 1
)
type RPCRequest struct {
FuncName string
Payload []byte
}
type RPCResponse struct {
Status RPCStatus
Error string
Payload []byte
}
type parseInputsFunc func(in []byte) (ret []reflect.Value, err error)
type parseOutputsFunc func(in []reflect.Value) (ret []byte, err error)
@ -84,42 +102,51 @@ func (r *RPC) RegisterName(name string, rcvr interface{}) (err error) {
return nil
}
func (r *RPC) RequestHandler(clientAddr string, request []byte) (response []byte, err error) {
func (r *RPC) RequestHandler(clientAddr string, request interface{}) (response interface{}) {
var fs service
var ok bool
var req RPCRequest
var req *RPCRequest
var res RPCResponse
in := bytes.NewBuffer(request)
dec := r.ed.NewDecoder(in)
if err = dec.Decode(&req); err != nil {
return nil, errors.New("srpc - Invalid request: '" + err.Error() + "'")
if req, ok = request.(*RPCRequest); !ok {
res.Error = "srpc - Invalid request type"
res.Status = RPCERR
return res
}
fsRaw, ok := r.serviceMap.Load(req.FuncName)
if !ok {
return nil, errors.New("Unknown method: '" + req.FuncName + "'")
res.Error = "Unknown method: '" + req.FuncName + "'"
res.Status = RPCERR
return res
}
fs = fsRaw.(service)
inputs, err := fs.fin(req.Payload)
if err != nil {
return nil, errors.New("srpc - Error parsing inputs '" + err.Error() + "'")
res.Error = "srpc - Error parsing inputs '" + err.Error() + "'"
res.Status = RPCERR
return res
}
outputs := fs.f.Call(inputs)
payload, err := fs.fout(outputs)
if err != nil {
return nil, errors.New("srpc - Error parsing outputs '" + err.Error() + "'")
res.Error = "srpc - Error parsing outputs '" + err.Error() + "'"
res.Status = RPCERR
return res
}
res.Status = RPCOK
res.Payload = append([]byte{}, payload...)
return res
}
var out bytes.Buffer
enc := r.ed.NewEncoder(&out)
if err = enc.Encode(res); err != nil {
return nil, errors.New("srpc - Error encoding response '" + err.Error() + "'")
}
return out.Bytes(), nil
func RequestDataHandler() (request interface{}) {
return new(RPCRequest)
}
func ResponseDataHandler() (request interface{}) {
return new(RPCRequest)
}
func NewRPC(ed IEncoderDecoder) *RPC {

View File

@ -12,13 +12,14 @@ import (
type serverMessage struct {
ID uint64
Request []byte
Response []byte
Data interface{}
Error string
ClientAddr string
}
type RequestHandlerFunc func(clientAddr string, request []byte) (response []byte, err error)
type RequestDataHandlerFunc func() (request interface{})
type RequestHandlerFunc func(clientAddr string, request interface{}) (response interface{})
type ConnectHandlerFunc func(remoteAddr string, rwc io.ReadWriteCloser) (io.ReadWriteCloser, error)
@ -60,14 +61,15 @@ func (nl *netListener) Addr() net.Addr {
}
type RPCServer struct {
Addr string
Listener IListener
LogError LogErrorFunc
ConnectHandler ConnectHandlerFunc
RequestHandler RequestHandlerFunc
Ed IEncoderDecoder
stopChan chan struct{}
stopWg sync.WaitGroup
Addr string
Listener IListener
LogError LogErrorFunc
ConnectHandler ConnectHandlerFunc
RequestDataHandler RequestDataHandlerFunc
RequestHandler RequestHandlerFunc
Ed IEncoderDecoder
stopChan chan struct{}
stopWg sync.WaitGroup
}
func (s *RPCServer) Start() (err error) {
@ -80,6 +82,10 @@ func (s *RPCServer) Start() (err error) {
}
s.stopChan = make(chan struct{})
if s.RequestDataHandler == nil {
return errors.New("srpc - Server needs a RequestHandlerData")
}
if s.RequestHandler == nil {
return errors.New("srpc - Server needs a RequestHandler")
}
@ -170,8 +176,9 @@ func serverHandleConnection(s *RPCServer, conn io.ReadWriteCloser, clientAddr st
}
}
var msg serverMessage
var err error
var msg serverMessage
msg.Data = s.RequestDataHandler()
dec := s.Ed.NewDecoder(conn)
enc := s.Ed.NewEncoder(conn)
@ -188,13 +195,8 @@ func serverHandleConnection(s *RPCServer, conn io.ReadWriteCloser, clientAddr st
var response serverMessage
response.ID = msg.ID
response.ClientAddr = msg.ClientAddr
response.Request = msg.Request
if response.Response, err = s.RequestHandler(msg.ClientAddr, msg.Request); err != nil {
s.LogError("srpc - '%s'=>'%s': Error handling request: '%s'\n", clientAddr, s.Addr, err)
response.Response = []byte{}
response.Error = err.Error()
}
response.Data = s.RequestHandler(msg.ClientAddr, msg.Data)
if err = enc.Encode(response); err != nil {
if !clientDisconnect(err) && !serverStop(s.stopChan) {
@ -219,10 +221,11 @@ func serverStop(stopChan <-chan struct{}) bool {
}
}
func NewUnixServer(addr string, handler RequestHandlerFunc) *RPCServer {
func NewUnixServer(addr string, dhandler RequestDataHandlerFunc, handler RequestHandlerFunc) *RPCServer {
return &RPCServer{
Addr: addr,
RequestHandler: handler,
Addr: addr,
RequestDataHandler: dhandler,
RequestHandler: handler,
Listener: &netListener{
F: func(addr string) (net.Listener, error) {
return net.Listen("unix", addr)