Skip to content

Commit

Permalink
fix: load return_to and append to errors (#2333)
Browse files Browse the repository at this point in the history
Closes #2275
Closes #2279
Closes #2285

Co-authored-by: aeneasr <3372410+aeneasr@users.noreply.github.com>
  • Loading branch information
jacoblehr and aeneasr committed Mar 28, 2022
1 parent d942c5d commit 5efe4a3
Show file tree
Hide file tree
Showing 16 changed files with 169 additions and 25 deletions.
12 changes: 12 additions & 0 deletions selfservice/flow/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package flow
import (
"fmt"
"net/http"
"net/url"
"time"

"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/x"
"github.com/ory/x/urlx"

"github.com/gofrs/uuid"
"github.com/pkg/errors"
Expand Down Expand Up @@ -86,3 +89,12 @@ func NewBrowserLocationChangeRequiredError(redirectTo string) *BrowserLocationCh
},
}
}

func GetFlowExpiredRedirectURL(config *config.Config, route, returnTo string) *url.URL {
redirectURL := urlx.AppendPaths(config.SelfPublicURL(), route)
if returnTo != "" {
redirectURL = urlx.CopyWithQuery(redirectURL, url.Values{"return_to": {returnTo}})
}

return redirectURL
}
18 changes: 17 additions & 1 deletion selfservice/flow/login/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"strings"
"time"

"github.com/gobuffalo/pop/v6"

"github.com/tidwall/gjson"

"github.com/ory/x/sqlxx"
Expand Down Expand Up @@ -185,8 +187,22 @@ func (f *Flow) EnsureInternalContext() {

func (f Flow) MarshalJSON() ([]byte, error) {
type local Flow
f.SetReturnTo()
return json.Marshal(local(f))
}

func (f *Flow) SetReturnTo() {
if u, err := url.Parse(f.RequestURL); err == nil {
f.ReturnTo = u.Query().Get("return_to")
}
return json.Marshal(local(f))
}

func (f *Flow) AfterFind(*pop.Connection) error {
f.SetReturnTo()
return nil
}

func (f *Flow) AfterSave(*pop.Connection) error {
f.SetReturnTo()
return nil
}
5 changes: 4 additions & 1 deletion selfservice/flow/login/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,9 +425,12 @@ func (h *Handler) fetchFlow(w http.ResponseWriter, r *http.Request, _ httprouter

if ar.ExpiresAt.Before(time.Now()) {
if ar.Type == flow.TypeBrowser {
redirectURL := flow.GetFlowExpiredRedirectURL(h.d.Config(r.Context()), RouteInitBrowserFlow, ar.ReturnTo)

h.d.Writer().WriteError(w, r, errors.WithStack(x.ErrGone.WithID(text.ErrIDSelfServiceFlowExpired).
WithReason("The login flow has expired. Redirect the user to the login flow init endpoint to initialize a new login flow.").
WithDetail("redirect_to", urlx.AppendPaths(h.d.Config(r.Context()).SelfPublicURL(), RouteInitBrowserFlow).String())))
WithDetail("redirect_to", redirectURL.String()).
WithDetail("return_to", ar.ReturnTo)))
return
}
h.d.Writer().WriteError(w, r, errors.WithStack(x.ErrGone.WithID(text.ErrIDSelfServiceFlowExpired).
Expand Down
13 changes: 10 additions & 3 deletions selfservice/flow/login/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package login_test
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
Expand Down Expand Up @@ -553,18 +554,24 @@ func TestGetFlow(t *testing.T) {
})

t.Run("case=expired with return_to", func(t *testing.T) {
conf.MustSet(config.ViperKeyURLsAllowedReturnToDomains, []string{"https://www.ory.sh/"})
returnTo := "https://www.ory.sh"
conf.MustSet(config.ViperKeyURLsAllowedReturnToDomains, []string{returnTo})

client := testhelpers.NewClientWithCookies(t)
setupLoginUI(t, client)
body := x.EasyGetBody(t, client, public.URL+login.RouteInitBrowserFlow+"?return_to=https://www.ory.sh")
body := x.EasyGetBody(t, client, public.URL+login.RouteInitBrowserFlow+"?return_to="+returnTo)

// Expire the flow
f, err := reg.LoginFlowPersister().GetLoginFlow(context.Background(), uuid.FromStringOrNil(gjson.GetBytes(body, "id").String()))
require.NoError(t, err)
f.ExpiresAt = time.Now().Add(-time.Second)
require.NoError(t, reg.LoginFlowPersister().UpdateLoginFlow(context.Background(), f))

// Retrieve the flow and verify that return_to is in the response
getURL := fmt.Sprintf("%s%s?id=%s&return_to=%s", public.URL, login.RouteGetFlow, f.ID, returnTo)
getBody := x.EasyGetBody(t, client, getURL)
assert.Equal(t, gjson.GetBytes(getBody, "error.details.return_to").String(), returnTo)

// submit the flow but it is expired
u := public.URL + login.RouteSubmitFlow + "?flow=" + f.ID.String()
res, err := client.PostForm(u, url.Values{"password_identifier": {"email@ory.sh"}, "csrf_token": {f.CSRFToken}, "password": {"password"}, "method": {"password"}})
Expand All @@ -574,7 +581,7 @@ func TestGetFlow(t *testing.T) {

f, err = reg.LoginFlowPersister().GetLoginFlow(context.Background(), uuid.FromStringOrNil(gjson.GetBytes(resBody, "id").String()))
require.NoError(t, err)
assert.Equal(t, public.URL+login.RouteInitBrowserFlow+"?return_to=https://www.ory.sh", f.RequestURL)
assert.Equal(t, public.URL+login.RouteInitBrowserFlow+"?return_to="+returnTo, f.RequestURL)
})

t.Run("case=not found", func(t *testing.T) {
Expand Down
18 changes: 17 additions & 1 deletion selfservice/flow/recovery/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"net/url"
"time"

"github.com/gobuffalo/pop/v6"

"github.com/gofrs/uuid"
"github.com/pkg/errors"

Expand Down Expand Up @@ -182,8 +184,22 @@ func (f *Flow) SetCSRFToken(token string) {

func (f Flow) MarshalJSON() ([]byte, error) {
type local Flow
f.SetReturnTo()
return json.Marshal(local(f))
}

func (f *Flow) SetReturnTo() {
if u, err := url.Parse(f.RequestURL); err == nil {
f.ReturnTo = u.Query().Get("return_to")
}
return json.Marshal(local(f))
}

func (f *Flow) AfterFind(*pop.Connection) error {
f.SetReturnTo()
return nil
}

func (f *Flow) AfterSave(*pop.Connection) error {
f.SetReturnTo()
return nil
}
5 changes: 4 additions & 1 deletion selfservice/flow/recovery/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,12 @@ func (h *Handler) fetch(w http.ResponseWriter, r *http.Request, _ httprouter.Par

if f.ExpiresAt.Before(time.Now().UTC()) {
if f.Type == flow.TypeBrowser {
redirectURL := flow.GetFlowExpiredRedirectURL(h.d.Config(r.Context()), RouteInitBrowserFlow, f.ReturnTo)

h.d.Writer().WriteError(w, r, errors.WithStack(x.ErrGone.
WithReason("The recovery flow has expired. Redirect the user to the recovery flow init endpoint to initialize a new recovery flow.").
WithDetail("redirect_to", urlx.AppendPaths(h.d.Config(r.Context()).SelfPublicURL(), RouteInitBrowserFlow).String())))
WithDetail("redirect_to", redirectURL.String()).
WithDetail("return_to", f.ReturnTo)))
return
}
h.d.Writer().WriteError(w, r, errors.WithStack(x.ErrGone.
Expand Down
13 changes: 10 additions & 3 deletions selfservice/flow/recovery/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package recovery_test
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -246,17 +247,23 @@ func TestGetFlow(t *testing.T) {
})

t.Run("case=expired with return_to", func(t *testing.T) {
conf.MustSet(config.ViperKeyURLsAllowedReturnToDomains, []string{"https://www.ory.sh/"})
returnTo := "https://www.ory.sh"
conf.MustSet(config.ViperKeyURLsAllowedReturnToDomains, []string{returnTo})
client := testhelpers.NewClientWithCookies(t)
setupRecoveryTS(t, client)
body := x.EasyGetBody(t, client, public.URL+recovery.RouteInitBrowserFlow+"?return_to=https://www.ory.sh")
body := x.EasyGetBody(t, client, public.URL+recovery.RouteInitBrowserFlow+"?return_to="+returnTo)

// Expire the flow
f, err := reg.RecoveryFlowPersister().GetRecoveryFlow(context.Background(), uuid.FromStringOrNil(gjson.GetBytes(body, "id").String()))
require.NoError(t, err)
f.ExpiresAt = time.Now().Add(-time.Second)
require.NoError(t, reg.RecoveryFlowPersister().UpdateRecoveryFlow(context.Background(), f))

// Retrieve the flow and verify that return_to is in the response
getURL := fmt.Sprintf("%s%s?id=%s&return_to=%s", public.URL, recovery.RouteGetFlow, f.ID, returnTo)
getBody := x.EasyGetBody(t, client, getURL)
assert.Equal(t, gjson.GetBytes(getBody, "error.details.return_to").String(), returnTo)

// submit the flow but it is expired
u := public.URL + recovery.RouteSubmitFlow + "?flow=" + f.ID.String()
res, err := client.PostForm(u, url.Values{"email": {"email@ory.sh"}, "csrf_token": {f.CSRFToken}, "method": {"link"}})
Expand All @@ -266,7 +273,7 @@ func TestGetFlow(t *testing.T) {

f, err = reg.RecoveryFlowPersister().GetRecoveryFlow(context.Background(), uuid.FromStringOrNil(gjson.GetBytes(resBody, "id").String()))
require.NoError(t, err)
assert.Equal(t, public.URL+recovery.RouteInitBrowserFlow+"?return_to=https://www.ory.sh", f.RequestURL)
assert.Equal(t, public.URL+recovery.RouteInitBrowserFlow+"?return_to="+returnTo, f.RequestURL)
})

t.Run("case=not found", func(t *testing.T) {
Expand Down
18 changes: 17 additions & 1 deletion selfservice/flow/registration/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"net/url"
"time"

"github.com/gobuffalo/pop/v6"

"github.com/tidwall/gjson"

"github.com/ory/x/sqlxx"
Expand Down Expand Up @@ -150,8 +152,22 @@ func (f *Flow) EnsureInternalContext() {

func (f Flow) MarshalJSON() ([]byte, error) {
type local Flow
f.SetReturnTo()
return json.Marshal(local(f))
}

func (f *Flow) SetReturnTo() {
if u, err := url.Parse(f.RequestURL); err == nil {
f.ReturnTo = u.Query().Get("return_to")
}
return json.Marshal(local(f))
}

func (f *Flow) AfterFind(*pop.Connection) error {
f.SetReturnTo()
return nil
}

func (f *Flow) AfterSave(*pop.Connection) error {
f.SetReturnTo()
return nil
}
5 changes: 4 additions & 1 deletion selfservice/flow/registration/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,12 @@ func (h *Handler) fetchFlow(w http.ResponseWriter, r *http.Request, ps httproute

if ar.ExpiresAt.Before(time.Now()) {
if ar.Type == flow.TypeBrowser {
redirectURL := flow.GetFlowExpiredRedirectURL(h.d.Config(r.Context()), RouteInitBrowserFlow, ar.ReturnTo)

h.d.Writer().WriteError(w, r, errors.WithStack(x.ErrGone.WithID(text.ErrIDSelfServiceFlowExpired).
WithReason("The registration flow has expired. Redirect the user to the registration flow init endpoint to initialize a new registration flow.").
WithDetail("redirect_to", urlx.AppendPaths(h.d.Config(r.Context()).SelfPublicURL(), RouteInitBrowserFlow).String())))
WithDetail("redirect_to", redirectURL.String()).
WithDetail("return_to", ar.ReturnTo)))
return
}
h.d.Writer().WriteError(w, r, errors.WithStack(x.ErrGone.WithID(text.ErrIDSelfServiceFlowExpired).
Expand Down
14 changes: 11 additions & 3 deletions selfservice/flow/registration/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package registration_test
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -303,17 +304,24 @@ func TestGetFlow(t *testing.T) {
})

t.Run("case=expired with return_to", func(t *testing.T) {
conf.MustSet(config.ViperKeyURLsAllowedReturnToDomains, []string{"https://www.ory.sh/"})
returnTo := "https://www.ory.sh"
conf.MustSet(config.ViperKeyURLsAllowedReturnToDomains, []string{returnTo})

client := testhelpers.NewClientWithCookies(t)
setupRegistrationUI(t, client)
body := x.EasyGetBody(t, client, public.URL+registration.RouteInitBrowserFlow+"?return_to=https://www.ory.sh")
body := x.EasyGetBody(t, client, public.URL+registration.RouteInitBrowserFlow+"?return_to="+returnTo)

// Expire the flow
f, err := reg.RegistrationFlowPersister().GetRegistrationFlow(context.Background(), uuid.FromStringOrNil(gjson.GetBytes(body, "id").String()))
require.NoError(t, err)
f.ExpiresAt = time.Now().Add(-time.Second)
require.NoError(t, reg.RegistrationFlowPersister().UpdateRegistrationFlow(context.Background(), f))

// Retrieve the flow and verify that return_to is in the response
getURL := fmt.Sprintf("%s%s?id=%s&return_to=%s", public.URL, registration.RouteGetFlow, f.ID, returnTo)
getBody := x.EasyGetBody(t, client, getURL)
assert.Equal(t, gjson.GetBytes(getBody, "error.details.return_to").String(), returnTo)

// submit the flow but it is expired
u := public.URL + registration.RouteSubmitFlow + "?flow=" + f.ID.String()
res, err := client.PostForm(u, url.Values{"method": {"password"}, "csrf_token": {f.CSRFToken}, "password": {"password"}, "traits.email": {"email@ory.sh"}})
Expand All @@ -323,7 +331,7 @@ func TestGetFlow(t *testing.T) {

f, err = reg.RegistrationFlowPersister().GetRegistrationFlow(context.Background(), uuid.FromStringOrNil(gjson.GetBytes(resBody, "id").String()))
require.NoError(t, err)
assert.Equal(t, public.URL+registration.RouteInitBrowserFlow+"?return_to=https://www.ory.sh", f.RequestURL)
assert.Equal(t, public.URL+registration.RouteInitBrowserFlow+"?return_to="+returnTo, f.RequestURL)
})

t.Run("case=not found", func(t *testing.T) {
Expand Down
18 changes: 17 additions & 1 deletion selfservice/flow/settings/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"net/url"
"time"

"github.com/gobuffalo/pop/v6"

"github.com/ory/kratos/text"

"github.com/tidwall/gjson"
Expand Down Expand Up @@ -194,8 +196,22 @@ func (f *Flow) EnsureInternalContext() {

func (f Flow) MarshalJSON() ([]byte, error) {
type local Flow
f.SetReturnTo()
return json.Marshal(local(f))
}

func (f *Flow) SetReturnTo() {
if u, err := url.Parse(f.RequestURL); err == nil {
f.ReturnTo = u.Query().Get("return_to")
}
return json.Marshal(local(f))
}

func (f *Flow) AfterFind(*pop.Connection) error {
f.SetReturnTo()
return nil
}

func (f *Flow) AfterSave(*pop.Connection) error {
f.SetReturnTo()
return nil
}
5 changes: 4 additions & 1 deletion selfservice/flow/settings/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,12 @@ func (h *Handler) fetchFlow(w http.ResponseWriter, r *http.Request) error {

if pr.ExpiresAt.Before(time.Now().UTC()) {
if pr.Type == flow.TypeBrowser {
redirectURL := flow.GetFlowExpiredRedirectURL(h.d.Config(r.Context()), RouteInitBrowserFlow, pr.ReturnTo)

h.d.Writer().WriteError(w, r, errors.WithStack(x.ErrGone.
WithReason("The settings flow has expired. Redirect the user to the settings flow init endpoint to initialize a new settings flow.").
WithDetail("redirect_to", urlx.AppendPaths(h.d.Config(r.Context()).SelfPublicURL(), RouteInitBrowserFlow).String())))
WithDetail("redirect_to", redirectURL.String()).
WithDetail("return_to", pr.ReturnTo)))
return nil
}
h.d.Writer().WriteError(w, r, errors.WithStack(x.ErrGone.
Expand Down
14 changes: 11 additions & 3 deletions selfservice/flow/settings/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package settings_test
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
Expand Down Expand Up @@ -223,16 +224,23 @@ func TestHandler(t *testing.T) {
})

t.Run("case=expired with return_to", func(t *testing.T) {
conf.MustSet(config.ViperKeyURLsAllowedReturnToDomains, []string{"https://www.ory.sh/"})
returnTo := "https://www.ory.sh"
conf.MustSet(config.ViperKeyURLsAllowedReturnToDomains, []string{returnTo})

client := testhelpers.NewHTTPClientWithArbitrarySessionToken(t, reg)
body := x.EasyGetBody(t, client, publicTS.URL+settings.RouteInitBrowserFlow+"?return_to=https://www.ory.sh")
body := x.EasyGetBody(t, client, publicTS.URL+settings.RouteInitBrowserFlow+"?return_to="+returnTo)

// Expire the flow
f, err := reg.SettingsFlowPersister().GetSettingsFlow(context.Background(), uuid.FromStringOrNil(gjson.GetBytes(body, "id").String()))
require.NoError(t, err)
f.ExpiresAt = time.Now().Add(-time.Second)
require.NoError(t, reg.SettingsFlowPersister().UpdateSettingsFlow(context.Background(), f))

// Retrieve the flow and verify that return_to is in the response
getURL := fmt.Sprintf("%s%s?id=%s&return_to=%s", publicTS.URL, settings.RouteGetFlow, f.ID, returnTo)
getBody := x.EasyGetBody(t, client, getURL)
assert.Equal(t, gjson.GetBytes(getBody, "error.details.return_to").String(), returnTo)

// submit the flow but it is expired
u := publicTS.URL + settings.RouteSubmitFlow + "?flow=" + f.ID.String()
res, err := client.PostForm(u, url.Values{"method": {"password"}, "csrf_token": {"csrf"}, "password": {"password"}})
Expand All @@ -242,7 +250,7 @@ func TestHandler(t *testing.T) {

f, err = reg.SettingsFlowPersister().GetSettingsFlow(context.Background(), uuid.FromStringOrNil(gjson.GetBytes(resBody, "id").String()))
require.NoError(t, err)
assert.Equal(t, publicTS.URL+settings.RouteInitBrowserFlow+"?return_to=https://www.ory.sh", f.RequestURL)
assert.Equal(t, publicTS.URL+settings.RouteInitBrowserFlow+"?return_to="+returnTo, f.RequestURL)
})

t.Run("description=should fail to fetch request if identity changed", func(t *testing.T) {
Expand Down
Loading

0 comments on commit 5efe4a3

Please sign in to comment.