276 lines
6.3 KiB
Go
276 lines
6.3 KiB
Go
package srpc
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"olznet.de/slog"
|
|
"reflect"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type parseInputsFunc func(in []byte) (ret []reflect.Value, err error)
|
|
type parseOutputsFunc func(in []reflect.Value) (ret []byte, err error)
|
|
|
|
type service struct {
|
|
f reflect.Value
|
|
fin parseInputsFunc
|
|
fout parseOutputsFunc
|
|
}
|
|
|
|
type Server struct {
|
|
ed IEncoderDecoder
|
|
serviceMap sync.Map
|
|
}
|
|
|
|
func (server *Server) ServeConn(conn IRPCConn) {
|
|
var fs service
|
|
var call *RPCCall
|
|
var err error
|
|
received := make(chan bool, 1)
|
|
|
|
defer conn.Close()
|
|
|
|
for {
|
|
go func() {
|
|
if call, err = conn.Receive(); err != nil {
|
|
slog.LOG_ERROR("srpc - Malformed request received: '%s'\n", err.Error())
|
|
return
|
|
}
|
|
received <- true
|
|
}()
|
|
|
|
select {
|
|
case <-received:
|
|
switch call.Header.RPCType {
|
|
case RPC_HEARTBEAT:
|
|
slog.LOG_INFO("srpc - Client sent hearbeat\n")
|
|
if err = conn.Send(new(RPCHeartbeat)); err != nil {
|
|
slog.LOG_ERROR("srpc - Error sending heartbeat: '%s'\n", err.Error())
|
|
}
|
|
break
|
|
case RPC_CLOSE:
|
|
slog.LOG_INFO("srpc - Client closed connection.\n")
|
|
if err = conn.Send(new(RPCClose)); err != nil {
|
|
slog.LOG_ERROR("srpc - Error sending close: '%s'\n", err.Error())
|
|
}
|
|
return
|
|
case RPC_RESPONSE:
|
|
slog.LOG_ERROR("srpc - Got response WTF?!\n")
|
|
break
|
|
case RPC_REQUEST:
|
|
request, ok := call.Payload.(*RPCRequest)
|
|
if !ok {
|
|
slog.LOG_ERROR("srpc - Expected request, but got: %v\n", call.Header.RPCType)
|
|
break
|
|
}
|
|
|
|
var response RPCResponse
|
|
response.Status = RPCOK
|
|
response.Error = ""
|
|
|
|
fsRaw, ok := server.serviceMap.Load(request.FuncName)
|
|
if !ok {
|
|
slog.LOG_ERROR("srpc - Call to unknown method: '%s'\n", request.FuncName)
|
|
err := fmt.Sprintf("Unknown method: '%s'", request.FuncName)
|
|
response.Status = RPCERR
|
|
response.Error = err
|
|
if err := conn.Send(&response); err != nil {
|
|
slog.LOG_ERROR("srpc - Error sending response: '%s'\n", err.Error())
|
|
}
|
|
break
|
|
}
|
|
|
|
fs = fsRaw.(service)
|
|
inputs, err := fs.fin(request.Payload)
|
|
if err != nil {
|
|
slog.LOG_ERROR("srpc - Error parsing inputs '%s'\n", err.Error())
|
|
response.Status = RPCERR
|
|
response.Error = err.Error()
|
|
if err := conn.Send(&response); err != nil {
|
|
slog.LOG_ERROR("srpc - Error sending response: '%s'\n", err.Error())
|
|
}
|
|
break
|
|
}
|
|
|
|
outputs := fs.f.Call(inputs)
|
|
payload, err := fs.fout(outputs)
|
|
if err != nil {
|
|
slog.LOG_ERROR("srpc - Error parsing outputs '%s'\n", err.Error())
|
|
response.Status = RPCERR
|
|
response.Error = err.Error()
|
|
if err := conn.Send(&response); err != nil {
|
|
slog.LOG_ERROR("srpc - Error sending response: '%s'\n", err.Error())
|
|
}
|
|
break
|
|
}
|
|
slog.LOG_DEBUG("%v\n", outputs)
|
|
response.Payload = append([]byte{}, payload...)
|
|
if err := conn.Send(&response); err != nil {
|
|
slog.LOG_ERROR("srpc - Error sending response: '%s'\n", err.Error())
|
|
}
|
|
break
|
|
default:
|
|
slog.LOG_ERROR("srpc - Unknown rpc call received\n")
|
|
break
|
|
}
|
|
case <-time.After(30 * time.Second):
|
|
slog.LOG_INFO("srpc - Client gone. Closing connection\n")
|
|
if err := conn.Send(new(RPCClose)); err != nil {
|
|
slog.LOG_ERROR("srpc - Error sending close: '%s'\n", err.Error())
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (server *Server) Accept(ln net.Listener) {
|
|
for {
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
slog.LOG_INFO("srpc - Accept: '%s'\n", err.Error())
|
|
return
|
|
}
|
|
go server.ServeConn(NewNetConn(conn, server.ed))
|
|
}
|
|
}
|
|
|
|
func (server *Server) RegisterName(name string, rcvr interface{}) (err error) {
|
|
if _, ok := server.serviceMap.Load(name); ok {
|
|
return errors.New("srpc: Function with name '" + name + "' already registerd")
|
|
}
|
|
ft := reflect.TypeOf(rcvr)
|
|
fv := reflect.ValueOf(rcvr)
|
|
|
|
nIn := ft.NumIn()
|
|
|
|
fs := service{}
|
|
|
|
fs.fin = func(in []byte) (ret []reflect.Value, err error) {
|
|
var b bytes.Buffer
|
|
if _, err = b.Write(in); err != nil {
|
|
return nil, err
|
|
}
|
|
decoder := server.ed.NewDecoder(&b)
|
|
|
|
ret = make([]reflect.Value, nIn)
|
|
for i := 0; i < nIn; i++ {
|
|
arg := reflect.New(ft.In(i))
|
|
if err = decoder.Decode(arg.Interface()); err != nil {
|
|
return nil, err
|
|
}
|
|
ret[i] = reflect.Indirect(arg)
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
fs.fout = func(in []reflect.Value) (ret []byte, err error) {
|
|
var b bytes.Buffer
|
|
encoder := server.ed.NewEncoder(&b)
|
|
|
|
for _, v := range in {
|
|
if v.Type() == reflect.TypeOf((*error)(nil)).Elem() {
|
|
if v.IsNil() {
|
|
if err = encoder.Encode(string("")); err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
if err = encoder.Encode(v.Interface().(error).Error()); err != nil {
|
|
fmt.Println(err)
|
|
return nil, err
|
|
}
|
|
}
|
|
} else {
|
|
if err = encoder.Encode(v.Interface()); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
|
|
return b.Bytes(), nil
|
|
}
|
|
|
|
fs.f = fv
|
|
server.serviceMap.Store(name, fs)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (server *Server) CallName(name string, args ...interface{}) (ret []byte, err error) {
|
|
var b bytes.Buffer
|
|
var fs service
|
|
encoder := server.ed.NewEncoder(&b)
|
|
|
|
if fsRaw, ok := server.serviceMap.Load(name); !ok {
|
|
return nil, errors.New("srpc - Call to unknown method: '" + name + "'")
|
|
} else {
|
|
fs = fsRaw.(service)
|
|
}
|
|
|
|
for _, a := range args {
|
|
if err := encoder.Encode(a); err != nil {
|
|
return nil, errors.New("srpc - Error: '" + err.Error() + "'")
|
|
}
|
|
}
|
|
if inputs, err := fs.fin(b.Bytes()); err != nil {
|
|
return nil, err
|
|
} else {
|
|
outputs := fs.f.Call(inputs)
|
|
return fs.fout(outputs)
|
|
}
|
|
}
|
|
|
|
func NewServer(ed IEncoderDecoder) *Server {
|
|
ret := &Server{}
|
|
ret.ed = ed
|
|
return ret
|
|
}
|
|
|
|
func NewDefaultServer() *Server {
|
|
ret := NewServer(NewEncoderDecoder())
|
|
return ret
|
|
}
|
|
|
|
var DefaultServer = NewDefaultServer()
|
|
|
|
type IListener interface {
|
|
Init(addr string) error
|
|
Accept() (conn io.ReadWriteCloser, clientAddr string, err error)
|
|
Close() error
|
|
Addr() net.Addr
|
|
}
|
|
|
|
type netListener struct {
|
|
F func(addr string) (net.Listener, error)
|
|
L net.Listener
|
|
}
|
|
|
|
func (nl *netListener) Init(addr string) (err error) {
|
|
nl.L, err = nl.F(addr)
|
|
return err
|
|
}
|
|
|
|
func (nl *netListener) Accept() (conn io.ReadWriteCloser, clientAddr string, err error) {
|
|
var c net.Conn
|
|
|
|
if c, err = nl.L.Accept(); err != nil {
|
|
return nil, "", err
|
|
}
|
|
return c, c.RemoteAddr().String(), nil
|
|
}
|
|
|
|
func (nl *netListener) Close() error {
|
|
return nl.L.Close()
|
|
}
|
|
|
|
func (nl *netListener) Addr() net.Addr {
|
|
if nl.L == nil {
|
|
return nil
|
|
}
|
|
return nl.L.Addr()
|
|
}
|