Skip to content

Commit

Permalink
refactor(openid/client): extract authorization code parameters
Browse files Browse the repository at this point in the history
Co-authored-by: sindrerh2 <sindre.rodseth.hansen@nav.no>
  • Loading branch information
tronghn and sindrerh2 committed Jan 23, 2025
1 parent 642457b commit 110dd64
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 128 deletions.
6 changes: 3 additions & 3 deletions pkg/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ func (s *Standalone) Login(w http.ResponseWriter, r *http.Request) {
"redirect_after_login": canonicalRedirect,
}

if acr := login.Acr; acr != "" {
fields["acr"] = acr
if acrValues := login.AcrValues; acrValues != "" {
fields["acr"] = acrValues
}

if locale := login.Locale; locale != "" {
if locale := login.UILocales; locale != "" {
fields["locale"] = locale
}

Expand Down
12 changes: 6 additions & 6 deletions pkg/openid/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,16 @@ func (c *Client) AuthCodeGrant(ctx context.Context, code string, opts []oauth2.A
}

func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid.TokenResponse, error) {
params, err := c.AuthParams()
params, err := c.ClientAuthenticationParams()
if err != nil {
return nil, err
}

payload := params.URLValues(map[string]string{
payload := params.Merge(openid.AuthParams{
"grant_type": "refresh_token",
"refresh_token": refreshToken,
"client_id": c.cfg.Client().ClientID(),
}).Encode()
}).URLValues().Encode()

endpoint := c.cfg.Provider().TokenEndpoint()
body, err := c.oauthPostRequest(ctx, endpoint, payload)
Expand All @@ -114,18 +114,18 @@ func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid
return &tokenResponse, nil
}

func (c *Client) AuthParams() (openid.AuthParams, error) {
func (c *Client) ClientAuthenticationParams() (openid.AuthParams, error) {
switch c.cfg.Client().AuthMethod() {
case openidconfig.AuthMethodPrivateKeyJWT:
assertion, err := c.MakeAssertion(DefaultClientAssertionLifetime)
if err != nil {
return nil, fmt.Errorf("creating client assertion: %w", err)
}

return openid.AuthParamsJwtBearer(assertion), nil
return openid.ClientAuthParamsJwtBearer(assertion), nil

case openidconfig.AuthMethodClientSecret:
return openid.AuthParamsClientSecret(c.cfg.Client().ClientSecret()), nil
return openid.ClientAuthParamsSecret(c.cfg.Client().ClientSecret()), nil
}

return nil, fmt.Errorf("unsupported client authentication method: %q", c.cfg.Client().AuthMethod())
Expand Down
118 changes: 24 additions & 94 deletions pkg/openid/client/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ const (
LocaleURLParameter = "locale"
SecurityLevelURLParameter = "level"
PromptURLParameter = "prompt"
MaxAgeURLParameter = "max_age"
)

var (
Expand All @@ -34,43 +33,17 @@ var (
ErrInvalidPrompt = errors.New("InvalidPrompt")
ErrInvalidLoginParameter = errors.New("InvalidLoginParameter")

// LoginParameterMapping maps incoming login parameters to OpenID Connect parameters
LoginParameterMapping = map[string]string{
LocaleURLParameter: "ui_locales",
SecurityLevelURLParameter: "acr_values",
}

PromptAllowedValues = []string{"login", "select_account"}
)

type Login struct {
authorizationRequest
openid.AuthorizationCodeParams
AuthCodeURL string
Cookie openid.LoginCookie
}

type authorizationRequest struct {
Acr string
CallbackURL string
CodeVerifier string
Locale string
Nonce string
Prompt string
State string
}

func (a authorizationRequest) ToCookie() openid.LoginCookie {
return openid.LoginCookie{
Acr: a.Acr,
CodeVerifier: a.CodeVerifier,
Nonce: a.Nonce,
State: a.State,
RedirectURI: a.CallbackURL,
}
}

func (c *Client) Login(r *http.Request) (*Login, error) {
request, err := c.newAuthorizationRequest(r)
request, err := c.newAuthorizationCodeParams(r)
if err != nil {
return nil, fmt.Errorf("login: %w", err)
}
Expand All @@ -81,14 +54,14 @@ func (c *Client) Login(r *http.Request) (*Login, error) {
}

return &Login{
AuthCodeURL: authCodeURL,
authorizationRequest: request,
Cookie: request.ToCookie(),
AuthCodeURL: authCodeURL,
AuthorizationCodeParams: request,
Cookie: request.Cookie(),
}, nil
}

func (c *Client) newAuthorizationRequest(r *http.Request) (authorizationRequest, error) {
var req authorizationRequest
func (c *Client) newAuthorizationCodeParams(r *http.Request) (openid.AuthorizationCodeParams, error) {
var req openid.AuthorizationCodeParams

callbackURL, err := url.LoginCallback(r)
if err != nil {
Expand Down Expand Up @@ -120,86 +93,42 @@ func (c *Client) newAuthorizationRequest(r *http.Request) (authorizationRequest,
return req, fmt.Errorf("creating state: %w", err)
}

resource := c.cfg.Client().ResourceIndicator()
codeVerifier := oauth2.GenerateVerifier()

return authorizationRequest{
Acr: acrParam,
CallbackURL: callbackURL,
return openid.AuthorizationCodeParams{
AcrValues: acrParam,
ClientID: c.oauth2Config.ClientID,
CodeVerifier: codeVerifier,
Locale: locale,
Nonce: nonce,
Prompt: prompt,
RedirectURI: callbackURL,
Resource: resource,
Scope: c.oauth2Config.Scopes,
State: state,
UILocales: locale,
}, nil
}

func (c *Client) authCodeURL(ctx context.Context, request authorizationRequest) (string, error) {
func (c *Client) authCodeURL(ctx context.Context, request openid.AuthorizationCodeParams) (string, error) {
var authCodeURL string

if c.cfg.Provider().PushedAuthorizationRequestEndpoint() == "" {
opts := []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("nonce", request.Nonce),
oauth2.SetAuthURLParam("response_mode", "query"),
oauth2.S256ChallengeOption(request.CodeVerifier),
openid.RedirectURIOption(request.CallbackURL),
}

if resource := c.cfg.Client().ResourceIndicator(); resource != "" {
opts = append(opts, oauth2.SetAuthURLParam("resource", resource))
}

if len(request.Acr) > 0 {
opts = append(opts, oauth2.SetAuthURLParam(LoginParameterMapping[SecurityLevelURLParameter], request.Acr))
}

if len(request.Locale) > 0 {
opts = append(opts, oauth2.SetAuthURLParam(LoginParameterMapping[LocaleURLParameter], request.Locale))
}

if len(request.Prompt) > 0 {
opts = append(opts, oauth2.SetAuthURLParam(PromptURLParameter, request.Prompt))
opts = append(opts, oauth2.SetAuthURLParam(MaxAgeURLParameter, "0"))
}
opts := request.AuthParams().AuthCodeOptions()

// TODO: replace with separate function
authCodeURL = c.oauth2Config.AuthCodeURL(request.State, opts...)
} else {
params := map[string]string{
"client_id": c.oauth2Config.ClientID,
"code_challenge": oauth2.S256ChallengeFromVerifier(request.CodeVerifier),
"code_challenge_method": "S256",
"nonce": request.Nonce,
"redirect_uri": request.CallbackURL,
"response_mode": "query",
"response_type": "code",
"scope": stringslib.Join(c.oauth2Config.Scopes, " "),
"state": request.State,
}

if resource := c.cfg.Client().ResourceIndicator(); resource != "" {
params["resource"] = resource
}

if len(request.Acr) > 0 {
params[LoginParameterMapping[SecurityLevelURLParameter]] = request.Acr
}

if len(request.Locale) > 0 {
params[LoginParameterMapping[LocaleURLParameter]] = request.Locale
}

if len(request.Prompt) > 0 {
params[PromptURLParameter] = request.Prompt
params[MaxAgeURLParameter] = "0"
}

authParams, err := c.AuthParams()
clientAuthParams, err := c.ClientAuthenticationParams()
if err != nil {
return "", fmt.Errorf("generating client authentication parameters: %w", err)
}

payload := authParams.URLValues(params).Encode()
endpoint := c.cfg.Provider().PushedAuthorizationRequestEndpoint()
body, err := c.oauthPostRequest(ctx, endpoint, payload)
body, err := c.oauthPostRequest(ctx, endpoint, request.AuthParams().
Merge(clientAuthParams).
URLValues().
Encode())
if err != nil {
return "", err
}
Expand All @@ -209,6 +138,7 @@ func (c *Client) authCodeURL(ctx context.Context, request authorizationRequest)
return "", fmt.Errorf("unmarshalling token response: %w", err)
}

// TODO: this can a separate function to replace oauth2config.AuthCodeURL
v := urllib.Values{
"client_id": {c.oauth2Config.ClientID},
"request_uri": {pushedAuthorizationResponse.RequestUri},
Expand Down
12 changes: 5 additions & 7 deletions pkg/openid/client/login_callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"fmt"
"net/http"

"golang.org/x/oauth2"

"github.com/nais/wonderwall/pkg/openid"
)

Expand Down Expand Up @@ -66,15 +64,15 @@ func (c *Client) authorizationServerIssuerIdentification(iss string) error {
}

func (c *Client) redeemTokens(ctx context.Context, code string, cookie *openid.LoginCookie) (*openid.Tokens, error) {
params, err := c.AuthParams()
params, err := c.ClientAuthenticationParams()
if err != nil {
return nil, err
}

rawTokens, err := c.AuthCodeGrant(ctx, code, params.AuthCodeOptions([]oauth2.AuthCodeOption{
openid.RedirectURIOption(cookie.RedirectURI),
oauth2.VerifierOption(cookie.CodeVerifier),
}))
rawTokens, err := c.AuthCodeGrant(ctx, code, params.Merge(openid.AuthParams{
"redirect_uri": cookie.RedirectURI,
"code_verifier": cookie.CodeVerifier,
}).AuthCodeOptions())
if err != nil {
return nil, fmt.Errorf("exchanging authorization code for token: %w", err)
}
Expand Down
Loading

0 comments on commit 110dd64

Please sign in to comment.