diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go index 8e3d8591d6..c40216a23d 100644 --- a/pkg/controller/controller.go +++ b/pkg/controller/controller.go @@ -83,6 +83,13 @@ type Controller interface { GetLogger() logr.Logger } +// PrestartHookable is implemented by controllers that support registering prestart hooks that run +// after caches have been synced (and optionally, leader election), but before their manage reconcile loop. +type PrestartHookable interface { + // Registers a prestart hook with the controller. + PrestartHook(func(ctx context.Context) error) error +} + // New returns a new Controller registered with the Manager. The Manager will ensure that shared Caches have // been synced before the Controller is Started. func New(name string, mgr manager.Manager, options Options) (Controller, error) { diff --git a/pkg/internal/controller/controller.go b/pkg/internal/controller/controller.go index 3732eea16e..38adbeb76d 100644 --- a/pkg/internal/controller/controller.go +++ b/pkg/internal/controller/controller.go @@ -92,6 +92,10 @@ type Controller struct { // RecoverPanic indicates whether the panic caused by reconcile should be recovered. RecoverPanic bool + + // prestartHooks are functions that are run after caches have been synced, but before the reconcile loop has + // been started. This allows for work to be done after winning a leader election. + prestartHooks []func(ctx context.Context) error } // watchDescription contains all the information necessary to start a watch. @@ -223,6 +227,18 @@ func (c *Controller) Start(ctx context.Context) error { // which won't be garbage collected if we hold a reference to it. c.startWatches = nil + c.LogConstructor(nil).Info("Running Prestart Hooks") + for _, hook := range c.prestartHooks { + if err := hook(ctx); err != nil { + err := fmt.Errorf("failed to run prestart hook: %w", err) + c.LogConstructor(nil).Error(err, "Could not run prestart hook") + return err + } + } + + // All the prestart hooks have been run, clear the slice to free the underlying resources. + c.prestartHooks = nil + // Launch workers to process resources c.LogConstructor(nil).Info("Starting workers", "worker count", c.MaxConcurrentReconciles) wg.Add(c.MaxConcurrentReconciles) @@ -354,6 +370,19 @@ func (c *Controller) InjectFunc(f inject.Func) error { return nil } +// PrestartHook implements controller.PrestartHookable. +func (c *Controller) PrestartHook(hook func(context.Context) error) error { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.Started { + c.prestartHooks = append(c.prestartHooks, hook) + return nil + } + + return errors.New("controller has already been added") +} + // updateMetrics updates prometheus metrics within the controller. func (c *Controller) updateMetrics(reconcileTime time.Duration) { ctrlmetrics.ReconcileTime.WithLabelValues(c.Name).Observe(reconcileTime.Seconds()) diff --git a/pkg/internal/controller/controller_test.go b/pkg/internal/controller/controller_test.go index cb50ec999b..926b7344fc 100644 --- a/pkg/internal/controller/controller_test.go +++ b/pkg/internal/controller/controller_test.go @@ -455,6 +455,69 @@ var _ = Describe("controller", func() { }) }) + Describe("PrestartHook", func() { + It("should register multiple prestart hooks", func() { + fn1 := func(ctx context.Context) error { + return nil + } + fn2 := func(ctx context.Context) error { + return nil + } + + Expect(ctrl.PrestartHook(fn1)).ShouldNot(HaveOccurred()) + Expect(ctrl.PrestartHook(fn2)).ShouldNot(HaveOccurred()) + Expect(ctrl.prestartHooks).Should(HaveLen(2)) + }) + + It("should call prestart hooks before reconciler", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ch := make(chan struct{}) + fn1 := func(ctx context.Context) error { + Consistently(reconciled).ShouldNot(Receive()) + close(ch) + return nil + } + + Expect(ctrl.PrestartHook(fn1)).ShouldNot(HaveOccurred()) + go func() { + defer GinkgoRecover() + Expect(ctrl.Start(ctx)).To(Succeed()) + }() + Eventually(ch).Should(BeClosed()) + }) + + It("should return an error if called after start", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + fn1 := func(ctx context.Context) error { + return nil + } + + go func() { + defer GinkgoRecover() + Expect(ctrl.Start(ctx)).To(Succeed()) + }() + + Eventually(func() bool { return ctrl.Started }).Should(BeTrue()) + Expect(ctrl.PrestartHook(fn1)).Should(HaveOccurred()) + }) + + It("should stop controller if hook returns error", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + fn1 := func(ctx context.Context) error { + return errors.New("hook error") + } + + Expect(ctrl.PrestartHook(fn1)).ShouldNot(HaveOccurred()) + Expect(ctrl.Start(ctx)).Should(MatchError(ContainSubstring("hook error"))) + }) + }) + Describe("Processing queue items from a Controller", func() { It("should call Reconciler if an item is enqueued", func() { ctx, cancel := context.WithCancel(context.Background())