diff --git a/internal/session/command.go b/internal/session/command.go index 1b186d21..a7123db4 100644 --- a/internal/session/command.go +++ b/internal/session/command.go @@ -34,11 +34,10 @@ func (s *Session) startCommandReader(ctx context.Context, del string) <-chan com } if err == nil && cmd.GetStartTLS() != nil { - // TLS needs to be handled here in order to ensure that next command read is over the - // tls connection. - if e := s.handleStartTLS(tag, cmd.GetStartTLS()); e != nil { + // TLS needs to be handled here to ensure that next command read is over the TLS connection. + if startTLSErr := s.handleStartTLS(tag, cmd.GetStartTLS()); startTLSErr != nil { cmd = nil - err = e + err = startTLSErr } else { continue } @@ -46,10 +45,10 @@ func (s *Session) startCommandReader(ctx context.Context, del string) <-chan com select { case cmdCh <- command{tag: tag, cmd: cmd, err: err}: + // ... case <-ctx.Done(): return - } } }) diff --git a/internal/session/handle.go b/internal/session/handle.go index a13b20ac..7dfb38ac 100644 --- a/internal/session/handle.go +++ b/internal/session/handle.go @@ -17,10 +17,10 @@ func (s *Session) handleOther( tag string, cmd *proto.Command, profiler profiling.CmdProfiler, -) chan response.Response { +) <-chan response.Response { ch := make(chan response.Response, channelBufferCount) - go func() { + s.handleWG.Go(func() { labels := pprof.Labels("go", "handleOther()", "SessionID", strconv.Itoa(s.sessionID)) pprof.Do(ctx, labels, func(_ context.Context) { defer close(ch) @@ -35,7 +35,7 @@ func (s *Session) handleOther( } } }) - }() + }) return ch } diff --git a/internal/session/session.go b/internal/session/session.go index e190009c..7121b7f3 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -22,6 +22,7 @@ import ( "github.com/ProtonMail/gluon/internal/state" "github.com/ProtonMail/gluon/profiling" "github.com/ProtonMail/gluon/version" + "github.com/ProtonMail/gluon/wait" "github.com/sirupsen/logrus" "golang.org/x/exp/slices" ) @@ -64,12 +65,24 @@ type Session struct { // before the client logs in or selects a mailbox. imapID imap.IMAPID + // version is the version info of the Gluon server. version version.Info + // cmdProfilerBuilder is used in profiling command execution. cmdProfilerBuilder profiling.CmdProfilerBuilder + + // handleWG is used to wait for all commands to finish before closing the session. + handleWG wait.Group } -func New(conn net.Conn, backend *backend.Backend, sessionID int, versionInfo version.Info, profiler profiling.CmdProfilerBuilder, eventCh chan<- events.Event) *Session { +func New( + conn net.Conn, + backend *backend.Backend, + sessionID int, + version version.Info, + profiler profiling.CmdProfilerBuilder, + eventCh chan<- events.Event, +) *Session { return &Session{ conn: conn, liner: liner.New(conn), @@ -77,7 +90,7 @@ func New(conn net.Conn, backend *backend.Backend, sessionID int, versionInfo ver caps: []imap.Capability{imap.IMAP4rev1, imap.IDLE, imap.UNSELECT, imap.UIDPLUS, imap.MOVE}, sessionID: sessionID, eventCh: eventCh, - version: versionInfo, + version: version, cmdProfilerBuilder: profiler, } } @@ -110,6 +123,7 @@ func (s *Session) SetTLSConfig(cfg *tls.Config) { func (s *Session) Serve(ctx context.Context) error { defer s.done(ctx) + defer s.handleWG.Wait() if err := s.greet(); err != nil { return err diff --git a/server.go b/server.go index 9d83f25a..e6dafdaf 100644 --- a/server.go +++ b/server.go @@ -21,6 +21,7 @@ import ( "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/gluon/store" "github.com/ProtonMail/gluon/version" + "github.com/ProtonMail/gluon/wait" "github.com/ProtonMail/gluon/watcher" _ "github.com/mattn/go-sqlite3" "github.com/sirupsen/logrus" @@ -45,7 +46,7 @@ type Server struct { serveDoneCh chan struct{} // serveWG keeps track of serving goroutines. - serveWG WaitGroup + serveWG wait.Group // nextID holds the ID that will be given to the next session. nextID int @@ -166,7 +167,7 @@ func (s *Server) Serve(ctx context.Context, l net.Listener) error { // serve handles incoming connections and starts a new goroutine for each. func (s *Server) serve(ctx context.Context, connCh <-chan net.Conn) { - var connWG WaitGroup + var connWG wait.Group defer connWG.Wait() for { diff --git a/wg.go b/wait/wg.go similarity index 52% rename from wg.go rename to wait/wg.go index 56df244f..f1ed50fe 100644 --- a/wg.go +++ b/wait/wg.go @@ -1,12 +1,12 @@ -package gluon +package wait import "sync" -type WaitGroup struct { +type Group struct { wg sync.WaitGroup } -func (wg *WaitGroup) Go(f func()) { +func (wg *Group) Go(f func()) { wg.wg.Add(1) go func() { @@ -15,6 +15,6 @@ func (wg *WaitGroup) Go(f func()) { }() } -func (wg *WaitGroup) Wait() { +func (wg *Group) Wait() { wg.wg.Wait() }