package srpc import ( "bytes" "errors" "fmt" "io" "net" "olznet.de/slog" "reflect" "sync" "time" ) type parseInputsFunc func(in []byte) (ret []reflect.Value, err error) type parseOutputsFunc func(in []reflect.Value) (ret []byte, err error) type service struct { f reflect.Value fin parseInputsFunc fout parseOutputsFunc } type Server struct { ed IEncoderDecoder serviceMap sync.Map } func (server *Server) ServeConn(conn IRPCConn) { var fs service var call *RPCCall var err error received := make(chan bool, 1) defer conn.Close() for { go func() { if call, err = conn.Receive(); err != nil { slog.LOG_ERROR("srpc - Malformed request received: '%s'\n", err.Error()) return } received <- true }() select { case <-received: switch call.Header.RPCType { case RPC_HEARTBEAT: slog.LOG_INFO("srpc - Client sent hearbeat\n") if err = conn.Send(new(RPCHeartbeat)); err != nil { slog.LOG_ERROR("srpc - Error sending heartbeat: '%s'\n", err.Error()) } break case RPC_CLOSE: slog.LOG_INFO("srpc - Client closed connection.\n") if err = conn.Send(new(RPCClose)); err != nil { slog.LOG_ERROR("srpc - Error sending close: '%s'\n", err.Error()) } return case RPC_RESPONSE: slog.LOG_ERROR("srpc - Got response WTF?!\n") break case RPC_REQUEST: request, ok := call.Payload.(*RPCRequest) if !ok { slog.LOG_ERROR("srpc - Expected request, but got: %v\n", call.Header.RPCType) break } var response RPCResponse response.Status = RPCOK response.Error = "" fsRaw, ok := server.serviceMap.Load(request.FuncName) if !ok { slog.LOG_ERROR("srpc - Call to unknown method: '%s'\n", request.FuncName) err := fmt.Sprintf("Unknown method: '%s'", request.FuncName) response.Status = RPCERR response.Error = err if err := conn.Send(&response); err != nil { slog.LOG_ERROR("srpc - Error sending response: '%s'\n", err.Error()) } break } fs = fsRaw.(service) inputs, err := fs.fin(request.Payload) if err != nil { slog.LOG_ERROR("srpc - Error parsing inputs '%s'\n", err.Error()) response.Status = RPCERR response.Error = err.Error() if err := conn.Send(&response); err != nil { slog.LOG_ERROR("srpc - Error sending response: '%s'\n", err.Error()) } break } outputs := fs.f.Call(inputs) payload, err := fs.fout(outputs) if err != nil { slog.LOG_ERROR("srpc - Error parsing outputs '%s'\n", err.Error()) response.Status = RPCERR response.Error = err.Error() if err := conn.Send(&response); err != nil { slog.LOG_ERROR("srpc - Error sending response: '%s'\n", err.Error()) } break } slog.LOG_DEBUG("%v\n", outputs) response.Payload = append([]byte{}, payload...) if err := conn.Send(&response); err != nil { slog.LOG_ERROR("srpc - Error sending response: '%s'\n", err.Error()) } break default: slog.LOG_ERROR("srpc - Unknown rpc call received\n") break } case <-time.After(30 * time.Second): slog.LOG_INFO("srpc - Client gone. Closing connection\n") if err := conn.Send(new(RPCClose)); err != nil { slog.LOG_ERROR("srpc - Error sending close: '%s'\n", err.Error()) } return } } } func (server *Server) Accept(ln net.Listener) { for { conn, err := ln.Accept() if err != nil { slog.LOG_INFO("srpc - Accept: '%s'\n", err.Error()) return } go server.ServeConn(NewNetConn(conn, server.ed)) } } func (server *Server) RegisterName(name string, rcvr interface{}) (err error) { if _, ok := server.serviceMap.Load(name); ok { return errors.New("srpc: Function with name '" + name + "' already registerd") } ft := reflect.TypeOf(rcvr) fv := reflect.ValueOf(rcvr) nIn := ft.NumIn() fs := service{} fs.fin = func(in []byte) (ret []reflect.Value, err error) { var b bytes.Buffer if _, err = b.Write(in); err != nil { return nil, err } decoder := server.ed.NewDecoder(&b) ret = make([]reflect.Value, nIn) for i := 0; i < nIn; i++ { arg := reflect.New(ft.In(i)) if err = decoder.Decode(arg.Interface()); err != nil { return nil, err } ret[i] = reflect.Indirect(arg) } return ret, nil } fs.fout = func(in []reflect.Value) (ret []byte, err error) { var b bytes.Buffer encoder := server.ed.NewEncoder(&b) for _, v := range in { if v.Type() == reflect.TypeOf((*error)(nil)).Elem() { if v.IsNil() { if err = encoder.Encode(string("")); err != nil { return nil, err } } else { if err = encoder.Encode(v.Interface().(error).Error()); err != nil { fmt.Println(err) return nil, err } } } else { if err = encoder.Encode(v.Interface()); err != nil { return nil, err } } } return b.Bytes(), nil } fs.f = fv server.serviceMap.Store(name, fs) return nil } func (server *Server) CallName(name string, args ...interface{}) (ret []byte, err error) { var b bytes.Buffer var fs service encoder := server.ed.NewEncoder(&b) if fsRaw, ok := server.serviceMap.Load(name); !ok { return nil, errors.New("srpc - Call to unknown method: '" + name + "'") } else { fs = fsRaw.(service) } for _, a := range args { if err := encoder.Encode(a); err != nil { return nil, errors.New("srpc - Error: '" + err.Error() + "'") } } if inputs, err := fs.fin(b.Bytes()); err != nil { return nil, err } else { outputs := fs.f.Call(inputs) return fs.fout(outputs) } } func NewServer(ed IEncoderDecoder) *Server { ret := &Server{} ret.ed = ed return ret } func NewDefaultServer() *Server { ret := NewServer(NewEncoderDecoder()) return ret } var DefaultServer = NewDefaultServer() type IListener interface { Init(addr string) error Accept() (conn io.ReadWriteCloser, clientAddr string, err error) Close() error Addr() net.Addr } type netListener struct { F func(addr string) (net.Listener, error) L net.Listener } func (nl *netListener) Init(addr string) (err error) { nl.L, err = nl.F(addr) return err } func (nl *netListener) Accept() (conn io.ReadWriteCloser, clientAddr string, err error) { var c net.Conn if c, err = nl.L.Accept(); err != nil { return nil, "", err } return c, c.RemoteAddr().String(), nil } func (nl *netListener) Close() error { return nl.L.Close() } func (nl *netListener) Addr() net.Addr { if nl.L == nil { return nil } return nl.L.Addr() }