From 5020bec824f5c07e8fff76410a10e51c4f7337e1 Mon Sep 17 00:00:00 2001 From: modulitos Date: Sun, 20 Oct 2024 19:55:00 -0700 Subject: [PATCH] fetch SAs from apiserver --- main.go | 1 + pkg/cache/cache.go | 72 ++++++++++++++++++++-------- pkg/cache/cache_test.go | 97 ++++++++++++++++++++++++++++++++++---- pkg/cache/notifications.go | 60 +++++++++++++++++++++++ 4 files changed, 201 insertions(+), 29 deletions(-) create mode 100644 pkg/cache/notifications.go diff --git a/main.go b/main.go index 17294659d..04e03f7a4 100644 --- a/main.go +++ b/main.go @@ -179,6 +179,7 @@ func main() { saInformer, cmInformer, composeRoleArnCache, + clientset.CoreV1(), ) stop := make(chan struct{}) informerFactory.Start(stop) diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index 55b5885f7..33b201a62 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -16,19 +16,23 @@ package cache import ( + "context" "encoding/json" "fmt" "regexp" "strconv" "strings" "sync" + "time" "github.com/aws/amazon-eks-pod-identity-webhook/pkg" "github.com/prometheus/client_golang/prometheus" v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" utilruntime "k8s.io/apimachinery/pkg/util/runtime" coreinformers "k8s.io/client-go/informers/core/v1" "k8s.io/client-go/kubernetes" + corev1 "k8s.io/client-go/kubernetes/typed/core/v1" "k8s.io/client-go/tools/cache" "k8s.io/klog/v2" ) @@ -79,8 +83,7 @@ type serviceAccountCache struct { composeRoleArn ComposeRoleArn defaultTokenExpiration int64 webhookUsage prometheus.Gauge - notificationHandlers map[string]chan struct{} - handlerMu sync.Mutex + notifications *notifications } type ComposeRoleArn struct { @@ -155,20 +158,13 @@ func (c *serviceAccountCache) GetCommonConfigurations(name, namespace string) (u return false, pkg.DefaultTokenExpiration } -func (c *serviceAccountCache) getSA(req Request) (*Entry, chan struct{}) { +func (c *serviceAccountCache) getSA(req Request) (*Entry, <-chan struct{}) { c.mu.RLock() defer c.mu.RUnlock() entry, ok := c.saCache[req.CacheKey()] if !ok && req.RequestNotification { klog.V(5).Infof("Service Account %s not found in cache, adding notification handler", req.CacheKey()) - c.handlerMu.Lock() - defer c.handlerMu.Unlock() - notifier, found := c.notificationHandlers[req.CacheKey()] - if !found { - notifier = make(chan struct{}) - c.notificationHandlers[req.CacheKey()] = notifier - } - return nil, notifier + return nil, c.notifications.create(req) } return entry, nil } @@ -263,13 +259,7 @@ func (c *serviceAccountCache) setSA(name, namespace string, entry *Entry) { klog.V(5).Infof("Adding SA %q to SA cache: %+v", key, entry) c.saCache[key] = entry - c.handlerMu.Lock() - defer c.handlerMu.Unlock() - if handler, found := c.notificationHandlers[key]; found { - klog.V(5).Infof("Notifying handlers for %q", key) - close(handler) - delete(c.notificationHandlers, key) - } + c.notifications.broadcast(key) } func (c *serviceAccountCache) setCM(name, namespace string, entry *Entry) { @@ -279,7 +269,15 @@ func (c *serviceAccountCache) setCM(name, namespace string, entry *Entry) { c.cmCache[namespace+"/"+name] = entry } -func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenExpiration int64, saInformer coreinformers.ServiceAccountInformer, cmInformer coreinformers.ConfigMapInformer, composeRoleArn ComposeRoleArn) ServiceAccountCache { +func New(defaultAudience, + prefix string, + defaultRegionalSTS bool, + defaultTokenExpiration int64, + saInformer coreinformers.ServiceAccountInformer, + cmInformer coreinformers.ConfigMapInformer, + composeRoleArn ComposeRoleArn, + SAGetter corev1.ServiceAccountsGetter, +) ServiceAccountCache { hasSynced := func() bool { if cmInformer != nil { return saInformer.Informer().HasSynced() && cmInformer.Informer().HasSynced() @@ -288,6 +286,8 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx } } + // Rate limit to 10 concurrent requests against the API server. + saFetchRequests := make(chan *Request, 10) c := &serviceAccountCache{ saCache: map[string]*Entry{}, cmCache: map[string]*Entry{}, @@ -298,9 +298,20 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx defaultTokenExpiration: defaultTokenExpiration, hasSynced: hasSynced, webhookUsage: webhookUsage, - notificationHandlers: map[string]chan struct{}{}, + notifications: newNotifications(saFetchRequests), } + go func() { + for req := range saFetchRequests { + sa, err := fetchFromAPI(SAGetter, req) + if err != nil { + klog.Errorf("fetching SA: %s, but got error from API: %v", req.CacheKey(), err) + continue + } + c.addSA(sa) + } + }() + saInformer.Informer().AddEventHandler( cache.ResourceEventHandlerFuncs{ AddFunc: func(obj interface{}) { @@ -350,6 +361,27 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx return c } +func fetchFromAPI(getter corev1.ServiceAccountsGetter, req *Request) (*v1.ServiceAccount, error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + saList, err := getter.ServiceAccounts(req.Namespace).List( + ctx, + metav1.ListOptions{}, + ) + if err != nil { + return nil, err + } + + // Find the ServiceAccount + for _, sa := range saList.Items { + if sa.Name == req.Name { + return &sa, nil + + } + } + return nil, fmt.Errorf("no SA found in namespace: %s", req.CacheKey()) +} + func (c *serviceAccountCache) populateCacheFromCM(oldCM, newCM *v1.ConfigMap) error { if newCM.Name != "pod-identity-webhook" { return nil diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index d4a540cfd..55a1ff0d1 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -35,6 +35,7 @@ func TestSaCache(t *testing.T) { defaultAudience: "sts.amazonaws.com", annotationPrefix: "eks.amazonaws.com", webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}), + notifications: newNotifications(make(chan *Request, 10)), } resp := cache.Get(Request{Name: "default", Namespace: "default"}) @@ -69,9 +70,9 @@ func TestNotification(t *testing.T) { t.Run("with one notification handler", func(t *testing.T) { cache := &serviceAccountCache{ - saCache: map[string]*Entry{}, - notificationHandlers: map[string]chan struct{}{}, - webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}), + saCache: map[string]*Entry{}, + webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}), + notifications: newNotifications(make(chan *Request, 10)), } // test that the requested SA is not in the cache @@ -106,9 +107,9 @@ func TestNotification(t *testing.T) { t.Run("with 10 notification handlers", func(t *testing.T) { cache := &serviceAccountCache{ - saCache: map[string]*Entry{}, - notificationHandlers: map[string]chan struct{}{}, - webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}), + saCache: map[string]*Entry{}, + webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}), + notifications: newNotifications(make(chan *Request, 5)), } // test that the requested SA is not in the cache @@ -153,6 +154,63 @@ func TestNotification(t *testing.T) { }) } +func TestFetchFromAPIServer(t *testing.T) { + testSA := &v1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + Name: "default", + Namespace: "default", + Annotations: map[string]string{ + "eks.amazonaws.com/role-arn": "arn:aws:iam::111122223333:role/s3-reader", + "eks.amazonaws.com/token-expiration": "3600", + }, + }, + } + fakeSAClient := fake.NewSimpleClientset(testSA) + + // use an empty informer to simulate the need to fetch SA from api server: + fakeEmptyClient := fake.NewSimpleClientset() + emptyInformerFactory := informers.NewSharedInformerFactory(fakeEmptyClient, 0) + emptyInformer := emptyInformerFactory.Core().V1().ServiceAccounts() + + cache := New( + "sts.amazonaws.com", + "eks.amazonaws.com", + true, + 86400, + emptyInformer, + nil, + ComposeRoleArn{}, + fakeSAClient.CoreV1(), + ) + + stop := make(chan struct{}) + emptyInformerFactory.Start(stop) + emptyInformerFactory.WaitForCacheSync(stop) + cache.Start(stop) + defer close(stop) + + err := wait.ExponentialBackoff(wait.Backoff{Duration: 10 * time.Millisecond, Factor: 1.0, Steps: 3}, func() (bool, error) { + return len(fakeEmptyClient.Actions()) != 0, nil + }) + if err != nil { + t.Fatalf("informer never called client: %v", err) + } + + resp := cache.Get(Request{Name: "default", Namespace: "default", RequestNotification: true}) + assert.False(t, resp.FoundInCache, "Expected cache entry to not be found") + + // wait for the notification while we fetch the SA from the API server: + select { + case <-resp.Notifier: + // expected + // test that the requested SA is now in the cache + resp := cache.Get(Request{Name: "default", Namespace: "default", RequestNotification: false}) + assert.True(t, resp.FoundInCache, "Expected cache entry to be found in cache") + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for notification") + } +} + func TestNonRegionalSTS(t *testing.T) { trueStr := "true" falseStr := "false" @@ -237,7 +295,16 @@ func TestNonRegionalSTS(t *testing.T) { testComposeRoleArn := ComposeRoleArn{} - cache := New(audience, "eks.amazonaws.com", tc.defaultRegionalSTS, 86400, informer, nil, testComposeRoleArn) + cache := New( + audience, + "eks.amazonaws.com", + tc.defaultRegionalSTS, + 86400, + informer, + nil, + testComposeRoleArn, + fakeClient.CoreV1(), + ) stop := make(chan struct{}) informerFactory.Start(stop) informerFactory.WaitForCacheSync(stop) @@ -295,7 +362,8 @@ func TestPopulateCacheFromCM(t *testing.T) { } c := serviceAccountCache{ - cmCache: make(map[string]*Entry), + cmCache: make(map[string]*Entry), + notifications: newNotifications(make(chan *Request, 10)), } { @@ -353,6 +421,7 @@ func TestSAAnnotationRemoval(t *testing.T) { saCache: make(map[string]*Entry), annotationPrefix: "eks.amazonaws.com", webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}), + notifications: newNotifications(make(chan *Request, 10)), } c.addSA(oldSA) @@ -416,6 +485,7 @@ func TestCachePrecedence(t *testing.T) { defaultTokenExpiration: pkg.DefaultTokenExpiration, annotationPrefix: "eks.amazonaws.com", webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}), + notifications: newNotifications(make(chan *Request, 10)), } { @@ -514,7 +584,15 @@ func TestRoleArnComposition(t *testing.T) { informerFactory := informers.NewSharedInformerFactory(fakeClient, 0) informer := informerFactory.Core().V1().ServiceAccounts() - cache := New(audience, "eks.amazonaws.com", true, 86400, informer, nil, testComposeRoleArn) + cache := New(audience, + "eks.amazonaws.com", + true, + 86400, + informer, + nil, + testComposeRoleArn, + fakeClient.CoreV1(), + ) stop := make(chan struct{}) informerFactory.Start(stop) informerFactory.WaitForCacheSync(stop) @@ -613,6 +691,7 @@ func TestGetCommonConfigurations(t *testing.T) { defaultAudience: "sts.amazonaws.com", annotationPrefix: "eks.amazonaws.com", webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}), + notifications: newNotifications(make(chan *Request, 10)), } if tc.serviceAccount != nil { diff --git a/pkg/cache/notifications.go b/pkg/cache/notifications.go new file mode 100644 index 000000000..8661d4170 --- /dev/null +++ b/pkg/cache/notifications.go @@ -0,0 +1,60 @@ +package cache + +import ( + "sync" + + "k8s.io/klog/v2" + + "github.com/prometheus/client_golang/prometheus" +) + +var notificationUsage = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "pod_identity_cache_notifications", + Help: "Counter of SA notifications", + }, + []string{"method"}, +) + +func init() { + prometheus.MustRegister(notificationUsage) +} + +type notifications struct { + handlers map[string]chan struct{} + mu sync.Mutex + fetchRequests chan<- *Request +} + +func newNotifications(saFetchRequests chan<- *Request) *notifications { + return ¬ifications{ + handlers: map[string]chan struct{}{}, + fetchRequests: saFetchRequests, + } +} + +func (n *notifications) create(req Request) <-chan struct{} { + n.mu.Lock() + defer n.mu.Unlock() + + notificationUsage.WithLabelValues("used").Inc() + notifier, found := n.handlers[req.CacheKey()] + if !found { + notifier = make(chan struct{}) + n.handlers[req.CacheKey()] = notifier + notificationUsage.WithLabelValues("created").Inc() + n.fetchRequests <- &req + } + return notifier +} + +func (n *notifications) broadcast(key string) { + n.mu.Lock() + defer n.mu.Unlock() + if handler, found := n.handlers[key]; found { + klog.V(5).Infof("Notifying handlers for %q", key) + notificationUsage.WithLabelValues("broadcast").Inc() + close(handler) + delete(n.handlers, key) + } +}