From f658b1d0cb6edc9d4815ca78c259f4420a3ed873 Mon Sep 17 00:00:00 2001 From: Stojan Dimitrovski Date: Tue, 9 May 2023 15:16:26 +0200 Subject: [PATCH] feat: add database cleanup logic, runs after each request (#875) Certain database entities such as refresh tokens and sessions pile up though normal operation without being cleaned up. This PR attempts to solve the problem by using a `models.Cleanup` function which takes care of these entities. The cleanup runs after each request on non-idempotent HTTP methods (POST, PUT, DELETE, PATCH). It uses fast deletes and updates using [`FOR UPDATE SKIP LOCKED`](https://www.postgresql.org/docs/current/sql-select.html#SQL-FOR-UPDATE-SHARE) so that deletes don't wait for other transactions to complete. It runs after each request as this model scales better than a background job that runs periodically as it is using resources only when the API is being used externally, making database use proportional to work performed. Rows are deleted about 24-72 hours after they have expired to aid in debugging if ever necessary. --- internal/api/api.go | 4 + internal/api/middleware.go | 25 +++++ internal/conf/configuration.go | 1 + internal/models/cleanup.go | 94 +++++++++++++++++++ internal/models/cleanup_test.go | 38 ++++++++ internal/models/db_test.go | 9 +- internal/models/factor_test.go | 1 + .../20230508135423_add_cleanup_indexes.up.sql | 17 ++++ 8 files changed, 184 insertions(+), 5 deletions(-) create mode 100644 internal/models/cleanup.go create mode 100644 internal/models/cleanup_test.go create mode 100644 migrations/20230508135423_add_cleanup_indexes.up.sql diff --git a/internal/api/api.go b/internal/api/api.go index 77bd27590..49b176cf6 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -81,6 +81,10 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati r.UseBypass(xffmw.Handler) r.Use(recoverer) + if globalConfig.DB.CleanupEnabled { + r.UseBypass(api.databaseCleanup) + } + r.Get("/health", api.HealthCheck) r.Route("/callback", func(r *router) { diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 8300a19f4..a0bf15591 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "github.com/supabase/gotrue/internal/models" "github.com/supabase/gotrue/internal/observability" "github.com/supabase/gotrue/internal/security" @@ -186,3 +187,27 @@ func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (cont } return ctx, nil } + +func (a *API) databaseCleanup(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + + switch r.Method { + case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete: + // continue + + default: + return + } + + db := a.db.WithContext(r.Context()) + log := observability.GetLogEntry(r) + + affectedRows, err := models.Cleanup(db) + if err != nil { + log.WithError(err).WithField("affected_rows", affectedRows).Warn("database cleanup failed") + } else if affectedRows > 0 { + log.WithField("affected_rows", affectedRows).Debug("cleaned up expired or stale rows") + } + }) +} diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index b104e6eb9..576cd920c 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -43,6 +43,7 @@ type DBConfiguration struct { ConnMaxIdleTime time.Duration `json:"conn_max_idle_time,omitempty" split_words:"true"` HealthCheckPeriod time.Duration `json:"health_check_period" split_words:"true"` MigrationsPath string `json:"migrations_path" split_words:"true" default:"./migrations"` + CleanupEnabled bool `json:"cleanup_enabled" split_words:"true" default:"false"` } func (c *DBConfiguration) Validate() error { diff --git a/internal/models/cleanup.go b/internal/models/cleanup.go new file mode 100644 index 000000000..ea669b79f --- /dev/null +++ b/internal/models/cleanup.go @@ -0,0 +1,94 @@ +package models + +import ( + "fmt" + "sync/atomic" + + "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/attribute" + metricglobal "go.opentelemetry.io/otel/metric/global" + metricinstrument "go.opentelemetry.io/otel/metric/instrument" + otelasyncint64instrument "go.opentelemetry.io/otel/metric/instrument/asyncint64" + + "github.com/supabase/gotrue/internal/observability" + "github.com/supabase/gotrue/internal/storage" +) + +// cleanupNext holds an atomically incrementing value that determines which of +// the cleanupStatements will be run next. +var cleanupNext uint32 + +// cleanupStatements holds all of the possible cleanup raw SQL. Only one at a +// time is executed using cleanupNext % len(cleanupStatements). +var CleanupStatements []string + +// cleanupAffectedRows tracks an OpenTelemetry metric on the total number of +// cleaned up rows. +var cleanupAffectedRows otelasyncint64instrument.Counter + +func init() { + tableRefreshTokens := RefreshToken{}.TableName() + tableSessions := Session{}.TableName() + tableRelayStates := SAMLRelayState{}.TableName() + tableFlowStates := FlowState{}.TableName() + + // These statements intentionally use SELECT ... FOR UPDATE SKIP LOCKED + // as this makes sure that only rows that are not being used in another + // transaction are deleted. These deletes are thus very quick and + // efficient, as they don't wait on other transactions. + CleanupStatements = append(CleanupStatements, + fmt.Sprintf("delete from %q where id in (select id from %q where revoked is true and updated_at < now() - interval '24 hours' limit 100 for update skip locked);", tableRefreshTokens, tableRefreshTokens), + fmt.Sprintf("update %q set revoked = true, updated_at = now() where id in (select %q.id from %q join %q on %q.session_id = %q.id where %q.not_after < now() - interval '24 hours' and %q.revoked is false limit 100 for update skip locked);", tableRefreshTokens, tableRefreshTokens, tableRefreshTokens, tableSessions, tableRefreshTokens, tableSessions, tableSessions, tableRefreshTokens), + // sessions are deleted after 72 hours to allow refresh tokens + // to be deleted piecemeal; 10 at once so that cascades don't + // overwork the database + fmt.Sprintf("delete from %q where id in (select id from %q where not_after < now() - interval '72 hours' limit 10 for update skip locked);", tableSessions, tableSessions), + fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '24 hours' limit 100 for update skip locked);", tableRelayStates, tableRelayStates), + fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '24 hours' limit 100 for update skip locked);", tableFlowStates, tableFlowStates), + ) + + var err error + cleanupAffectedRows, err = metricglobal.Meter("gotrue").AsyncInt64().Counter( + "gotrue_cleanup_affected_rows", + metricinstrument.WithDescription("Number of affected rows from cleaning up stale entities"), + ) + if err != nil { + logrus.WithError(err).Error("unable to get gotrue.gotrue_cleanup_rows counter metric") + } +} + +// Cleanup removes stale entities in the database. You can call it on each +// request or as a periodic background job. It does quick lockless updates or +// deletes, has an execution timeout and acquire timeout so that cleanups do +// not affect performance of other database jobs. Note that calling this does +// not clean up the whole database, but does a small piecemeal clean up each +// time when called. +func Cleanup(db *storage.Connection) (int, error) { + ctx, span := observability.Tracer("gotrue").Start(db.Context(), "database-cleanup") + defer span.End() + + affectedRows := 0 + defer span.SetAttributes(attribute.Int64("gotrue.cleanup.affected_rows", int64(affectedRows))) + + if err := db.WithContext(ctx).Transaction(func(tx *storage.Connection) error { + nextIndex := atomic.AddUint32(&cleanupNext, 1) % uint32(len(CleanupStatements)) + statement := CleanupStatements[nextIndex] + + count, terr := tx.RawQuery(statement).ExecWithCount() + if terr != nil { + return terr + } + + affectedRows += count + + return nil + }); err != nil { + return affectedRows, err + } + + if cleanupAffectedRows != nil { + cleanupAffectedRows.Observe(ctx, int64(affectedRows)) + } + + return affectedRows, nil +} diff --git a/internal/models/cleanup_test.go b/internal/models/cleanup_test.go new file mode 100644 index 000000000..c38cf72f3 --- /dev/null +++ b/internal/models/cleanup_test.go @@ -0,0 +1,38 @@ +package models + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/supabase/gotrue/internal/conf" + "github.com/supabase/gotrue/internal/storage/test" +) + +func TestCleanupSQL(t *testing.T) { + globalConfig, err := conf.LoadGlobal(modelsTestConfig) + require.NoError(t, err) + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + + for _, statement := range CleanupStatements { + _, err := conn.RawQuery(statement).ExecWithCount() + require.NoError(t, err, statement) + } +} + +func TestCleanup(t *testing.T) { + globalConfig, err := conf.LoadGlobal(modelsTestConfig) + require.NoError(t, err) + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + + for _, statement := range CleanupStatements { + _, err := Cleanup(conn) + if err != nil { + fmt.Printf("%v %t\n", err, err) + } + require.NoError(t, err, statement) + } +} diff --git a/internal/models/db_test.go b/internal/models/db_test.go index 6d41798cc..c3d6ab250 100644 --- a/internal/models/db_test.go +++ b/internal/models/db_test.go @@ -1,11 +1,10 @@ -package models_test +package models import ( "testing" "github.com/gobuffalo/pop/v6" "github.com/stretchr/testify/assert" - "github.com/supabase/gotrue/internal/models" ) func TestTableNameNamespacing(t *testing.T) { @@ -13,9 +12,9 @@ func TestTableNameNamespacing(t *testing.T) { expected string value interface{} }{ - {expected: "audit_log_entries", value: []*models.AuditLogEntry{}}, - {expected: "refresh_tokens", value: []*models.RefreshToken{}}, - {expected: "users", value: []*models.User{}}, + {expected: "audit_log_entries", value: []*AuditLogEntry{}}, + {expected: "refresh_tokens", value: []*RefreshToken{}}, + {expected: "users", value: []*User{}}, } for _, tc := range cases { diff --git a/internal/models/factor_test.go b/internal/models/factor_test.go index f725de5ee..d2e29749d 100644 --- a/internal/models/factor_test.go +++ b/internal/models/factor_test.go @@ -7,6 +7,7 @@ import ( "github.com/gofrs/uuid" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/supabase/gotrue/internal/conf" "github.com/supabase/gotrue/internal/storage" "github.com/supabase/gotrue/internal/storage/test" diff --git a/migrations/20230508135423_add_cleanup_indexes.up.sql b/migrations/20230508135423_add_cleanup_indexes.up.sql new file mode 100644 index 000000000..162acee15 --- /dev/null +++ b/migrations/20230508135423_add_cleanup_indexes.up.sql @@ -0,0 +1,17 @@ +-- Indexes used for cleaning up old or stale objects. + +create index if not exists + refresh_tokens_updated_at_idx + on {{ index .Options "Namespace" }}.refresh_tokens (updated_at desc); + +create index if not exists + flow_state_created_at_idx + on {{ index .Options "Namespace" }}.flow_state (created_at desc); + +create index if not exists + saml_relay_states_created_at_idx + on {{ index .Options "Namespace" }}.saml_relay_states (created_at desc); + +create index if not exists + sessions_not_after_idx + on {{ index .Options "Namespace" }}.sessions (not_after desc);