srpc/client.go

356 lines
6.9 KiB
Go
Raw Normal View History

2019-10-22 01:32:56 +02:00
package srpc
2020-02-10 02:53:05 +01:00
import (
2020-05-25 03:40:38 +02:00
"bytes"
2020-02-25 01:37:05 +01:00
"errors"
2020-02-10 02:53:05 +01:00
"io"
"net"
2020-02-18 03:03:02 +01:00
"sync"
2020-02-25 01:37:05 +01:00
"sync/atomic"
"time"
2020-02-10 02:53:05 +01:00
)
2020-03-24 08:08:27 +01:00
type clientMessage struct {
ID uint64
Data interface{}
ClientAddr string
}
type call struct {
Response serverMessage
2020-03-25 09:45:23 +01:00
Request clientMessage
Error error
state messageState
done chan bool
sync.RWMutex
}
func newCall() (ret *call) {
ret = new(call)
ret.state = FREE
ret.done = make(chan bool)
return ret
2020-03-24 08:08:27 +01:00
}
2020-02-24 17:03:24 +01:00
type ResponseDataHandlerFunc func() (response interface{})
type DialHandlerFunc func(addr string) (conn io.ReadWriteCloser, err error)
2020-02-10 02:53:05 +01:00
2020-02-24 17:03:24 +01:00
func unixDial(addr string) (conn io.ReadWriteCloser, err error) {
if conn, err = net.Dial("unix", addr); err != nil {
return nil, err
2020-02-10 02:53:05 +01:00
}
2020-02-24 17:03:24 +01:00
return conn, nil
}
2020-02-10 02:53:05 +01:00
2020-02-24 17:03:24 +01:00
type RPCClient struct {
Addr string
2020-03-25 09:45:23 +01:00
MaxRequests int
2020-02-24 17:03:24 +01:00
LogError LogErrorFunc
ConnectHander ConnectHandlerFunc
ResponseDataHandler ResponseDataHandlerFunc
DialHandler DialHandlerFunc
Ed IEncoderDecoder
2020-03-26 01:28:15 +01:00
RequestTimeout time.Duration
2020-03-25 09:45:23 +01:00
calls []*call
callChan chan *call
nextID uint64
2020-02-24 17:03:24 +01:00
stopChan chan struct{}
stopWg sync.WaitGroup
}
2020-02-18 03:03:02 +01:00
2020-02-24 17:03:24 +01:00
func (c *RPCClient) Start() (err error) {
2020-03-25 09:45:23 +01:00
if c.MaxRequests <= 0 {
c.MaxRequests = DefaultMaxClientRequests
}
c.calls = make([]*call, c.MaxRequests)
for i := 0; i < c.MaxRequests; i++ {
c.calls[i] = newCall()
}
c.nextID = 0
c.callChan = make(chan *call)
2020-02-24 17:03:24 +01:00
if c.LogError == nil {
c.LogError = logError
}
2020-02-10 02:53:05 +01:00
2020-02-24 17:03:24 +01:00
if c.stopChan != nil {
return errors.New("srpc - Client already running")
}
2020-02-24 17:03:24 +01:00
c.stopChan = make(chan struct{})
2020-02-10 02:53:05 +01:00
2020-02-24 17:03:24 +01:00
if c.ResponseDataHandler == nil {
return errors.New("srpc - Client needs a ResponseDataHandler")
2020-02-18 03:03:02 +01:00
}
2020-02-24 17:03:24 +01:00
if c.DialHandler == nil {
return errors.New("srpc - Client needs a DialHandler")
2020-02-14 01:51:30 +01:00
}
2020-03-26 01:28:15 +01:00
if c.RequestTimeout == 0 {
c.RequestTimeout = DefaultRequestTimeout
}
2020-02-24 17:03:24 +01:00
c.stopWg.Add(1)
2020-03-25 09:45:23 +01:00
go clientHandler(c)
2020-02-24 17:03:24 +01:00
return nil
2020-02-10 02:53:05 +01:00
}
2020-02-24 17:03:24 +01:00
func (c *RPCClient) Stop() {
if c.stopChan == nil {
return
2020-02-19 01:36:18 +01:00
}
2020-02-24 17:03:24 +01:00
close(c.stopChan)
c.stopWg.Wait()
c.stopChan = nil
2020-02-19 01:36:18 +01:00
}
2020-02-25 01:37:05 +01:00
2020-05-25 03:40:38 +02:00
func (c *RPCClient) MakeRequest(fctName string, params ...interface{}) (request RPCRequest, err error) {
data := bytes.Buffer{}
enc := c.Ed.NewEncoder(&data)
for _, param := range params {
if err = enc.Encode(param); err != nil {
return request, err
}
}
request.FuncName = fctName
request.Payload = data.Bytes()
return request, nil
}
func (c *RPCClient) Call(fctName string, params ...interface{}) (response interface{}, err error) {
2020-03-25 09:45:23 +01:00
var requestCall *call
2020-05-25 03:40:38 +02:00
var request RPCRequest
if request, err = c.MakeRequest(fctName, params...); err != nil {
return nil, errors.New("srpc - Client request encode failed: '" + err.Error() + "'")
}
2020-03-25 09:45:23 +01:00
for _, e := range c.calls {
e.Lock()
if e.state == FREE {
requestCall = e
requestCall.state = PENDING
requestCall.Request.ID = c.nextID
requestCall.Request.ClientAddr = c.Addr
requestCall.Request.Data = request
c.nextID++
e.Unlock()
break
}
e.Unlock()
}
if requestCall == nil {
return nil, errors.New("srpc - Client requests are full")
}
c.callChan <- requestCall
2020-03-26 01:28:15 +01:00
select {
case <-requestCall.done:
requestCall.Lock()
response = requestCall.Response.Data
err = requestCall.Error
requestCall.state = FREE
requestCall.Unlock()
case <-time.After(c.RequestTimeout):
requestCall.Lock()
response = nil
err = errors.New("srpc - Request timed out")
requestCall.state = FREE
requestCall.Unlock()
}
2020-03-25 09:45:23 +01:00
return response, err
}
2020-02-25 01:37:05 +01:00
func clientHandler(c *RPCClient) {
defer c.stopWg.Done()
var conn io.ReadWriteCloser
var err error
var stopping atomic.Value
for {
dialChan := make(chan struct{})
go func() {
if conn, err = c.DialHandler(c.Addr); err != nil {
if stopping.Load() == nil {
2020-05-25 03:40:38 +02:00
c.LogError("srpc - '%s' cannot establish rpc connection: '%s'\n", c.Addr, err)
2020-02-25 01:37:05 +01:00
}
}
close(dialChan)
}()
select {
case <-c.stopChan:
stopping.Store(true)
<-dialChan
return
case <-dialChan:
}
if err != nil {
select {
case <-c.stopChan:
return
case <-time.After(time.Second):
}
continue
}
2020-03-25 09:45:23 +01:00
clientHandleConnection(c, conn)
select {
case <-c.stopChan:
return
default:
}
2020-03-24 08:08:27 +01:00
}
}
func clientHandleConnection(c *RPCClient, conn io.ReadWriteCloser) {
if c.ConnectHander != nil {
if newConn, err := c.ConnectHander(c.Addr, conn); err != nil {
c.LogError("srpc - Client: [%s] - Connect error: '%s'\n", c.Addr, err)
conn.Close()
return
} else {
conn = newConn
}
2020-02-25 01:37:05 +01:00
}
2020-03-25 09:45:23 +01:00
stopChan := make(chan struct{})
writerDone := make(chan error, 1)
go clientWriter(c, conn, stopChan, writerDone)
readerDone := make(chan error, 1)
go clientReader(c, conn, readerDone)
var err error
select {
case err = <-writerDone:
close(stopChan)
conn.Close()
<-readerDone
case err = <-readerDone:
close(stopChan)
conn.Close()
<-writerDone
case <-c.stopChan:
close(stopChan)
conn.Close()
<-readerDone
<-writerDone
}
if err != nil {
c.LogError("srpc - '%s'\n", err)
}
for _, e := range c.calls {
e.Lock()
if e.state == PENDING {
e.Error = err
e.done <- true
}
e.Unlock()
}
}
func clientWriter(c *RPCClient, conn io.Writer, stopChan <-chan struct{}, done chan<- error) {
var err error
defer func() { done <- err }()
enc := c.Ed.NewEncoder(conn)
for {
err = nil
var requestCall *call
select {
case requestCall = <-c.callChan:
requestCall.Lock()
if err = enc.Encode(requestCall.Request); err != nil {
if !serverDisconnect(err) && !clientStop(c.stopChan) {
requestCall.Error = errors.New("srpc - '%s'=>'%s': Cannot encode request: '%s'\n")
requestCall.Unlock()
requestCall.done <- true
} else {
requestCall.Unlock()
return
}
}
2020-03-25 16:01:56 +01:00
requestCall.Unlock()
2020-03-25 09:45:23 +01:00
case <-stopChan:
return
}
}
}
2020-03-25 16:01:56 +01:00
func clientReader(c *RPCClient, conn io.Reader, done chan<- error) {
var err error
defer func() { done <- err }()
dec := c.Ed.NewDecoder(conn)
for {
var response serverMessage
response.Data = c.ResponseDataHandler()
if err = dec.Decode(&response); err != nil {
if serverDisconnect(err) || clientStop(c.stopChan) {
return
}
}
ok := false
for _, e := range c.calls {
e.Lock()
if e.Request.ID == response.ID {
e.Response = response
if e.Response.Error != "" {
e.Error = errors.New(e.Response.Error)
}
ok = true
e.done <- true
e.Unlock()
break
}
e.Unlock()
}
if !ok {
c.LogError("srpc - Client response for unknown request ID '%d'\n", response.ID)
}
}
2020-03-25 09:45:23 +01:00
}
func serverDisconnect(err error) bool {
return err == io.ErrUnexpectedEOF || err == io.EOF
}
func clientStop(stopChan <-chan struct{}) bool {
select {
case <-stopChan:
return true
default:
return false
}
}
2020-03-26 01:28:15 +01:00
func NewUnixClient(addr string, handler ResponseDataHandlerFunc, timeout time.Duration) *RPCClient {
2020-03-25 09:45:23 +01:00
return &RPCClient{
Addr: addr,
ResponseDataHandler: handler,
DialHandler: unixDial,
Ed: NewEncoderDecoder(),
2020-03-26 01:28:15 +01:00
RequestTimeout: timeout,
2020-03-25 09:45:23 +01:00
}
2020-02-25 01:37:05 +01:00
}