Skip to content

Commit

Permalink
Merge pull request #564 from go-kit/issue-562
Browse files Browse the repository at this point in the history
auth/jwt: prevent concurrent reads and writes on MapClaims
  • Loading branch information
peterbourgon authored Jul 17, 2017
2 parents 9813199 + b42a850 commit 0d70b13
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 10 deletions.
20 changes: 18 additions & 2 deletions auth/jwt/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,27 @@ func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims jwt.Clai
}
}

// ClaimsFactory is a factory for jwt.Claims.
// Useful in NewParser middleware.
type ClaimsFactory func() jwt.Claims

// MapClaimsFactory is a ClaimsFactory that returns
// an empty jwt.MapClaims.
func MapClaimsFactory() jwt.Claims {
return jwt.MapClaims{}
}

// StandardClaimsFactory is a ClaimsFactory that returns
// an empty jwt.StandardClaims.
func StandardClaimsFactory() jwt.Claims {
return &jwt.StandardClaims{}
}

// NewParser creates a new JWT token parsing middleware, specifying a
// jwt.Keyfunc interface, the signing method and the claims type to be used. NewParser
// adds the resulting claims to endpoint context or returns error on invalid token.
// Particularly useful for servers.
func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware {
func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, newClaims ClaimsFactory) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
// tokenString is stored in the context from the transport handlers.
Expand All @@ -85,7 +101,7 @@ func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims)
// of the token to identify which key to use, but the parsed token
// (head and claims) is provided to the callback, providing
// flexibility.
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, newClaims(), func(token *jwt.Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if token.Method != method {
return nil, ErrUnexpectedSigningMethod
Expand Down
36 changes: 28 additions & 8 deletions auth/jwt/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package jwt

import (
"context"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -73,7 +74,7 @@ func TestJWTParser(t *testing.T) {
return key, nil
}

parser := NewParser(keys, method, jwt.MapClaims{})(e)
parser := NewParser(keys, method, MapClaimsFactory)(e)

// No Token is passed into the parser
_, err := parser(context.Background(), struct{}{})
Expand All @@ -93,7 +94,7 @@ func TestJWTParser(t *testing.T) {
}

// Invalid Method is used in the parser
badParser := NewParser(keys, invalidMethod, jwt.MapClaims{})(e)
badParser := NewParser(keys, invalidMethod, MapClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
_, err = badParser(ctx, struct{}{})
if err == nil {
Expand All @@ -109,7 +110,7 @@ func TestJWTParser(t *testing.T) {
return []byte("bad"), nil
}

badParser = NewParser(invalidKeys, method, jwt.MapClaims{})(e)
badParser = NewParser(invalidKeys, method, MapClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
_, err = badParser(ctx, struct{}{})
if err == nil {
Expand All @@ -133,15 +134,15 @@ func TestJWTParser(t *testing.T) {
}

// Test for malformed token error response
parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
parser = NewParser(keys, method, StandardClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, malformedKey)
ctx1, err = parser(ctx, struct{}{})
if want, have := ErrTokenMalformed, err; want != have {
t.Fatalf("Expected %+v, got %+v", want, have)
}

// Test for expired token error response
parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
parser = NewParser(keys, method, StandardClaimsFactory)(e)
expired := jwt.NewWithClaims(method, jwt.StandardClaims{ExpiresAt: time.Now().Unix() - 100})
token, err := expired.SignedString(key)
if err != nil {
Expand All @@ -154,7 +155,7 @@ func TestJWTParser(t *testing.T) {
}

// Test for not activated token error response
parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
parser = NewParser(keys, method, StandardClaimsFactory)(e)
notactive := jwt.NewWithClaims(method, jwt.StandardClaims{NotBefore: time.Now().Unix() + 100})
token, err = notactive.SignedString(key)
if err != nil {
Expand All @@ -167,7 +168,7 @@ func TestJWTParser(t *testing.T) {
}

// test valid standard claims token
parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
parser = NewParser(keys, method, StandardClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey)
ctx1, err = parser(ctx, struct{}{})
if err != nil {
Expand All @@ -182,7 +183,7 @@ func TestJWTParser(t *testing.T) {
}

// test valid customized claims token
parser = NewParser(keys, method, &customClaims{})(e)
parser = NewParser(keys, method, func() jwt.Claims { return &customClaims{} })(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, customSignedKey)
ctx1, err = parser(ctx, struct{}{})
if err != nil {
Expand All @@ -199,3 +200,22 @@ func TestJWTParser(t *testing.T) {
t.Fatalf("JWT customClaims.MyProperty did not match: expecting %s got %s", myProperty, custCl.MyProperty)
}
}

func TestIssue562(t *testing.T) {
var (
kf = func(token *jwt.Token) (interface{}, error) { return []byte("secret"), nil }
e = NewParser(kf, jwt.SigningMethodHS256, MapClaimsFactory)(endpoint.Nop)
key = JWTTokenContextKey
val = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E"
ctx = context.WithValue(context.Background(), key, val)
)
wg := sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
e(ctx, struct{}{}) // fatal error: concurrent map read and map write
}()
}
wg.Wait()
}

0 comments on commit 0d70b13

Please sign in to comment.