diff --git a/internal/backend/state.go b/internal/backend/state.go index 0a34deb2..7193657c 100644 --- a/internal/backend/state.go +++ b/internal/backend/state.go @@ -29,7 +29,6 @@ type State struct { ro bool doneCh chan struct{} - stopCh chan struct{} updatesQueue *queue.QueuedChannel[stateUpdate] @@ -375,8 +374,6 @@ func (state *State) Done() <-chan struct{} { } func (state *State) Close(ctx context.Context) error { - defer close(state.stopCh) - return state.user.removeState(ctx, state.stateID) } diff --git a/internal/backend/user.go b/internal/backend/user.go index 1bea0167..a9ea3cb8 100644 --- a/internal/backend/user.go +++ b/internal/backend/user.go @@ -28,6 +28,9 @@ type user struct { updateWG sync.WaitGroup updateQuitCh chan struct{} + + // statesWG is + statesWG sync.WaitGroup } func newUser(ctx context.Context, userID string, db *DB, remote *remote.User, store store.Store, delimiter string) (*user, error) { @@ -77,6 +80,10 @@ func newUser(ctx context.Context, userID string, db *DB, remote *remote.User, st func (user *user) close(ctx context.Context) error { user.closeStates() + // Ensure we wait until all states have been removed/closed by any active sessions otherwise we run into issues + // since we close the database in this function. + user.statesWG.Wait() + close(user.updateQuitCh) // Wait until the connector update go routine has finished. diff --git a/internal/backend/user_state.go b/internal/backend/user_state.go index 6a57f7d8..02a488af 100644 --- a/internal/backend/user_state.go +++ b/internal/backend/user_state.go @@ -25,6 +25,8 @@ func (user *user) newState(metadataID remote.ConnMetadataID) (*State, error) { user.states[user.nextStateID] = newState + user.statesWG.Add(1) + return newState, nil } @@ -64,6 +66,9 @@ func (user *user) removeState(ctx context.Context, stateID int) error { return err } + // After this point we need to notify the WaitGroup or we risk deadlocks. + defer user.statesWG.Done() + if err := user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { return DBDeleteMessages(ctx, tx, messageIDs...) }); err != nil { @@ -109,8 +114,10 @@ func (user *user) getStates() []*State { } func (user *user) closeStates() { - for _, state := range user.getStates() { + user.statesLock.RLock() + defer user.statesLock.RUnlock() + + for _, state := range user.states { close(state.doneCh) - <-state.stopCh } } diff --git a/server.go b/server.go index 7d35e633..514110ef 100644 --- a/server.go +++ b/server.go @@ -62,6 +62,8 @@ type Server struct { // versionInfo holds info about the Gluon version. versionInfo internal.VersionInfo + + connectionWG sync.WaitGroup } // New creates a new server with the given options. @@ -152,8 +154,7 @@ func (s *Server) Serve(ctx context.Context, l net.Listener) chan error { s.addListener(l) defer s.removeListener(l) - var wg sync.WaitGroup - defer wg.Wait() + defer s.connectionWG.Wait() for { conn, err := l.Accept() @@ -161,10 +162,10 @@ func (s *Server) Serve(ctx context.Context, l net.Listener) chan error { return } - wg.Add(1) + s.connectionWG.Add(1) go func() { - defer wg.Done() + defer s.connectionWG.Done() s.handleConn(ctx, conn, errCh) }() } @@ -180,6 +181,8 @@ func (s *Server) Close(ctx context.Context) error { s.removeListener(l) } + s.connectionWG.Wait() + if err := s.backend.Close(ctx); err != nil { return fmt.Errorf("failed to close backend: %w", err) }