Skip to content

Commit

Permalink
fix(oidc): grace period for continuity container on oidc callbacks (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed May 17, 2024
1 parent d9dbaad commit 1a9a096
Show file tree
Hide file tree
Showing 11 changed files with 241 additions and 21 deletions.
27 changes: 13 additions & 14 deletions continuity/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,18 @@ type Manager interface {
}

type managerOptions struct {
iid uuid.UUID
ttl time.Duration
payload json.RawMessage
payloadRaw interface{}
cleanUp bool
iid uuid.UUID
ttl time.Duration
setExpiresIn time.Duration
payload json.RawMessage
payloadRaw interface{}
}

type ManagerOption func(*managerOptions) error

func newManagerOptions(opts []ManagerOption) (*managerOptions, error) {
var o = &managerOptions{
ttl: time.Minute,
cleanUp: true,
ttl: time.Minute * 10,
}
for _, opt := range opts {
if err := opt(o); err != nil {
Expand All @@ -49,13 +48,6 @@ func newManagerOptions(opts []ManagerOption) (*managerOptions, error) {
return o, nil
}

func DontCleanUp() ManagerOption {
return func(o *managerOptions) error {
o.cleanUp = false
return nil
}
}

func WithIdentity(i *identity.Identity) ManagerOption {
return func(o *managerOptions) error {
if i != nil {
Expand Down Expand Up @@ -83,3 +75,10 @@ func WithPayload(payload interface{}) ManagerOption {
return nil
}
}

func WithExpireInsteadOfDelete(duration time.Duration) ManagerOption {
return func(o *managerOptions) error {
o.setExpiresIn = duration
return nil
}
}
24 changes: 19 additions & 5 deletions continuity/manager_cookie.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"encoding/json"
"net/http"
"time"

"github.com/gofrs/uuid"
"github.com/pkg/errors"
Expand Down Expand Up @@ -93,12 +94,22 @@ func (m *ManagerCookie) Continue(ctx context.Context, w http.ResponseWriter, r *
}
}

if err := x.SessionUnsetKey(w, r, m.d.ContinuityCookieManager(ctx), CookieName, name); err != nil {
return nil, err
}
if o.setExpiresIn > 0 {
if err := m.d.ContinuityPersister().SetContinuitySessionExpiry(
ctx,
container.ID,
time.Now().UTC().Add(o.setExpiresIn).Truncate(time.Second),
); err != nil && !errors.Is(err, sqlcon.ErrNoRows) {
return nil, err
}
} else {
if err := x.SessionUnsetKey(w, r, m.d.ContinuityCookieManager(ctx), CookieName, name); err != nil {
return nil, err
}

if err := m.d.ContinuityPersister().DeleteContinuitySession(ctx, container.ID); err != nil && !errors.Is(err, sqlcon.ErrNoRows) {
return nil, err
if err := m.d.ContinuityPersister().DeleteContinuitySession(ctx, container.ID); err != nil && !errors.Is(err, sqlcon.ErrNoRows) {
return nil, err
}
}

return container, nil
Expand Down Expand Up @@ -136,6 +147,9 @@ func (m *ManagerCookie) container(ctx context.Context, w http.ResponseWriter, r
return nil, errors.WithStack(ErrNotResumable.WithDebugf("Resumable ID from cookie could not be found in the datastore: %+v", err))
} else if err != nil {
return nil, err
} else if container.ExpiresAt.Before(time.Now()) {
_ = x.SessionUnsetKey(w, r, m.d.ContinuityCookieManager(ctx), CookieName, name)
return nil, errors.WithStack(ErrNotResumable.WithDebugf("Resumable session has expired"))
}

return container, err
Expand Down
2 changes: 1 addition & 1 deletion continuity/manager_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestManagerOptions(t *testing.T) {
}{
{
e: func(t *testing.T, actual *managerOptions) {
assert.EqualValues(t, time.Minute, actual.ttl)
assert.EqualValues(t, time.Minute*10, actual.ttl)
},
},
{
Expand Down
45 changes: 45 additions & 0 deletions continuity/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/ory/kratos/driver/config"

Expand Down Expand Up @@ -181,6 +182,50 @@ func TestManager(t *testing.T) {
assert.Contains(t, href, gjson.GetBytes(body, "name").String(), "%s", body)
})

t.Run("case=pause and use session with expiry", func(t *testing.T) {
cl := newClient()

tc := &persisterTestCase{
ro: []continuity.ManagerOption{continuity.WithPayload(&persisterTestPayload{"bar"}), continuity.WithExpireInsteadOfDelete(time.Minute)},
wo: []continuity.ManagerOption{continuity.WithPayload(&persisterTestPayload{}), continuity.WithExpireInsteadOfDelete(time.Minute)},
}
ts := newServer(t, p, tc)
genid := func() string {
return ts.URL + "/" + x.NewUUID().String()
}

href := genid()
res, err := cl.Do(testhelpers.NewTestHTTPRequest(t, "PUT", href, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Equal(t, http.StatusNoContent, res.StatusCode)

res, err = cl.Do(testhelpers.NewTestHTTPRequest(t, "GET", href, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Equal(t, http.StatusOK, res.StatusCode)

res, err = cl.Do(testhelpers.NewTestHTTPRequest(t, "GET", href, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Equal(t, http.StatusOK, res.StatusCode)

tc.ro = []continuity.ManagerOption{continuity.WithPayload(&persisterTestPayload{"bar"}), continuity.WithExpireInsteadOfDelete(-time.Minute)}
tc.wo = []continuity.ManagerOption{continuity.WithPayload(&persisterTestPayload{""}), continuity.WithExpireInsteadOfDelete(-time.Minute)}

res, err = cl.Do(testhelpers.NewTestHTTPRequest(t, "GET", href, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Equal(t, http.StatusOK, res.StatusCode)

res, err = cl.Do(testhelpers.NewTestHTTPRequest(t, "GET", href, nil))
require.NoError(t, err)
require.Equal(t, http.StatusBadRequest, res.StatusCode)
body := ioutilx.MustReadAll(res.Body)
require.NoError(t, res.Body.Close())
assert.Contains(t, gjson.GetBytes(body, "error.reason").String(), continuity.ErrNotResumable.ReasonField)
})

for k, tc := range []persisterTestCase{
{},
{
Expand Down
1 change: 1 addition & 0 deletions continuity/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ type Persister interface {
SaveContinuitySession(ctx context.Context, c *Container) error
GetContinuitySession(ctx context.Context, id uuid.UUID) (*Container, error)
DeleteContinuitySession(ctx context.Context, id uuid.UUID) error
SetContinuitySessionExpiry(ctx context.Context, id uuid.UUID, expiresAt time.Time) error
DeleteExpiredContinuitySessions(ctx context.Context, deleteOlder time.Time, pageSize int) error
}
24 changes: 24 additions & 0 deletions continuity/test/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,30 @@ func TestPersister(ctx context.Context, p interface {
})
})

t.Run("case=set expiry", func(t *testing.T) {
// Create a new continuity session
expected := createContainer(t)
require.NoError(t, p.SaveContinuitySession(ctx, &expected))

// Set the expiry of the continuity session
newExpiry := time.Now().Add(48 * time.Hour).UTC().Truncate(time.Second)
require.NoError(t, p.SetContinuitySessionExpiry(ctx, expected.ID, newExpiry))

// Retrieve the continuity session
actual, err := p.GetContinuitySession(ctx, expected.ID)
require.NoError(t, err)

// Check if the expiry has been updated
assert.EqualValues(t, newExpiry, actual.ExpiresAt)

t.Run("can not update on another network", func(t *testing.T) {
_, p := testhelpers.NewNetwork(t, ctx, p)
newExpiry := time.Now().Add(12 * time.Hour).UTC().Truncate(time.Second)
err := p.SetContinuitySessionExpiry(ctx, expected.ID, newExpiry)
require.ErrorIs(t, err, sqlcon.ErrNoRows)
})
})

t.Run("case=cleanup", func(t *testing.T) {
id := x.NewUUID()
yesterday := time.Now().Add(-24 * time.Hour).UTC().Truncate(time.Second)
Expand Down
1 change: 1 addition & 0 deletions internal/client-go/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
Expand Down
17 changes: 17 additions & 0 deletions persistence/sql/persister_continuity.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,23 @@ func (p *Persister) SaveContinuitySession(ctx context.Context, c *continuity.Con
return sqlcon.HandleError(p.GetConnection(ctx).Create(c))
}

func (p *Persister) SetContinuitySessionExpiry(ctx context.Context, id uuid.UUID, expiresAt time.Time) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.SetContinuitySessionExpiry")
defer otelx.End(span, &err)

if rows, err := p.GetConnection(ctx).
Where("id = ? AND nid = ?", id, p.NetworkID(ctx)).
UpdateQuery(&continuity.Container{
ExpiresAt: expiresAt,
}, "expires_at"); err != nil {
return sqlcon.HandleError(err)
} else if rows == 0 {
return errors.WithStack(sqlcon.ErrNoRows)
}

return nil
}

func (p *Persister) GetContinuitySession(ctx context.Context, id uuid.UUID) (_ *continuity.Container, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetContinuitySession")
defer otelx.End(span, &err)
Expand Down
7 changes: 6 additions & 1 deletion selfservice/strategy/oidc/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"net/url"
"path/filepath"
"strings"
"time"

"golang.org/x/exp/maps"

Expand Down Expand Up @@ -316,7 +317,10 @@ func (s *Strategy) ValidateCallback(w http.ResponseWriter, r *http.Request) (flo

cntnr := AuthCodeContainer{}
if f.GetType() == flow.TypeBrowser || !hasSessionTokenCode {
if _, err := s.d.ContinuityManager().Continue(r.Context(), w, r, sessionName, continuity.WithPayload(&cntnr)); err != nil {
if _, err := s.d.ContinuityManager().Continue(r.Context(), w, r, sessionName,
continuity.WithPayload(&cntnr),
continuity.WithExpireInsteadOfDelete(time.Minute),
); err != nil {
return nil, nil, err
}
if stateParam != cntnr.State {
Expand All @@ -334,6 +338,7 @@ func (s *Strategy) ValidateCallback(w http.ResponseWriter, r *http.Request) (flo
if errorParam != "" {
return f, &cntnr, errors.WithStack(herodot.ErrBadRequest.WithReasonf(`Unable to complete OpenID Connect flow because the OpenID Provider returned error "%s": %s`, r.URL.Query().Get("error"), r.URL.Query().Get("error_description")))
}

if codeParam == "" {
return f, &cntnr, errors.WithStack(herodot.ErrBadRequest.WithReasonf(`Unable to complete OpenID Connect flow because the OpenID Provider did not return the code query parameter.`))
}
Expand Down
102 changes: 102 additions & 0 deletions selfservice/strategy/oidc/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ import (
"testing"
"time"

"github.com/davecgh/go-spew/spew"
"github.com/samber/lo"

"github.com/ory/kratos/selfservice/hook/hooktest"
"github.com/ory/x/sqlxx"

Expand Down Expand Up @@ -495,6 +498,105 @@ func TestStrategy(t *testing.T) {

postLoginWebhook.AssertTransientPayload(t, transientPayload)
})

t.Run("case=should pass double submit", func(t *testing.T) {
// This test checks that the continuity manager uses a grace period to handle potential double-submit issues.
//
// It addresses issues where Facebook and Apple consent screens on mobile behave in a way that makes it
// easy for users to experience double-submit issues.
j, err := cookiejar.New(nil)
require.NoError(t, err)

makeInitialRequest := func(t *testing.T, provider, action string, fv url.Values) (*http.Response, []byte, []string) {
fv.Set("provider", provider)

var lastVia []*http.Request
hc := &http.Client{
Jar: j,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
lastVia = via
return nil
},
}
res, err := hc.PostForm(action, fv)
require.NoError(t, err, action)

body, err := io.ReadAll(res.Body)
require.NoError(t, res.Body.Close())
require.NoError(t, err)
require.NotEmpty(t, lastVia)

vias := make([]string, len(lastVia))
for k, v := range lastVia {
vias[k] = v.URL.String()
}

return res, body, vias
}

r := newBrowserLoginFlow(t, returnTS.URL, time.Minute)
action := assertFormValues(t, r.ID, "valid")

// First login
res, body, via := makeInitialRequest(t, "valid", action, url.Values{})
assertIdentity(t, res, body)
expectTokens(t, "valid", body)
assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.0.provider").String(), "%s", body)

// We fetch the URL which includes the `?code` query parameter.
result := lo.Filter(via, func(s string, _ int) bool {
return strings.Contains(s, "code=")
})
require.Len(t, result, 1)

// And call that URL again. What's interesting here is that the whole requets passes because we are already authenticated.
//
// In this scenario, Ory Kratos correctly forwards the user to the return URL, which in our case returns the identity.
//
// We essentially run into this bit:
//
// if authenticated, err := s.alreadyAuthenticated(w, r, req); err != nil {
// s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err))
// } else if authenticated {
// return <-- we end up here on the second call
// }
res, err = (&http.Client{Jar: j}).Get(result[0])
require.NoError(t, err)
body, err = io.ReadAll(res.Body)
require.NoError(t, err)
require.NoError(t, res.Body.Close())

assertIdentity(t, res, body)
expectTokens(t, "valid", body)
assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.0.provider").String(), "%s", body)

// Trying this flow again without the Ory Session cookie will fail as we run into code reuse:
cookies := j.Cookies(urlx.ParseOrPanic(ts.URL))
t.Logf("Cookies: %s", spew.Sdump(cookies))

secondJar, err := cookiejar.New(nil)
require.NoError(t, err)

secondJar.SetCookies(urlx.ParseOrPanic(ts.URL), lo.Filter(cookies, func(item *http.Cookie, index int) bool {
return item.Name != "ory_kratos_session"
}))

cookies = secondJar.Cookies(urlx.ParseOrPanic(ts.URL))
t.Logf("Cookies after: %s", spew.Sdump(cookies))

// Doing the request but this time without the Ory Session Cookie. This may be the case in scenarios where we run into race conditions
// where the server sent a response but the client did not process it.
res, err = (&http.Client{Jar: secondJar}).Get(result[0])
require.NoError(t, err)
body, err = io.ReadAll(res.Body)
require.NoError(t, err)
require.NoError(t, res.Body.Close())

// The reason for `invalid_client` here is that the code was already used and the session was already authenticated. The invalid_client
// happens because of the way Golang's OAuth2 library is trying out different auth methods when a token request fails, which obfuscates
// the underlying error.
assert.Contains(t, string(body), "invalid_client", "%s", body)
})
})

t.Run("case=login without registered account", func(t *testing.T) {
Expand Down
Loading

0 comments on commit 1a9a096

Please sign in to comment.