From bb9304596cadd9d3ebdafd68e93ccf63a40163b1 Mon Sep 17 00:00:00 2001 From: Sid Kattoju Date: Tue, 2 Jul 2024 10:48:03 -0400 Subject: [PATCH] Adding a token getter to get service account tokens --- internal/authentication/tokengetter.go | 99 ++++++++++++++++++ internal/authentication/tokengetter_test.go | 106 ++++++++++++++++++++ 2 files changed, 205 insertions(+) create mode 100644 internal/authentication/tokengetter.go create mode 100644 internal/authentication/tokengetter_test.go diff --git a/internal/authentication/tokengetter.go b/internal/authentication/tokengetter.go new file mode 100644 index 00000000..cf5ee8d9 --- /dev/null +++ b/internal/authentication/tokengetter.go @@ -0,0 +1,99 @@ +package authentication + +import ( + "context" + "sync" + "time" + + authenticationv1 "k8s.io/api/authentication/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + corev1 "k8s.io/client-go/kubernetes/typed/core/v1" + "k8s.io/utils/ptr" +) + +type TokenGetter struct { + client corev1.ServiceAccountsGetter + expirationSeconds int64 + tokens map[types.NamespacedName]*authenticationv1.TokenRequestStatus + tokenLocks keyLock[types.NamespacedName] + mu sync.RWMutex +} + +// Returns a token getter that can fetch tokens given a service account. +// The token getter also caches tokens which helps reduce the number of requests to the API Server. +// In case a cached token is expiring a fresh token is created. +func NewTokenGetter(client corev1.ServiceAccountsGetter, expirationSeconds int64) *TokenGetter { + return &TokenGetter{ + client: client, + expirationSeconds: expirationSeconds, + tokens: map[types.NamespacedName]*authenticationv1.TokenRequestStatus{}, + tokenLocks: newKeyLock[types.NamespacedName](), + } +} + +type keyLock[K comparable] struct { + locks map[K]*sync.Mutex + mu sync.Mutex +} + +func newKeyLock[K comparable]() keyLock[K] { + return keyLock[K]{locks: map[K]*sync.Mutex{}} +} + +func (k *keyLock[K]) Lock(key K) { + k.getLock(key).Lock() +} + +func (k *keyLock[K]) Unlock(key K) { + k.getLock(key).Unlock() +} + +func (k *keyLock[K]) getLock(key K) *sync.Mutex { + k.mu.Lock() + defer k.mu.Unlock() + + lock, ok := k.locks[key] + if !ok { + lock = &sync.Mutex{} + k.locks[key] = lock + } + return lock +} + +// Returns a token from the cache if available and not expiring, otherwise creates a new token and caches it. +func (t *TokenGetter) Get(ctx context.Context, key types.NamespacedName) (string, error) { + t.tokenLocks.Lock(key) + defer t.tokenLocks.Unlock(key) + + t.mu.RLock() + token, ok := t.tokens[key] + t.mu.RUnlock() + + expireTime := time.Time{} + if ok { + expireTime = token.ExpirationTimestamp.Time + } + + fiveMinutesAfterNow := metav1.Now().Add(5 * time.Minute) + if expireTime.Before(fiveMinutesAfterNow) { + var err error + token, err = t.getToken(ctx, key) + if err != nil { + return "", err + } + t.mu.Lock() + t.tokens[key] = token + t.mu.Unlock() + } + + return token.Token, nil +} + +func (t *TokenGetter) getToken(ctx context.Context, key types.NamespacedName) (*authenticationv1.TokenRequestStatus, error) { + req, err := t.client.ServiceAccounts(key.Namespace).CreateToken(ctx, key.Name, &authenticationv1.TokenRequest{Spec: authenticationv1.TokenRequestSpec{ExpirationSeconds: ptr.To[int64](3600)}}, metav1.CreateOptions{}) + if err != nil { + return nil, err + } + return &req.Status, nil +} diff --git a/internal/authentication/tokengetter_test.go b/internal/authentication/tokengetter_test.go new file mode 100644 index 00000000..52b61c68 --- /dev/null +++ b/internal/authentication/tokengetter_test.go @@ -0,0 +1,106 @@ +package authentication + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + authenticationv1 "k8s.io/api/authentication/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes/fake" + ctest "k8s.io/client-go/testing" +) + +func TestNewTokenGetter(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + fakeClient.PrependReactor("create", "serviceaccounts/token", func(action ctest.Action) (bool, runtime.Object, error) { + act, ok := action.(ctest.CreateActionImpl) + if !ok { + return false, nil, nil + } + tokenRequest := act.GetObject().(*authenticationv1.TokenRequest) + var err error + if act.Name == "test-service-account-1" { + tokenRequest.Status = authenticationv1.TokenRequestStatus{ + Token: "test-token-1", + ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(5 * time.Minute)), + } + } + if act.Name == "test-service-account-2" { + tokenRequest.Status = authenticationv1.TokenRequestStatus{ + Token: "test-token-2", + ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(1 * time.Second)), + } + } + if act.Name == "test-service-account-3" { + tokenRequest = nil + err = fmt.Errorf("error when fetching token") + } + return true, tokenRequest, err + }) + tg := NewTokenGetter(fakeClient.CoreV1(), int64(5*time.Minute)) + t.Log("Testing NewTokenGetter with fake client") + token, err := tg.Get(context.Background(), types.NamespacedName{ + Namespace: "test-namespace-1", + Name: "test-service-account-1", + }) + if err != nil { + t.Fatalf("failed to get token: %v", err) + return + } + t.Log("token:", token) + if token != "test-token-1" { + t.Errorf("token does not match") + } + t.Log("Testing getting token from cache") + token, err = tg.Get(context.Background(), types.NamespacedName{ + Namespace: "test-namespace-1", + Name: "test-service-account-1", + }) + if err != nil { + t.Fatalf("failed to get token from cache: %v", err) + return + } + t.Log("token:", token) + if token != "test-token-1" { + t.Errorf("token does not match") + } + t.Log("Testing getting short lived token from fake client") + token, err = tg.Get(context.Background(), types.NamespacedName{ + Namespace: "test-namespace-2", + Name: "test-service-account-2", + }) + if err != nil { + t.Fatalf("failed to get token: %v", err) + return + } + t.Log("token:", token) + if token != "test-token-2" { + t.Errorf("token does not match") + } + //wait for token to expire + time.Sleep(1 * time.Second) + t.Log("Testing getting expired token from cache") + token, err = tg.Get(context.Background(), types.NamespacedName{ + Namespace: "test-namespace-2", + Name: "test-service-account-2", + }) + if err != nil { + t.Fatalf("failed to refresh token: %v", err) + return + } + t.Log("token:", token) + if token != "test-token-2" { + t.Errorf("token does not match") + } + t.Log("Testing error when getting token from fake client") + _, err = tg.Get(context.Background(), types.NamespacedName{ + Namespace: "test-namespace-3", + Name: "test-service-account-3", + }) + assert.EqualError(t, err, "error when fetching token") +}