diff --git a/server/service.go b/server/service.go index ec46cb2..48f51f1 100644 --- a/server/service.go +++ b/server/service.go @@ -39,6 +39,9 @@ type Service struct { upgrader websocket.Upgrader conns map[string]*wsConn // Connections by wsConn Id's wg sync.WaitGroup // Wait for all connections to be disconnected + + // handlers for testing + onWSClose func(*websocket.Conn) } // NewService creates a new Service @@ -75,6 +78,20 @@ func (s *Service) SetLogger(l logger.Logger) *Service { return s } +// SetOnWSClose sets a callback to be calld when a websocket connection is +// closed. Used for testing. +func (s *Service) SetOnWSClose(cb func(ws *websocket.Conn)) *Service { + s.mu.Lock() + defer s.mu.Unlock() + + if s.stop != nil { + panic("SetOnWSClose must be called before starting server") + } + + s.onWSClose = cb + return s +} + // Logf writes a formatted log message func (s *Service) Logf(format string, v ...interface{}) { s.logger.Log(fmt.Sprintf(format, v...)) diff --git a/server/wsHandler.go b/server/wsHandler.go index 2f4a148..7e2c594 100644 --- a/server/wsHandler.go +++ b/server/wsHandler.go @@ -93,6 +93,10 @@ func (s *Service) wsHandler(w http.ResponseWriter, r *http.Request) { if s.metrics != nil { s.metrics.WSConnections.Add(-1) } + + if s.onWSClose != nil { + s.onWSClose(ws) + } } // wsHeaderAuth sends an auth resource request if WSHeaderAuth is set, and diff --git a/test/test.go b/test/test.go index 7f60273..e543c60 100644 --- a/test/test.go +++ b/test/test.go @@ -29,6 +29,7 @@ type Session struct { *NATSTestClient s *server.Service conns map[*Conn]struct{} + dcCh chan struct{} *CountLogger } @@ -50,6 +51,15 @@ func setup(t *testing.T, cfgs ...func(*server.Config)) *Session { CountLogger: l, } + // Set on WS close handler to synchronize tests with WebSocket disconnects. + serv.SetOnWSClose(func(_ *websocket.Conn) { + ch := s.dcCh + s.dcCh = nil + if ch != nil { + close(ch) + } + }) + if err := serv.Start(); err != nil { panic("test: failed to start server: " + err.Error()) } diff --git a/test/ws.go b/test/ws.go index 349237e..f6e8110 100644 --- a/test/ws.go +++ b/test/ws.go @@ -115,8 +115,24 @@ func (c *Conn) Request(method string, params interface{}) *ClientRequest { // Disconnect closes the connection to the gateway func (c *Conn) Disconnect() { + var dcCh chan struct{} + if c.s.dcCh == nil { + dcCh = make(chan struct{}) + c.s.dcCh = dcCh + } + c.ws.Close() <-c.closeCh + + // Await synchronization + if dcCh != nil { + select { + case <-dcCh: + case <-time.After(time.Second): + } + } + + delete(c.s.conns, c) } // PanicOnError panics if the connection has encountered an error.