164 lines
3.1 KiB
Go
164 lines
3.1 KiB
Go
package srpc
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"sync"
|
|
)
|
|
|
|
type RPCStatus uint8
|
|
|
|
const (
|
|
RPCOK RPCStatus = 0
|
|
RPCERR RPCStatus = 1
|
|
)
|
|
|
|
type RPCRequest struct {
|
|
FuncName string
|
|
Payload []byte
|
|
}
|
|
|
|
type RPCResponse struct {
|
|
Status RPCStatus
|
|
Error string
|
|
Payload []byte
|
|
}
|
|
|
|
type parseInputsFunc func(in []byte) (ret []reflect.Value, err error)
|
|
type parseOutputsFunc func(in []reflect.Value) (ret []byte, err error)
|
|
|
|
type service struct {
|
|
f reflect.Value
|
|
fin parseInputsFunc
|
|
fout parseOutputsFunc
|
|
}
|
|
|
|
type RPC struct {
|
|
ed IEncoderDecoder
|
|
serviceMap sync.Map
|
|
}
|
|
|
|
func (r *RPC) RegisterName(name string, rcvr interface{}) (err error) {
|
|
if _, ok := r.serviceMap.Load(name); ok {
|
|
return errors.New("srpc: Function with name '" + name + "' already registerd")
|
|
}
|
|
ft := reflect.TypeOf(rcvr)
|
|
fv := reflect.ValueOf(rcvr)
|
|
|
|
nIn := ft.NumIn()
|
|
|
|
fs := service{}
|
|
|
|
fs.fin = func(in []byte) (ret []reflect.Value, err error) {
|
|
var b bytes.Buffer
|
|
if _, err = b.Write(in); err != nil {
|
|
return nil, err
|
|
}
|
|
decoder := r.ed.NewDecoder(&b)
|
|
|
|
ret = make([]reflect.Value, nIn)
|
|
for i := 0; i < nIn; i++ {
|
|
arg := reflect.New(ft.In(i))
|
|
if err = decoder.Decode(arg.Interface()); err != nil {
|
|
return nil, err
|
|
}
|
|
ret[i] = reflect.Indirect(arg)
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
fs.fout = func(in []reflect.Value) (ret []byte, err error) {
|
|
var b bytes.Buffer
|
|
encoder := r.ed.NewEncoder(&b)
|
|
|
|
for _, v := range in {
|
|
if v.Type() == reflect.TypeOf((*error)(nil)).Elem() {
|
|
if v.IsNil() {
|
|
if err = encoder.Encode(string("")); err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
if err = encoder.Encode(v.Interface().(error).Error()); err != nil {
|
|
fmt.Println(err)
|
|
return nil, err
|
|
}
|
|
}
|
|
} else {
|
|
if err = encoder.Encode(v.Interface()); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
|
|
return b.Bytes(), nil
|
|
}
|
|
|
|
fs.f = fv
|
|
r.serviceMap.Store(name, fs)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *RPC) RequestHandler(clientAddr string, request interface{}) (response interface{}) {
|
|
var fs service
|
|
var ok bool
|
|
var req *RPCRequest
|
|
var res RPCResponse
|
|
|
|
if req, ok = request.(*RPCRequest); !ok {
|
|
res.Error = "srpc - Invalid request type"
|
|
res.Status = RPCERR
|
|
return res
|
|
}
|
|
|
|
fsRaw, ok := r.serviceMap.Load(req.FuncName)
|
|
if !ok {
|
|
res.Error = "Unknown method: '" + req.FuncName + "'"
|
|
res.Status = RPCERR
|
|
return res
|
|
}
|
|
|
|
fs = fsRaw.(service)
|
|
inputs, err := fs.fin(req.Payload)
|
|
if err != nil {
|
|
res.Error = "srpc - Error parsing inputs '" + err.Error() + "'"
|
|
res.Status = RPCERR
|
|
return res
|
|
}
|
|
|
|
outputs := fs.f.Call(inputs)
|
|
payload, err := fs.fout(outputs)
|
|
if err != nil {
|
|
res.Error = "srpc - Error parsing outputs '" + err.Error() + "'"
|
|
res.Status = RPCERR
|
|
return res
|
|
}
|
|
res.Status = RPCOK
|
|
res.Payload = append([]byte{}, payload...)
|
|
return res
|
|
}
|
|
|
|
func RequestDataHandler() (request interface{}) {
|
|
return new(RPCRequest)
|
|
}
|
|
|
|
func ResponseDataHandler() (request interface{}) {
|
|
return new(RPCRequest)
|
|
}
|
|
|
|
func NewRPC(ed IEncoderDecoder) *RPC {
|
|
ret := &RPC{}
|
|
ret.ed = ed
|
|
return ret
|
|
}
|
|
|
|
func NewDefaultRPC() *RPC {
|
|
ret := NewRPC(NewEncoderDecoder())
|
|
return ret
|
|
}
|
|
|
|
var DefaultRPC = NewDefaultRPC()
|