package srpc import ( "errors" "fmt" "io" "net" "olznet.de/slog" "sync" "sync/atomic" "time" ) type serverMessage struct { ID uint64 Data interface{} Error string ClientAddr string } type RequestDataHandlerFunc func() (request interface{}) type RequestHandlerFunc func(clientAddr string, request interface{}) (response interface{}) type ConnectHandlerFunc func(remoteAddr string, rwc io.ReadWriteCloser) (io.ReadWriteCloser, error) type IListener interface { Init(addr string) error Accept() (conn io.ReadWriteCloser, clientAddr string, err error) Close() error Addr() net.Addr } type netListener struct { F func(addr string) (net.Listener, error) L net.Listener } func (nl *netListener) Init(addr string) (err error) { nl.L, err = nl.F(addr) return err } func (nl *netListener) Accept() (conn io.ReadWriteCloser, clientAddr string, err error) { var c net.Conn if c, err = nl.L.Accept(); err != nil { return nil, "", err } return c, c.RemoteAddr().String(), nil } func (nl *netListener) Close() error { return nl.L.Close() } func (nl *netListener) Addr() net.Addr { if nl.L == nil { return nil } return nl.L.Addr() } type RPCServer struct { Addr string Listener IListener LogError LogErrorFunc ConnectHandler ConnectHandlerFunc RequestDataHandler RequestDataHandlerFunc RequestHandler RequestHandlerFunc SimultanousClientRequests int Ed IEncoderDecoder stopChan chan struct{} stopWg sync.WaitGroup } func (s *RPCServer) Start() (err error) { if s.LogError == nil { s.LogError = logError } if s.stopChan != nil { return errors.New("srpc - Server already running") } s.stopChan = make(chan struct{}) if s.RequestDataHandler == nil { return errors.New("srpc - Server needs a RequestHandlerData") } if s.RequestHandler == nil { return errors.New("srpc - Server needs a RequestHandler") } if s.SimultanousClientRequests == 0 { s.SimultanousClientRequests = DefaultSimultanousClientRequests } if s.Listener == nil { s.Listener = &netListener{} } if err = s.Listener.Init(s.Addr); err != nil { s.LogError("%s\n", err) return err } s.stopWg.Add(1) go serverHandler(s) return nil } func (s *RPCServer) Serve() (err error) { if err = s.Start(); err != nil { return err } s.stopWg.Wait() return nil } func (s *RPCServer) Stop() { if s.stopChan == nil { return } close(s.stopChan) s.stopWg.Wait() s.stopChan = nil } func serverHandler(s *RPCServer) { defer s.stopWg.Done() var conn io.ReadWriteCloser var clientAddr string var err error var stopping atomic.Value for { acceptChan := make(chan struct{}) go func() { if conn, clientAddr, err = s.Listener.Accept(); err != nil { if stopping.Load() == nil { s.LogError("srpc - '%s' error accepting connection: '%s'\n", s.Addr, err) } } close(acceptChan) }() select { case <-s.stopChan: stopping.Store(true) s.Listener.Close() <-acceptChan return case <-acceptChan: } if err != nil { select { case <-s.stopChan: return case <-time.After(time.Second): } continue } s.stopWg.Add(1) go serverHandleConnection(s, conn, clientAddr) } } func serverHandleConnection(s *RPCServer, conn io.ReadWriteCloser, clientAddr string) { defer func() { slog.LOG_INFO("Client disconnected\n") s.stopWg.Done() }() if s.ConnectHandler != nil { if newConn, err := s.ConnectHandler(clientAddr, conn); err != nil { s.LogError("srpc - '%s'=>'%s': Connect error: '%s'\n", clientAddr, s.Addr, err) conn.Close() return } else { conn = newConn } } var err error dec := s.Ed.NewDecoder(conn) enc := s.Ed.NewEncoder(conn) clientConns := 0 clientConnsLock := sync.Mutex{} for { if clientConns == s.SimultanousClientRequests { if !clientDisconnect(err) && !serverStop(s.stopChan) { s.LogError("srpc - '%s'=>'%s': Client reached max requests\n", clientAddr, s.Addr) } return } clientConnsLock.Lock() clientConns++ clientConnsLock.Unlock() var msg clientMessage msg.Data = s.RequestDataHandler() if err = dec.Decode(&msg); err != nil { if !clientDisconnect(err) && !serverStop(s.stopChan) { s.LogError("srpc - '%s'=>'%s': Cannot decode request: '%s'\n", clientAddr, s.Addr, err) } conn.Close() clientConnsLock.Lock() clientConns-- clientConnsLock.Unlock() return } slog.LOG_DEBUG(fmt.Sprintln(msg.Data)) go func() { defer func() { clientConnsLock.Lock() clientConns-- clientConnsLock.Unlock() }() var response serverMessage response.ID = msg.ID response.ClientAddr = msg.ClientAddr response.Data = s.RequestHandler(msg.ClientAddr, msg.Data) if err = enc.Encode(response); err != nil { if !clientDisconnect(err) && !serverStop(s.stopChan) { s.LogError("srpc - '%s'=>'%s': Cannot encode response: '%s'\n", clientAddr, s.Addr, err) } conn.Close() return } }() } } func clientDisconnect(err error) bool { return err == io.ErrUnexpectedEOF || err == io.EOF } func serverStop(stopChan <-chan struct{}) bool { select { case <-stopChan: return true default: return false } } func NewUnixServer(addr string, dhandler RequestDataHandlerFunc, handler RequestHandlerFunc) *RPCServer { return &RPCServer{ Addr: addr, RequestDataHandler: dhandler, RequestHandler: handler, Listener: &netListener{ F: func(addr string) (net.Listener, error) { return net.Listen("unix", addr) }, }, Ed: NewEncoderDecoder(), } }