diff --git a/driver/config/provider.go b/driver/config/provider.go index a6abecaad0e..40a4a5206a8 100644 --- a/driver/config/provider.go +++ b/driver/config/provider.go @@ -102,6 +102,7 @@ const ( KeyExcludeNotBeforeClaim = "oauth2.exclude_not_before_claim" KeyAllowedTopLevelClaims = "oauth2.allowed_top_level_claims" KeyMirrorTopLevelClaims = "oauth2.mirror_top_level_claims" + KeyRefreshTokenRotationGracePeriod = "oauth2.grant.refresh_token.rotation_grace_period" // #nosec G101 KeyOAuth2GrantJWTIDOptional = "oauth2.grant.jwt.jti_optional" KeyOAuth2GrantJWTIssuedDateOptional = "oauth2.grant.jwt.iat_optional" KeyOAuth2GrantJWTMaxDuration = "oauth2.grant.jwt.max_ttl" @@ -669,3 +670,11 @@ func (p *DefaultProvider) cookieSuffix(ctx context.Context, key string) string { return p.getProvider(ctx).String(key) + suffix } + +func (p *DefaultProvider) RefreshTokenRotationGracePeriod(ctx context.Context) time.Duration { + var duration = p.getProvider(ctx).DurationF(KeyRefreshTokenRotationGracePeriod, 0) + if duration > time.Hour { + return time.Hour + } + return p.getProvider(ctx).DurationF(KeyRefreshTokenRotationGracePeriod, 0) +} diff --git a/driver/config/provider_test.go b/driver/config/provider_test.go index 8e5c44a9e2e..168ca81d69f 100644 --- a/driver/config/provider_test.go +++ b/driver/config/provider_test.go @@ -291,6 +291,13 @@ func TestViperProviderValidates(t *testing.T) { assert.Equal(t, "random_salt", c.SubjectIdentifierAlgorithmSalt(ctx)) assert.Equal(t, []string{"whatever"}, c.DefaultClientScope(ctx)) + // refresh + assert.Equal(t, time.Duration(0), c.RefreshTokenRotationGracePeriod(ctx)) + require.NoError(t, c.Set(ctx, KeyRefreshTokenRotationGracePeriod, "1s")) + assert.Equal(t, time.Second, c.RefreshTokenRotationGracePeriod(ctx)) + require.NoError(t, c.Set(ctx, KeyRefreshTokenRotationGracePeriod, "2h")) + assert.Equal(t, time.Hour, c.RefreshTokenRotationGracePeriod(ctx)) + // urls assert.Equal(t, urlx.ParseOrPanic("https://issuer"), c.IssuerURL(ctx)) assert.Equal(t, urlx.ParseOrPanic("https://public/"), c.PublicURL(ctx)) diff --git a/internal/config/config.yaml b/internal/config/config.yaml index f3e8bff399c..ad3e1ba74d6 100644 --- a/internal/config/config.yaml +++ b/internal/config/config.yaml @@ -402,6 +402,19 @@ oauth2: session: # store encrypted data in database, default true encrypt_at_rest: true + ## refresh_token_rotation + # By default Refresh Tokens are rotated and invalidated with each use. See https://datatracker.ietf.org/doc/html/draft-ietf-oauth-security-topics#section-4.13.2 for more details + refresh_token_rotation: + # + ## grace_period + # + # Set the grace period for a refresh token to allow it to be used for the duration of this configuration after its + # first use. New refresh tokens will continue to be issued. + # + # Examples: + # - 5s + # - 1m + grace_period: 0s # The secrets section configures secrets used for encryption and signing of several systems. All secrets can be rotated, # for more information on this topic navigate to: diff --git a/oauth2/fosite_store_helpers.go b/oauth2/fosite_store_helpers.go index 6b64d61991e..fb7162650c8 100644 --- a/oauth2/fosite_store_helpers.go +++ b/oauth2/fosite_store_helpers.go @@ -225,6 +225,7 @@ func TestHelperRunner(t *testing.T, store InternalRegistry, k string) { t.Run(fmt.Sprintf("case=testHelperDeleteAccessTokens/db=%s", k), testHelperDeleteAccessTokens(store)) t.Run(fmt.Sprintf("case=testHelperRevokeAccessToken/db=%s", k), testHelperRevokeAccessToken(store)) t.Run(fmt.Sprintf("case=testFositeJWTBearerGrantStorage/db=%s", k), testFositeJWTBearerGrantStorage(store)) + t.Run(fmt.Sprintf("case=testHelperRevokeRefreshTokenMaybeGracePeriod/db=%s", k), testHelperRevokeRefreshTokenMaybeGracePeriod(store)) } func testHelperRequestIDMultiples(m InternalRegistry, _ string) func(t *testing.T) { @@ -553,6 +554,68 @@ func testHelperRevokeAccessToken(x InternalRegistry) func(t *testing.T) { } } +func testHelperRevokeRefreshTokenMaybeGracePeriod(x InternalRegistry) func(t *testing.T) { + + return func(t *testing.T) { + t.Run("Revokes refresh token when grace period not configured", func(t *testing.T) { + // SETUP + m := x.OAuth2Storage() + ctx := context.Background() + + refreshTokenSession := fmt.Sprintf("refresh_token_%d", time.Now().Unix()) + err := m.CreateRefreshTokenSession(ctx, refreshTokenSession, &defaultRequest) + assert.NoError(t, err, "precondition failed: could not create refresh token session") + + // ACT + err = m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession) + assert.NoError(t, err) + + tmpSession := new(fosite.Session) + _, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, *tmpSession) + + // ASSERT + // a revoked refresh token returns an error when getting the token again + assert.Error(t, err) + assert.True(t, errors.Is(err, fosite.ErrInactiveToken)) + }) + + t.Run("refresh token enters grace period when configured,", func(t *testing.T) { + ctx := context.Background() + + // SETUP + x.Config().MustSet(ctx, "oauth2.refresh_token_rotation.grace_period", "1m") + + // always reset back to the default + t.Cleanup(func() { + x.Config().MustSet(ctx, "oauth2.refresh_token_rotation.grace_period", "0m") + }) + + m := x.OAuth2Storage() + + refreshTokenSession := fmt.Sprintf("refresh_token_%d_with_grace_period", time.Now().Unix()) + err := m.CreateRefreshTokenSession(ctx, refreshTokenSession, &defaultRequest) + assert.NoError(t, err, "precondition failed: could not create refresh token session") + + // ACT + assert.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession)) + assert.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession)) + assert.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession)) + + tmpSession := new(fosite.Session) + req, err := m.GetRefreshTokenSession(ctx, refreshTokenSession, *tmpSession) + assert.NoError(t, err) + + // ASSERT + // when grace period is configured the refresh token can be obtained within + // the grace period without error + assert.NoError(t, err) + + assert.Equal(t, defaultRequest.GetID(), req.GetID()) + }) + } + +} + func testHelperCreateGetDeletePKCERequestSession(x InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index aa6934062ed..971dd2a65a7 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -176,8 +176,9 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { } assertRefreshToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedExp time.Time) { - actualExp, err := strconv.ParseInt(testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS).Get("exp").String(), 10, 64) - require.NoError(t, err) + introspect := testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS) + actualExp, err := strconv.ParseInt(introspect.Get("exp").String(), 10, 64) + require.NoError(t, err, "%s", introspect) requirex.EqualTime(t, expectedExp, time.Unix(actualExp, 0), time.Second) } @@ -330,6 +331,150 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { }) }) + t.Run("case=graceful token rotation", func(t *testing.T) { + run := func(t *testing.T, strategy string) { + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "5s") + t.Cleanup(func() { + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, nil) + }) + + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, subject, nil), + acceptConsentHandler(t, c, subject, nil), + ) + + issueTokens := func(t *testing.T) *oauth2.Token { + code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + iat := time.Now() + require.NoError(t, err) + + introspectAccessToken(t, conf, token, subject) + assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + assertIDToken(t, token, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, token, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) + return token + } + + refreshTokens := func(t *testing.T, token *oauth2.Token) *oauth2.Token { + require.NotEmpty(t, token.RefreshToken) + token.Expiry = token.Expiry.Add(-time.Hour * 24) + iat := time.Now() + refreshedToken, err := conf.TokenSource(context.Background(), token).Token() + require.NoError(t, err) + + require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) + require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken) + require.NotEqual(t, token.Extra("id_token"), refreshedToken.Extra("id_token")) + + introspectAccessToken(t, conf, refreshedToken, subject) + assertJWTAccessToken(t, strategy, conf, refreshedToken, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + assertIDToken(t, refreshedToken, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, refreshedToken, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) + return refreshedToken + } + + t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { + start := time.Now() + + token := issueTokens(t) + var first, second *oauth2.Token + t.Run("followup=first refresh", func(t *testing.T) { + first = refreshTokens(t, token) + }) + + t.Run("followup=second refresh", func(t *testing.T) { + second = refreshTokens(t, token) + }) + + // Sleep until the grace period is over + time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) + t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { + _, err := conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err) + + i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + }) + }) + + t.Run("followup=graceful refresh tokens are all refreshed", func(t *testing.T) { + start := time.Now() + token := issueTokens(t) + var a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB *oauth2.Token + t.Run("followup=first refresh", func(t *testing.T) { + a1Refresh = refreshTokens(t, token) + }) + + t.Run("followup=second refresh", func(t *testing.T) { + b1Refresh = refreshTokens(t, token) + }) + + t.Run("followup=first refresh from first refresh", func(t *testing.T) { + a2RefreshA = refreshTokens(t, a1Refresh) + }) + + t.Run("followup=second refresh from first refresh", func(t *testing.T) { + a2RefreshB = refreshTokens(t, a1Refresh) + }) + + t.Run("followup=first refresh from second refresh", func(t *testing.T) { + b2RefreshA = refreshTokens(t, b1Refresh) + }) + + t.Run("followup=second refresh from second refresh", func(t *testing.T) { + b2RefreshB = refreshTokens(t, b1Refresh) + }) + + // Sleep until the grace period is over + time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) + t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { + _, err := conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err) + + for k, token := range []*oauth2.Token{ + a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB, + } { + t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + }) + } + }) + }) + } + + t.Run("strategy=jwt", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") + run(t, "jwt") + }) + + t.Run("strategy=opaque", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + run(t, "opaque") + }) + }) + t.Run("case=perform authorize code flow with verifable credentials", func(t *testing.T) { // Make sure we test against all crypto suites that we advertise. cfg, _, err := publicClient.OidcAPI.DiscoverOidcConfiguration(ctx).Execute() diff --git a/persistence/sql/migratest/migration_test.go b/persistence/sql/migratest/migration_test.go index 1fa0ce3836d..8564cfab969 100644 --- a/persistence/sql/migratest/migration_test.go +++ b/persistence/sql/migratest/migration_test.go @@ -144,7 +144,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_authentication_session", func(t *testing.T) { ss := []flow.LoginSession{} - c.All(&ss) + require.NoError(t, c.All(&ss)) require.Equal(t, 17, len(ss)) for _, s := range ss { @@ -157,7 +157,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_obfuscated_authentication_session", func(t *testing.T) { ss := []consent.ForcedObfuscatedLoginSession{} - c.All(&ss) + require.NoError(t, c.All(&ss)) require.Equal(t, 13, len(ss)) for _, s := range ss { @@ -169,7 +169,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_logout_request", func(t *testing.T) { lrs := []flow.LogoutRequest{} - c.All(&lrs) + require.NoError(t, c.All(&lrs)) require.Equal(t, 7, len(lrs)) for _, s := range lrs { @@ -182,7 +182,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_jti_blacklist", func(t *testing.T) { bjtis := []oauth2.BlacklistedJTI{} - c.All(&bjtis) + require.NoError(t, c.All(&bjtis)) require.Equal(t, 1, len(bjtis)) for _, bjti := range bjtis { testhelpersuuid.AssertUUID(t, bjti.NID) @@ -194,7 +194,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_access", func(t *testing.T) { as := []sql.OAuth2RequestSQL{} - c.RawQuery("SELECT * FROM hydra_oauth2_access").All(&as) + require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_access").All(&as)) require.Equal(t, 13, len(as)) for _, a := range as { @@ -210,7 +210,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_refresh", func(t *testing.T) { rs := []sql.OAuth2RequestSQL{} - c.RawQuery("SELECT * FROM hydra_oauth2_refresh").All(&rs) + require.NoError(t, c.RawQuery(`SELECT signature, nid, request_id, challenge_id, requested_at, client_id, scope, granted_scope, requested_audience, granted_audience, form_data, subject, active, session_data, expires_at FROM hydra_oauth2_refresh`).All(&rs)) require.Equal(t, 13, len(rs)) for _, r := range rs { @@ -226,7 +226,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_code", func(t *testing.T) { cs := []sql.OAuth2RequestSQL{} - c.RawQuery("SELECT * FROM hydra_oauth2_code").All(&cs) + require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_code").All(&cs)) require.Equal(t, 13, len(cs)) for _, c := range cs { @@ -242,7 +242,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_oidc", func(t *testing.T) { os := []sql.OAuth2RequestSQL{} - c.RawQuery("SELECT * FROM hydra_oauth2_oidc").All(&os) + require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_oidc").All(&os)) require.Equal(t, 13, len(os)) for _, o := range os { @@ -258,7 +258,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_pkce", func(t *testing.T) { ps := []sql.OAuth2RequestSQL{} - c.RawQuery("SELECT * FROM hydra_oauth2_pkce").All(&ps) + require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_pkce").All(&ps)) require.Equal(t, 11, len(ps)) for _, p := range ps { @@ -274,7 +274,7 @@ func TestMigrations(t *testing.T) { t.Run("case=networks", func(t *testing.T) { ns := []networkx.Network{} - c.RawQuery("SELECT * FROM networks").All(&ns) + require.NoError(t, c.RawQuery("SELECT * FROM networks").All(&ns)) require.Equal(t, 1, len(ns)) for _, n := range ns { testhelpersuuid.AssertUUID(t, n.ID) diff --git a/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.down.sql b/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.down.sql new file mode 100644 index 00000000000..a30a127e902 --- /dev/null +++ b/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.down.sql @@ -0,0 +1 @@ +ALTER TABLE hydra_oauth2_refresh DROP COLUMN first_used_at; diff --git a/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.up.sql b/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.up.sql new file mode 100644 index 00000000000..8ae823047f7 --- /dev/null +++ b/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.up.sql @@ -0,0 +1 @@ +ALTER TABLE hydra_oauth2_refresh ADD first_used_at TIMESTAMP DEFAULT NULL; diff --git a/persistence/sql/persister.go b/persistence/sql/persister.go index 93649fc46ef..98161c55cf6 100644 --- a/persistence/sql/persister.go +++ b/persistence/sql/persister.go @@ -60,6 +60,7 @@ type ( contextx.Provider x.RegistryLogger x.TracingProvider + config.Provider } ) diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index 6e1336b80de..0adce30cfb6 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -58,6 +58,10 @@ type ( // InternalExpiresAt denormalizes the expiry from the session to additionally store it as a row. InternalExpiresAt sqlxx.NullTime `db:"expires_at" json:"-"` } + OAuth2RefreshTable struct { + OAuth2RequestSQL + FirstUsedAt sql.NullTime `db:"first_used_at"` + } ) const ( @@ -72,6 +76,10 @@ func (r OAuth2RequestSQL) TableName() string { return "hydra_oauth2_" + string(r.Table) } +func (r OAuth2RefreshTable) TableName() string { + return "hydra_oauth2_refresh" +} + func (p *Persister) sqlSchemaFromRequest(ctx context.Context, signature string, r fosite.Requester, table tableName, expiresAt time.Time) (*OAuth2RequestSQL, error) { subject := "" if r.GetSession() == nil { @@ -122,6 +130,24 @@ func (p *Persister) sqlSchemaFromRequest(ctx context.Context, signature string, }, nil } +func (p *Persister) marshalSession(ctx context.Context, session fosite.Session) ([]byte, error) { + sessionBytes, err := json.Marshal(session) + if err != nil { + return nil, err + } + + if !p.config.EncryptSessionData(ctx) { + return sessionBytes, nil + } + + ciphertext, err := p.r.KeyCipher().Encrypt(ctx, sessionBytes, nil) + if err != nil { + return nil, err + } + + return []byte(ciphertext), nil +} + func (r *OAuth2RequestSQL) toRequest(ctx context.Context, session fosite.Session, p *Persister) (_ *fosite.Request, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.toRequest") defer otelx.End(span, &err) @@ -429,7 +455,34 @@ func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature str func (p *Persister) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetRefreshTokenSession") defer otelx.End(span, &err) - return p.findSessionBySignature(ctx, signature, session, sqlTableRefresh) + + r := OAuth2RefreshTable{OAuth2RequestSQL: OAuth2RequestSQL{Table: sqlTableRefresh}} + err = p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r) + if errors.Is(err, sql.ErrNoRows) { + return nil, errorsx.WithStack(fosite.ErrNotFound) + } else if err != nil { + return nil, sqlcon.HandleError(err) + } + + fositeRequest, err := r.toRequest(ctx, session, p) + if err != nil { + return nil, err + } + + if r.Active { + return fositeRequest, nil + } + + if gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx); gracePeriod > 0 && r.FirstUsedAt.Valid { + if r.FirstUsedAt.Time.Add(gracePeriod).Before(time.Now()) { + return fositeRequest, errorsx.WithStack(fosite.ErrInactiveToken) + } + + r.Active = true // We set active to true because we are in the grace period. + return r.toRequest(ctx, session, p) // And re-generate the request + } + + return fositeRequest, errorsx.WithStack(fosite.ErrInactiveToken) } func (p *Persister) DeleteRefreshTokenSession(ctx context.Context, signature string) (err error) { @@ -483,10 +536,25 @@ func (p *Persister) RevokeRefreshToken(ctx context.Context, id string) (err erro return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) } -func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, _ string) (err error) { +func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, signature string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRefreshTokenMaybeGracePeriod") defer otelx.End(span, &err) - return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) + + gracePeriod := p.config.RefreshTokenRotationGracePeriod(ctx) + if gracePeriod <= 0 { + return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) + } + + /* #nosec G201 table is static */ + return sqlcon.HandleError( + p.Connection(ctx). + RawQuery( + fmt.Sprintf("UPDATE %s SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active=true", OAuth2RequestSQL{Table: sqlTableRefresh}.TableName()), + id, + p.NetworkID(ctx), + ). + Exec(), + ) } func (p *Persister) RevokeAccessToken(ctx context.Context, id string) (err error) { diff --git a/spec/config.json b/spec/config.json index 2445cbc6a24..72f81534c66 100644 --- a/spec/config.json +++ b/spec/config.json @@ -1068,6 +1068,21 @@ "type": "object", "additionalProperties": false, "properties": { + "refresh_token": { + "type": "object", + "properties": { + "grace_period": { + "title": "Refresh Token Rotation Grace Period", + "description": "Configures how long a Refresh Token remains valid after it has been used. The maximum value is one hour.", + "default": "0s", + "allOf": [ + { + "$ref": "#/definitions/duration" + } + ] + } + } + }, "jwt": { "type": "object", "additionalProperties": false, @@ -1122,8 +1137,8 @@ } ] } - } - }, + } + }, "secrets": { "type": "object", "additionalProperties": false, diff --git a/x/fosite_storer.go b/x/fosite_storer.go index 23654c519b9..546cfc98870 100644 --- a/x/fosite_storer.go +++ b/x/fosite_storer.go @@ -18,16 +18,13 @@ import ( type FositeStorer interface { fosite.Storage oauth2.CoreStorage + oauth2.TokenRevocationStorage openid.OpenIDConnectRequestStorage pkce.PKCERequestStorage rfc7523.RFC7523KeyStorage verifiable.NonceManager oauth2.ResourceOwnerPasswordCredentialsGrantStorage - RevokeRefreshToken(ctx context.Context, requestID string) error - - RevokeAccessToken(ctx context.Context, requestID string) error - // flush the access token requests from the database. // no data will be deleted after the 'notAfter' timeframe. FlushInactiveAccessTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) error