diff --git a/server/oauth2.go b/server/oauth2.go index 9f5e95ec24..1a22cd621b 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -95,6 +95,7 @@ const ( errUnauthorizedClient = "unauthorized_client" errAccessDenied = "access_denied" errUnsupportedResponseType = "unsupported_response_type" + errRequestNotSupported = "request_not_supported" errInvalidScope = "invalid_scope" errServerError = "server_error" errTemporarilyUnavailable = "temporarily_unavailable" @@ -453,6 +454,12 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques return &authErr{state, redirectURI, typ, fmt.Sprintf(format, a...)} } + // dex doesn't support request parameter and must return request_not_supported error + // https://openid.net/specs/openid-connect-core-1_0.html#6.1 + if q.Get("request") != "" { + return nil, newErr(errRequestNotSupported, "Server does not support request parameter.") + } + if codeChallengeMethod != CodeChallengeMethodS256 && codeChallengeMethod != CodeChallengeMethodPlain { description := fmt.Sprintf("Unsupported PKCE challenge method (%q).", codeChallengeMethod) return nil, newErr(errInvalidRequest, description) diff --git a/server/server_test.go b/server/server_test.go index fa73743a4c..a09e4b727c 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -25,6 +25,7 @@ import ( "github.com/kylelemons/godebug/pretty" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" "golang.org/x/crypto/bcrypt" "golang.org/x/oauth2" jose "gopkg.in/square/go-jose.v2" @@ -223,6 +224,9 @@ type test struct { // extra parameters to pass when retrieving id token retrieveTokenOptions []oauth2.AuthCodeOption + // define an error response, when the test expects an error on the auth endpoint + authError *OAuth2ErrorResponse + // define an error response, when the test expects an error on the token endpoint tokenError ErrorResponse } @@ -607,6 +611,19 @@ func makeOAuth2Tests(clientID string, clientSecret string, now func() time.Time) StatusCode: http.StatusBadRequest, }, }, + { + name: "Request parameter in authorization query", + authCodeOptions: []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("request", "anything"), + }, + authError: &OAuth2ErrorResponse{ + Error: errRequestNotSupported, + ErrorDescription: "Server does not support request parameter.", + }, + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { + return nil + }, + }, }, } } @@ -665,7 +682,7 @@ func TestOAuth2CodeFlow(t *testing.T) { state = "a_state" ) defer func() { - if !gotCode { + if !gotCode && tc.authError == nil { t.Errorf("never got a code in callback\n%s\n%s", reqDump, respDump) } }() @@ -684,12 +701,18 @@ func TestOAuth2CodeFlow(t *testing.T) { // Did dex return an error? if errType := q.Get("error"); errType != "" { - if desc := q.Get("error_description"); desc != "" { - t.Errorf("got error from server %s: %s", errType, desc) - } else { - t.Errorf("got error from server %s", errType) + description := q.Get("error_description") + + if tc.authError == nil { + if description != "" { + t.Errorf("got error from server %s: %s", errType, description) + } else { + t.Errorf("got error from server %s", errType) + } + w.WriteHeader(http.StatusInternalServerError) + return } - w.WriteHeader(http.StatusInternalServerError) + require.Equal(t, *tc.authError, OAuth2ErrorResponse{Error: errType, ErrorDescription: description}) return }