srpc/client.go

401 lines
7.6 KiB
Go
Raw Normal View History

2019-10-22 01:32:56 +02:00
package srpc
2020-02-10 02:53:05 +01:00
import (
2020-02-25 01:37:05 +01:00
"errors"
2020-03-25 09:45:23 +01:00
"fmt"
2020-02-10 02:53:05 +01:00
"io"
"net"
2020-02-18 03:03:02 +01:00
"sync"
2020-02-25 01:37:05 +01:00
"sync/atomic"
"time"
2020-02-10 02:53:05 +01:00
)
2020-02-24 17:03:24 +01:00
//func (client *Client) Call(funcName string, args ...interface{}) (ret []byte, err error) {
//var b bytes.Buffer
//enc := client.ed.NewEncoder(&b)
//for _, a := range args {
//enc.Encode(a)
//}
//payload := b.Bytes()
//req := RPCRequest{funcName, payload}
//client.RLock()
//defer client.RUnlock()
//if err = client.conn.Send(&req); err != nil {
//return nil, err
//}
//var call *RPCCall
//if call, err = client.conn.Receive(); err != nil {
//return nil, err
//}
//var response *RPCResponse
//var ok bool
//if response, ok = call.Payload.(*RPCResponse); !ok {
//return nil, errors.New("srpc - Expected response")
//}
//if response.Status != RPCOK {
//err = errors.New("srpc - Response contained error: '" + response.Error + "'")
//}
//return response.Payload, err
//}
//func (client *Client) Close() {
//client.Lock()
//defer client.Unlock()
//client.conn.Send(new(RPCClose))
//client.running = false
//}
//func (client *Client) NewDecoder(r io.Reader) IDecoder {
//return client.ed.NewDecoder(r)
//}
//func NewClient(conn net.Conn) *Client {
//ret := &Client{sync.RWMutex{}, NewEncoderDecoder(), NewNetConn(conn, NewEncoderDecoder()), true}
//go func() {
//for {
//fmt.Println("heartbeat")
//ret.Lock()
//if ret.running == false {
//ret.Unlock()
//return
//}
//ret.conn.Send(new(RPCHeartbeat))
//if _, err := ret.conn.Receive(); err != nil {
//ret.conn.Send(new(RPCClose))
//ret.Unlock()
//return
//}
//fmt.Println("got heartbeat")
//ret.Unlock()
//time.Sleep(15 * time.Second)
//}
//}()
//return ret
//}
2020-03-24 08:08:27 +01:00
type clientMessage struct {
ID uint64
Data interface{}
ClientAddr string
}
type call struct {
Response serverMessage
2020-03-25 09:45:23 +01:00
Request clientMessage
Error error
state messageState
2020-03-24 08:08:27 +01:00
t time.Time
2020-03-25 09:45:23 +01:00
done chan bool
sync.RWMutex
}
func newCall() (ret *call) {
ret = new(call)
ret.state = FREE
ret.done = make(chan bool)
return ret
2020-03-24 08:08:27 +01:00
}
2020-02-24 17:03:24 +01:00
type ResponseDataHandlerFunc func() (response interface{})
type DialHandlerFunc func(addr string) (conn io.ReadWriteCloser, err error)
2020-02-10 02:53:05 +01:00
2020-02-24 17:03:24 +01:00
func unixDial(addr string) (conn io.ReadWriteCloser, err error) {
if conn, err = net.Dial("unix", addr); err != nil {
return nil, err
2020-02-10 02:53:05 +01:00
}
2020-02-24 17:03:24 +01:00
return conn, nil
}
2020-02-10 02:53:05 +01:00
2020-02-24 17:03:24 +01:00
type RPCClient struct {
Addr string
2020-03-25 09:45:23 +01:00
MaxRequests int
2020-02-24 17:03:24 +01:00
LogError LogErrorFunc
ConnectHander ConnectHandlerFunc
ResponseDataHandler ResponseDataHandlerFunc
DialHandler DialHandlerFunc
Ed IEncoderDecoder
2020-03-25 09:45:23 +01:00
calls []*call
callChan chan *call
nextID uint64
2020-02-24 17:03:24 +01:00
stopChan chan struct{}
stopWg sync.WaitGroup
}
2020-02-18 03:03:02 +01:00
2020-02-24 17:03:24 +01:00
func (c *RPCClient) Start() (err error) {
2020-03-25 09:45:23 +01:00
if c.MaxRequests <= 0 {
c.MaxRequests = DefaultMaxClientRequests
}
c.calls = make([]*call, c.MaxRequests)
for i := 0; i < c.MaxRequests; i++ {
c.calls[i] = newCall()
}
c.nextID = 0
c.callChan = make(chan *call)
2020-02-24 17:03:24 +01:00
if c.LogError == nil {
c.LogError = logError
}
2020-02-10 02:53:05 +01:00
2020-02-24 17:03:24 +01:00
if c.stopChan != nil {
return errors.New("srpc - Client already running")
}
2020-02-24 17:03:24 +01:00
c.stopChan = make(chan struct{})
2020-02-10 02:53:05 +01:00
2020-02-24 17:03:24 +01:00
if c.ResponseDataHandler == nil {
return errors.New("srpc - Client needs a ResponseDataHandler")
2020-02-18 03:03:02 +01:00
}
2020-02-24 17:03:24 +01:00
if c.DialHandler == nil {
return errors.New("srpc - Client needs a DialHandler")
2020-02-14 01:51:30 +01:00
}
2020-02-24 17:03:24 +01:00
c.stopWg.Add(1)
2020-03-25 09:45:23 +01:00
go clientHandler(c)
2020-02-24 17:03:24 +01:00
return nil
2020-02-10 02:53:05 +01:00
}
2020-02-24 17:03:24 +01:00
func (c *RPCClient) Stop() {
if c.stopChan == nil {
return
2020-02-19 01:36:18 +01:00
}
2020-02-24 17:03:24 +01:00
close(c.stopChan)
c.stopWg.Wait()
c.stopChan = nil
2020-02-19 01:36:18 +01:00
}
2020-02-25 01:37:05 +01:00
2020-03-25 09:45:23 +01:00
func (c *RPCClient) Call(request interface{}) (response interface{}, err error) {
var requestCall *call
for _, e := range c.calls {
e.Lock()
if e.state == FREE {
requestCall = e
requestCall.state = PENDING
requestCall.Request.ID = c.nextID
requestCall.Request.ClientAddr = c.Addr
requestCall.Request.Data = request
c.nextID++
e.Unlock()
break
}
e.Unlock()
}
if requestCall == nil {
return nil, errors.New("srpc - Client requests are full")
}
c.callChan <- requestCall
<-requestCall.done
requestCall.Lock()
response = requestCall.Response.Data
err = requestCall.Error
requestCall.state = FREE
requestCall.Unlock()
2020-03-25 16:01:56 +01:00
fmt.Println(response)
fmt.Println(err)
2020-03-25 09:45:23 +01:00
return response, err
}
2020-02-25 01:37:05 +01:00
func clientHandler(c *RPCClient) {
defer c.stopWg.Done()
var conn io.ReadWriteCloser
var err error
var stopping atomic.Value
for {
dialChan := make(chan struct{})
go func() {
if conn, err = c.DialHandler(c.Addr); err != nil {
if stopping.Load() == nil {
2020-03-24 08:08:27 +01:00
c.LogError("srpc - '%s' cannot estable rpc connection: '%s'\n", c.Addr, err)
2020-02-25 01:37:05 +01:00
}
}
close(dialChan)
}()
select {
case <-c.stopChan:
stopping.Store(true)
<-dialChan
return
case <-dialChan:
}
if err != nil {
select {
case <-c.stopChan:
return
case <-time.After(time.Second):
}
continue
}
2020-03-25 09:45:23 +01:00
clientHandleConnection(c, conn)
select {
case <-c.stopChan:
return
default:
}
2020-03-24 08:08:27 +01:00
}
}
func clientHandleConnection(c *RPCClient, conn io.ReadWriteCloser) {
if c.ConnectHander != nil {
if newConn, err := c.ConnectHander(c.Addr, conn); err != nil {
c.LogError("srpc - Client: [%s] - Connect error: '%s'\n", c.Addr, err)
conn.Close()
return
} else {
conn = newConn
}
2020-02-25 01:37:05 +01:00
}
2020-03-25 09:45:23 +01:00
stopChan := make(chan struct{})
writerDone := make(chan error, 1)
go clientWriter(c, conn, stopChan, writerDone)
readerDone := make(chan error, 1)
go clientReader(c, conn, readerDone)
var err error
select {
case err = <-writerDone:
close(stopChan)
conn.Close()
<-readerDone
case err = <-readerDone:
close(stopChan)
conn.Close()
<-writerDone
case <-c.stopChan:
close(stopChan)
conn.Close()
<-readerDone
<-writerDone
}
if err != nil {
c.LogError("srpc - '%s'\n", err)
}
for _, e := range c.calls {
e.Lock()
if e.state == PENDING {
e.Error = err
e.done <- true
}
e.Unlock()
}
}
func clientWriter(c *RPCClient, conn io.Writer, stopChan <-chan struct{}, done chan<- error) {
var err error
defer func() { done <- err }()
enc := c.Ed.NewEncoder(conn)
for {
err = nil
var requestCall *call
select {
case requestCall = <-c.callChan:
requestCall.Lock()
if err = enc.Encode(requestCall.Request); err != nil {
if !serverDisconnect(err) && !clientStop(c.stopChan) {
requestCall.Error = errors.New("srpc - '%s'=>'%s': Cannot encode request: '%s'\n")
requestCall.Unlock()
requestCall.done <- true
} else {
requestCall.Unlock()
return
}
}
2020-03-25 16:01:56 +01:00
requestCall.Unlock()
2020-03-25 09:45:23 +01:00
case <-stopChan:
return
}
}
}
2020-03-25 16:01:56 +01:00
func clientReader(c *RPCClient, conn io.Reader, done chan<- error) {
var err error
defer func() { done <- err }()
dec := c.Ed.NewDecoder(conn)
for {
var response serverMessage
response.Data = c.ResponseDataHandler()
if err = dec.Decode(&response); err != nil {
if serverDisconnect(err) || clientStop(c.stopChan) {
return
}
}
fmt.Println(response)
ok := false
for _, e := range c.calls {
fmt.Println(e.Request.ID)
e.Lock()
fmt.Println(e.Request.ID)
if e.Request.ID == response.ID {
e.Response = response
if e.Response.Error != "" {
e.Error = errors.New(e.Response.Error)
}
ok = true
e.done <- true
e.Unlock()
break
}
e.Unlock()
}
if !ok {
c.LogError("srpc - Client response for unknown request ID '%d'\n", response.ID)
}
}
2020-03-25 09:45:23 +01:00
}
func serverDisconnect(err error) bool {
return err == io.ErrUnexpectedEOF || err == io.EOF
}
func clientStop(stopChan <-chan struct{}) bool {
select {
case <-stopChan:
return true
default:
return false
}
}
func NewUnixClient(addr string, handler ResponseDataHandlerFunc) *RPCClient {
return &RPCClient{
Addr: addr,
ResponseDataHandler: handler,
DialHandler: unixDial,
Ed: NewEncoderDecoder(),
}
2020-02-25 01:37:05 +01:00
}