diff --git a/auth/jwt/middleware.go b/auth/jwt/middleware.go index e7dcb9d6a..c07d6e97e 100644 --- a/auth/jwt/middleware.go +++ b/auth/jwt/middleware.go @@ -66,11 +66,27 @@ func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims jwt.Clai } } +// ClaimsFactory is a factory for jwt.Claims. +// Useful in NewParser middleware. +type ClaimsFactory func() jwt.Claims + +// MapClaimsFactory is a ClaimsFactory that returns +// an empty jwt.MapClaims. +func MapClaimsFactory() jwt.Claims { + return jwt.MapClaims{} +} + +// StandardClaimsFactory is a ClaimsFactory that returns +// an empty jwt.StandardClaims. +func StandardClaimsFactory() jwt.Claims { + return &jwt.StandardClaims{} +} + // NewParser creates a new JWT token parsing middleware, specifying a // jwt.Keyfunc interface, the signing method and the claims type to be used. NewParser // adds the resulting claims to endpoint context or returns error on invalid token. // Particularly useful for servers. -func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware { +func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, newClaims ClaimsFactory) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { // tokenString is stored in the context from the transport handlers. @@ -85,7 +101,7 @@ func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims) // of the token to identify which key to use, but the parsed token // (head and claims) is provided to the callback, providing // flexibility. - token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(tokenString, newClaims(), func(token *jwt.Token) (interface{}, error) { // Don't forget to validate the alg is what you expect: if token.Method != method { return nil, ErrUnexpectedSigningMethod diff --git a/auth/jwt/middleware_test.go b/auth/jwt/middleware_test.go index efaaf2d14..3278e13a7 100644 --- a/auth/jwt/middleware_test.go +++ b/auth/jwt/middleware_test.go @@ -74,7 +74,7 @@ func TestJWTParser(t *testing.T) { return key, nil } - parser := NewParser(keys, method, jwt.MapClaims{})(e) + parser := NewParser(keys, method, MapClaimsFactory)(e) // No Token is passed into the parser _, err := parser(context.Background(), struct{}{}) @@ -94,7 +94,7 @@ func TestJWTParser(t *testing.T) { } // Invalid Method is used in the parser - badParser := NewParser(keys, invalidMethod, jwt.MapClaims{})(e) + badParser := NewParser(keys, invalidMethod, MapClaimsFactory)(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) _, err = badParser(ctx, struct{}{}) if err == nil { @@ -110,7 +110,7 @@ func TestJWTParser(t *testing.T) { return []byte("bad"), nil } - badParser = NewParser(invalidKeys, method, jwt.MapClaims{})(e) + badParser = NewParser(invalidKeys, method, MapClaimsFactory)(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) _, err = badParser(ctx, struct{}{}) if err == nil { @@ -134,7 +134,7 @@ func TestJWTParser(t *testing.T) { } // Test for malformed token error response - parser = NewParser(keys, method, &jwt.StandardClaims{})(e) + parser = NewParser(keys, method, StandardClaimsFactory)(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, malformedKey) ctx1, err = parser(ctx, struct{}{}) if want, have := ErrTokenMalformed, err; want != have { @@ -142,7 +142,7 @@ func TestJWTParser(t *testing.T) { } // Test for expired token error response - parser = NewParser(keys, method, &jwt.StandardClaims{})(e) + parser = NewParser(keys, method, StandardClaimsFactory)(e) expired := jwt.NewWithClaims(method, jwt.StandardClaims{ExpiresAt: time.Now().Unix() - 100}) token, err := expired.SignedString(key) if err != nil { @@ -155,7 +155,7 @@ func TestJWTParser(t *testing.T) { } // Test for not activated token error response - parser = NewParser(keys, method, &jwt.StandardClaims{})(e) + parser = NewParser(keys, method, StandardClaimsFactory)(e) notactive := jwt.NewWithClaims(method, jwt.StandardClaims{NotBefore: time.Now().Unix() + 100}) token, err = notactive.SignedString(key) if err != nil { @@ -168,7 +168,7 @@ func TestJWTParser(t *testing.T) { } // test valid standard claims token - parser = NewParser(keys, method, &jwt.StandardClaims{})(e) + parser = NewParser(keys, method, StandardClaimsFactory)(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey) ctx1, err = parser(ctx, struct{}{}) if err != nil { @@ -183,7 +183,7 @@ func TestJWTParser(t *testing.T) { } // test valid customized claims token - parser = NewParser(keys, method, &customClaims{})(e) + parser = NewParser(keys, method, func() jwt.Claims { return &customClaims{} })(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, customSignedKey) ctx1, err = parser(ctx, struct{}{}) if err != nil { @@ -204,7 +204,7 @@ func TestJWTParser(t *testing.T) { func TestIssue562(t *testing.T) { var ( kf = func(token *jwt.Token) (interface{}, error) { return []byte("secret"), nil } - e = NewParser(kf, jwt.SigningMethodHS256, jwt.MapClaims{})(endpoint.Nop) + e = NewParser(kf, jwt.SigningMethodHS256, MapClaimsFactory)(endpoint.Nop) key = JWTTokenContextKey val = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E" ctx = context.WithValue(context.Background(), key, val)