diff --git a/pkg/controllers/jobset_controller.go b/pkg/controllers/jobset_controller.go index 36b54187d..384095c27 100644 --- a/pkg/controllers/jobset_controller.go +++ b/pkg/controllers/jobset_controller.go @@ -388,10 +388,10 @@ func (r *JobSetReconciler) suspendJobs(ctx context.Context, js *jobset.JobSet, a // resumeJobsIfNecessary iterates through each replicatedJob, resuming any suspended jobs if the JobSet // is not suspended. func (r *JobSetReconciler) resumeJobsIfNecessary(ctx context.Context, js *jobset.JobSet, activeJobs []*batchv1.Job, replicatedJobStatuses []jobset.ReplicatedJobStatus, updateStatusOpts *statusUpdateOpts) error { - // Store node selector for each replicatedJob template. - nodeAffinities := map[string]map[string]string{} + // Store pod template for each replicatedJob. + replicatedJobTemplateMap := map[string]corev1.PodTemplateSpec{} for _, replicatedJob := range js.Spec.ReplicatedJobs { - nodeAffinities[replicatedJob.Name] = replicatedJob.Template.Spec.Template.Spec.NodeSelector + replicatedJobTemplateMap[replicatedJob.Name] = replicatedJob.Template.Spec.Template } // Map each replicatedJob to a list of its active jobs. @@ -415,7 +415,7 @@ func (r *JobSetReconciler) resumeJobsIfNecessary(ctx context.Context, js *jobset if !jobSuspended(job) { continue } - if err := r.resumeJob(ctx, job, nodeAffinities); err != nil { + if err := r.resumeJob(ctx, job, replicatedJobTemplateMap); err != nil { return err } } @@ -433,7 +433,7 @@ func (r *JobSetReconciler) resumeJobsIfNecessary(ctx context.Context, js *jobset return nil } -func (r *JobSetReconciler) resumeJob(ctx context.Context, job *batchv1.Job, nodeAffinities map[string]map[string]string) error { +func (r *JobSetReconciler) resumeJob(ctx context.Context, job *batchv1.Job, replicatedJobTemplateMap map[string]corev1.PodTemplateSpec) error { log := ctrl.LoggerFrom(ctx) // Kubernetes validates that a job template is immutable // so if the job has started i.e., startTime != nil), we must set it to nil first. @@ -443,10 +443,33 @@ func (r *JobSetReconciler) resumeJob(ctx context.Context, job *batchv1.Job, node return err } } + + // Get name of parent replicated job and use it to look up the pod template. + replicatedJobName := job.Labels[jobset.ReplicatedJobNameKey] + replicatedJobPodTemplate := replicatedJobTemplateMap[replicatedJobName] if job.Labels != nil && job.Labels[jobset.ReplicatedJobNameKey] != "" { - // When resuming a job, its nodeSelectors should match that of the replicatedJob template - // that it was created from, which may have been updated while it was suspended. - job.Spec.Template.Spec.NodeSelector = nodeAffinities[job.Labels[jobset.ReplicatedJobNameKey]] + // Certain fields on the Job pod template may be mutated while a JobSet is suspended, + // for integration with Kueue. Ensure these updates are propagated to the child Jobs + // when the JobSet is resumed. + // Merge values rather than overwriting them, since a different controller + // (e.g., the Job controller) may have added labels/annotations/etc to the + // Job that do not exist in the ReplicatedJob pod template. + job.Spec.Template.Labels = collections.MergeMaps( + job.Spec.Template.Labels, + replicatedJobPodTemplate.Labels, + ) + job.Spec.Template.Annotations = collections.MergeMaps( + job.Spec.Template.Annotations, + replicatedJobPodTemplate.Annotations, + ) + job.Spec.Template.Spec.NodeSelector = collections.MergeMaps( + job.Spec.Template.Spec.NodeSelector, + replicatedJobPodTemplate.Spec.NodeSelector, + ) + job.Spec.Template.Spec.Tolerations = collections.MergeSlices( + job.Spec.Template.Spec.Tolerations, + replicatedJobPodTemplate.Spec.Tolerations, + ) } else { log.Error(nil, "job missing ReplicatedJobName label") } diff --git a/pkg/util/collections/collections.go b/pkg/util/collections/collections.go index 7be453011..6e40f8195 100644 --- a/pkg/util/collections/collections.go +++ b/pkg/util/collections/collections.go @@ -47,3 +47,41 @@ func IndexOf[T comparable](slice []T, item T) int { } return -1 } + +// MergeMaps will merge the `old` and `new` maps and return the +// merged map. If a key appears in both maps, the key-value pair +// in the `new` map will overwrite the value in the `old` map. +func MergeMaps[K comparable, V any](old, new map[K]V) map[K]V { + merged := make(map[K]V) + for k, v := range old { + merged[k] = v + } + for k, v := range new { + merged[k] = v // Overwrite if duplicate + } + return merged +} + +func MergeSlices[T comparable](s1, s2 []T) []T { + mergedSet := make(map[T]bool) + + // Add elements from s1 to the set + for _, item := range s1 { + mergedSet[item] = true + } + + // Add elements from s2, only if they are not already in the set + for _, item := range s2 { + if _, exists := mergedSet[item]; !exists { + mergedSet[item] = true + } + } + + // Convert the set back into a slice + mergedSlice := make([]T, 0, len(mergedSet)) + for item := range mergedSet { + mergedSlice = append(mergedSlice, item) + } + + return mergedSlice +} diff --git a/pkg/util/collections/collections_test.go b/pkg/util/collections/collections_test.go index 5b3bec6b6..00d94226a 100644 --- a/pkg/util/collections/collections_test.go +++ b/pkg/util/collections/collections_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "golang.org/x/exp/slices" ) func TestConcat(t *testing.T) { @@ -151,3 +152,90 @@ func TestContains(t *testing.T) { }) } } + +func TestMergeMaps(t *testing.T) { + testCases := []struct { + name string + m1 map[string]int + m2 map[string]int + expected map[string]int + }{ + { + name: "Basic merge", + m1: map[string]int{"a": 1, "b": 2}, + m2: map[string]int{"c": 3, "d": 4}, + expected: map[string]int{"a": 1, "b": 2, "c": 3, "d": 4}, + }, + { + name: "Overlapping keys", + m1: map[string]int{"a": 1, "b": 2}, + m2: map[string]int{"b": 3, "c": 4}, + expected: map[string]int{"a": 1, "b": 3, "c": 4}, // m2 value for 'b' overwrites + }, + { + name: "Empty maps", + m1: map[string]int{}, + m2: map[string]int{}, + expected: map[string]int{}, + }, + { + name: "One empty map", + m1: map[string]int{"a": 1, "b": 2}, + m2: map[string]int{}, + expected: map[string]int{"a": 1, "b": 2}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + merged := MergeMaps(tc.m1, tc.m2) + + if !reflect.DeepEqual(merged, tc.expected) { + t.Errorf("expected %v, got %v", tc.expected, merged) + } + }) + } +} + +func TestMergeSlices(t *testing.T) { + testCases := []struct { + name string + s1 []int + s2 []int + expected []int + }{ + { + name: "merge with overlapping elements should not result in duplicates", + s1: []int{1, 2, 3}, + s2: []int{3, 4, 5}, + expected: []int{1, 2, 3, 4, 5}, + }, + { + name: "empty slices", + s1: []int{}, + s2: []int{}, + expected: []int{}, + }, + { + name: "one empty slice", + s1: []int{1, 2}, + s2: []int{}, + expected: []int{1, 2}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + merged := MergeSlices(tc.s1, tc.s2) + + // Sort before comparison so slices with the same elements + // should be the same. + slices.Sort(merged) + slices.Sort(tc.expected) + + if !reflect.DeepEqual(merged, tc.expected) { + t.Errorf("Expected %v, got %v", tc.expected, merged) + } + }) + } +} diff --git a/test/integration/controller/jobset_controller_test.go b/test/integration/controller/jobset_controller_test.go index 5afa01df1..f829fe06e 100644 --- a/test/integration/controller/jobset_controller_test.go +++ b/test/integration/controller/jobset_controller_test.go @@ -132,9 +132,16 @@ var _ = ginkgo.Describe("JobSet controller", func() { updates []*update } - nodeSelectors := map[string]map[string]string{ - "replicated-job-a": {"node-selector-test-a": "node-selector-test-a"}, - "replicated-job-b": {"node-selector-test-b": "node-selector-test-b"}, + var podTemplateUpdates = &updatePodTemplateOpts{ + labels: map[string]string{"label": "value"}, + annotations: map[string]string{"annotation": "value"}, + nodeSelector: map[string]string{"node-selector-test-a": "node-selector-test-a"}, + tolerations: []corev1.Toleration{ + { + Key: "key", + Operator: corev1.TolerationOpExists, + }, + }, } ginkgo.DescribeTable("jobset is created and its jobs go through a series of updates", @@ -514,7 +521,7 @@ var _ = ginkgo.Describe("JobSet controller", func() { }, { jobSetUpdateFn: func(js *jobset.JobSet) { - updateJobSetNodeSelectors(js, nodeSelectors) + updatePodTemplates(js, podTemplateUpdates) }, checkJobSetState: func(js *jobset.JobSet) { ginkgo.By("Check ReplicatedJobStatus for suspend") @@ -542,7 +549,7 @@ var _ = ginkgo.Describe("JobSet controller", func() { { checkJobSetState: func(js *jobset.JobSet) { ginkgo.By("checking jobs have expected node selectors") - gomega.Eventually(matchJobsNodeSelectors, timeout, interval).WithArguments(js, nodeSelectors).Should(gomega.Equal(true)) + gomega.Eventually(checkPodTemplateUpdates, timeout, interval).WithArguments(js, podTemplateUpdates).Should(gomega.Equal(true)) }, jobUpdateFn: completeAllJobs, checkJobSetCondition: testutil.JobSetCompleted, @@ -1464,15 +1471,35 @@ func suspendJobSet(js *jobset.JobSet, suspend bool) { }, timeout, interval).Should(gomega.Succeed()) } -func updateJobSetNodeSelectors(js *jobset.JobSet, nodeSelectors map[string]map[string]string) { +// updatePodTemplateOpts contains pod template values +// which can be mutated on a ReplicatedJob template +// while a JobSet is suspended. +type updatePodTemplateOpts struct { + labels map[string]string + annotations map[string]string + nodeSelector map[string]string + tolerations []corev1.Toleration +} + +func updatePodTemplates(js *jobset.JobSet, opts *updatePodTemplateOpts) { gomega.Eventually(func() error { var jsGet jobset.JobSet if err := k8sClient.Get(ctx, types.NamespacedName{Name: js.Name, Namespace: js.Namespace}, &jsGet); err != nil { return err } for index := range jsGet.Spec.ReplicatedJobs { - jsGet.Spec.ReplicatedJobs[index]. - Template.Spec.Template.Spec.NodeSelector = nodeSelectors[jsGet.Spec.ReplicatedJobs[index].Name] + podTemplate := &jsGet.Spec.ReplicatedJobs[index].Template.Spec.Template + // Update labels. + podTemplate.Labels = opts.labels + + // Update annotations. + podTemplate.Annotations = opts.annotations + + // Update node selector. + podTemplate.Spec.NodeSelector = opts.nodeSelector + + // Update tolerations. + podTemplate.Spec.Tolerations = opts.tolerations } return k8sClient.Update(ctx, &jsGet) }, timeout, interval).Should(gomega.Succeed()) @@ -1496,7 +1523,7 @@ func matchJobsSuspendState(js *jobset.JobSet, suspend bool) (bool, error) { return true, nil } -func matchJobsNodeSelectors(js *jobset.JobSet, nodeSelectors map[string]map[string]string) (bool, error) { +func checkPodTemplateUpdates(js *jobset.JobSet, podTemplateUpdates *updatePodTemplateOpts) (bool, error) { var jobList batchv1.JobList if err := k8sClient.List(ctx, &jobList, client.InNamespace(js.Namespace)); err != nil { return false, err @@ -1504,21 +1531,40 @@ func matchJobsNodeSelectors(js *jobset.JobSet, nodeSelectors map[string]map[stri // Count number of updated jobs jobsUpdated := 0 for _, job := range jobList.Items { - rjobName, ok := job.Labels[jobset.ReplicatedJobNameKey] - if !ok { - return false, fmt.Errorf(fmt.Sprintf("%s job missing ReplicatedJobName label", job.Name)) + // Check label was added. + for label, value := range podTemplateUpdates.labels { + if job.Spec.Template.Labels[label] != value { + return false, fmt.Errorf("%s != %s", job.Spec.Template.Labels[label], value) + } } - if !apiequality.Semantic.DeepEqual(job.Spec.Template.Spec.NodeSelector, nodeSelectors[rjobName]) { - return false, nil + + // Check annotation was added. + for annotation, value := range podTemplateUpdates.annotations { + if job.Spec.Template.Annotations[annotation] != value { + return false, fmt.Errorf("%s != %s", job.Spec.Template.Labels[annotation], value) + } } + + // Check nodeSelector was updated. + for label, value := range podTemplateUpdates.nodeSelector { + if job.Spec.Template.Spec.NodeSelector[label] != value { + return false, fmt.Errorf("%s != %s", job.Spec.Template.Spec.NodeSelector[label], value) + } + } + + // Check tolerations were updated. + for _, toleration := range podTemplateUpdates.tolerations { + if !collections.Contains(job.Spec.Template.Spec.Tolerations, toleration) { + return false, fmt.Errorf("missing toleration %v", toleration) + } + } + jobsUpdated++ } // Calculate expected number of updated jobs wantJobsUpdated := 0 for _, rjob := range js.Spec.ReplicatedJobs { - if _, exists := nodeSelectors[rjob.Name]; exists { - wantJobsUpdated += int(rjob.Replicas) - } + wantJobsUpdated += int(rjob.Replicas) } return wantJobsUpdated == jobsUpdated, nil }