Skip to content

Commit

Permalink
feat: add database cleanup logic, runs after each request
Browse files Browse the repository at this point in the history
  • Loading branch information
hf committed May 8, 2023
1 parent 184fa38 commit 0729bf6
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 5 deletions.
4 changes: 4 additions & 0 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati
r.UseBypass(xffmw.Handler)
r.Use(recoverer)

if globalConfig.DB.CleanupEnabled {
r.UseBypass(api.databaseCleanup)
}

r.Get("/health", api.HealthCheck)

r.Route("/callback", func(r *router) {
Expand Down
25 changes: 25 additions & 0 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"
"time"

"github.com/supabase/gotrue/internal/models"
"github.com/supabase/gotrue/internal/observability"
"github.com/supabase/gotrue/internal/security"

Expand Down Expand Up @@ -186,3 +187,27 @@ func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (cont
}
return ctx, nil
}

func (a *API) databaseCleanup(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)

switch r.Method {
case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete:
// continue

default:
return
}

db := a.db.WithContext(r.Context())
log := observability.GetLogEntry(r)

affectedRows, err := models.Cleanup(db)
if err != nil {
log.WithError(err).WithField("affected_rows", affectedRows).Warn("database cleanup failed")
} else if affectedRows > 0 {
log.WithField("affected_rows", affectedRows).Debug("cleaned up expired or stale rows")
}
})
}
1 change: 1 addition & 0 deletions internal/conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type DBConfiguration struct {
ConnMaxIdleTime time.Duration `json:"conn_max_idle_time,omitempty" split_words:"true"`
HealthCheckPeriod time.Duration `json:"health_check_period" split_words:"true"`
MigrationsPath string `json:"migrations_path" split_words:"true" default:"./migrations"`
CleanupEnabled bool `json:"cleanup_enabled" split_words:"true" default:"false"`
}

func (c *DBConfiguration) Validate() error {
Expand Down
94 changes: 94 additions & 0 deletions internal/models/cleanup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package models

import (
"fmt"
"sync/atomic"

"github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
metricglobal "go.opentelemetry.io/otel/metric/global"
metricinstrument "go.opentelemetry.io/otel/metric/instrument"
otelasyncint64instrument "go.opentelemetry.io/otel/metric/instrument/asyncint64"

"github.com/supabase/gotrue/internal/observability"
"github.com/supabase/gotrue/internal/storage"
)

// cleanupNext holds an atomically incrementing value that determines which of
// the cleanupStatements will be run next.
var cleanupNext uint32

// cleanupStatements holds all of the possible cleanup raw SQL. Only one at a
// time is executed using cleanupNext % len(cleanupStatements).
var CleanupStatements []string

// cleanupAffectedRows tracks an OpenTelemetry metric on the total number of
// cleaned up rows.
var cleanupAffectedRows otelasyncint64instrument.Counter

func init() {
tableRefreshTokens := RefreshToken{}.TableName()
tableSessions := Session{}.TableName()
tableRelayStates := SAMLRelayState{}.TableName()
tableFlowStates := FlowState{}.TableName()

// These statements intentionally use SELECT ... FOR UPDATE SKIP LOCKED
// as this makes sure that only rows that are not being used in another
// transaction are deleted. These deletes are thus very quick and
// efficient, as they don't wait on other transactions.
CleanupStatements = append(CleanupStatements,
fmt.Sprintf("delete from %q where id in (select id from %q where revoked is true and updated_at < now() - interval '24 hours' limit 100 for update skip locked);", tableRefreshTokens, tableRefreshTokens),
fmt.Sprintf("update %q set revoked = true, updated_at = now() where id in (select %q.id from %q join %q on %q.session_id = %q.id where %q.not_after < now() - interval '24 hours' and %q.revoked is false limit 100 for update skip locked);", tableRefreshTokens, tableRefreshTokens, tableRefreshTokens, tableSessions, tableRefreshTokens, tableSessions, tableSessions, tableRefreshTokens),
// sessions are deleted after 72 hours to allow refresh tokens
// to be deleted piecemeal; 10 at once so that cascades don't
// overwork the database
fmt.Sprintf("delete from %q where id in (select id from %q where not_after < now() - interval '72 hours' limit 10 for update skip locked);", tableSessions, tableSessions),
fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '24 hours' limit 100 for update skip locked);", tableRelayStates, tableRelayStates),
fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '24 hours' limit 100 for update skip locked);", tableFlowStates, tableFlowStates),
)

var err error
cleanupAffectedRows, err = metricglobal.Meter("gotrue").AsyncInt64().Counter(
"gotrue_cleanup_affected_rows",
metricinstrument.WithDescription("Number of affected rows from cleaning up stale entities"),
)
if err != nil {
logrus.WithError(err).Error("unable to get gotrue.gotrue_cleanup_rows counter metric")
}
}

// Cleanup removes stale entities in the database. You can call it on each
// request or as a periodic background job. It does quick lockless updates or
// deletes, has an execution timeout and acquire timeout so that cleanups do
// not affect performance of other database jobs. Note that calling this does
// not clean up the whole database, but does a small piecemeal clean up each
// time when called.
func Cleanup(db *storage.Connection) (int, error) {
ctx, span := observability.Tracer("gotrue").Start(db.Context(), "database-cleanup")
defer span.End()

affectedRows := 0
defer span.SetAttributes(attribute.Int64("gotrue.cleanup.affected_rows", int64(affectedRows)))

if err := db.WithContext(ctx).Transaction(func(tx *storage.Connection) error {
nextIndex := atomic.AddUint32(&cleanupNext, 1) % uint32(len(CleanupStatements))
statement := CleanupStatements[nextIndex]

count, terr := tx.RawQuery(statement).ExecWithCount()
if terr != nil {
return terr
}

affectedRows += count

return nil
}); err != nil {
return affectedRows, err
}

if cleanupAffectedRows != nil {
cleanupAffectedRows.Observe(ctx, int64(affectedRows))
}

return affectedRows, nil
}
38 changes: 38 additions & 0 deletions internal/models/cleanup_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package models

import (
"fmt"
"testing"

"github.com/stretchr/testify/require"

"github.com/supabase/gotrue/internal/conf"
"github.com/supabase/gotrue/internal/storage/test"
)

func TestCleanupSQL(t *testing.T) {
globalConfig, err := conf.LoadGlobal(modelsTestConfig)
require.NoError(t, err)
conn, err := test.SetupDBConnection(globalConfig)
require.NoError(t, err)

for _, statement := range CleanupStatements {
_, err := conn.RawQuery(statement).ExecWithCount()
require.NoError(t, err, statement)
}
}

func TestCleanup(t *testing.T) {
globalConfig, err := conf.LoadGlobal(modelsTestConfig)
require.NoError(t, err)
conn, err := test.SetupDBConnection(globalConfig)
require.NoError(t, err)

for _, statement := range CleanupStatements {
_, err := Cleanup(conn)
if err != nil {
fmt.Printf("%v %t\n", err, err)
}
require.NoError(t, err, statement)
}
}
9 changes: 4 additions & 5 deletions internal/models/db_test.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
package models_test
package models

import (
"testing"

"github.com/gobuffalo/pop/v6"
"github.com/stretchr/testify/assert"
"github.com/supabase/gotrue/internal/models"
)

func TestTableNameNamespacing(t *testing.T) {
cases := []struct {
expected string
value interface{}
}{
{expected: "audit_log_entries", value: []*models.AuditLogEntry{}},
{expected: "refresh_tokens", value: []*models.RefreshToken{}},
{expected: "users", value: []*models.User{}},
{expected: "audit_log_entries", value: []*AuditLogEntry{}},
{expected: "refresh_tokens", value: []*RefreshToken{}},
{expected: "users", value: []*User{}},
}

for _, tc := range cases {
Expand Down
1 change: 1 addition & 0 deletions internal/models/factor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/gofrs/uuid"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"

"github.com/supabase/gotrue/internal/conf"
"github.com/supabase/gotrue/internal/storage"
"github.com/supabase/gotrue/internal/storage/test"
Expand Down
13 changes: 13 additions & 0 deletions migrations/20230508135423_add_cleanup_indexes.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
-- Indexes used for cleaning up old or stale objects.

create index if not exists
refresh_tokens_updated_at_idx
on {{ index .Options "Namespace" }}.refresh_tokens (updated_at desc);

create index if not exists
flow_state_created_at_idx
on {{ index .Options "Namespace" }}.flow_state (created_at desc);

create index if not exists
saml_relay_states_created_at_idx
on {{ index .Options "Namespace" }}.saml_relay_states (created_at desc);

0 comments on commit 0729bf6

Please sign in to comment.