diff --git a/claims.go b/claims.go index b115d5e0..9dee607a 100644 --- a/claims.go +++ b/claims.go @@ -1,177 +1,16 @@ package jwt -import ( - "crypto/subtle" - "fmt" - "time" -) - -// Claims must just have a Valid method that determines -// if the token is invalid for any supported reason +// Claims represent any form of a JWT Claims Set according to +// https://datatracker.ietf.org/doc/html/rfc7519#section-4. In order to have a +// common basis for validation, it is required that an implementation is able to +// supply at least the claim names provided in +// https://datatracker.ietf.org/doc/html/rfc7519#section-4.1 namely `exp`, +// `iat`, `nbf`, `iss` and `aud`. type Claims interface { - Valid() error -} - -// RegisteredClaims are a structured version of the JWT Claims Set, -// restricted to Registered Claim Names, as referenced at -// https://datatracker.ietf.org/doc/html/rfc7519#section-4.1 -// -// This type can be used on its own, but then additional private and -// public claims embedded in the JWT will not be parsed. The typical usecase -// therefore is to embedded this in a user-defined claim type. -// -// See examples for how to use this with your own claim types. -type RegisteredClaims struct { - // the `iss` (Issuer) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.1 - Issuer string `json:"iss,omitempty"` - - // the `sub` (Subject) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.2 - Subject string `json:"sub,omitempty"` - - // the `aud` (Audience) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3 - Audience ClaimStrings `json:"aud,omitempty"` - - // the `exp` (Expiration Time) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4 - ExpiresAt *NumericDate `json:"exp,omitempty"` - - // the `nbf` (Not Before) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.5 - NotBefore *NumericDate `json:"nbf,omitempty"` - - // the `iat` (Issued At) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.6 - IssuedAt *NumericDate `json:"iat,omitempty"` - - // the `jti` (JWT ID) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.7 - ID string `json:"jti,omitempty"` -} - -// Valid validates time based claims "exp, iat, nbf". -// There is no accounting for clock skew. -// As well, if any of the above claims are not in the token, it will still -// be considered a valid claim. -func (c RegisteredClaims) Valid() error { - vErr := new(ValidationError) - now := TimeFunc() - - // The claims below are optional, by default, so if they are set to the - // default value in Go, let's not fail the verification for them. - if !c.VerifyExpiresAt(now, false) { - delta := now.Sub(c.ExpiresAt.Time) - vErr.Inner = fmt.Errorf("%s by %s", ErrTokenExpired, delta) - vErr.Errors |= ValidationErrorExpired - } - - if !c.VerifyIssuedAt(now, false) { - vErr.Inner = ErrTokenUsedBeforeIssued - vErr.Errors |= ValidationErrorIssuedAt - } - - if !c.VerifyNotBefore(now, false) { - vErr.Inner = ErrTokenNotValidYet - vErr.Errors |= ValidationErrorNotValidYet - } - - if vErr.valid() { - return nil - } - - return vErr -} - -// VerifyAudience compares the aud claim against cmp. -// If required is false, this method will return true if the value matches or is unset -func (c *RegisteredClaims) VerifyAudience(cmp string, req bool) bool { - return verifyAud(c.Audience, cmp, req) -} - -// VerifyExpiresAt compares the exp claim against cmp (cmp < exp). -// If req is false, it will return true, if exp is unset. -func (c *RegisteredClaims) VerifyExpiresAt(cmp time.Time, req bool) bool { - if c.ExpiresAt == nil { - return verifyExp(nil, cmp, req) - } - - return verifyExp(&c.ExpiresAt.Time, cmp, req) -} - -// VerifyIssuedAt compares the iat claim against cmp (cmp >= iat). -// If req is false, it will return true, if iat is unset. -func (c *RegisteredClaims) VerifyIssuedAt(cmp time.Time, req bool) bool { - if c.IssuedAt == nil { - return verifyIat(nil, cmp, req) - } - - return verifyIat(&c.IssuedAt.Time, cmp, req) -} - -// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf). -// If req is false, it will return true, if nbf is unset. -func (c *RegisteredClaims) VerifyNotBefore(cmp time.Time, req bool) bool { - if c.NotBefore == nil { - return verifyNbf(nil, cmp, req) - } - - return verifyNbf(&c.NotBefore.Time, cmp, req) -} - -// VerifyIssuer compares the iss claim against cmp. -// If required is false, this method will return true if the value matches or is unset -func (c *RegisteredClaims) VerifyIssuer(cmp string, req bool) bool { - return verifyIss(c.Issuer, cmp, req) -} - -// ----- helpers - -func verifyAud(aud []string, cmp string, required bool) bool { - if len(aud) == 0 { - return !required - } - // use a var here to keep constant time compare when looping over a number of claims - result := false - - var stringClaims string - for _, a := range aud { - if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 { - result = true - } - stringClaims = stringClaims + a - } - - // case where "" is sent in one or many aud claims - if len(stringClaims) == 0 { - return !required - } - - return result -} - -func verifyExp(exp *time.Time, now time.Time, required bool) bool { - if exp == nil { - return !required - } - return now.Before(*exp) -} - -func verifyIat(iat *time.Time, now time.Time, required bool) bool { - if iat == nil { - return !required - } - return now.After(*iat) || now.Equal(*iat) -} - -func verifyNbf(nbf *time.Time, now time.Time, required bool) bool { - if nbf == nil { - return !required - } - return now.After(*nbf) || now.Equal(*nbf) -} - -func verifyIss(iss string, cmp string, required bool) bool { - if iss == "" { - return !required - } - if subtle.ConstantTimeCompare([]byte(iss), []byte(cmp)) != 0 { - return true - } else { - return false - } + GetExpirationTime() (*NumericDate, error) + GetIssuedAt() (*NumericDate, error) + GetNotBefore() (*NumericDate, error) + GetIssuer() (string, error) + GetSubject() (string, error) + GetAudience() (ClaimStrings, error) } diff --git a/errors.go b/errors.go index 10ac8835..34f32faf 100644 --- a/errors.go +++ b/errors.go @@ -18,6 +18,7 @@ var ( ErrTokenExpired = errors.New("token is expired") ErrTokenUsedBeforeIssued = errors.New("token used before issued") ErrTokenInvalidIssuer = errors.New("token has invalid issuer") + ErrTokenInvalidSubject = errors.New("token has invalid subject") ErrTokenNotValidYet = errors.New("token is not valid yet") ErrTokenInvalidId = errors.New("token has invalid id") ErrTokenInvalidClaims = errors.New("token has invalid claims") @@ -29,11 +30,12 @@ const ( ValidationErrorUnverifiable // Token could not be verified because of signing problems ValidationErrorSignatureInvalid // Signature validation failed - // Standard Claim validation errors + // Registered Claim validation errors ValidationErrorAudience // AUD validation failed ValidationErrorExpired // EXP validation failed ValidationErrorIssuedAt // IAT validation failed ValidationErrorIssuer // ISS validation failed + ValidationErrorSubject // SUB validation failed ValidationErrorNotValidYet // NBF validation failed ValidationErrorId // JTI validation failed ValidationErrorClaimsInvalid // Generic claims validation error diff --git a/example_test.go b/example_test.go index b76699ff..650132aa 100644 --- a/example_test.go +++ b/example_test.go @@ -70,7 +70,7 @@ func ExampleNewWithClaims_customClaimsType() { //Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.xVuY2FZ_MRXMIEgVQ7J-TFtaucVFRXUzHm9LmV41goM } -// Example creating a token using a custom claims type. The StandardClaim is embedded +// Example creating a token using a custom claims type. The RegisteredClaims is embedded // in the custom type to allow for easy encoding, parsing and validation of standard claims. func ExampleParseWithClaims_customClaimsType() { tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA" @@ -93,7 +93,63 @@ func ExampleParseWithClaims_customClaimsType() { // Output: bar test } -// An example of parsing the error types using bitfield checks +// Example creating a token using a custom claims type and validation options. The RegisteredClaims is embedded +// in the custom type to allow for easy encoding, parsing and validation of standard claims. +func ExampleParseWithClaims_validationOptions() { + tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA" + + type MyCustomClaims struct { + Foo string `json:"foo"` + jwt.RegisteredClaims + } + + token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) { + return []byte("AllYourBase"), nil + }, jwt.WithLeeway(5*time.Second)) + + if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid { + fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer) + } else { + fmt.Println(err) + } + + // Output: bar test +} + +type MyCustomClaims struct { + Foo string `json:"foo"` + jwt.RegisteredClaims +} + +func (m MyCustomClaims) CustomValidation() error { + if m.Foo != "bar" { + return errors.New("must be foobar") + } + + return nil +} + +// Example creating a token using a custom claims type and validation options. +// The RegisteredClaims is embedded in the custom type to allow for easy +// encoding, parsing and validation of standard claims and the function +// CustomValidation is implemented. +func ExampleParseWithClaims_customValidation() { + tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA" + + token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) { + return []byte("AllYourBase"), nil + }, jwt.WithLeeway(5*time.Second)) + + if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid { + fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer) + } else { + fmt.Println(err) + } + + // Output: bar test +} + +// An example of parsing the error types using errors.Is. func ExampleParse_errorChecking() { // Token from another example. This token is expired var tokenString = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c" diff --git a/map_claims.go b/map_claims.go index 2700d64a..a1e4935f 100644 --- a/map_claims.go +++ b/map_claims.go @@ -3,149 +3,109 @@ package jwt import ( "encoding/json" "errors" - "time" - // "fmt" ) // MapClaims is a claims type that uses the map[string]interface{} for JSON decoding. // This is the default claims type if you don't supply one type MapClaims map[string]interface{} -// VerifyAudience Compares the aud claim against cmp. -// If required is false, this method will return true if the value matches or is unset -func (m MapClaims) VerifyAudience(cmp string, req bool) bool { - var aud []string - switch v := m["aud"].(type) { - case string: - aud = append(aud, v) - case []string: - aud = v - case []interface{}: - for _, a := range v { - vs, ok := a.(string) - if !ok { - return false - } - aud = append(aud, vs) - } - } - return verifyAud(aud, cmp, req) -} +var ErrInvalidType = errors.New("invalid type for claim") -// VerifyExpiresAt compares the exp claim against cmp (cmp <= exp). -// If req is false, it will return true, if exp is unset. -func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool { - cmpTime := time.Unix(cmp, 0) - - v, ok := m["exp"] - if !ok { - return !req - } +// GetExpirationTime implements the Claims interface. +func (m MapClaims) GetExpirationTime() (*NumericDate, error) { + return m.ParseNumericDate("exp") +} - switch exp := v.(type) { - case float64: - if exp == 0 { - return verifyExp(nil, cmpTime, req) - } +// GetNotBefore implements the Claims interface. +func (m MapClaims) GetNotBefore() (*NumericDate, error) { + return m.ParseNumericDate("nbf") +} - return verifyExp(&newNumericDateFromSeconds(exp).Time, cmpTime, req) - case json.Number: - v, _ := exp.Float64() +// GetIssuedAt implements the Claims interface. +func (m MapClaims) GetIssuedAt() (*NumericDate, error) { + return m.ParseNumericDate("iat") +} - return verifyExp(&newNumericDateFromSeconds(v).Time, cmpTime, req) - } +// GetAudience implements the Claims interface. +func (m MapClaims) GetAudience() (ClaimStrings, error) { + return m.ParseClaimsString("aud") +} - return false +// GetIssuer implements the Claims interface. +func (m MapClaims) GetIssuer() (string, error) { + return m.ParseString("iss") } -// VerifyIssuedAt compares the exp claim against cmp (cmp >= iat). -// If req is false, it will return true, if iat is unset. -func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool { - cmpTime := time.Unix(cmp, 0) +// GetSubject implements the Claims interface. +func (m MapClaims) GetSubject() (string, error) { + return m.ParseString("sub") +} - v, ok := m["iat"] +// ParseNumericDate tries to parse a key in the map claims type as a number +// date. This will succeed, if the underlying type is either a [float64] or a +// [json.Number]. Otherwise, nil will be returned. +func (m MapClaims) ParseNumericDate(key string) (*NumericDate, error) { + v, ok := m[key] if !ok { - return !req + return nil, nil } - switch iat := v.(type) { + switch exp := v.(type) { case float64: - if iat == 0 { - return verifyIat(nil, cmpTime, req) + if exp == 0 { + return nil, nil } - return verifyIat(&newNumericDateFromSeconds(iat).Time, cmpTime, req) + return newNumericDateFromSeconds(exp), nil case json.Number: - v, _ := iat.Float64() + v, _ := exp.Float64() - return verifyIat(&newNumericDateFromSeconds(v).Time, cmpTime, req) + return newNumericDateFromSeconds(v), nil } - return false + return nil, ErrInvalidType } -// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf). -// If req is false, it will return true, if nbf is unset. -func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool { - cmpTime := time.Unix(cmp, 0) - - v, ok := m["nbf"] - if !ok { - return !req - } - - switch nbf := v.(type) { - case float64: - if nbf == 0 { - return verifyNbf(nil, cmpTime, req) +// ParseClaimsString tries to parse a key in the map claims type as a +// [ClaimsStrings] type, which can either be a string or an array of string. +func (m MapClaims) ParseClaimsString(key string) (ClaimStrings, error) { + var cs []string + switch v := m[key].(type) { + case string: + cs = append(cs, v) + case []string: + cs = v + case []interface{}: + for _, a := range v { + vs, ok := a.(string) + if !ok { + return nil, ErrInvalidType + } + cs = append(cs, vs) } - - return verifyNbf(&newNumericDateFromSeconds(nbf).Time, cmpTime, req) - case json.Number: - v, _ := nbf.Float64() - - return verifyNbf(&newNumericDateFromSeconds(v).Time, cmpTime, req) } - return false -} - -// VerifyIssuer compares the iss claim against cmp. -// If required is false, this method will return true if the value matches or is unset -func (m MapClaims) VerifyIssuer(cmp string, req bool) bool { - iss, _ := m["iss"].(string) - return verifyIss(iss, cmp, req) + return cs, nil } -// Valid validates time based claims "exp, iat, nbf". -// There is no accounting for clock skew. -// As well, if any of the above claims are not in the token, it will still -// be considered a valid claim. -func (m MapClaims) Valid() error { - vErr := new(ValidationError) - now := TimeFunc().Unix() - - if !m.VerifyExpiresAt(now, false) { - // TODO(oxisto): this should be replaced with ErrTokenExpired - vErr.Inner = errors.New("Token is expired") - vErr.Errors |= ValidationErrorExpired - } - - if !m.VerifyIssuedAt(now, false) { - // TODO(oxisto): this should be replaced with ErrTokenUsedBeforeIssued - vErr.Inner = errors.New("Token used before issued") - vErr.Errors |= ValidationErrorIssuedAt - } - - if !m.VerifyNotBefore(now, false) { - // TODO(oxisto): this should be replaced with ErrTokenNotValidYet - vErr.Inner = errors.New("Token is not valid yet") - vErr.Errors |= ValidationErrorNotValidYet +// ParseString tries to parse a key in the map claims type as a [string] type. +// If the key does not exist, an empty string is returned. If the key has the +// wrong type, an error is returned. +func (m MapClaims) ParseString(key string) (string, error) { + var ( + ok bool + raw interface{} + iss string + ) + raw, ok = m[key] + if !ok { + return "", nil } - if vErr.valid() { - return nil + iss, ok = raw.(string) + if !ok { + return "", ErrInvalidType } - return vErr + return iss, nil } diff --git a/map_claims_test.go b/map_claims_test.go index 361c49d2..5c3a5c18 100644 --- a/map_claims_test.go +++ b/map_claims_test.go @@ -42,7 +42,7 @@ func TestVerifyAud(t *testing.T) { {Name: "[]String Aud without match not required", MapClaims: MapClaims{"aud": []string{"not.example.com", "example.example.com"}}, Expected: false, Required: true, Comparison: "example.com"}, // Required = false - {Name: "Empty []String Aud without match required", MapClaims: MapClaims{"aud": []string{""}}, Expected: false, Required: true, Comparison: "example.com"}, + {Name: "Empty []String Aud without match required", MapClaims: MapClaims{"aud": []string{""}}, Expected: true, Required: false, Comparison: "example.com"}, // []interface{} {Name: "Empty []interface{} Aud without match required", MapClaims: MapClaims{"aud": nilListInterface}, Expected: true, Required: false, Comparison: "example.com"}, @@ -56,10 +56,17 @@ func TestVerifyAud(t *testing.T) { for _, test := range tests { t.Run(test.Name, func(t *testing.T) { - got := test.MapClaims.VerifyAudience(test.Comparison, test.Required) + var opts []ParserOption + + if test.Required { + opts = append(opts, WithAudience(test.Comparison)) + } + + validator := newValidator(opts...) + got := validator.Validate(test.MapClaims) - if got != test.Expected { - t.Errorf("Expected %v, got %v", test.Expected, got) + if (got == nil) != test.Expected { + t.Errorf("Expected %v, got %v", test.Expected, (got == nil)) } }) } @@ -70,9 +77,9 @@ func TestMapclaimsVerifyIssuedAtInvalidTypeString(t *testing.T) { "iat": "foo", } want := false - got := mapClaims.VerifyIssuedAt(0, false) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + got := newValidator(WithIssuedAt()).Validate(mapClaims) + if want != (got == nil) { + t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil)) } } @@ -81,9 +88,9 @@ func TestMapclaimsVerifyNotBeforeInvalidTypeString(t *testing.T) { "nbf": "foo", } want := false - got := mapClaims.VerifyNotBefore(0, false) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + got := newValidator().Validate(mapClaims) + if want != (got == nil) { + t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil)) } } @@ -92,32 +99,38 @@ func TestMapclaimsVerifyExpiresAtInvalidTypeString(t *testing.T) { "exp": "foo", } want := false - got := mapClaims.VerifyExpiresAt(0, false) + got := newValidator().Validate(mapClaims) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + if want != (got == nil) { + t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil)) } } func TestMapClaimsVerifyExpiresAtExpire(t *testing.T) { - exp := time.Now().Unix() + exp := time.Now() mapClaims := MapClaims{ - "exp": float64(exp), + "exp": float64(exp.Unix()), } want := false - got := mapClaims.VerifyExpiresAt(exp, true) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + got := newValidator(WithTimeFunc(func() time.Time { + return exp + })).Validate(mapClaims) + if want != (got == nil) { + t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil)) } - got = mapClaims.VerifyExpiresAt(exp+1, true) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + got = newValidator(WithTimeFunc(func() time.Time { + return exp.Add(1 * time.Second) + })).Validate(mapClaims) + if want != (got == nil) { + t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil)) } want = true - got = mapClaims.VerifyExpiresAt(exp-1, true) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + got = newValidator(WithTimeFunc(func() time.Time { + return exp.Add(-1 * time.Second) + })).Validate(mapClaims) + if want != (got == nil) { + t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil)) } } diff --git a/parser.go b/parser.go index 2f61a69d..461b934e 100644 --- a/parser.go +++ b/parser.go @@ -9,26 +9,24 @@ import ( type Parser struct { // If populated, only these methods will be considered valid. - // - // Deprecated: In future releases, this field will not be exported anymore and should be set with an option to NewParser instead. - ValidMethods []string + validMethods []string // Use JSON Number format in JSON decoder. - // - // Deprecated: In future releases, this field will not be exported anymore and should be set with an option to NewParser instead. - UseJSONNumber bool + useJSONNumber bool // Skip claims validation during token parsing. - // - // Deprecated: In future releases, this field will not be exported anymore and should be set with an option to NewParser instead. - SkipClaimsValidation bool + skipClaimsValidation bool + + validator *validator } // NewParser creates a new Parser with the specified options func NewParser(options ...ParserOption) *Parser { - p := &Parser{} + p := &Parser{ + validator: &validator{}, + } - // loop through our parsing options and apply them + // Loop through our parsing options and apply them for _, option := range options { option(p) } @@ -49,10 +47,10 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf } // Verify signing method is in the required set - if p.ValidMethods != nil { + if p.validMethods != nil { var signingMethodValid = false var alg = token.Method.Alg() - for _, m := range p.ValidMethods { + for _, m := range p.validMethods { if m == alg { signingMethodValid = true break @@ -81,9 +79,13 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf vErr := &ValidationError{} // Validate Claims - if !p.SkipClaimsValidation { - if err := token.Claims.Valid(); err != nil { + if !p.skipClaimsValidation { + // Make sure we have at least a default validator + if p.validator == nil { + p.validator = newValidator() + } + if err := p.validator.Validate(claims); err != nil { // If the Claims Valid returned an error, check if it is a validation error, // If it was another error type, create a ValidationError with a generic ClaimsInvalid flag set if e, ok := err.(*ValidationError); !ok { @@ -143,7 +145,7 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} } dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) - if p.UseJSONNumber { + if p.useJSONNumber { dec.UseNumber() } // JSON Decode. Special case for map type to avoid weird pointer behavior diff --git a/parser_option.go b/parser_option.go index 6ea6f952..0442cdcd 100644 --- a/parser_option.go +++ b/parser_option.go @@ -1,5 +1,7 @@ package jwt +import "time" + // ParserOption is used to implement functional-style options that modify the behavior of the parser. To add // new options, just create a function (ideally beginning with With or Without) that returns an anonymous function that // takes a *Parser type as input and manipulates its configuration accordingly. @@ -9,14 +11,14 @@ type ParserOption func(*Parser) // It is heavily encouraged to use this option in order to prevent attacks such as https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/. func WithValidMethods(methods []string) ParserOption { return func(p *Parser) { - p.ValidMethods = methods + p.validMethods = methods } } // WithJSONNumber is an option to configure the underlying JSON parser with UseNumber func WithJSONNumber() ParserOption { return func(p *Parser) { - p.UseJSONNumber = true + p.useJSONNumber = true } } @@ -24,6 +26,69 @@ func WithJSONNumber() ParserOption { // what you are doing. func WithoutClaimsValidation() ParserOption { return func(p *Parser) { - p.SkipClaimsValidation = true + p.skipClaimsValidation = true + } +} + +// WithLeeway returns the ParserOption for specifying the leeway window. +func WithLeeway(leeway time.Duration) ParserOption { + return func(p *Parser) { + p.validator.leeway = leeway + } +} + +// WithTimeFunc returns the ParserOption for specifying the time func. The +// primary use-case for this is testing. If you are looking for a way to account +// for clock-skew, WithLeeway should be used instead. +func WithTimeFunc(f func() time.Time) ParserOption { + return func(p *Parser) { + p.validator.timeFunc = f + } +} + +// WithIssuedAt returns the ParserOption to enable verification +// of issued-at. +func WithIssuedAt() ParserOption { + return func(p *Parser) { + p.validator.verifyIat = true + } +} + +// WithAudience configures the validator to require the specified audience in +// the `aud` claim. Validation will fail if the audience is not listed in the +// token or the `aud` claim is missing. +// +// NOTE: While the `aud` claim is OPTIONAL is a JWT, the handling of it is +// application-specific. Since this validation API is helping developers in +// writing secure application, we decided to REQUIRE the existence of the claim. +func WithAudience(aud string) ParserOption { + return func(p *Parser) { + p.validator.expectedAud = aud + } +} + +// WithIssuer configures the validator to require the specified issuer in the +// `iss` claim. Validation will fail if a different issuer is specified in the +// token or the `iss` claim is missing. +// +// NOTE: While the `iss` claim is OPTIONAL is a JWT, the handling of it is +// application-specific. Since this validation API is helping developers in +// writing secure application, we decided to REQUIRE the existence of the claim. +func WithIssuer(iss string) ParserOption { + return func(p *Parser) { + p.validator.expectedIss = iss + } +} + +// WithSubject configures the validator to require the specified subject in the +// `sub` claim. Validation will fail if a different subject is specified in the +// token or the `sub` claim is missing. +// +// NOTE: While the `sub` claim is OPTIONAL is a JWT, the handling of it is +// application-specific. Since this validation API is helping developers in +// writing secure application, we decided to REQUIRE the existence of the claim. +func WithSubject(sub string) ParserOption { + return func(p *Parser) { + p.validator.expectedSub = sub } } diff --git a/parser_test.go b/parser_test.go index 9b09b164..462dd170 100644 --- a/parser_test.go +++ b/parser_test.go @@ -152,7 +152,7 @@ var jwtTestData = []struct { false, jwt.ValidationErrorSignatureInvalid, []error{jwt.ErrTokenSignatureInvalid}, - &jwt.Parser{ValidMethods: []string{"HS256"}}, + jwt.NewParser(jwt.WithValidMethods([]string{"HS256"})), jwt.SigningMethodRS256, }, { @@ -163,7 +163,7 @@ var jwtTestData = []struct { true, 0, nil, - &jwt.Parser{ValidMethods: []string{"RS256", "HS256"}}, + jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})), jwt.SigningMethodRS256, }, { @@ -174,7 +174,7 @@ var jwtTestData = []struct { false, jwt.ValidationErrorSignatureInvalid, []error{jwt.ErrTokenSignatureInvalid}, - &jwt.Parser{ValidMethods: []string{"RS256", "HS256"}}, + jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})), jwt.SigningMethodES256, }, { @@ -185,7 +185,7 @@ var jwtTestData = []struct { true, 0, nil, - &jwt.Parser{ValidMethods: []string{"HS256", "ES256"}}, + jwt.NewParser(jwt.WithValidMethods([]string{"HS256", "ES256"})), jwt.SigningMethodES256, }, { @@ -196,7 +196,7 @@ var jwtTestData = []struct { true, 0, nil, - &jwt.Parser{UseJSONNumber: true}, + jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, }, { @@ -207,7 +207,7 @@ var jwtTestData = []struct { false, jwt.ValidationErrorExpired, []error{jwt.ErrTokenExpired}, - &jwt.Parser{UseJSONNumber: true}, + jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, }, { @@ -218,7 +218,7 @@ var jwtTestData = []struct { false, jwt.ValidationErrorNotValidYet, []error{jwt.ErrTokenNotValidYet}, - &jwt.Parser{UseJSONNumber: true}, + jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, }, { @@ -229,7 +229,7 @@ var jwtTestData = []struct { false, jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired, []error{jwt.ErrTokenNotValidYet}, - &jwt.Parser{UseJSONNumber: true}, + jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, }, { @@ -240,7 +240,7 @@ var jwtTestData = []struct { true, 0, nil, - &jwt.Parser{UseJSONNumber: true, SkipClaimsValidation: true}, + jwt.NewParser(jwt.WithJSONNumber(), jwt.WithoutClaimsValidation()), jwt.SigningMethodRS256, }, { @@ -253,7 +253,7 @@ var jwtTestData = []struct { true, 0, nil, - &jwt.Parser{UseJSONNumber: true}, + jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, }, { @@ -266,7 +266,7 @@ var jwtTestData = []struct { true, 0, nil, - &jwt.Parser{UseJSONNumber: true}, + jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, }, { @@ -279,7 +279,7 @@ var jwtTestData = []struct { true, 0, nil, - &jwt.Parser{UseJSONNumber: true}, + jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, }, { @@ -292,7 +292,7 @@ var jwtTestData = []struct { false, jwt.ValidationErrorMalformed, []error{jwt.ErrTokenMalformed}, - &jwt.Parser{UseJSONNumber: true}, + jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, }, { @@ -305,7 +305,29 @@ var jwtTestData = []struct { false, jwt.ValidationErrorMalformed, []error{jwt.ErrTokenMalformed}, - &jwt.Parser{UseJSONNumber: true}, + jwt.NewParser(jwt.WithJSONNumber()), + jwt.SigningMethodRS256, + }, + { + "RFC7519 Claims - nbf with 60s skew", + "", // autogen + defaultKeyFunc, + &jwt.RegisteredClaims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Second * 100))}, + false, + jwt.ValidationErrorNotValidYet, + []error{jwt.ErrTokenNotValidYet}, + jwt.NewParser(jwt.WithLeeway(time.Minute)), + jwt.SigningMethodRS256, + }, + { + "RFC7519 Claims - nbf with 120s skew", + "", // autogen + defaultKeyFunc, + &jwt.RegisteredClaims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Second * 100))}, + true, + 0, + nil, + jwt.NewParser(jwt.WithLeeway(2 * time.Minute)), jwt.SigningMethodRS256, }, } @@ -341,7 +363,7 @@ func TestParser_Parse(t *testing.T) { var err error var parser = data.parser if parser == nil { - parser = new(jwt.Parser) + parser = jwt.NewParser() } // Figure out correct claims type switch data.claims.(type) { @@ -548,8 +570,7 @@ func TestSetPadding(t *testing.T) { // Parse the token var token *jwt.Token var err error - parser := new(jwt.Parser) - parser.SkipClaimsValidation = true + parser := jwt.NewParser(jwt.WithoutClaimsValidation()) // Figure out correct claims type token, err = parser.ParseWithClaims(data.tokenString, jwt.MapClaims{}, data.keyfunc) diff --git a/registered_claims.go b/registered_claims.go new file mode 100644 index 00000000..77951a53 --- /dev/null +++ b/registered_claims.go @@ -0,0 +1,63 @@ +package jwt + +// RegisteredClaims are a structured version of the JWT Claims Set, +// restricted to Registered Claim Names, as referenced at +// https://datatracker.ietf.org/doc/html/rfc7519#section-4.1 +// +// This type can be used on its own, but then additional private and +// public claims embedded in the JWT will not be parsed. The typical use-case +// therefore is to embedded this in a user-defined claim type. +// +// See examples for how to use this with your own claim types. +type RegisteredClaims struct { + // the `iss` (Issuer) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.1 + Issuer string `json:"iss,omitempty"` + + // the `sub` (Subject) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.2 + Subject string `json:"sub,omitempty"` + + // the `aud` (Audience) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3 + Audience ClaimStrings `json:"aud,omitempty"` + + // the `exp` (Expiration Time) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4 + ExpiresAt *NumericDate `json:"exp,omitempty"` + + // the `nbf` (Not Before) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.5 + NotBefore *NumericDate `json:"nbf,omitempty"` + + // the `iat` (Issued At) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.6 + IssuedAt *NumericDate `json:"iat,omitempty"` + + // the `jti` (JWT ID) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.7 + ID string `json:"jti,omitempty"` +} + +// GetExpirationTime implements the Claims interface. +func (c RegisteredClaims) GetExpirationTime() (*NumericDate, error) { + return c.ExpiresAt, nil +} + +// GetNotBefore implements the Claims interface. +func (c RegisteredClaims) GetNotBefore() (*NumericDate, error) { + return c.NotBefore, nil +} + +// GetIssuedAt implements the Claims interface. +func (c RegisteredClaims) GetIssuedAt() (*NumericDate, error) { + return c.IssuedAt, nil +} + +// GetAudience implements the Claims interface. +func (c RegisteredClaims) GetAudience() (ClaimStrings, error) { + return c.Audience, nil +} + +// GetIssuer implements the Claims interface. +func (c RegisteredClaims) GetIssuer() (string, error) { + return c.Issuer, nil +} + +// GetSubject implements the Claims interface. +func (c RegisteredClaims) GetSubject() (string, error) { + return c.Subject, nil +} diff --git a/token.go b/token.go index 3cb0f3f0..738eef0e 100644 --- a/token.go +++ b/token.go @@ -4,7 +4,6 @@ import ( "encoding/base64" "encoding/json" "strings" - "time" ) // DecodePaddingAllowed will switch the codec used for decoding JWTs respectively. Note that the JWS RFC7515 @@ -14,11 +13,6 @@ import ( // To use the non-recommended decoding, set this boolean to `true` prior to using this package. var DecodePaddingAllowed bool -// TimeFunc provides the current time when parsing token to validate "exp" claim (expiration time). -// You can override it to use another time value. This is useful for testing or if your -// server uses a different time zone than your tokens. -var TimeFunc = time.Now - // Keyfunc will be used by the Parse methods as a callback function to supply // the key for verification. The function receives the parsed, // but unverified Token. This allows you to use properties in the diff --git a/validator.go b/validator.go new file mode 100644 index 00000000..3e512f67 --- /dev/null +++ b/validator.go @@ -0,0 +1,264 @@ +package jwt + +import ( + "crypto/subtle" + "time" +) + +// validator is the core of the new Validation API. It is automatically used by +// a [Parser] during parsing and can be modified with various parser options. +// +// Note: This struct is intentionally not exported (yet) as we want to +// internally finalize its API. In the future, we might make it publicly available. +type validator struct { + // leeway is an optional leeway that can be provided to account for clock skew. + leeway time.Duration + + // timeFunc is used to supply the current time that is needed for + // validation. If unspecified, this defaults to time.Now. + timeFunc func() time.Time + + // verifyIat specifies whether the iat (Issued At) claim will be verified. + // According to https://www.rfc-editor.org/rfc/rfc7519#section-4.1.6 this + // only specifies the age of the token, but no validation check is + // necessary. However, if wanted, it can be checked if the iat is + // unrealistic, i.e., in the future. + verifyIat bool + + // expectedAud contains the audience this token expects. Supplying an empty + // string will disable aud checking. + expectedAud string + + // expectedIss contains the issuer this token expects. Supplying an empty + // string will disable iss checking. + expectedIss string + + // expectedSub contains the subject this token expects. Supplying an empty + // string will disable sub checking. + expectedSub string +} + +// CustomClaims represents a custom claims interface, which can be built upon the integrated +// claim types, such as map claims or registered claims. +type CustomClaims interface { + // CustomValidation can be implemented by a user-specific claim to support + // additional validation steps in addition to the regular validation. + CustomValidation() error +} + +// newValidator can be used to create a stand-alone validator with the supplied +// options. This validator can then be used to validate already parsed claims. +func newValidator(opts ...ParserOption) *validator { + p := NewParser(opts...) + return p.validator +} + +// Validate validates the given claims. It will also perform any custom validation if claims implements the CustomValidator interface. +func (v *validator) Validate(claims Claims) error { + var now time.Time + vErr := new(ValidationError) + + // Check, if we have a time func + if v.timeFunc != nil { + now = v.timeFunc() + } else { + now = time.Now() + } + + // We always need to check the expiration time, but the claim itself is OPTIONAL + if !v.VerifyExpiresAt(claims, now, false) { + vErr.Inner = ErrTokenExpired + vErr.Errors |= ValidationErrorExpired + } + + // We always need to check not-before, but the claim itself is OPTIONAL + if !v.VerifyNotBefore(claims, now, false) { + vErr.Inner = ErrTokenNotValidYet + vErr.Errors |= ValidationErrorNotValidYet + } + + // Check issued-at if the option is enabled + if v.verifyIat && !v.VerifyIssuedAt(claims, now, false) { + vErr.Inner = ErrTokenUsedBeforeIssued + vErr.Errors |= ValidationErrorIssuedAt + } + + // If we have an expected audience, we also require the audience claim + if v.expectedAud != "" && !v.VerifyAudience(claims, v.expectedAud, true) { + vErr.Inner = ErrTokenInvalidAudience + vErr.Errors |= ValidationErrorAudience + } + + // If we have an expected issuer, we also require the issuer claim + if v.expectedIss != "" && !v.VerifyIssuer(claims, v.expectedIss, true) { + vErr.Inner = ErrTokenInvalidIssuer + vErr.Errors |= ValidationErrorIssuer + } + + // If we have an expected subject, we also require the subject claim + if v.expectedSub != "" && !v.VerifySubject(claims, v.expectedSub, true) { + vErr.Inner = ErrTokenInvalidSubject + vErr.Errors |= ValidationErrorSubject + } + + // Finally, we want to give the claim itself some possibility to do some + // additional custom validation based on their custom claims + cvt, ok := claims.(CustomClaims) + if ok { + if err := cvt.CustomValidation(); err != nil { + vErr.Inner = err + vErr.Errors |= ValidationErrorClaimsInvalid + } + } + + if vErr.valid() { + return nil + } + + return vErr +} + +// VerifyAudience compares the aud claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (v *validator) VerifyAudience(claims Claims, cmp string, req bool) bool { + aud, err := claims.GetAudience() + if err != nil { + return false + } + + return verifyAud(aud, cmp, req) +} + +// VerifyExpiresAt compares the exp claim against cmp (cmp < exp). +// If req is false, it will return true, if exp is unset. +func (v *validator) VerifyExpiresAt(claims Claims, cmp time.Time, req bool) bool { + var time *time.Time = nil + + exp, err := claims.GetExpirationTime() + if err != nil { + return false + } else if exp != nil { + time = &exp.Time + } + + return verifyExp(time, cmp, req, v.leeway) +} + +// VerifyIssuedAt compares the iat claim against cmp (cmp >= iat). +// If req is false, it will return true, if iat is unset. +func (v *validator) VerifyIssuedAt(claims Claims, cmp time.Time, req bool) bool { + var time *time.Time = nil + + iat, err := claims.GetIssuedAt() + if err != nil { + return false + } else if iat != nil { + time = &iat.Time + } + + return verifyIat(time, cmp, req, v.leeway) +} + +// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf). +// If req is false, it will return true, if nbf is unset. +func (v *validator) VerifyNotBefore(claims Claims, cmp time.Time, req bool) bool { + var time *time.Time = nil + + nbf, err := claims.GetNotBefore() + if err != nil { + return false + } else if nbf != nil { + time = &nbf.Time + } + + return verifyNbf(time, cmp, req, v.leeway) +} + +// VerifyIssuer compares the iss claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (v *validator) VerifyIssuer(claims Claims, cmp string, req bool) bool { + iss, err := claims.GetIssuer() + if err != nil { + return false + } + + return verifyIss(iss, cmp, req) +} + +// VerifySubject compares the sub claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (v *validator) VerifySubject(claims Claims, cmp string, req bool) bool { + iss, err := claims.GetSubject() + if err != nil { + return false + } + + return verifySub(iss, cmp, req) +} + +// ----- helpers + +func verifyAud(aud []string, cmp string, required bool) bool { + if len(aud) == 0 { + return !required + } + // use a var here to keep constant time compare when looping over a number of claims + result := false + + var stringClaims string + for _, a := range aud { + if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 { + result = true + } + stringClaims = stringClaims + a + } + + // case where "" is sent in one or many aud claims + if stringClaims == "" { + return !required + } + + return result +} + +func verifyExp(exp *time.Time, now time.Time, required bool, skew time.Duration) bool { + if exp == nil { + return !required + } + + return now.Before((*exp).Add(+skew)) +} + +func verifyIat(iat *time.Time, now time.Time, required bool, skew time.Duration) bool { + if iat == nil { + return !required + } + + t := iat.Add(-skew) + return !now.Before(t) +} + +func verifyNbf(nbf *time.Time, now time.Time, required bool, skew time.Duration) bool { + if nbf == nil { + return !required + } + + t := nbf.Add(-skew) + return !now.Before(t) +} + +func verifyIss(iss string, cmp string, required bool) bool { + if iss == "" { + return !required + } + + return iss == cmp +} + +func verifySub(sub string, cmp string, required bool) bool { + if sub == "" { + return !required + } + + return sub == cmp +}