Initial heartbeat and protocol stuff.

This commit is contained in:
Matthias Fulz 2020-02-18 03:03:02 +01:00
parent 38a3047664
commit da05b60694
3 changed files with 194 additions and 155 deletions

View File

@ -3,18 +3,21 @@ package srpc
import ( import (
"bytes" "bytes"
"errors" "errors"
"fmt"
"io" "io"
"net" "net"
"sync"
"time"
) )
type Client struct { type Client struct {
ed IEncoderDecoder sync.RWMutex
conn IRPCConn ed IEncoderDecoder
conn IRPCConn
running bool
} }
func (client *Client) Call(funcName string, args ...interface{}) (ret []byte, err error) { func (client *Client) Call(funcName string, args ...interface{}) (ret []byte, err error) {
defer client.conn.Close()
var b bytes.Buffer var b bytes.Buffer
enc := client.ed.NewEncoder(&b) enc := client.ed.NewEncoder(&b)
@ -25,13 +28,22 @@ func (client *Client) Call(funcName string, args ...interface{}) (ret []byte, er
payload := b.Bytes() payload := b.Bytes()
req := RPCRequest{funcName, payload} req := RPCRequest{funcName, payload}
if err = client.conn.SendRequest(&req); err != nil { client.RLock()
defer client.RUnlock()
if err = client.conn.Send(&req); err != nil {
return nil, err
}
var call *RPCCall
if call, err = client.conn.Receive(); err != nil {
return nil, err return nil, err
} }
var response *RPCResponse var response *RPCResponse
if response, err = client.conn.ReceiveResponse(); err != nil { var ok bool
return nil, err if response, ok = call.Payload.(*RPCResponse); !ok {
return nil, errors.New("srpc - Expected response")
} }
if response.Status != RPCOK { if response.Status != RPCOK {
@ -41,10 +53,40 @@ func (client *Client) Call(funcName string, args ...interface{}) (ret []byte, er
return response.Payload, err return response.Payload, err
} }
func (client *Client) Close() {
client.Lock()
defer client.Unlock()
client.conn.Send(new(RPCClose))
client.running = false
}
func (client *Client) NewDecoder(r io.Reader) IDecoder { func (client *Client) NewDecoder(r io.Reader) IDecoder {
return client.ed.NewDecoder(r) return client.ed.NewDecoder(r)
} }
func NewClient(conn net.Conn) *Client { func NewClient(conn net.Conn) *Client {
return &Client{NewEncoderDecoder(), NewTCPConn(conn, NewEncoderDecoder())} ret := &Client{sync.RWMutex{}, NewEncoderDecoder(), NewNetConn(conn, NewEncoderDecoder()), true}
go func() {
for {
fmt.Println("heartbeat")
ret.Lock()
if ret.running == false {
ret.Unlock()
return
}
ret.conn.Send(new(RPCHeartbeat))
if _, err := ret.conn.Receive(); err != nil {
ret.conn.Send(new(RPCClose))
ret.Unlock()
return
}
fmt.Println("got heartbeat")
ret.Unlock()
time.Sleep(15 * time.Second)
}
}()
return ret
} }

View File

@ -2,10 +2,10 @@ package srpc
import ( import (
"bytes" "bytes"
"errors"
"io" "io"
"net" "net"
"sync" "sync"
"time"
) )
type RPCMsgType uint8 type RPCMsgType uint8
@ -14,9 +14,10 @@ const (
RPC_REQUEST RPCMsgType = 0 RPC_REQUEST RPCMsgType = 0
RPC_RESPONSE RPCMsgType = 1 RPC_RESPONSE RPCMsgType = 1
RPC_HEARTBEAT RPCMsgType = 2 RPC_HEARTBEAT RPCMsgType = 2
RPC_CLOSE RPCMsgType = 3
) )
const RPCHeaderSize = 10 const RPCHeaderSize = 13
type RPCHeader struct { type RPCHeader struct {
RPCType RPCMsgType RPCType RPCMsgType
@ -31,6 +32,11 @@ const (
) )
type RPCHeartbeat struct { type RPCHeartbeat struct {
OK bool
}
type RPCClose struct {
OK bool
} }
type RPCRequest struct { type RPCRequest struct {
@ -44,16 +50,19 @@ type RPCResponse struct {
Payload []byte Payload []byte
} }
type RPCCall struct {
Header RPCHeader
Payload interface{}
}
type IRPCConn interface { type IRPCConn interface {
ReceiveRequest() (*RPCRequest, error) Send(payload interface{}) error
ReceiveResponse() (*RPCResponse, error) Receive() (*RPCCall, error)
SendRequest(request *RPCRequest) error
SendResponse(response *RPCResponse) error
Close() Close()
} }
type TCPConn struct { type NetConn struct {
sync.Mutex sync.RWMutex
conn net.Conn conn net.Conn
ed IEncoderDecoder ed IEncoderDecoder
@ -61,12 +70,14 @@ type TCPConn struct {
IRPCConn IRPCConn
} }
func (tc *TCPConn) send(payload interface{}) (err error) { func (tc *NetConn) Send(payload interface{}) (err error) {
var header RPCHeader var header RPCHeader
var hb, b bytes.Buffer var hb, b bytes.Buffer
if _, ok := payload.(*RPCHeartbeat); ok { if _, ok := payload.(*RPCHeartbeat); ok {
header.RPCType = RPC_HEARTBEAT header.RPCType = RPC_HEARTBEAT
} else if _, ok := payload.(*RPCClose); ok {
header.RPCType = RPC_CLOSE
} else if _, ok := payload.(*RPCRequest); ok { } else if _, ok := payload.(*RPCRequest); ok {
header.RPCType = RPC_REQUEST header.RPCType = RPC_REQUEST
} else if _, ok := payload.(*RPCResponse); ok { } else if _, ok := payload.(*RPCResponse); ok {
@ -81,7 +92,7 @@ func (tc *TCPConn) send(payload interface{}) (err error) {
} }
header.Size = uint64(len(b.Bytes())) header.Size = uint64(len(b.Bytes()))
enc := tc.ed.NewEncoder(&hb) enc = tc.ed.NewEncoder(&hb)
if err = enc.Encode(header); err != nil { if err = enc.Encode(header); err != nil {
return errors.New("srpc - Error: '" + err.Error() + "'") return errors.New("srpc - Error: '" + err.Error() + "'")
} }
@ -89,16 +100,32 @@ func (tc *TCPConn) send(payload interface{}) (err error) {
data := append([]byte{}, hb.Bytes()...) data := append([]byte{}, hb.Bytes()...)
data = append(data, b.Bytes()...) data = append(data, b.Bytes()...)
if n, err = tc.conn.Write(data); err != nil || n != len(data) { tc.Lock()
defer tc.Unlock()
if n, err := tc.conn.Write(data); err != nil || n != len(data) {
return errors.New("srpc - Error writing message") return errors.New("srpc - Error writing message")
} }
return nil
} }
func (tc *TCPConn) receive() (payload interface{}, err error) { func (tc *NetConn) Receive() (ret *RPCCall, err error) {
var header RPCHeader var header RPCHeader
var payload interface{}
hb := make([]byte, RPCHeaderSize) hb := make([]byte, RPCHeaderSize)
if n, err = tc.conn.Read(hb); err != nil || n != RPCHeaderSize { tc.Lock()
defer tc.Unlock()
if n, err := tc.conn.Read(hb); err != nil || n != RPCHeaderSize {
if err != nil {
if err == io.EOF {
header.RPCType = RPC_CLOSE
ret = &RPCCall{header, new(RPCClose)}
return ret, nil
}
}
return nil, errors.New("srpc - Error receiving message") return nil, errors.New("srpc - Error receiving message")
} }
@ -111,6 +138,9 @@ func (tc *TCPConn) receive() (payload interface{}, err error) {
case RPC_HEARTBEAT: case RPC_HEARTBEAT:
payload = new(RPCHeartbeat) payload = new(RPCHeartbeat)
break break
case RPC_CLOSE:
payload = new(RPCClose)
break
case RPC_REQUEST: case RPC_REQUEST:
payload = new(RPCRequest) payload = new(RPCRequest)
break break
@ -122,119 +152,30 @@ func (tc *TCPConn) receive() (payload interface{}, err error) {
} }
b := make([]byte, header.Size) b := make([]byte, header.Size)
if n, err = tc.conn.Read(b); err != nil || n != int(header.Size) { if n, err := tc.conn.Read(b); err != nil || n != int(header.Size) {
return nil, errors.New("srpc - Error receiving message") return nil, errors.New("srpc - Error receiving message")
} }
dec := tc.ed.NewDecoder(bytes.NewReader(b)) dec = tc.ed.NewDecoder(bytes.NewReader(b))
if err = dec.Decode(payload); err != nil { if err = dec.Decode(payload); err != nil {
return errors.New("srpc - Error: '" + err.Error() + "'") return nil, errors.New("srpc - Error: '" + err.Error() + "'")
} }
return payload, nil ret = new(RPCCall)
ret.Header = header
ret.Payload = payload
return ret, nil
} }
func (tc *TCPConn) ReceiveRequest() (ret *RPCRequest, err error) { func (tc *NetConn) Close() {
ret = new(RPCRequest)
dec := tc.ed.NewDecoder(tc.conn)
if err = dec.Decode(ret); err != nil {
return nil, err
}
return ret, err
}
func (tc *TCPConn) ReceiveResponse() (ret *RPCResponse, err error) {
ret = new(RPCResponse)
dec := tc.ed.NewDecoder(tc.conn)
if err = dec.Decode(ret); err != nil {
return nil, err
}
return ret, err
}
func (tc *TCPConn) SendRequest(request *RPCRequest) (err error) {
enc := tc.ed.NewEncoder(tc.conn)
err = enc.Encode(request)
return err
}
func (tc *TCPConn) SendResponse(response *RPCResponse) (err error) {
enc := tc.ed.NewEncoder(tc.conn)
err = enc.Encode(response)
return err
}
func (tc *TCPConn) Close() {
tc.conn.Close() tc.conn.Close()
} }
func NewTCPConn(conn net.Conn, ed IEncoderDecoder) *TCPConn { func NewNetConn(conn net.Conn, ed IEncoderDecoder) *NetConn {
ret := &TCPConn{} ret := &NetConn{}
ret.conn = conn ret.conn = conn
ret.ed = ed ret.ed = ed
return ret return ret
} }
type NetHeartbeatConn struct {
sync.Mutex
conn io.ReadWriteCloser
dec IDecoder
enc IEncoder
// interfaces
IRPCConn
}
func (nhc *NetHeartbeatConn) ReceiveRequest() (ret *RPCRequest, err error) {
nhc.Lock()
defer nhc.Unlock()
ret = new(RPCRequest)
if err = nhc.dec.Decode(ret); err != nil {
return nil, err
}
return ret, err
}
func (nhc *NetHeartbeatConn) ReceiveResponse() (ret *RPCResponse, err error) {
nhc.Lock()
defer nhc.Unlock()
ret = new(RPCResponse)
if err = nhc.dec.Decode(ret); err != nil {
return nil, err
}
return ret, err
}
func (nhc *NetHeartbeatConn) SendRequest(request *RPCRequest) (err error) {
nhc.Lock()
defer nhc.Unlock()
err = nhc.enc.Encode(request)
return err
}
func (nhc *NetHeartbeatConn) SendResponse(response *RPCResponse) (err error) {
nhc.Lock()
defer nhc.Unlock()
err = nhc.enc.Encode(response)
return err
}
func (nhc *NetHeartbeatConn) Close() {
nhc.Lock()
defer nhc.Unlock()
nhc.conn.Close()
}

124
server.go
View File

@ -8,6 +8,7 @@ import (
"olznet.de/slog" "olznet.de/slog"
"reflect" "reflect"
"sync" "sync"
"time"
) )
type parseInputsFunc func(in []byte) (ret []reflect.Value, err error) type parseInputsFunc func(in []byte) (ret []reflect.Value, err error)
@ -24,50 +25,105 @@ type Server struct {
serviceMap sync.Map serviceMap sync.Map
} }
func (server *Server) ServeConn(conn net.Conn) { func (server *Server) ServeConn(conn IRPCConn) {
var fs service var fs service
var request *RPCRequest var call *RPCCall
var err error var err error
received := make(chan bool, 1)
tcpConn := NewTCPConn(conn, server.ed) defer conn.Close()
defer tcpConn.Close()
if request, err = tcpConn.ReceiveRequest(); err != nil { for {
slog.LOG_ERROR("srpc - Malformed request received: '%s'\n", err.Error()) go func() {
return if call, err = conn.Receive(); err != nil {
} slog.LOG_ERROR("srpc - Malformed request received: '%s'\n", err.Error())
return
}
received <- true
}()
var response RPCResponse select {
response.Status = RPCOK case <-received:
response.Error = "" switch call.Header.RPCType {
case RPC_HEARTBEAT:
slog.LOG_INFO("srpc - Client sent hearbeat\n")
if err = conn.Send(new(RPCHeartbeat)); err != nil {
slog.LOG_ERROR("srpc - Error sending heartbeat: '%s'\n", err.Error())
}
break
case RPC_CLOSE:
slog.LOG_INFO("srpc - Client closed connection.\n")
if err = conn.Send(new(RPCClose)); err != nil {
slog.LOG_ERROR("srpc - Error sending close: '%s'\n", err.Error())
}
return
case RPC_RESPONSE:
slog.LOG_ERROR("srpc - Got response WTF?!\n")
break
case RPC_REQUEST:
request, ok := call.Payload.(*RPCRequest)
if !ok {
slog.LOG_ERROR("srpc - Expected request, but got: %v\n", call.Header.RPCType)
break
}
if fsRaw, ok := server.serviceMap.Load(request.FuncName); !ok { var response RPCResponse
slog.LOG_ERROR("srpc - Call to unknown method: '%s'\n", request.FuncName) response.Status = RPCOK
err := fmt.Sprintf("Unknown method: '%s'", request.FuncName) response.Error = ""
response.Status = RPCERR
response.Error = err fsRaw, ok := server.serviceMap.Load(request.FuncName)
} else { if !ok {
fs = fsRaw.(service) slog.LOG_ERROR("srpc - Call to unknown method: '%s'\n", request.FuncName)
if inputs, err := fs.fin(request.Payload); err != nil { err := fmt.Sprintf("Unknown method: '%s'", request.FuncName)
slog.LOG_ERROR("srpc - Error parsing inputs '%s'\n", err.Error()) response.Status = RPCERR
response.Status = RPCERR response.Error = err
response.Error = err.Error() if err := conn.Send(&response); err != nil {
} else { slog.LOG_ERROR("srpc - Error sending response: '%s'\n", err.Error())
outputs := fs.f.Call(inputs) }
if payload, err := fs.fout(outputs); err != nil { break
slog.LOG_ERROR("srpc - Error parsing outputs '%s'\n", err.Error()) }
response.Status = RPCERR
response.Error = err.Error() fs = fsRaw.(service)
} else { inputs, err := fs.fin(request.Payload)
if err != nil {
slog.LOG_ERROR("srpc - Error parsing inputs '%s'\n", err.Error())
response.Status = RPCERR
response.Error = err.Error()
if err := conn.Send(&response); err != nil {
slog.LOG_ERROR("srpc - Error sending response: '%s'\n", err.Error())
}
break
}
outputs := fs.f.Call(inputs)
payload, err := fs.fout(outputs)
if err != nil {
slog.LOG_ERROR("srpc - Error parsing outputs '%s'\n", err.Error())
response.Status = RPCERR
response.Error = err.Error()
if err := conn.Send(&response); err != nil {
slog.LOG_ERROR("srpc - Error sending response: '%s'\n", err.Error())
}
break
}
slog.LOG_DEBUG("%v\n", outputs) slog.LOG_DEBUG("%v\n", outputs)
response.Payload = append([]byte{}, payload...) response.Payload = append([]byte{}, payload...)
if err := conn.Send(&response); err != nil {
slog.LOG_ERROR("srpc - Error sending response: '%s'\n", err.Error())
}
break
default:
slog.LOG_ERROR("srpc - Unknown rpc call received\n")
break
} }
case <-time.After(30 * time.Second):
slog.LOG_INFO("srpc - Client gone. Closing connection\n")
if err := conn.Send(new(RPCClose)); err != nil {
slog.LOG_ERROR("srpc - Error sending close: '%s'\n", err.Error())
}
return
} }
} }
if err = tcpConn.SendResponse(&response); err != nil {
slog.LOG_ERROR("srpc - Error sending response '%s'\n", err.Error())
}
} }
func (server *Server) Accept(ln net.Listener) { func (server *Server) Accept(ln net.Listener) {
@ -77,7 +133,7 @@ func (server *Server) Accept(ln net.Listener) {
slog.LOG_INFO("srpc - Accept: '%s'\n", err.Error()) slog.LOG_INFO("srpc - Accept: '%s'\n", err.Error())
return return
} }
go server.ServeConn(conn) go server.ServeConn(NewNetConn(conn, server.ed))
} }
} }