From c99ec6195798108a15c285838896825ce415eb40 Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Thu, 10 Nov 2022 14:59:25 +0100 Subject: [PATCH] feat: Return mailbox counts when user added --- events/user.go | 4 ++++ internal/backend/backend.go | 33 +++++++++++++++++++++++++++++++++ server.go | 8 +++++++- 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/events/user.go b/events/user.go index c46cc37f..5f36426e 100644 --- a/events/user.go +++ b/events/user.go @@ -1,9 +1,13 @@ package events +import "github.com/ProtonMail/gluon/imap" + type UserAdded struct { eventBase UserID string + + Counts map[imap.MailboxID]int } type UserRemoved struct { diff --git a/internal/backend/backend.go b/internal/backend/backend.go index 77fc455c..25b1a420 100644 --- a/internal/backend/backend.go +++ b/internal/backend/backend.go @@ -9,7 +9,10 @@ import ( "time" "github.com/ProtonMail/gluon/connector" + "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/db" + "github.com/ProtonMail/gluon/internal/db/ent" + "github.com/ProtonMail/gluon/internal/db/ent/mailbox" "github.com/ProtonMail/gluon/internal/state" "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/gluon/store" @@ -122,6 +125,36 @@ func (b *Backend) RemoveUser(ctx context.Context, userID string, removeFiles boo return nil } +func (b *Backend) GetMailboxMessageCounts(ctx context.Context, userID string) (map[imap.MailboxID]int, error) { + b.usersLock.Lock() + defer b.usersLock.Unlock() + + user, ok := b.users[userID] + if !ok { + return nil, ErrNoSuchUser + } + + return db.ReadResult(ctx, user.db, func(ctx context.Context, c *ent.Client) (map[imap.MailboxID]int, error) { + counts := make(map[imap.MailboxID]int) + + mailboxes, err := c.Mailbox.Query().Select(mailbox.FieldRemoteID).All(ctx) + if err != nil { + return nil, err + } + + for _, mailbox := range mailboxes { + messageCount, err := mailbox.QueryUIDs().Count(ctx) + if err != nil { + return nil, err + } + + counts[mailbox.RemoteID] = messageCount + } + + return counts, nil + }) +} + func (b *Backend) GetState(ctx context.Context, username string, password []byte, sessionID int) (*state.State, error) { b.usersLock.Lock() defer b.usersLock.Unlock() diff --git a/server.go b/server.go index 1f59ffb6..1210d4cf 100644 --- a/server.go +++ b/server.go @@ -116,11 +116,17 @@ func (s *Server) LoadUser(ctx context.Context, conn connector.Connector, userID ctx = reporter.NewContextWithReporter(ctx, s.reporter) if err := s.backend.AddUser(ctx, userID, conn, passphrase); err != nil { - return err + return fmt.Errorf("failed to add user: %w", err) + } + + counts, err := s.backend.GetMailboxMessageCounts(ctx, userID) + if err != nil { + return fmt.Errorf("failed to get counts: %w", err) } s.publish(events.UserAdded{ UserID: userID, + Counts: counts, }) return nil