srpc/server.go

238 lines
5.6 KiB
Go
Raw Normal View History

2019-10-22 01:32:56 +02:00
package srpc
import (
2019-11-05 18:21:41 +01:00
"bytes"
"errors"
2019-11-05 18:21:41 +01:00
"fmt"
2019-11-08 02:14:36 +01:00
"net"
2020-02-14 01:51:30 +01:00
"olznet.de/slog"
2019-11-05 18:21:41 +01:00
"reflect"
2019-10-22 01:32:56 +02:00
"sync"
2020-02-18 03:03:02 +01:00
"time"
2019-10-22 01:32:56 +02:00
)
type parseInputsFunc func(in []byte) (ret []reflect.Value, err error)
type parseOutputsFunc func(in []reflect.Value) (ret []byte, err error)
2019-11-05 18:21:41 +01:00
type service struct {
f reflect.Value
fin parseInputsFunc
fout parseOutputsFunc
2019-11-05 18:21:41 +01:00
}
2019-10-22 01:32:56 +02:00
type Server struct {
2020-02-10 02:53:05 +01:00
ed IEncoderDecoder
2019-10-22 01:32:56 +02:00
serviceMap sync.Map
}
2020-02-18 03:03:02 +01:00
func (server *Server) ServeConn(conn IRPCConn) {
2019-11-08 02:14:36 +01:00
var fs service
2020-02-18 03:03:02 +01:00
var call *RPCCall
var err error
2020-02-18 03:03:02 +01:00
received := make(chan bool, 1)
2019-11-08 02:14:36 +01:00
2020-02-18 03:03:02 +01:00
defer conn.Close()
2019-11-08 02:14:36 +01:00
2020-02-18 03:03:02 +01:00
for {
go func() {
if call, err = conn.Receive(); err != nil {
slog.LOG_ERROR("srpc - Malformed request received: '%s'\n", err.Error())
return
}
received <- true
}()
select {
case <-received:
switch call.Header.RPCType {
case RPC_HEARTBEAT:
slog.LOG_INFO("srpc - Client sent hearbeat\n")
if err = conn.Send(new(RPCHeartbeat)); err != nil {
slog.LOG_ERROR("srpc - Error sending heartbeat: '%s'\n", err.Error())
}
break
case RPC_CLOSE:
slog.LOG_INFO("srpc - Client closed connection.\n")
if err = conn.Send(new(RPCClose)); err != nil {
slog.LOG_ERROR("srpc - Error sending close: '%s'\n", err.Error())
}
return
case RPC_RESPONSE:
slog.LOG_ERROR("srpc - Got response WTF?!\n")
break
case RPC_REQUEST:
request, ok := call.Payload.(*RPCRequest)
if !ok {
slog.LOG_ERROR("srpc - Expected request, but got: %v\n", call.Header.RPCType)
break
}
var response RPCResponse
response.Status = RPCOK
response.Error = ""
2020-02-14 01:51:30 +01:00
2020-02-18 03:03:02 +01:00
fsRaw, ok := server.serviceMap.Load(request.FuncName)
if !ok {
slog.LOG_ERROR("srpc - Call to unknown method: '%s'\n", request.FuncName)
err := fmt.Sprintf("Unknown method: '%s'", request.FuncName)
response.Status = RPCERR
response.Error = err
if err := conn.Send(&response); err != nil {
slog.LOG_ERROR("srpc - Error sending response: '%s'\n", err.Error())
}
break
}
2019-11-08 02:14:36 +01:00
2020-02-18 03:03:02 +01:00
fs = fsRaw.(service)
inputs, err := fs.fin(request.Payload)
if err != nil {
slog.LOG_ERROR("srpc - Error parsing inputs '%s'\n", err.Error())
response.Status = RPCERR
response.Error = err.Error()
if err := conn.Send(&response); err != nil {
slog.LOG_ERROR("srpc - Error sending response: '%s'\n", err.Error())
}
break
}
outputs := fs.f.Call(inputs)
payload, err := fs.fout(outputs)
if err != nil {
slog.LOG_ERROR("srpc - Error parsing outputs '%s'\n", err.Error())
response.Status = RPCERR
response.Error = err.Error()
if err := conn.Send(&response); err != nil {
slog.LOG_ERROR("srpc - Error sending response: '%s'\n", err.Error())
}
break
}
slog.LOG_DEBUG("%v\n", outputs)
response.Payload = append([]byte{}, payload...)
2020-02-18 03:03:02 +01:00
if err := conn.Send(&response); err != nil {
slog.LOG_ERROR("srpc - Error sending response: '%s'\n", err.Error())
}
break
default:
slog.LOG_ERROR("srpc - Unknown rpc call received\n")
break
}
case <-time.After(30 * time.Second):
slog.LOG_INFO("srpc - Client gone. Closing connection\n")
if err := conn.Send(new(RPCClose)); err != nil {
slog.LOG_ERROR("srpc - Error sending close: '%s'\n", err.Error())
}
2020-02-18 03:03:02 +01:00
return
}
2019-11-08 02:14:36 +01:00
}
2019-10-22 01:32:56 +02:00
}
2019-11-08 02:14:36 +01:00
func (server *Server) Accept(ln net.Listener) {
for {
conn, err := ln.Accept()
if err != nil {
2020-02-14 01:51:30 +01:00
slog.LOG_INFO("srpc - Accept: '%s'\n", err.Error())
2019-11-08 02:14:36 +01:00
return
}
2020-02-18 03:03:02 +01:00
go server.ServeConn(NewNetConn(conn, server.ed))
2019-11-08 02:14:36 +01:00
}
}
2019-11-05 18:21:41 +01:00
func (server *Server) RegisterName(name string, rcvr interface{}) (err error) {
if _, ok := server.serviceMap.Load(name); ok {
return errors.New("srpc: Function with name '" + name + "' already registerd")
}
2019-11-05 18:21:41 +01:00
ft := reflect.TypeOf(rcvr)
fv := reflect.ValueOf(rcvr)
nIn := ft.NumIn()
fs := service{}
fs.fin = func(in []byte) (ret []reflect.Value, err error) {
2019-11-05 18:21:41 +01:00
var b bytes.Buffer
if _, err = b.Write(in); err != nil {
return nil, err
2019-11-05 18:21:41 +01:00
}
2020-02-10 02:53:05 +01:00
decoder := server.ed.NewDecoder(&b)
2019-11-05 18:21:41 +01:00
ret = make([]reflect.Value, nIn)
for i := 0; i < nIn; i++ {
arg := reflect.New(ft.In(i))
if err = decoder.Decode(arg.Interface()); err != nil {
return nil, err
2019-11-05 18:21:41 +01:00
}
ret[i] = reflect.Indirect(arg)
2019-11-05 18:21:41 +01:00
}
return ret, nil
2019-11-05 18:21:41 +01:00
}
fs.fout = func(in []reflect.Value) (ret []byte, err error) {
var b bytes.Buffer
2020-02-10 02:53:05 +01:00
encoder := server.ed.NewEncoder(&b)
for _, v := range in {
if v.Type() == reflect.TypeOf((*error)(nil)).Elem() {
if v.IsNil() {
if err = encoder.Encode(string("")); err != nil {
return nil, err
}
} else {
if err = encoder.Encode(v.Interface().(error).Error()); err != nil {
fmt.Println(err)
return nil, err
}
}
} else {
if err = encoder.Encode(v.Interface()); err != nil {
return nil, err
}
}
}
return b.Bytes(), nil
}
2019-11-05 18:21:41 +01:00
fs.f = fv
server.serviceMap.Store(name, fs)
return nil
}
2019-11-08 02:14:36 +01:00
func (server *Server) CallName(name string, args ...interface{}) (ret []byte, err error) {
var b bytes.Buffer
2019-11-08 02:14:36 +01:00
var fs service
2020-02-10 02:53:05 +01:00
encoder := server.ed.NewEncoder(&b)
2019-11-08 02:14:36 +01:00
if fsRaw, ok := server.serviceMap.Load(name); !ok {
2020-02-14 01:51:30 +01:00
return nil, errors.New("srpc - Call to unknown method: '" + name + "'")
2019-11-08 02:14:36 +01:00
} else {
fs = fsRaw.(service)
}
2019-11-05 18:21:41 +01:00
for _, a := range args {
if err := encoder.Encode(a); err != nil {
2020-02-14 01:51:30 +01:00
return nil, errors.New("srpc - Error: '" + err.Error() + "'")
}
}
if inputs, err := fs.fin(b.Bytes()); err != nil {
return nil, err
} else {
outputs := fs.f.Call(inputs)
return fs.fout(outputs)
}
2019-11-05 18:21:41 +01:00
}
2019-11-08 02:14:36 +01:00
2020-02-10 02:53:05 +01:00
func NewServer(ed IEncoderDecoder) *Server {
ret := &Server{}
ret.ed = ed
return ret
}
func NewDefaultServer() *Server {
ret := NewServer(NewEncoderDecoder())
return ret
2019-11-08 02:14:36 +01:00
}
2020-02-10 02:53:05 +01:00
var DefaultServer = NewDefaultServer()