2019-10-22 01:32:56 +02:00
|
|
|
package srpc
|
|
|
|
|
|
|
|
import (
|
2019-11-06 01:16:23 +01:00
|
|
|
"errors"
|
2019-11-05 18:21:41 +01:00
|
|
|
"fmt"
|
2020-02-19 01:36:18 +01:00
|
|
|
"io"
|
2019-11-08 02:14:36 +01:00
|
|
|
"net"
|
2019-10-22 01:32:56 +02:00
|
|
|
"sync"
|
2020-02-23 03:45:57 +01:00
|
|
|
"sync/atomic"
|
2020-02-18 03:03:02 +01:00
|
|
|
"time"
|
2019-10-22 01:32:56 +02:00
|
|
|
)
|
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
type serverMessage struct {
|
|
|
|
ID uint64
|
|
|
|
Request []byte
|
|
|
|
Response []byte
|
|
|
|
Error string
|
|
|
|
ClientAddr string
|
|
|
|
}
|
|
|
|
|
|
|
|
type RequestHandlerFunc func(clientAddr string, request []byte) (response []byte, err error)
|
|
|
|
|
|
|
|
type ConnectHandlerFunc func(remoteAddr string, rwc io.ReadWriteCloser) (io.ReadWriteCloser, error)
|
2019-11-05 18:21:41 +01:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
type IListener interface {
|
|
|
|
Init(addr string) error
|
|
|
|
Accept() (conn io.ReadWriteCloser, clientAddr string, err error)
|
|
|
|
Close() error
|
|
|
|
Addr() net.Addr
|
2019-11-05 18:21:41 +01:00
|
|
|
}
|
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
type netListener struct {
|
|
|
|
F func(addr string) (net.Listener, error)
|
|
|
|
L net.Listener
|
2019-10-22 01:32:56 +02:00
|
|
|
}
|
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
func (nl *netListener) Init(addr string) (err error) {
|
|
|
|
nl.L, err = nl.F(addr)
|
|
|
|
return err
|
|
|
|
}
|
2019-11-08 02:14:36 +01:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
func (nl *netListener) Accept() (conn io.ReadWriteCloser, clientAddr string, err error) {
|
|
|
|
var c net.Conn
|
2019-11-08 02:14:36 +01:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
if c, err = nl.L.Accept(); err != nil {
|
|
|
|
return nil, "", err
|
|
|
|
}
|
|
|
|
return c, c.RemoteAddr().String(), nil
|
|
|
|
}
|
2020-02-18 03:03:02 +01:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
func (nl *netListener) Close() error {
|
|
|
|
return nl.L.Close()
|
|
|
|
}
|
2020-02-18 03:03:02 +01:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
func (nl *netListener) Addr() net.Addr {
|
|
|
|
if nl.L == nil {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
return nl.L.Addr()
|
|
|
|
}
|
2019-11-08 02:14:36 +01:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
type RPCServer struct {
|
|
|
|
Addr string
|
|
|
|
Listener IListener
|
|
|
|
LogError LogErrorFunc
|
|
|
|
ConnectHandler ConnectHandlerFunc
|
|
|
|
RequestHandler RequestHandlerFunc
|
|
|
|
Ed IEncoderDecoder
|
|
|
|
stopChan chan struct{}
|
|
|
|
stopWg sync.WaitGroup
|
|
|
|
}
|
2020-02-18 03:03:02 +01:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
func (s *RPCServer) Start() (err error) {
|
|
|
|
if s.LogError == nil {
|
|
|
|
s.LogError = logError
|
2019-11-08 02:14:36 +01:00
|
|
|
}
|
2019-10-22 01:32:56 +02:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
if s.stopChan != nil {
|
|
|
|
return errors.New("srpc - Server already running")
|
|
|
|
}
|
|
|
|
s.stopChan = make(chan struct{})
|
|
|
|
|
|
|
|
if s.RequestHandler == nil {
|
|
|
|
return errors.New("srpc - Server needs a RequestHandler")
|
|
|
|
}
|
|
|
|
|
|
|
|
if s.Listener == nil {
|
|
|
|
s.Listener = &netListener{}
|
|
|
|
}
|
|
|
|
if err = s.Listener.Init(s.Addr); err != nil {
|
|
|
|
err = fmt.Errorf("srpc - '%s' cannot listen to: '%s'", s.Addr, err)
|
|
|
|
s.LogError("%s\n", err)
|
|
|
|
return err
|
2019-11-08 02:14:36 +01:00
|
|
|
}
|
2020-02-23 03:45:57 +01:00
|
|
|
|
|
|
|
s.stopWg.Add(1)
|
|
|
|
go serverHandler(s)
|
|
|
|
return nil
|
2019-11-08 02:14:36 +01:00
|
|
|
}
|
2019-11-05 18:21:41 +01:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
func (s *RPCServer) Serve() (err error) {
|
|
|
|
if err = s.Start(); err != nil {
|
|
|
|
return err
|
2019-11-06 01:16:23 +01:00
|
|
|
}
|
2020-02-23 03:45:57 +01:00
|
|
|
s.stopWg.Wait()
|
|
|
|
return nil
|
|
|
|
}
|
2019-11-05 18:21:41 +01:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
func (s *RPCServer) Stop() {
|
|
|
|
if s.stopChan == nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
close(s.stopChan)
|
|
|
|
s.stopWg.Wait()
|
|
|
|
s.stopChan = nil
|
|
|
|
}
|
2019-11-05 18:21:41 +01:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
func serverHandler(s *RPCServer) {
|
|
|
|
defer s.stopWg.Done()
|
2019-11-05 18:21:41 +01:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
var conn io.ReadWriteCloser
|
|
|
|
var clientAddr string
|
|
|
|
var err error
|
|
|
|
var stopping atomic.Value
|
2019-11-05 18:21:41 +01:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
for {
|
|
|
|
acceptChan := make(chan struct{})
|
|
|
|
go func() {
|
|
|
|
if conn, clientAddr, err = s.Listener.Accept(); err != nil {
|
|
|
|
if stopping.Load() == nil {
|
|
|
|
s.LogError("srpc - '%s' error accepting connection: '%s'\n", s.Addr, err)
|
|
|
|
}
|
2019-11-05 18:21:41 +01:00
|
|
|
}
|
2020-02-23 03:45:57 +01:00
|
|
|
close(acceptChan)
|
|
|
|
}()
|
2019-11-05 18:21:41 +01:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
select {
|
|
|
|
case <-s.stopChan:
|
|
|
|
stopping.Store(true)
|
|
|
|
s.Listener.Close()
|
|
|
|
<-acceptChan
|
|
|
|
return
|
|
|
|
case <-acceptChan:
|
|
|
|
}
|
2019-11-05 18:21:41 +01:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
if err != nil {
|
|
|
|
select {
|
|
|
|
case <-s.stopChan:
|
|
|
|
return
|
|
|
|
case <-time.After(time.Second):
|
2019-11-06 01:16:23 +01:00
|
|
|
}
|
2020-02-23 03:45:57 +01:00
|
|
|
continue
|
2019-11-06 01:16:23 +01:00
|
|
|
}
|
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
s.stopWg.Add(1)
|
|
|
|
go serverHandleConnection(s, conn, clientAddr)
|
2019-11-06 01:16:23 +01:00
|
|
|
}
|
2019-11-05 18:21:41 +01:00
|
|
|
}
|
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
func serverHandleConnection(s *RPCServer, conn io.ReadWriteCloser, clientAddr string) {
|
|
|
|
defer s.stopWg.Done()
|
2019-11-06 01:16:23 +01:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
if s.ConnectHandler != nil {
|
|
|
|
if newConn, err := s.ConnectHandler(clientAddr, conn); err != nil {
|
|
|
|
s.LogError("srpc - '%s'=>'%s': Connect error: '%s'\n", clientAddr, s.Addr, err)
|
|
|
|
conn.Close()
|
|
|
|
return
|
|
|
|
} else {
|
|
|
|
conn = newConn
|
2019-11-06 01:16:23 +01:00
|
|
|
}
|
|
|
|
}
|
2019-11-08 02:14:36 +01:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
var msg serverMessage
|
|
|
|
var err error
|
2020-02-10 02:53:05 +01:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
dec := s.Ed.NewDecoder(conn)
|
|
|
|
enc := s.Ed.NewEncoder(conn)
|
|
|
|
|
|
|
|
for {
|
|
|
|
if err = dec.Decode(&msg); err != nil {
|
|
|
|
if !clientDisconnect(err) && !serverStop(s.stopChan) {
|
|
|
|
s.LogError("srpc - '%s'=>'%s': Cannot decode request: '%s'\n", clientAddr, s.Addr, err)
|
|
|
|
}
|
|
|
|
conn.Close()
|
|
|
|
return
|
|
|
|
}
|
2019-11-08 02:14:36 +01:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
var response serverMessage
|
|
|
|
response.ID = msg.ID
|
|
|
|
response.ClientAddr = msg.ClientAddr
|
|
|
|
response.Request = msg.Request
|
2020-02-19 01:36:18 +01:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
if response.Response, err = s.RequestHandler(msg.ClientAddr, msg.Request); err != nil {
|
|
|
|
s.LogError("srpc - '%s'=>'%s': Error handling request: '%s'\n", clientAddr, s.Addr, err)
|
|
|
|
response.Response = []byte{}
|
|
|
|
response.Error = err.Error()
|
|
|
|
}
|
2020-02-19 01:36:18 +01:00
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
if err = enc.Encode(response); err != nil {
|
|
|
|
if !clientDisconnect(err) && !serverStop(s.stopChan) {
|
|
|
|
s.LogError("srpc - '%s'=>'%s': Cannot encode response: '%s'\n", clientAddr, s.Addr, err)
|
|
|
|
}
|
|
|
|
conn.Close()
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
2020-02-19 01:36:18 +01:00
|
|
|
}
|
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
func clientDisconnect(err error) bool {
|
|
|
|
return err == io.ErrUnexpectedEOF || err == io.EOF
|
2020-02-19 01:36:18 +01:00
|
|
|
}
|
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
func serverStop(stopChan <-chan struct{}) bool {
|
|
|
|
select {
|
|
|
|
case <-stopChan:
|
|
|
|
return true
|
|
|
|
default:
|
|
|
|
return false
|
2020-02-19 01:36:18 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-02-23 03:45:57 +01:00
|
|
|
func NewUnixServer(addr string, handler RequestHandlerFunc) *RPCServer {
|
|
|
|
return &RPCServer{
|
|
|
|
Addr: addr,
|
|
|
|
RequestHandler: handler,
|
|
|
|
Listener: &netListener{
|
|
|
|
F: func(addr string) (net.Listener, error) {
|
|
|
|
return net.Listen("unix", addr)
|
|
|
|
},
|
|
|
|
},
|
|
|
|
Ed: NewEncoderDecoder(),
|
2020-02-19 01:36:18 +01:00
|
|
|
}
|
|
|
|
}
|