diff --git a/pkg/controller/controller_test.go b/pkg/controller/controller_test.go index 0454cb4b90..237565fa07 100644 --- a/pkg/controller/controller_test.go +++ b/pkg/controller/controller_test.go @@ -79,7 +79,7 @@ var _ = Describe("controller.Controller", func() { ctx, cancel := context.WithCancel(context.Background()) watchChan := make(chan event.GenericEvent, 1) - watch := source.Channel(watchChan, &handler.EnqueueRequestForObject{}) + watch := source.Channel(source.NewChannelBroadcaster(watchChan), &handler.EnqueueRequestForObject{}) watchChan <- event.GenericEvent{Object: &corev1.Pod{}} reconcileStarted := make(chan struct{}) diff --git a/pkg/internal/controller/controller_test.go b/pkg/internal/controller/controller_test.go index 2e1842d907..fb91dff3ca 100644 --- a/pkg/internal/controller/controller_test.go +++ b/pkg/internal/controller/controller_test.go @@ -227,7 +227,7 @@ var _ = Describe("controller", func() { } ins := source.Channel( - ch, + source.NewChannelBroadcaster(ch), handler.Funcs{ GenericFunc: func(ctx context.Context, evt event.GenericEvent, q workqueue.RateLimitingInterface) { defer GinkgoRecover() @@ -248,7 +248,7 @@ var _ = Describe("controller", func() { <-processed }) - It("should error when channel source is not specified", func() { + It("should error when ChannelBroadcaster is not specified", func() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -257,7 +257,7 @@ var _ = Describe("controller", func() { e := ctrl.Start(ctx) Expect(e).To(HaveOccurred()) - Expect(e.Error()).To(ContainSubstring("must specify Channel.Source")) + Expect(e.Error()).To(ContainSubstring("must create Channel with a non-nil Broadcaster")) }) It("should call Start on sources with the appropriate EventHandler, Queue, and Predicates", func() { diff --git a/pkg/source/example_test.go b/pkg/source/example_test.go index b596ff0a0a..6ba5acacdc 100644 --- a/pkg/source/example_test.go +++ b/pkg/source/example_test.go @@ -44,7 +44,7 @@ func ExampleChannel() { err := ctrl.Watch( source.Channel( - events, + source.NewChannelBroadcaster(events), &handler.EnqueueRequestForObject{}, ), ) diff --git a/pkg/source/internal/channel.go b/pkg/source/internal/channel.go index 2f1dad3316..252bbb500b 100644 --- a/pkg/source/internal/channel.go +++ b/pkg/source/internal/channel.go @@ -23,7 +23,6 @@ import ( "sync" "k8s.io/client-go/util/workqueue" - "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/event" "sigs.k8s.io/controller-runtime/pkg/handler" "sigs.k8s.io/controller-runtime/pkg/predicate" @@ -33,23 +32,18 @@ import ( // (e.g. GitHub Webhook callback). Channel requires the user to wire the external // source (e.g. http handler) to write GenericEvents to the underlying channel. type Channel[T any] struct { - // once ensures the event distribution goroutine will be performed only once - once sync.Once - - // source is the source channel to fetch GenericEvents - Source <-chan event.TypedGenericEvent[T] + // Broadcaster contains the source channel for events. + Broadcaster *ChannelBroadcaster[T] Handler handler.TypedEventHandler[T] Predicates []predicate.TypedPredicate[T] - BufferSize *int - - // dest is the destination channels of the added event handlers - dest []chan event.TypedGenericEvent[T] + DestBufferSize int - // destLock is to ensure the destination channels are safely added/removed - destLock sync.Mutex + mu sync.Mutex + // isStarted is true if the source has been started. A source can only be started once. + isStarted bool } func (cs *Channel[T]) String() string { @@ -62,89 +56,72 @@ func (cs *Channel[T]) Start( queue workqueue.RateLimitingInterface, ) error { // Source should have been specified by the user. - if cs.Source == nil { - return fmt.Errorf("must specify Channel.Source") + if cs.Broadcaster == nil { + return fmt.Errorf("must create Channel with a non-nil Broadcaster") } if cs.Handler == nil { - return errors.New("must specify Channel.Handler") + return errors.New("must create Channel with a non-nil Handler") } - - if cs.BufferSize == nil { - cs.BufferSize = ptr.To(1024) + if cs.DestBufferSize == 0 { + return errors.New("must create Channel with a >0 DestBufferSize") } - dst := make(chan event.TypedGenericEvent[T], *cs.BufferSize) - - cs.destLock.Lock() - cs.dest = append(cs.dest, dst) - cs.destLock.Unlock() + cs.mu.Lock() + defer cs.mu.Unlock() + if cs.isStarted { + return fmt.Errorf("cannot start an already started Channel source") + } + cs.isStarted = true - cs.once.Do(func() { - // Distribute GenericEvents to all EventHandler / Queue pairs Watching this source - go cs.syncLoop(ctx) - }) + // Create a destination channel for the event handler + // and add it to the list of destinations + destination := make(chan event.TypedGenericEvent[T], cs.DestBufferSize) + cs.Broadcaster.AddListener(destination) go func() { - for evt := range dst { - shouldHandle := true - for _, p := range cs.Predicates { - if !p.Generic(evt) { - shouldHandle = false - break - } - } - - if shouldHandle { - func() { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - cs.Handler.Generic(ctx, evt, queue) - }() - } - } + // Remove the listener and wait for the broadcaster + // to stop sending events to the destination channel. + defer cs.Broadcaster.RemoveListener(destination) + + cs.processReceivedEvents( + ctx, + destination, + queue, + cs.Handler, + cs.Predicates, + ) }() return nil } -func (cs *Channel[T]) doStop() { - cs.destLock.Lock() - defer cs.destLock.Unlock() - - for _, dst := range cs.dest { - close(dst) - } -} - -func (cs *Channel[T]) distribute(evt event.TypedGenericEvent[T]) { - cs.destLock.Lock() - defer cs.destLock.Unlock() - - for _, dst := range cs.dest { - // We cannot make it under goroutine here, or we'll meet the - // race condition of writing message to closed channels. - // To avoid blocking, the dest channels are expected to be of - // proper buffer size. If we still see it blocked, then - // the controller is thought to be in an abnormal state. - dst <- evt - } -} - -func (cs *Channel[T]) syncLoop(ctx context.Context) { +func (cs *Channel[T]) processReceivedEvents( + ctx context.Context, + destination <-chan event.TypedGenericEvent[T], + queue workqueue.RateLimitingInterface, + eventHandler handler.TypedEventHandler[T], + predicates []predicate.TypedPredicate[T], +) { +eventloop: for { select { case <-ctx.Done(): - // Close destination channels - cs.doStop() return - case evt, stillOpen := <-cs.Source: + case event, stillOpen := <-destination: if !stillOpen { - // if the source channel is closed, we're never gonna get - // anything more on it, so stop & bail - cs.doStop() return } - cs.distribute(evt) + + // Check predicates against the event first + // and continue the outer loop if any of them fail. + for _, p := range predicates { + if !p.Generic(event) { + continue eventloop + } + } + + // Call the event handler with the event. + eventHandler.Generic(ctx, event, queue) } } } diff --git a/pkg/source/internal/channel_broadcast.go b/pkg/source/internal/channel_broadcast.go new file mode 100644 index 0000000000..60c41298e0 --- /dev/null +++ b/pkg/source/internal/channel_broadcast.go @@ -0,0 +1,189 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package internal + +import ( + "sync" + + "sigs.k8s.io/controller-runtime/pkg/event" +) + +// ChannelBroadcaster is a wrapper around a channel that allows multiple listeners to all +// receive the events from the channel. +type ChannelBroadcaster[T any] struct { + Source <-chan event.TypedGenericEvent[T] + + mu sync.Mutex + rcCount uint + managementCh chan managementMsg[T] + doneCh chan struct{} +} + +type managementOperation bool + +const ( + addChannel managementOperation = true + removeChannel managementOperation = false +) + +type managementMsg[T any] struct { + operation managementOperation + ch chan event.TypedGenericEvent[T] +} + +// AddListener adds a new listener to the ChannelBroadcaster. Each listener +// will receive all events from the source channel. All listeners have to be +// removed using RemoveListener before the ChannelBroadcaster can be garbage +// collected. +func (sc *ChannelBroadcaster[T]) AddListener(ch chan event.TypedGenericEvent[T]) { + var managementCh chan managementMsg[T] + var doneCh chan struct{} + isFirst := false + func() { + sc.mu.Lock() + defer sc.mu.Unlock() + + isFirst = sc.rcCount == 0 + sc.rcCount++ + + if isFirst { + sc.managementCh = make(chan managementMsg[T]) + sc.doneCh = make(chan struct{}) + } + + managementCh = sc.managementCh + doneCh = sc.doneCh + }() + + if isFirst { + go startLoop(sc.Source, managementCh, doneCh) + } + + // If the goroutine is not yet stopped, send a message to add the + // destination channel. The routine might be stopped already because + // the source channel was closed. + select { + case <-doneCh: + default: + managementCh <- managementMsg[T]{ + operation: addChannel, + ch: ch, + } + } +} + +func startLoop[T any]( + source <-chan event.TypedGenericEvent[T], + managementCh chan managementMsg[T], + doneCh chan struct{}, +) { + defer close(doneCh) + + var destinations []chan event.TypedGenericEvent[T] + + // Close all remaining destinations in case the Source channel is closed. + defer func() { + for _, dst := range destinations { + close(dst) + } + }() + + // Wait for the first destination to be added before starting the loop. + for len(destinations) == 0 { + managementMsg := <-managementCh + if managementMsg.operation == addChannel { + destinations = append(destinations, managementMsg.ch) + } + } + + for { + select { + case msg := <-managementCh: + + switch msg.operation { + case addChannel: + destinations = append(destinations, msg.ch) + case removeChannel: + SearchLoop: + for i, dst := range destinations { + if dst == msg.ch { + destinations = append(destinations[:i], destinations[i+1:]...) + close(dst) + break SearchLoop + } + } + + if len(destinations) == 0 { + return + } + } + + case evt, stillOpen := <-source: + if !stillOpen { + return + } + + for _, dst := range destinations { + // We cannot make it under goroutine here, or we'll meet the + // race condition of writing message to closed channels. + // To avoid blocking, the dest channels are expected to be of + // proper buffer size. If we still see it blocked, then + // the controller is thought to be in an abnormal state. + dst <- evt + } + } + } +} + +// RemoveListener removes a listener from the ChannelBroadcaster. The listener +// will no longer receive events from the source channel. If this is the last +// listener, this function will block until the ChannelBroadcaster's is stopped. +func (sc *ChannelBroadcaster[T]) RemoveListener(ch chan event.TypedGenericEvent[T]) { + var managementCh chan managementMsg[T] + var doneCh chan struct{} + isLast := false + func() { + sc.mu.Lock() + defer sc.mu.Unlock() + + sc.rcCount-- + isLast = sc.rcCount == 0 + + managementCh = sc.managementCh + doneCh = sc.doneCh + }() + + // If the goroutine is not yet stopped, send a message to remove the + // destination channel. The routine might be stopped already because + // the source channel was closed. + select { + case <-doneCh: + default: + managementCh <- managementMsg[T]{ + operation: removeChannel, + ch: ch, + } + } + + // Wait for the doneCh to be closed (in case we are the last one) + if isLast { + <-doneCh + } + + // Wait for the destination channel to be closed. + <-ch +} diff --git a/pkg/source/internal/informer.go b/pkg/source/internal/informer.go index db77a37dab..c3f6252a63 100644 --- a/pkg/source/internal/informer.go +++ b/pkg/source/internal/informer.go @@ -20,6 +20,7 @@ import ( "context" "errors" "fmt" + "sync" "k8s.io/client-go/util/workqueue" "sigs.k8s.io/controller-runtime/pkg/cache" @@ -33,6 +34,10 @@ type Informer struct { Informer cache.Informer Handler handler.EventHandler Predicates []predicate.Predicate + + mu sync.Mutex + // isStarted is true if the source has been started. A source can only be started once. + isStarted bool } // Start is internal and should be called only by the Controller to register an EventHandler with the Informer @@ -40,12 +45,19 @@ type Informer struct { func (is *Informer) Start(ctx context.Context, queue workqueue.RateLimitingInterface) error { // Informer should have been specified by the user. if is.Informer == nil { - return fmt.Errorf("must specify Informer.Informer") + return fmt.Errorf("must create Informer with a non-nil Informer") } if is.Handler == nil { return errors.New("must specify Informer.Handler") } + is.mu.Lock() + defer is.mu.Unlock() + if is.isStarted { + return fmt.Errorf("cannot start an already started Informer source") + } + is.isStarted = true + _, err := is.Informer.AddEventHandler(NewEventHandler(ctx, queue, is.Handler, is.Predicates).HandlerFuncs()) if err != nil { return err diff --git a/pkg/source/internal/kind.go b/pkg/source/internal/kind.go index ae22d0ad9c..b7fdf8d791 100644 --- a/pkg/source/internal/kind.go +++ b/pkg/source/internal/kind.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "reflect" + "sync" "time" "k8s.io/apimachinery/pkg/api/meta" @@ -46,6 +47,9 @@ type Kind[T client.Object] struct { Predicates []predicate.TypedPredicate[T] + mu sync.RWMutex + isStarted bool + // startedErr may contain an error if one was encountered during startup. If its closed and does not // contain an error, startup and syncing finished. startedErr chan error @@ -56,14 +60,21 @@ type Kind[T client.Object] struct { // to enqueue reconcile.Requests. func (ks *Kind[T]) Start(ctx context.Context, queue workqueue.RateLimitingInterface) error { if isNil(ks.Type) { - return fmt.Errorf("must create Kind with a non-nil object") + return fmt.Errorf("must create Kind with a non-nil Type") } if isNil(ks.Cache) { - return fmt.Errorf("must create Kind with a non-nil cache") + return fmt.Errorf("must create Kind with a non-nil Cache") } if isNil(ks.Handler) { - return errors.New("must create Kind with non-nil handler") + return errors.New("must create Kind with a non-nil Handler") + } + + ks.mu.Lock() + defer ks.mu.Unlock() + if ks.isStarted { + return fmt.Errorf("cannot start an already started Kind source") } + ks.isStarted = true // cache.GetInformer will block until its context is cancelled if the cache was already started and it can not // sync that informer (most commonly due to RBAC issues). diff --git a/pkg/source/source.go b/pkg/source/source.go index 08b5abed63..a9d4277187 100644 --- a/pkg/source/source.go +++ b/pkg/source/source.go @@ -23,10 +23,10 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/event" "sigs.k8s.io/controller-runtime/pkg/handler" + "sigs.k8s.io/controller-runtime/pkg/predicate" internal "sigs.k8s.io/controller-runtime/pkg/source/internal" "sigs.k8s.io/controller-runtime/pkg/cache" - "sigs.k8s.io/controller-runtime/pkg/predicate" ) // Source is a source of events (e.g. Create, Update, Delete operations on Kubernetes Objects, Webhook callbacks, etc) @@ -69,7 +69,7 @@ type Informer = internal.Informer var _ Source = &internal.Informer{} type channelOpts[T any] struct { - bufferSize *int + bufferSize int predicates []predicate.TypedPredicate[T] } @@ -87,24 +87,35 @@ func WithPredicates[T any](p ...predicate.TypedPredicate[T]) ChannelOpt[T] { // default, the buffer size is 1024. func WithBufferSize[T any](bufferSize int) ChannelOpt[T] { return func(c *channelOpts[T]) { - c.bufferSize = &bufferSize + c.bufferSize = bufferSize + } +} + +// NewChannelBroadcaster creates a new ChannelBroadcaster for the given channel. +// A ChannelBroadcaster is a wrapper around a channel that allows multiple listeners to all +// receive the events from the channel. +func NewChannelBroadcaster[T any](source <-chan event.TypedGenericEvent[T]) *internal.ChannelBroadcaster[T] { + return &internal.ChannelBroadcaster[T]{ + Source: source, } } // Channel is used to provide a source of events originating outside the cluster // (e.g. GitHub Webhook callback). Channel requires the user to wire the external // source (e.g. http handler) to write GenericEvents to the underlying channel. -func Channel[T any](source <-chan event.TypedGenericEvent[T], handler handler.TypedEventHandler[T], opts ...ChannelOpt[T]) Source { - c := &channelOpts[T]{} +func Channel[T any](broadcaster *internal.ChannelBroadcaster[T], handler handler.TypedEventHandler[T], opts ...ChannelOpt[T]) Source { + c := &channelOpts[T]{ + bufferSize: 1024, + } for _, opt := range opts { opt(c) } return &internal.Channel[T]{ - Source: source, - Handler: handler, - BufferSize: c.bufferSize, - Predicates: c.predicates, + Broadcaster: broadcaster, + Handler: handler, + DestBufferSize: c.bufferSize, + Predicates: c.predicates, } } diff --git a/pkg/source/source_test.go b/pkg/source/source_test.go index d30d5ae5c7..59f9bc5f78 100644 --- a/pkg/source/source_test.go +++ b/pkg/source/source_test.go @@ -184,20 +184,20 @@ var _ = Describe("Source", func() { instance := source.Kind(nil, &corev1.Pod{}, nil) err := instance.Start(ctx, nil) Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("must create Kind with a non-nil cache")) + Expect(err.Error()).To(ContainSubstring("must create Kind with a non-nil Cache")) }) It("should return an error from Start if a type was not provided", func() { instance := source.Kind[client.Object](ic, nil, nil) err := instance.Start(ctx, nil) Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("must create Kind with a non-nil object")) + Expect(err.Error()).To(ContainSubstring("must create Kind with a non-nil Type")) }) It("should return an error from Start if a handler was not provided", func() { instance := source.Kind(ic, &corev1.Pod{}, nil) err := instance.Start(ctx, nil) Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("must create Kind with non-nil handler")) + Expect(err.Error()).To(ContainSubstring("must create Kind with a non-nil Handler")) }) It("should return an error if syncing fails", func() { @@ -295,7 +295,7 @@ var _ = Describe("Source", func() { q := workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "test") instance := source.Channel( - ch, + source.NewChannelBroadcaster(ch), handler.Funcs{ CreateFunc: func(context.Context, event.CreateEvent, workqueue.RateLimitingInterface) { defer GinkgoRecover() @@ -337,7 +337,7 @@ var _ = Describe("Source", func() { q := workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "test") // Add a handler to get distribution blocked instance := source.Channel( - ch, + source.NewChannelBroadcaster(ch), handler.Funcs{ CreateFunc: func(context.Context, event.CreateEvent, workqueue.RateLimitingInterface) { defer GinkgoRecover() @@ -395,7 +395,7 @@ var _ = Describe("Source", func() { q := workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "test") // Add a handler to get distribution blocked instance := source.Channel( - ch, + source.NewChannelBroadcaster(ch), handler.Funcs{ CreateFunc: func(context.Context, event.CreateEvent, workqueue.RateLimitingInterface) { defer GinkgoRecover() @@ -438,7 +438,7 @@ var _ = Describe("Source", func() { processed := make(chan struct{}) defer close(processed) src := source.Channel( - ch, + source.NewChannelBroadcaster(ch), handler.Funcs{ CreateFunc: func(context.Context, event.CreateEvent, workqueue.RateLimitingInterface) { defer GinkgoRecover() @@ -467,11 +467,11 @@ var _ = Describe("Source", func() { Eventually(processed).Should(Receive()) Consistently(processed).ShouldNot(Receive()) }) - It("should get error if no source specified", func() { + It("should get error if no Broadcaster specified", func() { q := workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "test") instance := source.Channel[string](nil, nil /*no source specified*/) err := instance.Start(ctx, q) - Expect(err).To(Equal(fmt.Errorf("must specify Channel.Source"))) + Expect(err).To(Equal(fmt.Errorf("must create Channel with a non-nil Broadcaster"))) }) }) })