From 8e176480d727c85a264826f27e69068e83b6379f Mon Sep 17 00:00:00 2001 From: subham sarkar Date: Thu, 10 Oct 2024 18:43:23 +0530 Subject: [PATCH] Use jwt.MapClaims instead of jwt.RegisteredClaims (#3) --- credentials/jwt.go | 21 +++++++-------------- credentials/jwt_test.go | 18 +++++++++--------- 2 files changed, 16 insertions(+), 23 deletions(-) diff --git a/credentials/jwt.go b/credentials/jwt.go index 1196f04..e7ba560 100644 --- a/credentials/jwt.go +++ b/credentials/jwt.go @@ -19,12 +19,6 @@ type jwtProvider struct { now func() time.Time // Function to get the current time, useful for testing } -type claims struct { - // Reason for using RegisteredClaims instead of StandardClaims - // See: https://github.com/golang-jwt/jwt/blob/62e504c2810b67f6b97313424411cfffb25e41b0/MIGRATION_GUIDE.md?plain=1#L81 - jwt.RegisteredClaims -} - func (provider *jwtProvider) Retrieve() (io.Reader, error) { expirationTime := provider.GetAppropriateExpirationTime() tokenString, err := provider.BuildClaimsToken(expirationTime, provider.creds.URL, provider.creds.ClientId, provider.creds.ClientUsername) @@ -51,14 +45,13 @@ func (provider *jwtProvider) GetAppropriateExpirationTime() time.Time { } func (provider *jwtProvider) BuildClaimsToken(expirationTime time.Time, url string, clientId string, clientUsername string) (string, error) { - claims := &claims{ - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(expirationTime), - Audience: []string{url}, - Issuer: clientId, - Subject: clientUsername, - }, + claims := jwt.MapClaims{ + "iss": clientId, + "sub": clientUsername, + "aud": url, + "exp": expirationTime.Unix(), } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) if provider.creds.ClientKey == nil { @@ -70,4 +63,4 @@ func (provider *jwtProvider) BuildClaimsToken(expirationTime time.Time, url stri return "", fmt.Errorf("jwtProvider.BuildClaimsToken() error: failed to sign token: %w", err) } return tokenString, nil -} \ No newline at end of file +} diff --git a/credentials/jwt_test.go b/credentials/jwt_test.go index 33af6ac..55f266e 100644 --- a/credentials/jwt_test.go +++ b/credentials/jwt_test.go @@ -121,7 +121,7 @@ func Test_jwtProvider_Retrieve(t *testing.T) { } gotToken := gotForm.Get("assertion") - gotClaims := &claims{} + gotClaims := &jwt.MapClaims{} _, err = jwt.ParseWithClaims(gotToken, gotClaims, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) @@ -133,16 +133,16 @@ func Test_jwtProvider_Retrieve(t *testing.T) { return } - wantClaims := &claims{ - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(fixedTime().Add(JwtExpiration)), - Audience: []string{provider.creds.URL}, - Issuer: provider.creds.ClientId, - Subject: provider.creds.ClientUsername, - }, + wantClaims := jwt.MapClaims{ + "iss": provider.creds.ClientId, + "sub": provider.creds.ClientUsername, + "aud": provider.creds.URL, + "exp": fixedTime().Add(JwtExpiration).Unix(), } - if !reflect.DeepEqual(gotClaims, wantClaims) { + (*gotClaims)["exp"] = int64((*gotClaims)["exp"].(float64)) + + if !reflect.DeepEqual(*gotClaims, wantClaims) { t.Errorf("jwtProvider.Retrieve() claims = %v, want %v", gotClaims, wantClaims) } }