diff --git a/builder.go b/builder.go index 5cb8f1bb..76a8b665 100644 --- a/builder.go +++ b/builder.go @@ -2,6 +2,8 @@ package gluon import ( "crypto/tls" + "github.com/ProtonMail/gluon/internal/db" + "github.com/sirupsen/logrus" "io" "os" "time" @@ -86,6 +88,12 @@ func (builder *serverBuilder) build() (*Server, error) { return nil, err } + // Defer delete all the previous databases from removed user accounts. This is required since we can't + // close ent databases on demand. + if err := db.DeleteDeferredDBFiles(builder.databaseDir); err != nil { + logrus.WithError(err).Error("Failed to remove old database files") + } + return &Server{ dataDir: builder.dataDir, databaseDir: builder.databaseDir, diff --git a/internal/db/db.go b/internal/db/db.go index 2efa1d0f..8531224b 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "github.com/ProtonMail/gluon/internal/utils" + "github.com/google/uuid" "io/fs" "os" "path/filepath" @@ -109,6 +110,10 @@ func getDatabasePath(dir, userID string) string { return filepath.Join(dir, fmt.Sprintf("%v.db", userID)) } +func getDeferredDeleteDBPath(dir string) string { + return filepath.Join(dir, "deferred_delete") +} + // NewDB creates a new database instance. // If the database does not exist, it will be created and the second return value will be true. func NewDB(dir, userID string) (*DB, bool, error) { @@ -132,8 +137,42 @@ func NewDB(dir, userID string) (*DB, bool, error) { return &DB{db: client}, !exists, nil } +// DeleteDB will rename all the database files for the given user to a directory within the same folder to avoid +// issues with ent not being able to close the database on demand. The database will be cleaned up on the next +// run on the Gluon server. func DeleteDB(dir, userID string) error { - return os.Remove(getDatabasePath(dir, userID)) + // Rather than deleting the files immediately move them to a directory to be cleaned up later. + deferredDeletePath := getDeferredDeleteDBPath(dir) + + if err := os.MkdirAll(deferredDeletePath, 0o700); err != nil { + return fmt.Errorf("failed to create deferred delete dir: %w", err) + } + + matchingFiles, err := filepath.Glob(filepath.Join(dir, userID)) + if err != nil { + return fmt.Errorf("failed to match db files:%w", err) + } + + for _, file := range matchingFiles { + // Use new UUID to avoid conflict with existing files + if err := os.Rename(file, filepath.Join(deferredDeletePath, uuid.NewString())); err != nil { + return fmt.Errorf("failed to move db file '%v' :%w", file, err) + } + } + + return nil +} + +// DeleteDeferredDBFiles deletes all data from previous databases that were scheduled for removal. +func DeleteDeferredDBFiles(dir string) error { + deferredDeleteDir := getDeferredDeleteDBPath(dir) + if err := os.RemoveAll(deferredDeleteDir); err != nil { + if !errors.Is(err, os.ErrNotExist) { + return err + } + } + + return nil } // pathExists returns whether the given file exists.