srpc/connection.go

241 lines
4.5 KiB
Go
Raw Normal View History

package srpc
import (
2020-02-17 19:18:39 +01:00
"bytes"
2020-02-17 01:30:24 +01:00
"io"
"net"
2020-02-17 01:30:24 +01:00
"sync"
2020-02-17 19:18:39 +01:00
"time"
)
2020-02-17 19:18:39 +01:00
type RPCMsgType uint8
const (
RPC_REQUEST RPCMsgType = 0
RPC_RESPONSE RPCMsgType = 1
RPC_HEARTBEAT RPCMsgType = 2
)
const RPCHeaderSize = 10
type RPCHeader struct {
RPCType RPCMsgType
Size uint64
}
type RPCStatus uint8
const (
RPCOK RPCStatus = 0
RPCERR RPCStatus = 1
)
2020-02-17 19:18:39 +01:00
type RPCHeartbeat struct {
}
type RPCRequest struct {
FuncName string
Payload []byte
}
type RPCResponse struct {
Status RPCStatus
Error string
Payload []byte
}
type IRPCConn interface {
ReceiveRequest() (*RPCRequest, error)
ReceiveResponse() (*RPCResponse, error)
SendRequest(request *RPCRequest) error
SendResponse(response *RPCResponse) error
Close()
}
type TCPConn struct {
2020-02-17 19:18:39 +01:00
sync.Mutex
conn net.Conn
ed IEncoderDecoder
// interface
IRPCConn
}
2020-02-17 19:18:39 +01:00
func (tc *TCPConn) send(payload interface{}) (err error) {
var header RPCHeader
var hb, b bytes.Buffer
if _, ok := payload.(*RPCHeartbeat); ok {
header.RPCType = RPC_HEARTBEAT
} else if _, ok := payload.(*RPCRequest); ok {
header.RPCType = RPC_REQUEST
} else if _, ok := payload.(*RPCResponse); ok {
header.RPCType = RPC_RESPONSE
} else {
return errors.New("srpc - Invalid RPC message type")
}
enc := tc.ed.NewEncoder(&b)
if err = enc.Encode(payload); err != nil {
return errors.New("srpc - Error: '" + err.Error() + "'")
}
header.Size = uint64(len(b.Bytes()))
enc := tc.ed.NewEncoder(&hb)
if err = enc.Encode(header); err != nil {
return errors.New("srpc - Error: '" + err.Error() + "'")
}
data := append([]byte{}, hb.Bytes()...)
data = append(data, b.Bytes()...)
if n, err = tc.conn.Write(data); err != nil || n != len(data) {
return errors.New("srpc - Error writing message")
}
}
func (tc *TCPConn) receive() (payload interface{}, err error) {
var header RPCHeader
hb := make([]byte, RPCHeaderSize)
if n, err = tc.conn.Read(hb); err != nil || n != RPCHeaderSize {
return nil, errors.New("srpc - Error receiving message")
}
dec := tc.ed.NewDecoder(bytes.NewReader(hb))
if err = dec.Decode(&header); err != nil {
return nil, errors.New("srpc - Error: '" + err.Error() + "'")
}
switch header.RPCType {
case RPC_HEARTBEAT:
payload = new(RPCHeartbeat)
break
case RPC_REQUEST:
payload = new(RPCRequest)
break
case RPC_RESPONSE:
payload = new(RPCResponse)
break
default:
return nil, errors.New("srpc - Invalid RPC message type")
}
b := make([]byte, header.Size)
if n, err = tc.conn.Read(b); err != nil || n != int(header.Size) {
return nil, errors.New("srpc - Error receiving message")
}
dec := tc.ed.NewDecoder(bytes.NewReader(b))
if err = dec.Decode(payload); err != nil {
return errors.New("srpc - Error: '" + err.Error() + "'")
}
return payload, nil
}
func (tc *TCPConn) ReceiveRequest() (ret *RPCRequest, err error) {
ret = new(RPCRequest)
dec := tc.ed.NewDecoder(tc.conn)
if err = dec.Decode(ret); err != nil {
return nil, err
}
return ret, err
}
func (tc *TCPConn) ReceiveResponse() (ret *RPCResponse, err error) {
ret = new(RPCResponse)
dec := tc.ed.NewDecoder(tc.conn)
if err = dec.Decode(ret); err != nil {
return nil, err
}
return ret, err
}
func (tc *TCPConn) SendRequest(request *RPCRequest) (err error) {
enc := tc.ed.NewEncoder(tc.conn)
err = enc.Encode(request)
return err
}
func (tc *TCPConn) SendResponse(response *RPCResponse) (err error) {
enc := tc.ed.NewEncoder(tc.conn)
err = enc.Encode(response)
return err
}
func (tc *TCPConn) Close() {
tc.conn.Close()
}
func NewTCPConn(conn net.Conn, ed IEncoderDecoder) *TCPConn {
ret := &TCPConn{}
ret.conn = conn
ret.ed = ed
return ret
}
2020-02-17 01:30:24 +01:00
type NetHeartbeatConn struct {
sync.Mutex
conn io.ReadWriteCloser
dec IDecoder
enc IEncoder
// interfaces
IRPCConn
}
func (nhc *NetHeartbeatConn) ReceiveRequest() (ret *RPCRequest, err error) {
nhc.Lock()
defer nhc.Unlock()
ret = new(RPCRequest)
if err = nhc.dec.Decode(ret); err != nil {
return nil, err
}
return ret, err
}
func (nhc *NetHeartbeatConn) ReceiveResponse() (ret *RPCResponse, err error) {
nhc.Lock()
defer nhc.Unlock()
ret = new(RPCResponse)
if err = nhc.dec.Decode(ret); err != nil {
return nil, err
}
return ret, err
}
func (nhc *NetHeartbeatConn) SendRequest(request *RPCRequest) (err error) {
nhc.Lock()
defer nhc.Unlock()
err = nhc.enc.Encode(request)
return err
}
func (nhc *NetHeartbeatConn) SendResponse(response *RPCResponse) (err error) {
nhc.Lock()
defer nhc.Unlock()
err = nhc.enc.Encode(response)
return err
}
func (nhc *NetHeartbeatConn) Close() {
nhc.Lock()
defer nhc.Unlock()
nhc.conn.Close()
}