Skip to content

Commit

Permalink
feat: use bitwise comparison for jwt validation errors (#633)
Browse files Browse the repository at this point in the history
Co-authored-by: hackerman <3372410+aeneasr@users.noreply.github.com>
  • Loading branch information
narg95 and aeneasr authored Nov 13, 2021
1 parent 2ae47fb commit 52ee93f
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 34 deletions.
2 changes: 1 addition & 1 deletion handler/oauth2/introspector_jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func TestIntrospectJWT(t *testing.T) {
},
},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
t.Run(fmt.Sprintf("case=%d:%v", k, c.description), func(t *testing.T) {
if c.scopes == nil {
c.scopes = []string{}
}
Expand Down
58 changes: 27 additions & 31 deletions handler/oauth2/strategy_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@ import (
"strings"
"time"

"github.com/ory/x/errorsx"

"github.com/pkg/errors"

"github.com/ory/fosite"
"github.com/ory/fosite/token/jwt"
"github.com/ory/x/errorsx"
)

// DefaultJWTStrategy is a JWT RS256 strategy.
Expand Down Expand Up @@ -100,44 +99,41 @@ func (h *DefaultJWTStrategy) ValidateAuthorizeCode(ctx context.Context, req fosi

func validate(ctx context.Context, jwtStrategy jwt.JWTStrategy, token string) (t *jwt.Token, err error) {
t, err = jwtStrategy.Decode(ctx, token)

if err == nil {
err = t.Claims.Valid()
return
}

if err != nil {
var e *jwt.ValidationError
if errors.As(err, &e) {
switch e.Errors {
case jwt.ValidationErrorMalformed:
err = errorsx.WithStack(fosite.ErrInvalidTokenFormat.WithWrap(err).WithDebug(err.Error()))
case jwt.ValidationErrorUnverifiable:
err = errorsx.WithStack(fosite.ErrTokenSignatureMismatch.WithWrap(err).WithDebug(err.Error()))
case jwt.ValidationErrorSignatureInvalid:
err = errorsx.WithStack(fosite.ErrTokenSignatureMismatch.WithWrap(err).WithDebug(err.Error()))
case jwt.ValidationErrorAudience:
err = errorsx.WithStack(fosite.ErrTokenClaim.WithWrap(err).WithDebug(err.Error()))
case jwt.ValidationErrorExpired:
err = errorsx.WithStack(fosite.ErrTokenExpired.WithWrap(err).WithDebug(err.Error()))
case jwt.ValidationErrorIssuedAt:
err = errorsx.WithStack(fosite.ErrTokenClaim.WithWrap(err).WithDebug(err.Error()))
case jwt.ValidationErrorIssuer:
err = errorsx.WithStack(fosite.ErrTokenClaim.WithWrap(err).WithDebug(err.Error()))
case jwt.ValidationErrorNotValidYet:
err = errorsx.WithStack(fosite.ErrTokenClaim.WithWrap(err).WithDebug(err.Error()))
case jwt.ValidationErrorId:
err = errorsx.WithStack(fosite.ErrTokenClaim.WithWrap(err).WithDebug(err.Error()))
case jwt.ValidationErrorClaimsInvalid:
err = errorsx.WithStack(fosite.ErrTokenClaim.WithWrap(err).WithDebug(err.Error()))
default:
err = errorsx.WithStack(fosite.ErrRequestUnauthorized.WithWrap(err).WithDebug(err.Error()))
}
}
var e *jwt.ValidationError
if err != nil && errors.As(err, &e) {
err = errorsx.WithStack(toRFCErr(e).WithWrap(err).WithDebug(err.Error()))
}

return
}

func toRFCErr(v *jwt.ValidationError) *fosite.RFC6749Error {
switch {
case v == nil:
return nil
case v.Has(jwt.ValidationErrorMalformed):
return fosite.ErrInvalidTokenFormat
case v.Has(jwt.ValidationErrorUnverifiable | jwt.ValidationErrorSignatureInvalid):
return fosite.ErrTokenSignatureMismatch
case v.Has(jwt.ValidationErrorExpired):
return fosite.ErrTokenExpired
case v.Has(jwt.ValidationErrorAudience |
jwt.ValidationErrorIssuedAt |
jwt.ValidationErrorIssuer |
jwt.ValidationErrorNotValidYet |
jwt.ValidationErrorId |
jwt.ValidationErrorClaimsInvalid):
return fosite.ErrTokenClaim
default:
return fosite.ErrRequestUnauthorized
}
}

func (h *DefaultJWTStrategy) generate(ctx context.Context, tokenType fosite.TokenType, requester fosite.Requester) (string, string, error) {
if jwtSession, ok := requester.GetSession().(JWTSessionContainer); !ok {
return "", "", errors.Errorf("Session must be of type JWTSessionContainer but got type: %T", requester.GetSession())
Expand Down
2 changes: 1 addition & 1 deletion handler/openid/strategy_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ func (h DefaultStrategy) GenerateIDToken(ctx context.Context, requester fosite.R
if tokenHintString := requester.GetRequestForm().Get("id_token_hint"); tokenHintString != "" {
tokenHint, err := h.JWTStrategy.Decode(ctx, tokenHintString)
var ve *jwt.ValidationError
if errors.As(err, &ve) && ve.Errors == jwt.ValidationErrorExpired {
if errors.As(err, &ve) && ve.Has(jwt.ValidationErrorExpired) {
// Expired ID Tokens are allowed as values to id_token_hint
} else if err != nil {
return "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebugf("Unable to decode id token from 'id_token_hint' parameter because %s.", err.Error()))
Expand Down
2 changes: 1 addition & 1 deletion handler/openid/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func (v *OpenIDConnectRequestValidator) ValidatePrompt(ctx context.Context, req

tokenHint, err := v.Strategy.Decode(ctx, idTokenHint)
var ve *jwt.ValidationError
if errors.As(err, &ve) && ve.Errors == jwt.ValidationErrorExpired {
if errors.As(err, &ve) && ve.Has(jwt.ValidationErrorExpired) {
// Expired tokens are ok
} else if err != nil {
return errorsx.WithStack(fosite.ErrInvalidRequest.WithHint("Failed to validate OpenID Connect request as decoding id token from id_token_hint parameter failed.").WithWrap(err).WithDebug(err.Error()))
Expand Down
4 changes: 4 additions & 0 deletions token/jwt/validation_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,7 @@ func (e ValidationError) Error() string {
func (e *ValidationError) valid() bool {
return e.Errors == 0
}

func (e *ValidationError) Has(verr uint32) bool {
return (e.Errors & verr) != 0
}

0 comments on commit 52ee93f

Please sign in to comment.