137 lines
2.9 KiB
Go
137 lines
2.9 KiB
Go
|
package srpc
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"reflect"
|
||
|
"sync"
|
||
|
)
|
||
|
|
||
|
type parseInputsFunc func(in []byte) (ret []reflect.Value, err error)
|
||
|
type parseOutputsFunc func(in []reflect.Value) (ret []byte, err error)
|
||
|
|
||
|
type service struct {
|
||
|
f reflect.Value
|
||
|
fin parseInputsFunc
|
||
|
fout parseOutputsFunc
|
||
|
}
|
||
|
|
||
|
type RPC struct {
|
||
|
ed IEncoderDecoder
|
||
|
serviceMap sync.Map
|
||
|
}
|
||
|
|
||
|
func (r *RPC) RegisterName(name string, rcvr interface{}) (err error) {
|
||
|
if _, ok := r.serviceMap.Load(name); ok {
|
||
|
return errors.New("srpc: Function with name '" + name + "' already registerd")
|
||
|
}
|
||
|
ft := reflect.TypeOf(rcvr)
|
||
|
fv := reflect.ValueOf(rcvr)
|
||
|
|
||
|
nIn := ft.NumIn()
|
||
|
|
||
|
fs := service{}
|
||
|
|
||
|
fs.fin = func(in []byte) (ret []reflect.Value, err error) {
|
||
|
var b bytes.Buffer
|
||
|
if _, err = b.Write(in); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
decoder := r.ed.NewDecoder(&b)
|
||
|
|
||
|
ret = make([]reflect.Value, nIn)
|
||
|
for i := 0; i < nIn; i++ {
|
||
|
arg := reflect.New(ft.In(i))
|
||
|
if err = decoder.Decode(arg.Interface()); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
ret[i] = reflect.Indirect(arg)
|
||
|
}
|
||
|
|
||
|
return ret, nil
|
||
|
}
|
||
|
|
||
|
fs.fout = func(in []reflect.Value) (ret []byte, err error) {
|
||
|
var b bytes.Buffer
|
||
|
encoder := r.ed.NewEncoder(&b)
|
||
|
|
||
|
for _, v := range in {
|
||
|
if v.Type() == reflect.TypeOf((*error)(nil)).Elem() {
|
||
|
if v.IsNil() {
|
||
|
if err = encoder.Encode(string("")); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
} else {
|
||
|
if err = encoder.Encode(v.Interface().(error).Error()); err != nil {
|
||
|
fmt.Println(err)
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
} else {
|
||
|
if err = encoder.Encode(v.Interface()); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return b.Bytes(), nil
|
||
|
}
|
||
|
|
||
|
fs.f = fv
|
||
|
r.serviceMap.Store(name, fs)
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (r *RPC) RequestHandler(clientAddr string, request []byte) (response []byte, err error) {
|
||
|
var fs service
|
||
|
var ok bool
|
||
|
var req RPCRequest
|
||
|
var res RPCResponse
|
||
|
|
||
|
in := bytes.NewBuffer(request)
|
||
|
dec := r.ed.NewDecoder(in)
|
||
|
if err = dec.Decode(&req); err != nil {
|
||
|
return nil, errors.New("srpc - Invalid request: '" + err.Error() + "'")
|
||
|
}
|
||
|
|
||
|
fsRaw, ok := r.serviceMap.Load(req.FuncName)
|
||
|
if !ok {
|
||
|
return nil, errors.New("Unknown method: '" + req.FuncName + "'")
|
||
|
}
|
||
|
|
||
|
fs = fsRaw.(service)
|
||
|
inputs, err := fs.fin(req.Payload)
|
||
|
if err != nil {
|
||
|
return nil, errors.New("srpc - Error parsing inputs '" + err.Error() + "'")
|
||
|
}
|
||
|
|
||
|
outputs := fs.f.Call(inputs)
|
||
|
payload, err := fs.fout(outputs)
|
||
|
if err != nil {
|
||
|
return nil, errors.New("srpc - Error parsing outputs '" + err.Error() + "'")
|
||
|
}
|
||
|
res.Payload = append([]byte{}, payload...)
|
||
|
|
||
|
var out bytes.Buffer
|
||
|
enc := r.ed.NewEncoder(&out)
|
||
|
if err = enc.Encode(res); err != nil {
|
||
|
return nil, errors.New("srpc - Error encoding response '" + err.Error() + "'")
|
||
|
}
|
||
|
return out.Bytes(), nil
|
||
|
}
|
||
|
|
||
|
func NewRPC(ed IEncoderDecoder) *RPC {
|
||
|
ret := &RPC{}
|
||
|
ret.ed = ed
|
||
|
return ret
|
||
|
}
|
||
|
|
||
|
func NewDefaultRPC() *RPC {
|
||
|
ret := NewRPC(NewEncoderDecoder())
|
||
|
return ret
|
||
|
}
|
||
|
|
||
|
var DefaultRPC = NewDefaultRPC()
|