From a902abc15ca6ce6337cb956931128de2ec0f103b Mon Sep 17 00:00:00 2001 From: Sid Kattoju Date: Tue, 2 Jul 2024 10:48:03 -0400 Subject: [PATCH] Adding a token getter to get service account tokens --- internal/authentication/tokengetter.go | 105 ++++++++++++++++++ internal/authentication/tokengetter_test.go | 112 ++++++++++++++++++++ 2 files changed, 217 insertions(+) create mode 100644 internal/authentication/tokengetter.go create mode 100644 internal/authentication/tokengetter_test.go diff --git a/internal/authentication/tokengetter.go b/internal/authentication/tokengetter.go new file mode 100644 index 00000000..af5c2b95 --- /dev/null +++ b/internal/authentication/tokengetter.go @@ -0,0 +1,105 @@ +package authentication + +import ( + "context" + "sync" + "time" + + authenticationv1 "k8s.io/api/authentication/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + corev1 "k8s.io/client-go/kubernetes/typed/core/v1" + "k8s.io/utils/ptr" +) + +type TokenGetter struct { + client corev1.ServiceAccountsGetter + expirationSeconds int64 + rotationThreshold time.Duration + tokens map[types.NamespacedName]*authenticationv1.TokenRequestStatus + mu sync.RWMutex +} + +type TokenGetterOption func(*TokenGetter) + +// Returns a token getter that can fetch tokens given a service account. +// The token getter also caches tokens which helps reduce the number of requests to the API Server. +// In case a cached token is expiring a fresh token is created. +func NewTokenGetter(client corev1.ServiceAccountsGetter, options ...TokenGetterOption) *TokenGetter { + tokenGetter := &TokenGetter{ + client: client, + expirationSeconds: int64(5 * time.Minute), // default token ttl + rotationThreshold: 1 * time.Minute, // default rotation threshold + tokens: map[types.NamespacedName]*authenticationv1.TokenRequestStatus{}, + } + + for _, opt := range options { + opt(tokenGetter) + } + + return tokenGetter +} + +func WithExpirationSeconds(expirationSeconds int64) TokenGetterOption { + return func(tg *TokenGetter) { + tg.expirationSeconds = expirationSeconds + } +} + +func WithRotationThreshold(rotationThreshold time.Duration) TokenGetterOption { + return func(tg *TokenGetter) { + tg.rotationThreshold = rotationThreshold + } +} + +// Get returns a token from the cache if available and not expiring, otherwise creates a new token +func (t *TokenGetter) Get(ctx context.Context, key types.NamespacedName) (string, error) { + t.mu.RLock() + token, ok := t.tokens[key] + t.mu.RUnlock() + + expireTime := time.Time{} + if ok { + expireTime = token.ExpirationTimestamp.Time + } + + // Create a new token if the cached token expires within rotationThreshold seconds from now + rotationThresholdAfterNow := metav1.Now().Add(t.rotationThreshold) + if expireTime.Before(rotationThresholdAfterNow) { + var err error + token, err = t.getToken(ctx, key) + if err != nil { + return "", err + } + t.mu.Lock() + t.tokens[key] = token + t.mu.Unlock() + } + + return token.Token, nil +} + +func (t *TokenGetter) getToken(ctx context.Context, key types.NamespacedName) (*authenticationv1.TokenRequestStatus, error) { + req, err := t.client.ServiceAccounts(key.Namespace).CreateToken(ctx, + key.Name, + &authenticationv1.TokenRequest{ + Spec: authenticationv1.TokenRequestSpec{ExpirationSeconds: ptr.To[int64](t.expirationSeconds)}, + }, metav1.CreateOptions{}) + if err != nil { + return nil, err + } + return &req.Status, nil +} + +func (t *TokenGetter) Clean(ctx context.Context, key types.NamespacedName) { + t.mu.RLock() + defer t.mu.RUnlock() + delete(t.tokens, key) +} + +func (t *TokenGetter) TokenExists(ctx context.Context, key types.NamespacedName) bool { + t.mu.RLock() + defer t.mu.RUnlock() + _, ok := t.tokens[key] + return ok +} diff --git a/internal/authentication/tokengetter_test.go b/internal/authentication/tokengetter_test.go new file mode 100644 index 00000000..bfb005fa --- /dev/null +++ b/internal/authentication/tokengetter_test.go @@ -0,0 +1,112 @@ +package authentication + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + authenticationv1 "k8s.io/api/authentication/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes/fake" + ctest "k8s.io/client-go/testing" +) + +func TestNewTokenGetterGet(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + fakeClient.PrependReactor("create", "serviceaccounts/token", + func(action ctest.Action) (bool, runtime.Object, error) { + act, ok := action.(ctest.CreateActionImpl) + if !ok { + return false, nil, nil + } + tokenRequest := act.GetObject().(*authenticationv1.TokenRequest) + var err error + if act.Name == "test-service-account-1" { + tokenRequest.Status = authenticationv1.TokenRequestStatus{ + Token: "test-token-1", + ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(5 * time.Minute)), + } + } + if act.Name == "test-service-account-2" { + tokenRequest.Status = authenticationv1.TokenRequestStatus{ + Token: "test-token-2", + ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(1 * time.Second)), + } + } + if act.Name == "test-service-account-3" { + tokenRequest = nil + err = fmt.Errorf("error when fetching token") + } + return true, tokenRequest, err + }) + + tg := NewTokenGetter(fakeClient.CoreV1(), + WithExpirationSeconds(int64(5*time.Minute)), + WithRotationThreshold(1*time.Minute)) + + tests := []struct { + testName string + serviceAccountName string + namespace string + want string + errorMsg string + }{ + {"Testing getting token with fake client", "test-service-account-1", + "test-namespace-1", "test-token-1", "failed to get token"}, + {"Testing getting token from cache", "test-service-account-1", + "test-namespace-1", "test-token-1", "failed to get token from cache"}, + {"Testing getting short lived token from fake client", "test-service-account-2", + "test-namespace-2", "test-token-2", "failed to get token"}, + {"Testing getting expired token from cache", "test-service-account-2", + "test-namespace-2", "test-token-2", "failed to refresh token"}, + {"Testing error when getting token from fake client", "test-service-account-3", + "test-namespace-3", "error when fetching token", "error when fetching token"}, + } + + for _, tc := range tests { + got, err := tg.Get(context.Background(), types.NamespacedName{Namespace: tc.namespace, Name: tc.serviceAccountName}) + if err != nil { + t.Logf("%s: expected: %v, got: %v", tc.testName, tc.want, err) + assert.EqualError(t, err, tc.errorMsg) + } else { + t.Logf("%s: expected: %v, got: %v", tc.testName, tc.want, got) + assert.Equal(t, tc.want, got, tc.errorMsg) + } + } +} + +func TestNewTokenGetterClean(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + fakeClient.PrependReactor("create", "serviceaccounts/token", + func(action ctest.Action) (bool, runtime.Object, error) { + act, ok := action.(ctest.CreateActionImpl) + if !ok { + return false, nil, nil + } + tokenRequest := act.GetObject().(*authenticationv1.TokenRequest) + if act.Name == "test-service-account-1" { + tokenRequest.Status = authenticationv1.TokenRequestStatus{ + Token: "test-token-1", + ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(5 * time.Minute)), + } + } + return true, tokenRequest, nil + }) + + tg := NewTokenGetter(fakeClient.CoreV1(), + WithExpirationSeconds(int64(5*time.Minute)), + WithRotationThreshold(1*time.Minute)) + + _, err := tg.Get(context.Background(), types.NamespacedName{Namespace: "test-namespace-1", Name: "test-service-account-1"}) + if err != nil { + t.Fatalf("failed to get token: %v", err) + } + assert.True(t, tg.TokenExists(context.Background(), types.NamespacedName{Namespace: "test-namespace-1", Name: "test-service-account-1"})) + t.Logf("Testing removing token from cache") + tg.Clean(context.Background(), types.NamespacedName{Namespace: "test-namespace-1", Name: "test-service-account-1"}) + assert.False(t, tg.TokenExists(context.Background(), types.NamespacedName{Namespace: "test-namespace-1", Name: "test-service-account-1"})) +}