From 7cc9872805a96d96de40797175397d6dd745bd1e Mon Sep 17 00:00:00 2001 From: Sid Kattoju <83437591+skattoju@users.noreply.github.com> Date: Fri, 5 Jul 2024 14:45:54 -0400 Subject: [PATCH] Fixes and Improvements for TokenGetter (#1014) --- internal/authentication/tokengetter.go | 41 ++++++++------------- internal/authentication/tokengetter_test.go | 9 ++--- 2 files changed, 20 insertions(+), 30 deletions(-) diff --git a/internal/authentication/tokengetter.go b/internal/authentication/tokengetter.go index aed477560..585fc65e6 100644 --- a/internal/authentication/tokengetter.go +++ b/internal/authentication/tokengetter.go @@ -13,19 +13,17 @@ import ( ) type TokenGetter struct { - client corev1.ServiceAccountsGetter - expirationDuration time.Duration - removeAfterExpiredDuration time.Duration - tokens map[types.NamespacedName]*authenticationv1.TokenRequestStatus - mu sync.RWMutex + client corev1.ServiceAccountsGetter + expirationDuration time.Duration + tokens map[types.NamespacedName]*authenticationv1.TokenRequestStatus + mu sync.RWMutex } type TokenGetterOption func(*TokenGetter) const ( - RotationThresholdPercentage = 10 - DefaultExpirationDuration = 5 * time.Minute - DefaultRemoveAfterExpiredDuration = 90 * time.Minute + rotationThresholdFraction = 0.1 + DefaultExpirationDuration = 5 * time.Minute ) // Returns a token getter that can fetch tokens given a service account. @@ -33,10 +31,9 @@ const ( // In case a cached token is expiring a fresh token is created. func NewTokenGetter(client corev1.ServiceAccountsGetter, options ...TokenGetterOption) *TokenGetter { tokenGetter := &TokenGetter{ - client: client, - expirationDuration: DefaultExpirationDuration, - removeAfterExpiredDuration: DefaultRemoveAfterExpiredDuration, - tokens: map[types.NamespacedName]*authenticationv1.TokenRequestStatus{}, + client: client, + expirationDuration: DefaultExpirationDuration, + tokens: map[types.NamespacedName]*authenticationv1.TokenRequestStatus{}, } for _, opt := range options { @@ -52,12 +49,6 @@ func WithExpirationDuration(expirationDuration time.Duration) TokenGetterOption } } -func WithRemoveAfterExpiredDuration(removeAfterExpiredDuration time.Duration) TokenGetterOption { - return func(tg *TokenGetter) { - tg.removeAfterExpiredDuration = removeAfterExpiredDuration - } -} - // Get returns a token from the cache if available and not expiring, otherwise creates a new token func (t *TokenGetter) Get(ctx context.Context, key types.NamespacedName) (string, error) { t.mu.RLock() @@ -69,8 +60,8 @@ func (t *TokenGetter) Get(ctx context.Context, key types.NamespacedName) (string expireTime = token.ExpirationTimestamp.Time } - // Create a new token if the cached token expires within DurationPercentage of expirationDuration from now - rotationThresholdAfterNow := metav1.Now().Add(t.expirationDuration * (RotationThresholdPercentage / 100)) + // Create a new token if the cached token expires within rotationThresholdFraction of expirationDuration from now + rotationThresholdAfterNow := metav1.Now().Add(time.Duration(float64(t.expirationDuration) * (rotationThresholdFraction))) if expireTime.Before(rotationThresholdAfterNow) { var err error token, err = t.getToken(ctx, key) @@ -82,8 +73,8 @@ func (t *TokenGetter) Get(ctx context.Context, key types.NamespacedName) (string t.mu.Unlock() } - // Delete tokens that have been expired for more than ExpiredDuration - t.reapExpiredTokens(t.removeAfterExpiredDuration) + // Delete tokens that have expired + t.reapExpiredTokens() return token.Token, nil } @@ -92,7 +83,7 @@ func (t *TokenGetter) getToken(ctx context.Context, key types.NamespacedName) (* req, err := t.client.ServiceAccounts(key.Namespace).CreateToken(ctx, key.Name, &authenticationv1.TokenRequest{ - Spec: authenticationv1.TokenRequestSpec{ExpirationSeconds: ptr.To[int64](int64(t.expirationDuration))}, + Spec: authenticationv1.TokenRequestSpec{ExpirationSeconds: ptr.To[int64](int64(t.expirationDuration / time.Second))}, }, metav1.CreateOptions{}) if err != nil { return nil, err @@ -100,11 +91,11 @@ func (t *TokenGetter) getToken(ctx context.Context, key types.NamespacedName) (* return &req.Status, nil } -func (t *TokenGetter) reapExpiredTokens(expiredDuration time.Duration) { +func (t *TokenGetter) reapExpiredTokens() { t.mu.Lock() defer t.mu.Unlock() for key, token := range t.tokens { - if metav1.Now().Sub(token.ExpirationTimestamp.Time) > expiredDuration { + if metav1.Now().Sub(token.ExpirationTimestamp.Time) > 0 { delete(t.tokens, key) } } diff --git a/internal/authentication/tokengetter_test.go b/internal/authentication/tokengetter_test.go index 5b246c36a..b9553cac3 100644 --- a/internal/authentication/tokengetter_test.go +++ b/internal/authentication/tokengetter_test.go @@ -40,7 +40,7 @@ func TestTokenGetterGet(t *testing.T) { if act.Name == "test-service-account-3" { tokenRequest.Status = authenticationv1.TokenRequestStatus{ Token: "test-token-3", - ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(-DefaultRemoveAfterExpiredDuration)), + ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(-10 * time.Second)), } } if act.Name == "test-service-account-4" { @@ -51,8 +51,7 @@ func TestTokenGetterGet(t *testing.T) { }) tg := NewTokenGetter(fakeClient.CoreV1(), - WithExpirationDuration(DefaultExpirationDuration), - WithRemoveAfterExpiredDuration(DefaultRemoveAfterExpiredDuration)) + WithExpirationDuration(DefaultExpirationDuration)) tests := []struct { testName string @@ -67,9 +66,9 @@ func TestTokenGetterGet(t *testing.T) { "test-namespace-1", "test-token-1", "failed to get token"}, {"Testing getting short lived token from fake client", "test-service-account-2", "test-namespace-2", "test-token-2", "failed to get token"}, - {"Testing getting expired token from cache", "test-service-account-2", + {"Testing getting nearly expired token from cache", "test-service-account-2", "test-namespace-2", "test-token-2", "failed to refresh token"}, - {"Testing token that expired 90 minutes ago", "test-service-account-3", + {"Testing token that expired 10 seconds ago", "test-service-account-3", "test-namespace-3", "test-token-3", "failed to get token"}, {"Testing error when getting token from fake client", "test-service-account-4", "test-namespace-4", "error when fetching token", "error when fetching token"},