From fe8e1434c6aef507930e9c0e6a7f2a964fe77c2d Mon Sep 17 00:00:00 2001 From: Jaco Esterhuizen Date: Tue, 11 Jul 2017 13:04:25 +0200 Subject: [PATCH 1/3] auth/jwt: MapClaims: passing add claimsFactory type make NewParser take a claimsFactory instead of an instance of jwt.Claims use claimsFactory to create a jwt.Claims to pass in to jwt.ParseWithClaims update NewParser calls in tests to take a claimsFactory instead of a jwt.Claims instance --- auth/jwt/middleware.go | 6 ++++-- auth/jwt/middleware_test.go | 18 +++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/auth/jwt/middleware.go b/auth/jwt/middleware.go index e7dcb9d6a..b03b9aa94 100644 --- a/auth/jwt/middleware.go +++ b/auth/jwt/middleware.go @@ -66,11 +66,13 @@ func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims jwt.Clai } } +type claimsFactory func() jwt.Claims + // NewParser creates a new JWT token parsing middleware, specifying a // jwt.Keyfunc interface, the signing method and the claims type to be used. NewParser // adds the resulting claims to endpoint context or returns error on invalid token. // Particularly useful for servers. -func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware { +func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, newClaims claimsFactory) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { // tokenString is stored in the context from the transport handlers. @@ -85,7 +87,7 @@ func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims) // of the token to identify which key to use, but the parsed token // (head and claims) is provided to the callback, providing // flexibility. - token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(tokenString, newClaims(), func(token *jwt.Token) (interface{}, error) { // Don't forget to validate the alg is what you expect: if token.Method != method { return nil, ErrUnexpectedSigningMethod diff --git a/auth/jwt/middleware_test.go b/auth/jwt/middleware_test.go index efaaf2d14..977bef6b9 100644 --- a/auth/jwt/middleware_test.go +++ b/auth/jwt/middleware_test.go @@ -74,7 +74,7 @@ func TestJWTParser(t *testing.T) { return key, nil } - parser := NewParser(keys, method, jwt.MapClaims{})(e) + parser := NewParser(keys, method, func() jwt.Claims { return jwt.MapClaims{} })(e) // No Token is passed into the parser _, err := parser(context.Background(), struct{}{}) @@ -94,7 +94,7 @@ func TestJWTParser(t *testing.T) { } // Invalid Method is used in the parser - badParser := NewParser(keys, invalidMethod, jwt.MapClaims{})(e) + badParser := NewParser(keys, invalidMethod, func() jwt.Claims { return jwt.MapClaims{} })(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) _, err = badParser(ctx, struct{}{}) if err == nil { @@ -110,7 +110,7 @@ func TestJWTParser(t *testing.T) { return []byte("bad"), nil } - badParser = NewParser(invalidKeys, method, jwt.MapClaims{})(e) + badParser = NewParser(invalidKeys, method, func() jwt.Claims { return jwt.MapClaims{} })(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) _, err = badParser(ctx, struct{}{}) if err == nil { @@ -134,7 +134,7 @@ func TestJWTParser(t *testing.T) { } // Test for malformed token error response - parser = NewParser(keys, method, &jwt.StandardClaims{})(e) + parser = NewParser(keys, method, func() jwt.Claims { return &jwt.StandardClaims{} })(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, malformedKey) ctx1, err = parser(ctx, struct{}{}) if want, have := ErrTokenMalformed, err; want != have { @@ -142,7 +142,7 @@ func TestJWTParser(t *testing.T) { } // Test for expired token error response - parser = NewParser(keys, method, &jwt.StandardClaims{})(e) + parser = NewParser(keys, method, func() jwt.Claims { return &jwt.StandardClaims{} })(e) expired := jwt.NewWithClaims(method, jwt.StandardClaims{ExpiresAt: time.Now().Unix() - 100}) token, err := expired.SignedString(key) if err != nil { @@ -155,7 +155,7 @@ func TestJWTParser(t *testing.T) { } // Test for not activated token error response - parser = NewParser(keys, method, &jwt.StandardClaims{})(e) + parser = NewParser(keys, method, func() jwt.Claims { return &jwt.StandardClaims{} })(e) notactive := jwt.NewWithClaims(method, jwt.StandardClaims{NotBefore: time.Now().Unix() + 100}) token, err = notactive.SignedString(key) if err != nil { @@ -168,7 +168,7 @@ func TestJWTParser(t *testing.T) { } // test valid standard claims token - parser = NewParser(keys, method, &jwt.StandardClaims{})(e) + parser = NewParser(keys, method, func() jwt.Claims { return &jwt.StandardClaims{} })(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey) ctx1, err = parser(ctx, struct{}{}) if err != nil { @@ -183,7 +183,7 @@ func TestJWTParser(t *testing.T) { } // test valid customized claims token - parser = NewParser(keys, method, &customClaims{})(e) + parser = NewParser(keys, method, func() jwt.Claims { return &customClaims{} })(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, customSignedKey) ctx1, err = parser(ctx, struct{}{}) if err != nil { @@ -204,7 +204,7 @@ func TestJWTParser(t *testing.T) { func TestIssue562(t *testing.T) { var ( kf = func(token *jwt.Token) (interface{}, error) { return []byte("secret"), nil } - e = NewParser(kf, jwt.SigningMethodHS256, jwt.MapClaims{})(endpoint.Nop) + e = NewParser(kf, jwt.SigningMethodHS256, func() jwt.Claims { return jwt.MapClaims{} })(endpoint.Nop) key = JWTTokenContextKey val = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E" ctx = context.WithValue(context.Background(), key, val) From 44bb40480b7b607b18b0bf15914b2e289e45fae9 Mon Sep 17 00:00:00 2001 From: Jaco Esterhuizen Date: Wed, 12 Jul 2017 11:34:30 +0200 Subject: [PATCH 2/3] auth/jwt: MapClaims: export ClaimsFactory and provide implementations for Map and Standard claims factories --- auth/jwt/middleware.go | 16 ++++++++++++++-- auth/jwt/middleware_test.go | 16 ++++++++-------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/auth/jwt/middleware.go b/auth/jwt/middleware.go index b03b9aa94..ff44ed763 100644 --- a/auth/jwt/middleware.go +++ b/auth/jwt/middleware.go @@ -66,13 +66,25 @@ func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims jwt.Clai } } -type claimsFactory func() jwt.Claims +type ClaimsFactory func() jwt.Claims + +// MapClaimsFactory is a ClaimsFactory that returns +// an empty jwt.MapClaims. +func MapClaimsFactory() jwt.Claims { + return jwt.MapClaims{} +} + +// StandardClaimsFactory is a ClaimsFactory that returns +// an empty jwt.StandardClaims. +func StandardClaimsFactory() jwt.Claims { + return &jwt.StandardClaims{} +} // NewParser creates a new JWT token parsing middleware, specifying a // jwt.Keyfunc interface, the signing method and the claims type to be used. NewParser // adds the resulting claims to endpoint context or returns error on invalid token. // Particularly useful for servers. -func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, newClaims claimsFactory) endpoint.Middleware { +func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, newClaims ClaimsFactory) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { // tokenString is stored in the context from the transport handlers. diff --git a/auth/jwt/middleware_test.go b/auth/jwt/middleware_test.go index 977bef6b9..3278e13a7 100644 --- a/auth/jwt/middleware_test.go +++ b/auth/jwt/middleware_test.go @@ -74,7 +74,7 @@ func TestJWTParser(t *testing.T) { return key, nil } - parser := NewParser(keys, method, func() jwt.Claims { return jwt.MapClaims{} })(e) + parser := NewParser(keys, method, MapClaimsFactory)(e) // No Token is passed into the parser _, err := parser(context.Background(), struct{}{}) @@ -94,7 +94,7 @@ func TestJWTParser(t *testing.T) { } // Invalid Method is used in the parser - badParser := NewParser(keys, invalidMethod, func() jwt.Claims { return jwt.MapClaims{} })(e) + badParser := NewParser(keys, invalidMethod, MapClaimsFactory)(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) _, err = badParser(ctx, struct{}{}) if err == nil { @@ -110,7 +110,7 @@ func TestJWTParser(t *testing.T) { return []byte("bad"), nil } - badParser = NewParser(invalidKeys, method, func() jwt.Claims { return jwt.MapClaims{} })(e) + badParser = NewParser(invalidKeys, method, MapClaimsFactory)(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) _, err = badParser(ctx, struct{}{}) if err == nil { @@ -134,7 +134,7 @@ func TestJWTParser(t *testing.T) { } // Test for malformed token error response - parser = NewParser(keys, method, func() jwt.Claims { return &jwt.StandardClaims{} })(e) + parser = NewParser(keys, method, StandardClaimsFactory)(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, malformedKey) ctx1, err = parser(ctx, struct{}{}) if want, have := ErrTokenMalformed, err; want != have { @@ -142,7 +142,7 @@ func TestJWTParser(t *testing.T) { } // Test for expired token error response - parser = NewParser(keys, method, func() jwt.Claims { return &jwt.StandardClaims{} })(e) + parser = NewParser(keys, method, StandardClaimsFactory)(e) expired := jwt.NewWithClaims(method, jwt.StandardClaims{ExpiresAt: time.Now().Unix() - 100}) token, err := expired.SignedString(key) if err != nil { @@ -155,7 +155,7 @@ func TestJWTParser(t *testing.T) { } // Test for not activated token error response - parser = NewParser(keys, method, func() jwt.Claims { return &jwt.StandardClaims{} })(e) + parser = NewParser(keys, method, StandardClaimsFactory)(e) notactive := jwt.NewWithClaims(method, jwt.StandardClaims{NotBefore: time.Now().Unix() + 100}) token, err = notactive.SignedString(key) if err != nil { @@ -168,7 +168,7 @@ func TestJWTParser(t *testing.T) { } // test valid standard claims token - parser = NewParser(keys, method, func() jwt.Claims { return &jwt.StandardClaims{} })(e) + parser = NewParser(keys, method, StandardClaimsFactory)(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey) ctx1, err = parser(ctx, struct{}{}) if err != nil { @@ -204,7 +204,7 @@ func TestJWTParser(t *testing.T) { func TestIssue562(t *testing.T) { var ( kf = func(token *jwt.Token) (interface{}, error) { return []byte("secret"), nil } - e = NewParser(kf, jwt.SigningMethodHS256, func() jwt.Claims { return jwt.MapClaims{} })(endpoint.Nop) + e = NewParser(kf, jwt.SigningMethodHS256, MapClaimsFactory)(endpoint.Nop) key = JWTTokenContextKey val = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E" ctx = context.WithValue(context.Background(), key, val) From 37eab0a86bd1eadaa58c1d8737d779d470706d8d Mon Sep 17 00:00:00 2001 From: Jaco Esterhuizen Date: Thu, 13 Jul 2017 10:47:21 +0200 Subject: [PATCH 3/3] auth/jwt: MapClaims: add doc comment for ClaimsFactory type --- auth/jwt/middleware.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/auth/jwt/middleware.go b/auth/jwt/middleware.go index ff44ed763..c07d6e97e 100644 --- a/auth/jwt/middleware.go +++ b/auth/jwt/middleware.go @@ -66,6 +66,8 @@ func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims jwt.Clai } } +// ClaimsFactory is a factory for jwt.Claims. +// Useful in NewParser middleware. type ClaimsFactory func() jwt.Claims // MapClaimsFactory is a ClaimsFactory that returns