srpc/server.go

240 lines
4.8 KiB
Go

package srpc
import (
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
)
type serverMessage struct {
ID uint64
Data interface{}
Error string
ClientAddr string
}
type RequestDataHandlerFunc func() (request interface{})
type RequestHandlerFunc func(clientAddr string, request interface{}) (response interface{})
type ConnectHandlerFunc func(remoteAddr string, rwc io.ReadWriteCloser) (io.ReadWriteCloser, error)
type IListener interface {
Init(addr string) error
Accept() (conn io.ReadWriteCloser, clientAddr string, err error)
Close() error
Addr() net.Addr
}
type netListener struct {
F func(addr string) (net.Listener, error)
L net.Listener
}
func (nl *netListener) Init(addr string) (err error) {
nl.L, err = nl.F(addr)
return err
}
func (nl *netListener) Accept() (conn io.ReadWriteCloser, clientAddr string, err error) {
var c net.Conn
if c, err = nl.L.Accept(); err != nil {
return nil, "", err
}
return c, c.RemoteAddr().String(), nil
}
func (nl *netListener) Close() error {
return nl.L.Close()
}
func (nl *netListener) Addr() net.Addr {
if nl.L == nil {
return nil
}
return nl.L.Addr()
}
type RPCServer struct {
Addr string
Listener IListener
LogError LogErrorFunc
ConnectHandler ConnectHandlerFunc
RequestDataHandler RequestDataHandlerFunc
RequestHandler RequestHandlerFunc
Ed IEncoderDecoder
stopChan chan struct{}
stopWg sync.WaitGroup
}
func (s *RPCServer) Start() (err error) {
if s.LogError == nil {
s.LogError = logError
}
if s.stopChan != nil {
return errors.New("srpc - Server already running")
}
s.stopChan = make(chan struct{})
if s.RequestDataHandler == nil {
return errors.New("srpc - Server needs a RequestHandlerData")
}
if s.RequestHandler == nil {
return errors.New("srpc - Server needs a RequestHandler")
}
if s.Listener == nil {
s.Listener = &netListener{}
}
if err = s.Listener.Init(s.Addr); err != nil {
s.LogError("%s\n", err)
return err
}
s.stopWg.Add(1)
go serverHandler(s)
return nil
}
func (s *RPCServer) Serve() (err error) {
if err = s.Start(); err != nil {
return err
}
s.stopWg.Wait()
return nil
}
func (s *RPCServer) Stop() {
if s.stopChan == nil {
return
}
close(s.stopChan)
s.stopWg.Wait()
s.stopChan = nil
}
func serverHandler(s *RPCServer) {
defer s.stopWg.Done()
var conn io.ReadWriteCloser
var clientAddr string
var err error
var stopping atomic.Value
for {
acceptChan := make(chan struct{})
go func() {
if conn, clientAddr, err = s.Listener.Accept(); err != nil {
if stopping.Load() == nil {
s.LogError("srpc - '%s' error accepting connection: '%s'\n", s.Addr, err)
}
}
close(acceptChan)
}()
select {
case <-s.stopChan:
stopping.Store(true)
s.Listener.Close()
<-acceptChan
return
case <-acceptChan:
}
if err != nil {
select {
case <-s.stopChan:
return
case <-time.After(time.Second):
}
continue
}
s.stopWg.Add(1)
go serverHandleConnection(s, conn, clientAddr)
}
}
func serverHandleConnection(s *RPCServer, conn io.ReadWriteCloser, clientAddr string) {
defer s.stopWg.Done()
if s.ConnectHandler != nil {
if newConn, err := s.ConnectHandler(clientAddr, conn); err != nil {
s.LogError("srpc - '%s'=>'%s': Connect error: '%s'\n", clientAddr, s.Addr, err)
conn.Close()
return
} else {
conn = newConn
}
}
var err error
dec := s.Ed.NewDecoder(conn)
enc := s.Ed.NewEncoder(conn)
for {
//TODO: check for multiple requests per client
var msg clientMessage
msg.Data = s.RequestDataHandler()
if err = dec.Decode(&msg); err != nil {
if !clientDisconnect(err) && !serverStop(s.stopChan) {
s.LogError("srpc - '%s'=>'%s': Cannot decode request: '%s'\n", clientAddr, s.Addr, err)
}
conn.Close()
return
}
fmt.Println(msg.Data)
go func() {
var response serverMessage
response.ID = msg.ID
response.ClientAddr = msg.ClientAddr
response.Data = s.RequestHandler(msg.ClientAddr, msg.Data)
if err = enc.Encode(response); err != nil {
if !clientDisconnect(err) && !serverStop(s.stopChan) {
s.LogError("srpc - '%s'=>'%s': Cannot encode response: '%s'\n", clientAddr, s.Addr, err)
}
conn.Close()
return
}
}()
}
}
func clientDisconnect(err error) bool {
return err == io.ErrUnexpectedEOF || err == io.EOF
}
func serverStop(stopChan <-chan struct{}) bool {
select {
case <-stopChan:
return true
default:
return false
}
}
func NewUnixServer(addr string, dhandler RequestDataHandlerFunc, handler RequestHandlerFunc) *RPCServer {
return &RPCServer{
Addr: addr,
RequestDataHandler: dhandler,
RequestHandler: handler,
Listener: &netListener{
F: func(addr string) (net.Listener, error) {
return net.Listen("unix", addr)
},
},
Ed: NewEncoderDecoder(),
}
}