Skip to content

Commit

Permalink
Use jwt.MapClaims instead of jwt.RegisteredClaims (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
shmsr authored Oct 10, 2024
1 parent bcc8456 commit 8e17648
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 23 deletions.
21 changes: 7 additions & 14 deletions credentials/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,6 @@ type jwtProvider struct {
now func() time.Time // Function to get the current time, useful for testing
}

type claims struct {
// Reason for using RegisteredClaims instead of StandardClaims
// See: https://github.com/golang-jwt/jwt/blob/62e504c2810b67f6b97313424411cfffb25e41b0/MIGRATION_GUIDE.md?plain=1#L81
jwt.RegisteredClaims
}

func (provider *jwtProvider) Retrieve() (io.Reader, error) {
expirationTime := provider.GetAppropriateExpirationTime()
tokenString, err := provider.BuildClaimsToken(expirationTime, provider.creds.URL, provider.creds.ClientId, provider.creds.ClientUsername)
Expand All @@ -51,14 +45,13 @@ func (provider *jwtProvider) GetAppropriateExpirationTime() time.Time {
}

func (provider *jwtProvider) BuildClaimsToken(expirationTime time.Time, url string, clientId string, clientUsername string) (string, error) {
claims := &claims{
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expirationTime),
Audience: []string{url},
Issuer: clientId,
Subject: clientUsername,
},
claims := jwt.MapClaims{
"iss": clientId,
"sub": clientUsername,
"aud": url,
"exp": expirationTime.Unix(),
}

token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)

if provider.creds.ClientKey == nil {
Expand All @@ -70,4 +63,4 @@ func (provider *jwtProvider) BuildClaimsToken(expirationTime time.Time, url stri
return "", fmt.Errorf("jwtProvider.BuildClaimsToken() error: failed to sign token: %w", err)
}
return tokenString, nil
}
}
18 changes: 9 additions & 9 deletions credentials/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func Test_jwtProvider_Retrieve(t *testing.T) {
}

gotToken := gotForm.Get("assertion")
gotClaims := &claims{}
gotClaims := &jwt.MapClaims{}
_, err = jwt.ParseWithClaims(gotToken, gotClaims, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
Expand All @@ -133,16 +133,16 @@ func Test_jwtProvider_Retrieve(t *testing.T) {
return
}

wantClaims := &claims{
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(fixedTime().Add(JwtExpiration)),
Audience: []string{provider.creds.URL},
Issuer: provider.creds.ClientId,
Subject: provider.creds.ClientUsername,
},
wantClaims := jwt.MapClaims{
"iss": provider.creds.ClientId,
"sub": provider.creds.ClientUsername,
"aud": provider.creds.URL,
"exp": fixedTime().Add(JwtExpiration).Unix(),
}

if !reflect.DeepEqual(gotClaims, wantClaims) {
(*gotClaims)["exp"] = int64((*gotClaims)["exp"].(float64))

if !reflect.DeepEqual(*gotClaims, wantClaims) {
t.Errorf("jwtProvider.Retrieve() claims = %v, want %v", gotClaims, wantClaims)
}
}
Expand Down

0 comments on commit 8e17648

Please sign in to comment.