Skip to content

Commit

Permalink
Refactor source/handler/predicate packages to remove dep injection
Browse files Browse the repository at this point in the history
Signed-off-by: Vince Prignano <vince@prigna.com>
  • Loading branch information
vincepri committed Jan 18, 2023
1 parent 16f7965 commit d6a053f
Show file tree
Hide file tree
Showing 28 changed files with 217 additions and 698 deletions.
6 changes: 3 additions & 3 deletions examples/builtins/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ func main() {
}

// Watch ReplicaSets and enqueue ReplicaSet object key
if err := c.Watch(&source.Kind{Type: &appsv1.ReplicaSet{}}, &handler.EnqueueRequestForObject{}); err != nil {
if err := c.Watch(source.Kind(mgr.GetCache(), &appsv1.ReplicaSet{}), &handler.EnqueueRequestForObject{}); err != nil {
entryLog.Error(err, "unable to watch ReplicaSets")
os.Exit(1)
}

// Watch Pods and enqueue owning ReplicaSet key
if err := c.Watch(&source.Kind{Type: &corev1.Pod{}},
&handler.EnqueueRequestForOwner{OwnerType: &appsv1.ReplicaSet{}, IsController: true}); err != nil {
if err := c.Watch(source.Kind(mgr.GetCache(), &corev1.Pod{}),
handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &appsv1.ReplicaSet{}, handler.OnlyControllerOwner())); err != nil {
entryLog.Error(err, "unable to watch Pods")
os.Exit(1)
}
Expand Down
2 changes: 1 addition & 1 deletion hack/test-all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ if [[ -n ${ARTIFACTS:-} ]]; then
fi

result=0
go test -race ${P_FLAG} ${MOD_OPT} ./... ${GINKGO_ARGS} || result=$?
go test -v -race ${P_FLAG} ${MOD_OPT} ./... --ginkgo.fail-fast ${GINKGO_ARGS} || result=$?

if [[ -n ${ARTIFACTS:-} ]]; then
mkdir -p ${ARTIFACTS}
Expand Down
40 changes: 18 additions & 22 deletions pkg/builder/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"strings"

"github.com/go-logr/logr"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/klog/v2"

Expand Down Expand Up @@ -197,18 +196,16 @@ func (blder *Builder) Build(r reconcile.Reconciler) (controller.Controller, erro
return blder.ctrl, nil
}

func (blder *Builder) project(obj client.Object, proj objectProjection) (client.Object, error) {
func (blder *Builder) project(obj client.Object, proj objectProjection) (source.Source, error) {
src := source.Kind(blder.mgr.GetCache(), obj)
switch proj {
case projectAsNormal:
return obj, nil
return src, nil
case projectAsMetadata:
metaObj := &metav1.PartialObjectMetadata{}
gvk, err := getGvk(obj, blder.mgr.GetScheme())
if err != nil {
return nil, fmt.Errorf("unable to determine GVK of %T for a metadata-only watch: %w", obj, err)
if err := source.KindAsPartialMetadata(src, blder.mgr.GetScheme()); err != nil {
return nil, err
}
metaObj.SetGroupVersionKind(gvk)
return metaObj, nil
return src, nil
default:
panic(fmt.Sprintf("unexpected projection type %v on type %T, should not be possible since this is an internal field", proj, obj))
}
Expand All @@ -217,11 +214,10 @@ func (blder *Builder) project(obj client.Object, proj objectProjection) (client.
func (blder *Builder) doWatch() error {
// Reconcile type
if blder.forInput.object != nil {
typeForSrc, err := blder.project(blder.forInput.object, blder.forInput.objectProjection)
src, err := blder.project(blder.forInput.object, blder.forInput.objectProjection)
if err != nil {
return err
}
src := &source.Kind{Type: typeForSrc}
hdler := &handler.EnqueueRequestForObject{}
allPredicates := append(blder.globalPredicates, blder.forInput.predicates...)
if err := blder.ctrl.Watch(src, hdler, allPredicates...); err != nil {
Expand All @@ -234,15 +230,15 @@ func (blder *Builder) doWatch() error {
return errors.New("Owns() can only be used together with For()")
}
for _, own := range blder.ownsInput {
typeForSrc, err := blder.project(own.object, own.objectProjection)
src, err := blder.project(own.object, own.objectProjection)
if err != nil {
return err
}
src := &source.Kind{Type: typeForSrc}
hdler := &handler.EnqueueRequestForOwner{
OwnerType: blder.forInput.object,
IsController: true,
}
hdler := handler.EnqueueRequestForOwner(
blder.mgr.GetScheme(), blder.mgr.GetRESTMapper(),
blder.forInput.object,
handler.OnlyControllerOwner(),
)
allPredicates := append([]predicate.Predicate(nil), blder.globalPredicates...)
allPredicates = append(allPredicates, own.predicates...)
if err := blder.ctrl.Watch(src, hdler, allPredicates...); err != nil {
Expand All @@ -259,12 +255,12 @@ func (blder *Builder) doWatch() error {
allPredicates = append(allPredicates, w.predicates...)

// If the source of this watch is of type *source.Kind, project it.
if srckind, ok := w.src.(*source.Kind); ok {
typeForSrc, err := blder.project(srckind.Type, w.objectProjection)
if err != nil {
return err
if srckind, ok := w.src.(source.SyncingSource); ok {
if w.objectProjection == projectAsMetadata {
if err := source.KindAsPartialMetadata(srckind, blder.mgr.GetScheme()); err != nil {
return err
}
}
srckind.Type = typeForSrc
}

if err := blder.ctrl.Watch(w.src, w.eventhandler, allPredicates...); err != nil {
Expand Down
18 changes: 10 additions & 8 deletions pkg/builder/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ var _ = Describe("application", func() {
Expect(err).NotTo(HaveOccurred())

instance, err := ControllerManagedBy(m).
Watches(&source.Kind{Type: &appsv1.ReplicaSet{}}, &handler.EnqueueRequestForObject{}).
Watches(source.Kind(m.GetCache(), &appsv1.ReplicaSet{}), &handler.EnqueueRequestForObject{}).
Build(noop)
Expect(err).To(MatchError(ContainSubstring("one of For() or Named() must be called")))
Expect(instance).To(BeNil())
Expand Down Expand Up @@ -157,7 +157,7 @@ var _ = Describe("application", func() {

instance, err := ControllerManagedBy(m).
Named("my_controller").
Watches(&source.Kind{Type: &appsv1.ReplicaSet{}}, &handler.EnqueueRequestForObject{}).
Watches(source.Kind(m.GetCache(), &appsv1.ReplicaSet{}), &handler.EnqueueRequestForObject{}).
Build(noop)
Expect(err).NotTo(HaveOccurred())
Expect(instance).NotTo(BeNil())
Expand Down Expand Up @@ -369,8 +369,9 @@ var _ = Describe("application", func() {
bldr := ControllerManagedBy(m).
For(&appsv1.Deployment{}).
Watches( // Equivalent of Owns
&source.Kind{Type: &appsv1.ReplicaSet{}},
&handler.EnqueueRequestForOwner{OwnerType: &appsv1.Deployment{}, IsController: true})
source.Kind(m.GetCache(), &appsv1.ReplicaSet{}),
handler.EnqueueRequestForOwner(m.GetScheme(), m.GetRESTMapper(), &appsv1.Deployment{}, handler.OnlyControllerOwner()),
)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand All @@ -384,10 +385,11 @@ var _ = Describe("application", func() {
bldr := ControllerManagedBy(m).
Named("Deployment").
Watches( // Equivalent of For
&source.Kind{Type: &appsv1.Deployment{}}, &handler.EnqueueRequestForObject{}).
source.Kind(m.GetCache(), &appsv1.Deployment{}), &handler.EnqueueRequestForObject{}).
Watches( // Equivalent of Owns
&source.Kind{Type: &appsv1.ReplicaSet{}},
&handler.EnqueueRequestForOwner{OwnerType: &appsv1.Deployment{}, IsController: true})
source.Kind(m.GetCache(), &appsv1.ReplicaSet{}),
handler.EnqueueRequestForOwner(m.GetScheme(), m.GetRESTMapper(), &appsv1.Deployment{}, handler.OnlyControllerOwner()),
)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down Expand Up @@ -481,7 +483,7 @@ var _ = Describe("application", func() {
bldr := ControllerManagedBy(mgr).
For(&appsv1.Deployment{}, OnlyMetadata).
Owns(&appsv1.ReplicaSet{}, OnlyMetadata).
Watches(&source.Kind{Type: &appsv1.StatefulSet{}},
Watches(source.Kind(mgr.GetCache(), &appsv1.StatefulSet{}),
handler.EnqueueRequestsFromMapFunc(func(o client.Object) []reconcile.Request {
defer GinkgoRecover()

Expand Down
6 changes: 3 additions & 3 deletions pkg/cluster/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ type Cluster interface {
// GetConfig returns an initialized Config
GetConfig() *rest.Config

// GetCache returns a cache.Cache
GetCache() cache.Cache

// GetScheme returns an initialized Scheme
GetScheme() *runtime.Scheme

Expand All @@ -57,9 +60,6 @@ type Cluster interface {
// GetFieldIndexer returns a client.FieldIndexer configured with the client
GetFieldIndexer() client.FieldIndexer

// GetCache returns a cache.Cache
GetCache() cache.Cache

// GetEventRecorderFor returns a new EventRecorder for the provided name
GetEventRecorderFor(name string) record.EventRecorder

Expand Down
1 change: 0 additions & 1 deletion pkg/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ func NewUnmanaged(name string, mgr manager.Manager, options Options) (Controller
},
MaxConcurrentReconciles: options.MaxConcurrentReconciles,
CacheSyncTimeout: options.CacheSyncTimeout,
SetFields: mgr.SetFields,
Name: name,
LogConstructor: options.LogConstructor,
RecoverPanic: options.RecoverPanic,
Expand Down
9 changes: 5 additions & 4 deletions pkg/controller/controller_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,13 @@ var _ = Describe("controller", func() {
Expect(err).NotTo(HaveOccurred())

By("Watching Resources")
err = instance.Watch(&source.Kind{Type: &appsv1.ReplicaSet{}}, &handler.EnqueueRequestForOwner{
OwnerType: &appsv1.Deployment{},
})
err = instance.Watch(
source.Kind(cm.GetCache(), &appsv1.ReplicaSet{}),
handler.EnqueueRequestForOwner(cm.GetScheme(), cm.GetRESTMapper(), &appsv1.Deployment{}),
)
Expect(err).NotTo(HaveOccurred())

err = instance.Watch(&source.Kind{Type: &appsv1.Deployment{}}, &handler.EnqueueRequestForObject{})
err = instance.Watch(source.Kind(cm.GetCache(), &appsv1.Deployment{}), &handler.EnqueueRequestForObject{})
Expect(err).NotTo(HaveOccurred())

err = cm.GetClient().Get(ctx, types.NamespacedName{Name: "foo"}, &corev1.Namespace{})
Expand Down
2 changes: 1 addition & 1 deletion pkg/controller/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ var _ = Describe("controller.Controller", func() {
It("should not leak goroutines when stopped", func() {
currentGRs := goleak.IgnoreCurrent()

ctx, cancel := context.WithCancel(context.Background())
watchChan := make(chan event.GenericEvent, 1)
watch := &source.Channel{Source: watchChan}
watchChan <- event.GenericEvent{Object: &corev1.Pod{}}
Expand All @@ -114,7 +115,6 @@ var _ = Describe("controller.Controller", func() {
Expect(c.Watch(watch, &handler.EnqueueRequestForObject{})).To(Succeed())
Expect(err).NotTo(HaveOccurred())

ctx, cancel := context.WithCancel(context.Background())
go func() {
defer GinkgoRecover()
Expect(m.Start(ctx)).To(Succeed())
Expand Down
6 changes: 3 additions & 3 deletions pkg/controller/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func ExampleController() {
}

// Watch for Pod create / update / delete events and call Reconcile
err = c.Watch(&source.Kind{Type: &corev1.Pod{}}, &handler.EnqueueRequestForObject{})
err = c.Watch(source.Kind(mgr.GetCache(), &corev1.Pod{}), &handler.EnqueueRequestForObject{})
if err != nil {
log.Error(err, "unable to watch pods")
os.Exit(1)
Expand Down Expand Up @@ -108,7 +108,7 @@ func ExampleController_unstructured() {
Version: "v1",
})
// Watch for Pod create / update / delete events and call Reconcile
err = c.Watch(&source.Kind{Type: u}, &handler.EnqueueRequestForObject{})
err = c.Watch(source.Kind(mgr.GetCache(), u), &handler.EnqueueRequestForObject{})
if err != nil {
log.Error(err, "unable to watch pods")
os.Exit(1)
Expand Down Expand Up @@ -139,7 +139,7 @@ func ExampleNewUnmanaged() {
os.Exit(1)
}

if err := c.Watch(&source.Kind{Type: &corev1.Pod{}}, &handler.EnqueueRequestForObject{}); err != nil {
if err := c.Watch(source.Kind(mgr.GetCache(), &corev1.Pod{}), &handler.EnqueueRequestForObject{}); err != nil {
log.Error(err, "unable to watch pods")
os.Exit(1)
}
Expand Down
11 changes: 0 additions & 11 deletions pkg/handler/enqueue_mapped.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/event"
"sigs.k8s.io/controller-runtime/pkg/reconcile"
"sigs.k8s.io/controller-runtime/pkg/runtime/inject"
)

// MapFunc is the signature required for enqueueing requests from a generic function.
Expand Down Expand Up @@ -85,13 +84,3 @@ func (e *enqueueRequestsFromMapFunc) mapAndEnqueue(q workqueue.RateLimitingInter
}
}
}

// EnqueueRequestsFromMapFunc can inject fields into the mapper.

// InjectFunc implements inject.Injector.
func (e *enqueueRequestsFromMapFunc) InjectFunc(f inject.Func) error {
if f == nil {
return nil
}
return f(e.toRequests)
}
Loading

0 comments on commit d6a053f

Please sign in to comment.