Skip to content

Commit

Permalink
Adding a token getter to get service account tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
skattoju committed Jul 2, 2024
1 parent 872b7f7 commit b3d079a
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 0 deletions.
99 changes: 99 additions & 0 deletions internal/authentication/tokengetter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package authentication

import (
"context"
"sync"
"time"

authv1 "k8s.io/api/authentication/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
corev1 "k8s.io/client-go/kubernetes/typed/core/v1"
"k8s.io/utils/ptr"
)

type TokenGetter struct {
client corev1.ServiceAccountsGetter
expirationSeconds int64
tokens map[types.NamespacedName]*authv1.TokenRequestStatus
tokenLocks keyLock[types.NamespacedName]
mu sync.RWMutex
}

// Returns a token getter that can fetch tokens given a service account.
// The token getter also caches tokens which helps reduce the number of requests to the API Server.
// In case a cached token is expiring a fresh token is created.
func NewTokenGetter(client corev1.ServiceAccountsGetter, expirationSeconds int64) *TokenGetter {
return &TokenGetter{
client: client,
expirationSeconds: expirationSeconds,
tokens: map[types.NamespacedName]*authv1.TokenRequestStatus{},
tokenLocks: newKeyLock[types.NamespacedName](),
}
}

type keyLock[K comparable] struct {
locks map[K]*sync.Mutex
mu sync.Mutex
}

func newKeyLock[K comparable]() keyLock[K] {
return keyLock[K]{locks: map[K]*sync.Mutex{}}
}

func (k *keyLock[K]) Lock(key K) {
k.getLock(key).Lock()
}

func (k *keyLock[K]) Unlock(key K) {
k.getLock(key).Unlock()
}

func (k *keyLock[K]) getLock(key K) *sync.Mutex {
k.mu.Lock()
defer k.mu.Unlock()

lock, ok := k.locks[key]
if !ok {
lock = &sync.Mutex{}
k.locks[key] = lock
}
return lock
}

// Returns a token from the cache if available and not expiring, otherwise creates a new token and caches it.
func (t *TokenGetter) Get(ctx context.Context, key types.NamespacedName) (string, error) {
t.tokenLocks.Lock(key)
defer t.tokenLocks.Unlock(key)

t.mu.RLock()
token, ok := t.tokens[key]
t.mu.RUnlock()

expireTime := time.Time{}
if ok {
expireTime = token.ExpirationTimestamp.Time
}

fiveMinutesAfterNow := metav1.Now().Add(5 * time.Minute)
if expireTime.Before(fiveMinutesAfterNow) {
var err error
token, err = t.getToken(ctx, key)
if err != nil {
return "", err
}
t.mu.Lock()
t.tokens[key] = token
t.mu.Unlock()
}

return token.Token, nil
}

func (t *TokenGetter) getToken(ctx context.Context, key types.NamespacedName) (*authv1.TokenRequestStatus, error) {
req, err := t.client.ServiceAccounts(key.Namespace).CreateToken(ctx, key.Name, &authv1.TokenRequest{Spec: authv1.TokenRequestSpec{ExpirationSeconds: ptr.To[int64](3600)}}, metav1.CreateOptions{})
if err != nil {
return nil, err
}
return &req.Status, nil
}
167 changes: 167 additions & 0 deletions internal/authentication/tokengetter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package authentication

import (
"context"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/assert"
authv1 "k8s.io/api/authentication/v1"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/watch"
corev1 "k8s.io/client-go/applyconfigurations/core/v1"
"k8s.io/client-go/kubernetes/fake"
corev1client "k8s.io/client-go/kubernetes/typed/core/v1"
ctest "k8s.io/client-go/testing"
)

func TestNewTokenGetter(t *testing.T) {
fakeClient := fake.NewSimpleClientset()
fakeClient.PrependReactor("create", "serviceaccounts/token", func(action ctest.Action) (handled bool, ret runtime.Object, err error) {
act, ok := action.(ctest.CreateActionImpl)
if !ok {
return false, nil, nil
}
tokenRequest := act.GetObject().(*authv1.TokenRequest)
if act.Name == "test-service-account-1" {
tokenRequest.Status = authv1.TokenRequestStatus{
Token: "test-token-1",
ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(5 * time.Minute)),
}
}
if act.Name == "test-service-account-2" {
tokenRequest.Status = authv1.TokenRequestStatus{
Token: "test-token-2",
ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(1 * time.Second)),
}
}

return true, tokenRequest, nil
})
tg := NewTokenGetter(fakeClient.CoreV1(), int64(5*time.Minute))
t.Log("Testing NewTokenGetter with fake client")
token, err := tg.Get(context.Background(), types.NamespacedName{
Namespace: "test-namespace-1",
Name: "test-service-account-1",
})
if err != nil {
t.Fatalf("failed to get token: %v", err)
return
}
t.Log("token:", token)
if token != "test-token-1" {
t.Errorf("token does not match")
}
t.Log("Testing getting token from cache")
token, err = tg.Get(context.Background(), types.NamespacedName{
Namespace: "test-namespace-1",
Name: "test-service-account-1",
})
if err != nil {
t.Fatalf("failed to get token from cache: %v", err)
return
}
t.Log("token:", token)
if token != "test-token-1" {
t.Errorf("token does not match")
}
t.Log("Testing getting short lived token from fake client")
token, err = tg.Get(context.Background(), types.NamespacedName{
Namespace: "test-namespace-2",
Name: "test-service-account-2",
})
if err != nil {
t.Fatalf("failed to get token: %v", err)
return
}
t.Log("token:", token)
if token != "test-token-2" {
t.Errorf("token does not match")
}
//wait for token to expire
time.Sleep(1 * time.Second)
t.Log("Testing getting expired token from cache")
token, err = tg.Get(context.Background(), types.NamespacedName{
Namespace: "test-namespace-2",
Name: "test-service-account-2",
})
if err != nil {
t.Fatalf("failed to refresh token: %v", err)
return
}
t.Log("token:", token)
if token != "test-token-2" {
t.Errorf("token does not match")
}
}

type ServiceAccountsGetterImpl struct{}

func (ServiceAccountsGetterImpl) ServiceAccounts(namespace string) corev1client.ServiceAccountInterface {
return &ServiceAccountTokenInterfaceImpl{}
}

type ServiceAccountTokenInterfaceImpl struct{}

func (i ServiceAccountTokenInterfaceImpl) Apply(ctx context.Context, serviceAccount *corev1.ServiceAccountApplyConfiguration, opts metav1.ApplyOptions) (result *v1.ServiceAccount, err error) {
panic("placeholder, not implemented")
}

func (i ServiceAccountTokenInterfaceImpl) Create(ctx context.Context, serviceAccount *v1.ServiceAccount, opts metav1.CreateOptions) (*v1.ServiceAccount, error) {
panic("placeholder, not implemented")
}

func (i ServiceAccountTokenInterfaceImpl) Update(ctx context.Context, serviceAccount *v1.ServiceAccount, opts metav1.UpdateOptions) (*v1.ServiceAccount, error) {
panic("placeholder, not implemented")

}

func (i ServiceAccountTokenInterfaceImpl) Delete(ctx context.Context, name string, opts metav1.DeleteOptions) error {
panic("placeholder, not implemented")

}

func (i ServiceAccountTokenInterfaceImpl) DeleteCollection(ctx context.Context, opts metav1.DeleteOptions, listOpts metav1.ListOptions) error {
panic("placeholder, not implemented")

}

func (i ServiceAccountTokenInterfaceImpl) Get(ctx context.Context, name string, opts metav1.GetOptions) (*v1.ServiceAccount, error) {
panic("placeholder, not implemented")

}

func (i ServiceAccountTokenInterfaceImpl) List(ctx context.Context, opts metav1.ListOptions) (*v1.ServiceAccountList, error) {
panic("placeholder, not implemented")

}

func (i ServiceAccountTokenInterfaceImpl) Watch(ctx context.Context, opts metav1.ListOptions) (watch.Interface, error) {
panic("placeholder, not implemented")

}

func (i ServiceAccountTokenInterfaceImpl) Patch(ctx context.Context, name string, pt types.PatchType, data []byte, opts metav1.PatchOptions, subresources ...string) (result *v1.ServiceAccount, err error) {
panic("placeholder, not implemented")

}

func (ServiceAccountTokenInterfaceImpl) CreateToken(ctx context.Context, serviceAccountName string, tokenRequest *authv1.TokenRequest, opts metav1.CreateOptions) (*authv1.TokenRequest, error) {
err := fmt.Errorf("error when fetching token")
return nil, err
}

func TestTokenGetter_GetToken(t *testing.T) {
t.Log("Testing NewTokenGetter with test service account getter implementation")
saGetter := &ServiceAccountsGetterImpl{}
tg := NewTokenGetter(saGetter, int64(5*time.Minute))
_, err := tg.Get(context.Background(), types.NamespacedName{
Namespace: "test-namespace",
Name: "test-service-account",
})
assert.EqualError(t, err, "error when fetching token")
}

0 comments on commit b3d079a

Please sign in to comment.