diff --git a/internal/authentication/tokengetter.go b/internal/authentication/tokengetter.go new file mode 100644 index 00000000..aed47756 --- /dev/null +++ b/internal/authentication/tokengetter.go @@ -0,0 +1,111 @@ +package authentication + +import ( + "context" + "sync" + "time" + + authenticationv1 "k8s.io/api/authentication/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + corev1 "k8s.io/client-go/kubernetes/typed/core/v1" + "k8s.io/utils/ptr" +) + +type TokenGetter struct { + client corev1.ServiceAccountsGetter + expirationDuration time.Duration + removeAfterExpiredDuration time.Duration + tokens map[types.NamespacedName]*authenticationv1.TokenRequestStatus + mu sync.RWMutex +} + +type TokenGetterOption func(*TokenGetter) + +const ( + RotationThresholdPercentage = 10 + DefaultExpirationDuration = 5 * time.Minute + DefaultRemoveAfterExpiredDuration = 90 * time.Minute +) + +// Returns a token getter that can fetch tokens given a service account. +// The token getter also caches tokens which helps reduce the number of requests to the API Server. +// In case a cached token is expiring a fresh token is created. +func NewTokenGetter(client corev1.ServiceAccountsGetter, options ...TokenGetterOption) *TokenGetter { + tokenGetter := &TokenGetter{ + client: client, + expirationDuration: DefaultExpirationDuration, + removeAfterExpiredDuration: DefaultRemoveAfterExpiredDuration, + tokens: map[types.NamespacedName]*authenticationv1.TokenRequestStatus{}, + } + + for _, opt := range options { + opt(tokenGetter) + } + + return tokenGetter +} + +func WithExpirationDuration(expirationDuration time.Duration) TokenGetterOption { + return func(tg *TokenGetter) { + tg.expirationDuration = expirationDuration + } +} + +func WithRemoveAfterExpiredDuration(removeAfterExpiredDuration time.Duration) TokenGetterOption { + return func(tg *TokenGetter) { + tg.removeAfterExpiredDuration = removeAfterExpiredDuration + } +} + +// Get returns a token from the cache if available and not expiring, otherwise creates a new token +func (t *TokenGetter) Get(ctx context.Context, key types.NamespacedName) (string, error) { + t.mu.RLock() + token, ok := t.tokens[key] + t.mu.RUnlock() + + expireTime := time.Time{} + if ok { + expireTime = token.ExpirationTimestamp.Time + } + + // Create a new token if the cached token expires within DurationPercentage of expirationDuration from now + rotationThresholdAfterNow := metav1.Now().Add(t.expirationDuration * (RotationThresholdPercentage / 100)) + if expireTime.Before(rotationThresholdAfterNow) { + var err error + token, err = t.getToken(ctx, key) + if err != nil { + return "", err + } + t.mu.Lock() + t.tokens[key] = token + t.mu.Unlock() + } + + // Delete tokens that have been expired for more than ExpiredDuration + t.reapExpiredTokens(t.removeAfterExpiredDuration) + + return token.Token, nil +} + +func (t *TokenGetter) getToken(ctx context.Context, key types.NamespacedName) (*authenticationv1.TokenRequestStatus, error) { + req, err := t.client.ServiceAccounts(key.Namespace).CreateToken(ctx, + key.Name, + &authenticationv1.TokenRequest{ + Spec: authenticationv1.TokenRequestSpec{ExpirationSeconds: ptr.To[int64](int64(t.expirationDuration))}, + }, metav1.CreateOptions{}) + if err != nil { + return nil, err + } + return &req.Status, nil +} + +func (t *TokenGetter) reapExpiredTokens(expiredDuration time.Duration) { + t.mu.Lock() + defer t.mu.Unlock() + for key, token := range t.tokens { + if metav1.Now().Sub(token.ExpirationTimestamp.Time) > expiredDuration { + delete(t.tokens, key) + } + } +} diff --git a/internal/authentication/tokengetter_test.go b/internal/authentication/tokengetter_test.go new file mode 100644 index 00000000..5b246c36 --- /dev/null +++ b/internal/authentication/tokengetter_test.go @@ -0,0 +1,88 @@ +package authentication + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + authenticationv1 "k8s.io/api/authentication/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes/fake" + ctest "k8s.io/client-go/testing" +) + +func TestTokenGetterGet(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + fakeClient.PrependReactor("create", "serviceaccounts/token", + func(action ctest.Action) (bool, runtime.Object, error) { + act, ok := action.(ctest.CreateActionImpl) + if !ok { + return false, nil, nil + } + tokenRequest := act.GetObject().(*authenticationv1.TokenRequest) + var err error + if act.Name == "test-service-account-1" { + tokenRequest.Status = authenticationv1.TokenRequestStatus{ + Token: "test-token-1", + ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(DefaultExpirationDuration)), + } + } + if act.Name == "test-service-account-2" { + tokenRequest.Status = authenticationv1.TokenRequestStatus{ + Token: "test-token-2", + ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(1 * time.Second)), + } + } + if act.Name == "test-service-account-3" { + tokenRequest.Status = authenticationv1.TokenRequestStatus{ + Token: "test-token-3", + ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(-DefaultRemoveAfterExpiredDuration)), + } + } + if act.Name == "test-service-account-4" { + tokenRequest = nil + err = fmt.Errorf("error when fetching token") + } + return true, tokenRequest, err + }) + + tg := NewTokenGetter(fakeClient.CoreV1(), + WithExpirationDuration(DefaultExpirationDuration), + WithRemoveAfterExpiredDuration(DefaultRemoveAfterExpiredDuration)) + + tests := []struct { + testName string + serviceAccountName string + namespace string + want string + errorMsg string + }{ + {"Testing getting token with fake client", "test-service-account-1", + "test-namespace-1", "test-token-1", "failed to get token"}, + {"Testing getting token from cache", "test-service-account-1", + "test-namespace-1", "test-token-1", "failed to get token"}, + {"Testing getting short lived token from fake client", "test-service-account-2", + "test-namespace-2", "test-token-2", "failed to get token"}, + {"Testing getting expired token from cache", "test-service-account-2", + "test-namespace-2", "test-token-2", "failed to refresh token"}, + {"Testing token that expired 90 minutes ago", "test-service-account-3", + "test-namespace-3", "test-token-3", "failed to get token"}, + {"Testing error when getting token from fake client", "test-service-account-4", + "test-namespace-4", "error when fetching token", "error when fetching token"}, + } + + for _, tc := range tests { + got, err := tg.Get(context.Background(), types.NamespacedName{Namespace: tc.namespace, Name: tc.serviceAccountName}) + if err != nil { + t.Logf("%s: expected: %v, got: %v", tc.testName, tc.want, err) + assert.EqualError(t, err, tc.errorMsg) + } else { + t.Logf("%s: expected: %v, got: %v", tc.testName, tc.want, got) + assert.Equal(t, tc.want, got, tc.errorMsg) + } + } +}