Skip to content

Commit

Permalink
transit from jwt-go to go-jose
Browse files Browse the repository at this point in the history
  • Loading branch information
narg95 committed May 7, 2021
1 parent 893aae4 commit b8ef805
Show file tree
Hide file tree
Showing 21 changed files with 585 additions and 250 deletions.
25 changes: 9 additions & 16 deletions authorize_request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ import (
"net/http"
"strings"

"github.com/ory/fosite/token/jwt"
"github.com/ory/x/errorsx"
"gopkg.in/square/go-jose.v2"

jwt "github.com/dgrijalva/jwt-go"
"github.com/pkg/errors"

"github.com/ory/go-convenience/stringslice"
Expand Down Expand Up @@ -101,7 +102,7 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequest(request *Aut
assertion = string(body)
}

token, err := jwt.ParseWithClaims(assertion, new(jwt.MapClaims), func(t *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(assertion, jwt.MapClaims{}, func(t *jwt.Token) (interface{}, error) {
// request_object_signing_alg - OPTIONAL.
// JWS [JWS] alg algorithm [JWA] that MUST be used for signing Request Objects sent to the OP. All Request Objects from this Client MUST be rejected,
// if not signed with this algorithm. Request Objects are described in Section 6.1 of OpenID Connect Core 1.0 [OpenID.Core]. This algorithm MUST
Expand All @@ -111,26 +112,22 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequest(request *Aut
return nil, errorsx.WithStack(ErrInvalidRequestObject.WithHintf("The request object uses signing algorithm '%s', but the requested OAuth 2.0 Client enforces signing algorithm '%s'.", t.Header["alg"], oidcClient.GetRequestObjectSigningAlgorithm()))
}

if t.Method == jwt.SigningMethodNone {
return jwt.UnsafeAllowNoneSignatureType, nil
}

switch t.Method.(type) {
case *jwt.SigningMethodRSA:
switch t.Method {
case jose.RS256, jose.RS384, jose.RS512:
key, err := f.findClientPublicJWK(oidcClient, t, true)
if err != nil {
return nil, wrapSigningKeyFailure(
ErrInvalidRequestObject.WithHint("Unable to retrieve RSA signing key from OAuth 2.0 Client."), err)
}
return key, nil
case *jwt.SigningMethodECDSA:
case jose.ES256, jose.ES384, jose.ES512:
key, err := f.findClientPublicJWK(oidcClient, t, false)
if err != nil {
return nil, wrapSigningKeyFailure(
ErrInvalidRequestObject.WithHint("Unable to retrieve ECDSA signing key from OAuth 2.0 Client."), err)
}
return key, nil
case *jwt.SigningMethodRSAPSS:
case jose.PS256, jose.PS384, jose.PS512:
key, err := f.findClientPublicJWK(oidcClient, t, true)
if err != nil {
return nil, wrapSigningKeyFailure(
Expand All @@ -155,12 +152,8 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequest(request *Aut
return errorsx.WithStack(ErrInvalidRequestObject.WithHint("Unable to verify the request object because its claims could not be validated, check if the expiry time is set correctly.").WithWrap(err).WithDebug(err.Error()))
}

claims, ok := token.Claims.(*jwt.MapClaims)
if !ok {
return errorsx.WithStack(ErrInvalidRequestObject.WithHint("Unable to type assert claims from request object.").WithDebugf(`Got claims of type %T but expected type '*jwt.MapClaims'.`, token.Claims))
}

for k, v := range *claims {
claims := token.Claims
for k, v := range claims {
request.Form.Set(k, fmt.Sprintf("%s", v))
}

Expand Down
38 changes: 12 additions & 26 deletions authorize_request_handler_oidc_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ import (

"github.com/pkg/errors"

jwt "github.com/dgrijalva/jwt-go"
"github.com/ory/fosite/token/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
jose "gopkg.in/square/go-jose.v2"
)

func mustGenerateAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token := jwt.NewWithClaims(jose.RS256, claims)
if kid != "" {
token.Header["kid"] = kid
}
Expand All @@ -50,19 +50,12 @@ func mustGenerateAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateK
}

func mustGenerateHSAssertion(t *testing.T, claims jwt.MapClaims) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token := jwt.NewWithClaims(jose.HS256, claims)
tokenString, err := token.SignedString([]byte("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd"))
require.NoError(t, err)
return tokenString
}

func mustGenerateNoneAssertion(t *testing.T, claims jwt.MapClaims) string {
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
tokenString, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
require.NoError(t, err)
return tokenString
}

func TestAuthorizeRequestParametersFromOpenIDConnectRequest(t *testing.T) {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
Expand All @@ -78,9 +71,9 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequest(t *testing.T) {
},
}

validRequestObject := mustGenerateAssertion(t, jwt.MapClaims{"scope": "foo", "foo": "bar", "baz": "baz", "response_type": "token", "response_mode": "post_form"}, key, "kid-foo")
claims := jwt.MapClaims{"scope": "foo", "foo": "bar", "baz": "baz", "response_type": "token", "response_mode": "post_form"}
validRequestObject := mustGenerateAssertion(t, claims, key, "kid-foo")
validRequestObjectWithoutKid := mustGenerateAssertion(t, jwt.MapClaims{"scope": "foo", "foo": "bar", "baz": "baz"}, key, "")
validNoneRequestObject := mustGenerateNoneAssertion(t, jwt.MapClaims{"scope": "foo", "foo": "bar", "baz": "baz", "state": "some-state"})

var reqH http.HandlerFunc = func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte(validRequestObject))
Expand Down Expand Up @@ -190,18 +183,6 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequest(t *testing.T) {
client: &DefaultOpenIDConnectClient{JSONWebKeysURI: reqJWK.URL, RequestObjectSigningAlgorithm: "RS256", RequestURIs: []string{reqTS.URL}},
expectForm: url.Values{"response_type": {"token"}, "response_mode": {"post_form"}, "scope": {"foo openid"}, "request_uri": {reqTS.URL}, "foo": {"bar"}, "baz": {"baz"}},
},
{
d: "should pass when request object uses algorithm none",
form: url.Values{"scope": {"openid"}, "request": {validNoneRequestObject}},
client: &DefaultOpenIDConnectClient{JSONWebKeysURI: reqJWK.URL, RequestObjectSigningAlgorithm: "none"},
expectForm: url.Values{"state": {"some-state"}, "scope": {"foo openid"}, "request": {validNoneRequestObject}, "foo": {"bar"}, "baz": {"baz"}},
},
{
d: "should pass when request object uses algorithm none and the client did not explicitly allow any algorithm",
form: url.Values{"scope": {"openid"}, "request": {validNoneRequestObject}},
client: &DefaultOpenIDConnectClient{JSONWebKeysURI: reqJWK.URL},
expectForm: url.Values{"state": {"some-state"}, "scope": {"foo openid"}, "request": {validNoneRequestObject}, "foo": {"bar"}, "baz": {"baz"}},
},
} {
t.Run(fmt.Sprintf("case=%d/description=%s", k, tc.d), func(t *testing.T) {
req := &AuthorizeRequest{
Expand All @@ -217,10 +198,15 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequest(t *testing.T) {
if tc.expectErrReason != "" {
real := new(RFC6749Error)
require.True(t, errors.As(err, &real))
assert.EqualValues(t, real.Reason(), tc.expectErrReason)
assert.EqualValues(t, tc.expectErrReason, real.Reason())
}
} else {
require.NoError(t, err)
if err != nil {
real := new(RFC6749Error)
errors.As(err, &real)
require.NoErrorf(t, err, "Hint: %v\nDebug:%v", real.HintField, real.DebugField)
}
require.NoErrorf(t, err, "%+v", err)
require.Equal(t, len(tc.expectForm), len(req.Form))
for k, v := range tc.expectForm {
assert.EqualValues(t, v, req.Form[k])
Expand Down
37 changes: 16 additions & 21 deletions client_authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import (

"github.com/ory/x/errorsx"

jwt "github.com/dgrijalva/jwt-go"
"github.com/ory/fosite/token/jwt"
"github.com/pkg/errors"
jose "gopkg.in/square/go-jose.v2"
)
Expand Down Expand Up @@ -90,17 +90,16 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
var clientID string
var client Client

token, err := jwt.ParseWithClaims(assertion, new(jwt.MapClaims), func(t *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(assertion, jwt.MapClaims{}, func(t *jwt.Token) (interface{}, error) {
var err error
clientID, _, err = clientCredentialsFromRequestBody(form, false)
if err != nil {
return nil, err
}

if clientID == "" {
if claims, ok := t.Claims.(*jwt.MapClaims); !ok {
return nil, errorsx.WithStack(ErrRequestUnauthorized.WithHint("Unable to type assert claims from client_assertion.").WithDebugf(`Expected claims to be of type '*jwt.MapClaims' but got '%T'.`, t.Claims))
} else if sub, ok := (*claims)["sub"].(string); !ok {
claims := t.Claims
if sub, ok := claims["sub"].(string); !ok {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The claim 'sub' from the client_assertion JSON Web Token is undefined."))
} else {
clientID = sub
Expand Down Expand Up @@ -135,18 +134,18 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
if oidcClient.GetTokenEndpointAuthSigningAlgorithm() != fmt.Sprintf("%s", t.Header["alg"]) {
return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The 'client_assertion' uses signing algorithm '%s' but the requested OAuth 2.0 Client enforces signing algorithm '%s'.", t.Header["alg"], oidcClient.GetTokenEndpointAuthSigningAlgorithm()))
}

if _, ok := t.Method.(*jwt.SigningMethodRSA); ok {
switch t.Method {
case jose.RS256, jose.RS384, jose.RS512:
return f.findClientPublicJWK(oidcClient, t, true)
} else if _, ok := t.Method.(*jwt.SigningMethodECDSA); ok {
case jose.ES256, jose.ES384, jose.ES512:
return f.findClientPublicJWK(oidcClient, t, false)
} else if _, ok := t.Method.(*jwt.SigningMethodRSAPSS); ok {
case jose.PS256, jose.PS384, jose.PS512:
return f.findClientPublicJWK(oidcClient, t, true)
} else if _, ok := t.Method.(*jwt.SigningMethodHMAC); ok {
case jose.HS256, jose.HS384, jose.HS512:
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("This authorization server does not support client authentication method 'client_secret_jwt'."))
default:
return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The 'client_assertion' request parameter uses unsupported signing algorithm '%s'.", t.Header["alg"]))
}

return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The 'client_assertion' request parameter uses unsupported signing algorithm '%s'.", t.Header["alg"]))
})
if err != nil {
// Do not re-process already enhanced errors
Expand All @@ -162,19 +161,15 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Unable to verify the request object because its claims could not be validated, check if the expiry time is set correctly.").WithWrap(err).WithDebug(err.Error()))
}

claims, ok := token.Claims.(*jwt.MapClaims)
if !ok {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Unable to type assert claims from request parameter 'client_assertion'.").WithDebugf("Got claims of type %T but expected type '*jwt.MapClaims'.", token.Claims))
}

claims := token.Claims
var jti string
if !claims.VerifyIssuer(clientID, true) {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'iss' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client."))
} else if f.TokenURL == "" {
return nil, errorsx.WithStack(ErrMisconfiguration.WithHint("The authorization server's token endpoint URL has not been set."))
} else if sub, ok := (*claims)["sub"].(string); !ok || sub != clientID {
} else if sub, ok := claims["sub"].(string); !ok || sub != clientID {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'sub' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client."))
} else if jti, ok = (*claims)["jti"].(string); !ok || len(jti) == 0 {
} else if jti, ok = claims["jti"].(string); !ok || len(jti) == 0 {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'jti' from 'client_assertion' must be set but is not."))
} else if f.Store.ClientAssertionJWTValid(ctx, jti) != nil {
return nil, errorsx.WithStack(ErrJTIKnown.WithHint("Claim 'jti' from 'client_assertion' MUST only be used once."))
Expand All @@ -183,7 +178,7 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
// type conversion according to jwt.MapClaims.VerifyExpiresAt
var expiry int64
err = nil
switch exp := (*claims)["exp"].(type) {
switch exp := claims["exp"].(type) {
case float64:
expiry = int64(exp)
case json.Number:
Expand All @@ -199,7 +194,7 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
return nil, err
}

if auds, ok := (*claims)["aud"].([]interface{}); !ok {
if auds, ok := claims["aud"].([]interface{}); !ok {
if !claims.VerifyAudience(f.TokenURL, true) {
return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("Claim 'audience' from 'client_assertion' must match the authorization server's token endpoint '%s'.", f.TokenURL))
}
Expand Down
29 changes: 5 additions & 24 deletions client_authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import (
"testing"
"time"

jwt "github.com/dgrijalva/jwt-go"
"github.com/ory/fosite/token/jwt"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -46,35 +46,28 @@ import (
)

func mustGenerateRSAAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token := jwt.NewWithClaims(jose.RS256, claims)
token.Header["kid"] = kid
tokenString, err := token.SignedString(key)
require.NoError(t, err)
return tokenString
}

func mustGenerateECDSAAssertion(t *testing.T, claims jwt.MapClaims, key *ecdsa.PrivateKey, kid string) string {
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
token := jwt.NewWithClaims(jose.ES256, claims)
token.Header["kid"] = kid
tokenString, err := token.SignedString(key)
require.NoError(t, err)
return tokenString
}

func mustGenerateHSAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token := jwt.NewWithClaims(jose.HS256, claims)
tokenString, err := token.SignedString([]byte("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd"))
require.NoError(t, err)
return tokenString
}

func mustGenerateNoneAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string {
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
tokenString, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
require.NoError(t, err)
return tokenString
}

// returns an http basic authorization header, encoded using application/x-www-form-urlencoded
func clientBasicAuthHeader(clientID, clientSecret string) http.Header {
creds := url.QueryEscape(clientID) + ":" + url.QueryEscape(clientSecret)
Expand Down Expand Up @@ -408,19 +401,6 @@ func TestAuthenticateClient(t *testing.T) {
r: new(http.Request),
expectErr: ErrInvalidClient,
},
{
d: "should fail because JWT algorithm is none",
client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: rsaJwks, TokenEndpointAuthMethod: "private_key_jwt"},
form: url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateNoneAssertion(t, jwt.MapClaims{
"sub": "bar",
"exp": time.Now().Add(time.Hour).Unix(),
"iss": "bar",
"jti": "12345",
"aud": "token-url",
}, rsaKey, "kid-foo")}, "client_assertion_type": []string{at}},
r: new(http.Request),
expectErr: ErrInvalidClient,
},
{
d: "should pass with proper assertion when JWKs URI is set",
client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeysURI: ts.URL, TokenEndpointAuthMethod: "private_key_jwt"},
Expand Down Expand Up @@ -503,6 +483,7 @@ func TestAuthenticateClient(t *testing.T) {
t.Logf("Error is: %s", validationError.Inner)
} else if errors.As(err, &rfcError) {
t.Logf("DebugField is: %s", rfcError.DebugField)
t.Logf("HintField is: %s", rfcError.HintField)
}
}
require.NoError(t, err)
Expand Down
2 changes: 0 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ replace github.com/gogo/protobuf => github.com/gogo/protobuf v1.3.2
require (
github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535
github.com/dgraph-io/ristretto v0.0.3 // indirect
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/form3tech-oss/jwt-go v3.2.2+incompatible // indirect
github.com/golang/mock v1.4.4
github.com/gorilla/mux v1.7.3
github.com/gorilla/websocket v1.4.2
Expand Down
3 changes: 0 additions & 3 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,7 @@ github.com/fatih/structs v1.0.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga
github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M=
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/form3tech-oss/jwt-go v3.2.1+incompatible h1:xdtqez379uWVJ9P3qQMX8W+F/nqsTdUvyMZB36tnacA=
github.com/form3tech-oss/jwt-go v3.2.1+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k=
github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk=
github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
Expand Down
6 changes: 2 additions & 4 deletions handler/oauth2/introspector_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ import (
"context"
"time"

jwtx "github.com/dgrijalva/jwt-go"

"github.com/ory/fosite"
"github.com/ory/fosite/token/jwt"
)
Expand All @@ -37,8 +35,8 @@ type StatelessJWTValidator struct {
}

// AccessTokenJWTToRequest tries to reconstruct fosite.Request from a JWT.
func AccessTokenJWTToRequest(token *jwtx.Token) fosite.Requester {
mapClaims := token.Claims.(jwtx.MapClaims)
func AccessTokenJWTToRequest(token *jwt.Token) fosite.Requester {
mapClaims := token.Claims
claims := jwt.JWTClaims{}
claims.FromMapClaims(mapClaims)

Expand Down
Loading

0 comments on commit b8ef805

Please sign in to comment.