srpc/rpc.go

164 lines
3.1 KiB
Go

package srpc
import (
"bytes"
"errors"
"olznet.de/slog"
"reflect"
"sync"
)
type RPCStatus uint8
const (
RPCOK RPCStatus = 0
RPCERR RPCStatus = 1
)
type RPCRequest struct {
FuncName string
Payload []byte
}
type RPCResponse struct {
Status RPCStatus
Error string
Payload []byte
}
type parseInputsFunc func(in []byte) (ret []reflect.Value, err error)
type parseOutputsFunc func(in []reflect.Value) (ret []byte, err error)
type service struct {
f reflect.Value
fin parseInputsFunc
fout parseOutputsFunc
}
type RPC struct {
ed IEncoderDecoder
serviceMap sync.Map
}
func (r *RPC) RegisterName(name string, rcvr interface{}) (err error) {
if _, ok := r.serviceMap.Load(name); ok {
return errors.New("srpc: Function with name '" + name + "' already registerd")
}
ft := reflect.TypeOf(rcvr)
fv := reflect.ValueOf(rcvr)
nIn := ft.NumIn()
fs := service{}
fs.fin = func(in []byte) (ret []reflect.Value, err error) {
var b bytes.Buffer
if _, err = b.Write(in); err != nil {
return nil, err
}
decoder := r.ed.NewDecoder(&b)
ret = make([]reflect.Value, nIn)
for i := 0; i < nIn; i++ {
arg := reflect.New(ft.In(i))
if err = decoder.Decode(arg.Interface()); err != nil {
return nil, err
}
ret[i] = reflect.Indirect(arg)
}
return ret, nil
}
fs.fout = func(in []reflect.Value) (ret []byte, err error) {
var b bytes.Buffer
encoder := r.ed.NewEncoder(&b)
for _, v := range in {
if v.Type() == reflect.TypeOf((*error)(nil)).Elem() {
if v.IsNil() {
if err = encoder.Encode(string("")); err != nil {
return nil, err
}
} else {
if err = encoder.Encode(v.Interface().(error).Error()); err != nil {
slog.LOG_ERRORLN(err)
return nil, err
}
}
} else {
if err = encoder.Encode(v.Interface()); err != nil {
return nil, err
}
}
}
return b.Bytes(), nil
}
fs.f = fv
r.serviceMap.Store(name, fs)
return nil
}
func (r *RPC) RequestHandler(clientAddr string, request interface{}) (response interface{}) {
var fs service
var ok bool
var req *RPCRequest
var res RPCResponse
if req, ok = request.(*RPCRequest); !ok {
res.Error = "srpc - Invalid request type"
res.Status = RPCERR
return res
}
fsRaw, ok := r.serviceMap.Load(req.FuncName)
if !ok {
res.Error = "Unknown method: '" + req.FuncName + "'"
res.Status = RPCERR
return res
}
fs = fsRaw.(service)
inputs, err := fs.fin(req.Payload)
if err != nil {
res.Error = "srpc - Error parsing inputs '" + err.Error() + "'"
res.Status = RPCERR
return res
}
outputs := fs.f.Call(inputs)
payload, err := fs.fout(outputs)
if err != nil {
res.Error = "srpc - Error parsing outputs '" + err.Error() + "'"
res.Status = RPCERR
return res
}
res.Status = RPCOK
res.Payload = append([]byte{}, payload...)
return res
}
func RequestDataHandler() (request interface{}) {
return new(RPCRequest)
}
func ResponseDataHandler() (response interface{}) {
return new(RPCResponse)
}
func NewRPC(ed IEncoderDecoder) *RPC {
ret := &RPC{}
ret.ed = ed
return ret
}
func NewDefaultRPC() *RPC {
ret := NewRPC(NewEncoderDecoder())
return ret
}
var DefaultRPC = NewDefaultRPC()