From 799345ab07f0117de2e34c262c69ffde0411cb3e Mon Sep 17 00:00:00 2001 From: Shingo Omura Date: Mon, 20 Jan 2025 14:13:13 +0900 Subject: [PATCH] Implement SparkApplication integration (without dynamic scaling) Signed-off-by: Shingo Omura --- .../manager/controller_manager_config.yaml | 2 + config/components/manager/manager.yaml | 2 +- config/components/rbac/role.yaml | 25 ++ config/components/webhook/manifests.yaml | 39 ++ pkg/controller/jobs/jobs.go | 1 + .../sparkapplication_controller.go | 377 ++++++++++++++++++ .../sparkapplication_controller_test.go | 373 +++++++++++++++++ .../sparkapplication_webhook.go | 137 +++++++ .../sparkapplication_webhook_test.go | 289 ++++++++++++++ .../testingjobs/sparkapplication/wrappers.go | 201 ++++++++++ 10 files changed, 1445 insertions(+), 1 deletion(-) create mode 100644 pkg/controller/jobs/sparkapplication/sparkapplication_controller.go create mode 100644 pkg/controller/jobs/sparkapplication/sparkapplication_controller_test.go create mode 100644 pkg/controller/jobs/sparkapplication/sparkapplication_webhook.go create mode 100644 pkg/controller/jobs/sparkapplication/sparkapplication_webhook_test.go create mode 100644 pkg/util/testingjobs/sparkapplication/wrappers.go diff --git a/config/components/manager/controller_manager_config.yaml b/config/components/manager/controller_manager_config.yaml index 618353c0fe0..89ac859be00 100644 --- a/config/components/manager/controller_manager_config.yaml +++ b/config/components/manager/controller_manager_config.yaml @@ -53,6 +53,8 @@ integrations: - "kubeflow.org/tfjob" - "kubeflow.org/xgboostjob" - "workload.codeflare.dev/appwrapper" + - "codeflare.dev/appwrapper" + - "sparkoperator.k8s.io/sparkapplication" # - "pod" # - "deployment" # requires enabling pod integration # - "statefulset" # requires enabling pod integration diff --git a/config/components/manager/manager.yaml b/config/components/manager/manager.yaml index 7b3fd277916..4648f0c1687 100644 --- a/config/components/manager/manager.yaml +++ b/config/components/manager/manager.yaml @@ -24,7 +24,7 @@ spec: - /manager args: - "--zap-log-level=2" - imagePullPolicy: Always + imagePullPolicy: IfNotPresent image: controller:latest name: manager securityContext: diff --git a/config/components/rbac/role.yaml b/config/components/rbac/role.yaml index 15483e27dcd..3ff99d95c00 100644 --- a/config/components/rbac/role.yaml +++ b/config/components/rbac/role.yaml @@ -324,6 +324,31 @@ rules: - get - list - watch +- apiGroups: + - sparkoperator.k8s.io + resources: + - sparkapplications + verbs: + - get + - list + - patch + - update + - watch +- apiGroups: + - sparkoperator.k8s.io + resources: + - sparkapplications/finalizers + verbs: + - get + - update +- apiGroups: + - sparkoperator.k8s.io + resources: + - sparkapplications/status + verbs: + - get + - patch + - update - apiGroups: - workload.codeflare.dev resources: diff --git a/config/components/webhook/manifests.yaml b/config/components/webhook/manifests.yaml index f9fe92e071a..9a5689993cd 100644 --- a/config/components/webhook/manifests.yaml +++ b/config/components/webhook/manifests.yaml @@ -253,6 +253,25 @@ webhooks: resources: - rayjobs sideEffects: None +- admissionReviewVersions: + - v1 + clientConfig: + service: + name: webhook-service + namespace: system + path: /mutate-sparkoperator-k8s-io-v1beta2-sparkapplication + failurePolicy: Fail + name: msparkapplication.kb.io + rules: + - apiGroups: + - sparkoperator.k8s.io + apiVersions: + - v1beta2 + operations: + - CREATE + resources: + - sparkapplications + sideEffects: None - admissionReviewVersions: - v1 clientConfig: @@ -596,6 +615,26 @@ webhooks: resources: - rayjobs sideEffects: None +- admissionReviewVersions: + - v1 + clientConfig: + service: + name: webhook-service + namespace: system + path: /validate-sparkoperator-k8s-io-v1beta2-sparkapplication + failurePolicy: Fail + name: vsparkapplication.kb.io + rules: + - apiGroups: + - sparkoperator.k8s.io + apiVersions: + - v1beta2 + operations: + - CREATE + - UPDATE + resources: + - sparkapplications + sideEffects: None - admissionReviewVersions: - v1 clientConfig: diff --git a/pkg/controller/jobs/jobs.go b/pkg/controller/jobs/jobs.go index bfd87210f3f..61ea32c22af 100644 --- a/pkg/controller/jobs/jobs.go +++ b/pkg/controller/jobs/jobs.go @@ -28,5 +28,6 @@ import ( _ "sigs.k8s.io/kueue/pkg/controller/jobs/pod" _ "sigs.k8s.io/kueue/pkg/controller/jobs/raycluster" _ "sigs.k8s.io/kueue/pkg/controller/jobs/rayjob" + _ "sigs.k8s.io/kueue/pkg/controller/jobs/sparkapplication" _ "sigs.k8s.io/kueue/pkg/controller/jobs/statefulset" ) diff --git a/pkg/controller/jobs/sparkapplication/sparkapplication_controller.go b/pkg/controller/jobs/sparkapplication/sparkapplication_controller.go new file mode 100644 index 00000000000..5ef79f516a4 --- /dev/null +++ b/pkg/controller/jobs/sparkapplication/sparkapplication_controller.go @@ -0,0 +1,377 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sparkapplication + +import ( + "context" + "fmt" + "strings" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/utils/field" + "k8s.io/utils/ptr" + + "sigs.k8s.io/controller-runtime/pkg/client" + kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" + "sigs.k8s.io/kueue/pkg/controller/jobframework" + "sigs.k8s.io/kueue/pkg/podset" + + kfsparkapi "github.com/kubeflow/spark-operator/api/v1beta2" + kfsparkcommon "github.com/kubeflow/spark-operator/pkg/common" + kfsparkutil "github.com/kubeflow/spark-operator/pkg/util" +) + +var ( + gvk = kfsparkapi.SchemeGroupVersion.WithKind("SparkApplication") + + FrameworkName = "sparkoperator.k8s.io/sparkapplication" +) + +func init() { + utilruntime.Must(jobframework.RegisterIntegration(FrameworkName, jobframework.IntegrationCallbacks{ + SetupIndexes: SetupIndexes, + NewJob: NewJob, + NewReconciler: NewReconciler, + SetupWebhook: SetupSparkApplicationWebhook, + JobType: &kfsparkapi.SparkApplication{}, + AddToScheme: kfsparkapi.AddToScheme, + IsManagingObjectsOwner: isSparkApplication, + })) +} + +// +kubebuilder:rbac:groups="",resources=events,verbs=create;watch;update;patch +// +kubebuilder:rbac:groups=sparkoperator.k8s.io,resources=sparkapplications,verbs=get;list;watch;update;patch +// +kubebuilder:rbac:groups=sparkoperator.k8s.io,resources=sparkapplications/status,verbs=get;update;patch +// +kubebuilder:rbac:groups=sparkoperator.k8s.io,resources=sparkapplications/finalizers,verbs=get;update +// +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads,verbs=get;list;watch;create;update;patch;delete +// +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads/status,verbs=get;update;patch +// +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads/finalizers,verbs=update + +func NewJob() jobframework.GenericJob { + return &SparkApplication{} +} + +var NewReconciler = jobframework.NewGenericReconcilerFactory(NewJob) + +func isSparkApplication(owner *metav1.OwnerReference) bool { + return owner.Kind == "SparkApplication" && strings.HasPrefix(owner.APIVersion, kfsparkapi.GroupVersion.Group) +} + +type SparkApplication kfsparkapi.SparkApplication + +var _ jobframework.GenericJob = (*SparkApplication)(nil) + +func fromObject(obj runtime.Object) *SparkApplication { + return (*SparkApplication)(obj.(*kfsparkapi.SparkApplication)) +} + +func (s *SparkApplication) Finished() (string, bool, bool) { + failed := s.Status.AppState.State == kfsparkapi.ApplicationStateFailed + success := s.Status.AppState.State == kfsparkapi.ApplicationStateCompleted + finished := success || failed + if finished { + if failed { + return s.Status.AppState.ErrorMessage, false, true + } + return "", true, true + } + return "", false, false +} + +func (s *SparkApplication) GVK() schema.GroupVersionKind { + return gvk +} + +func (s *SparkApplication) IsActive() bool { + return kfsparkutil.IsTerminated((*kfsparkapi.SparkApplication)(s)) +} + +func (s *SparkApplication) IsSuspended() bool { + return s.Spec.Suspend +} + +func (s *SparkApplication) Object() client.Object { + return (*kfsparkapi.SparkApplication)(s) +} + +func (s *SparkApplication) PodSets() ([]kueue.PodSet, error) { + driverPodSet, err := s.driverPodSet() + if err != nil { + return nil, err + } + executorPodSet, err := s.executorPodSet() + if err != nil { + return nil, err + } + return []kueue.PodSet{*driverPodSet, *executorPodSet}, nil +} + +func (s *SparkApplication) PodsReady() bool { + driverReady := kfsparkutil.IsDriverRunning(s.asSparkApp()) + if !driverReady { + return false + } + for _, executorState := range s.Status.ExecutorState { + if executorState != kfsparkapi.ExecutorStateRunning { + return false + } + } + return true +} + +func (s *SparkApplication) RestorePodSetsInfo(podSetsInfo []podset.PodSetInfo) bool { + if len(podSetsInfo) == 0 { + return false + } + changed := false + for _, info := range podSetsInfo { + if info.Name == kfsparkcommon.SparkRoleDriver { + ps, err := s.driverPodSet() + if err != nil { + continue + } + changed = podset.RestorePodSpec(&ps.Template.ObjectMeta, &ps.Template.Spec, info) || changed + continue + } + if info.Name == kfsparkcommon.SparkRoleExecutor { + ps, err := s.executorPodSet() + if err != nil { + continue + } + changed = podset.RestorePodSpec(&ps.Template.ObjectMeta, &ps.Template.Spec, info) || changed + continue + } + } + return changed +} + +func (s *SparkApplication) RunWithPodSetsInfo(podSetsInfo []podset.PodSetInfo) error { + s.Spec.Suspend = false + for _, info := range podSetsInfo { + if info.Name == kfsparkcommon.SparkRoleDriver { + ps, err := s.driverPodSet() + if err != nil { + return err + } + if err := podset.Merge(&ps.Template.ObjectMeta, &ps.Template.Spec, info); err != nil { + return nil + } + continue + } + if info.Name == kfsparkcommon.SparkRoleExecutor { + ps, err := s.executorPodSet() + if err != nil { + return err + } + if err := podset.Merge(&ps.Template.ObjectMeta, &ps.Template.Spec, info); err != nil { + return nil + } + continue + } + } + return nil +} + +func (s *SparkApplication) Suspend() { + s.Spec.Suspend = true +} + +func SetupIndexes(ctx context.Context, indexer client.FieldIndexer) error { + return jobframework.SetupWorkloadOwnerIndex(ctx, indexer, gvk) +} + +func (s *SparkApplication) asSparkApp() *kfsparkapi.SparkApplication { + return (*kfsparkapi.SparkApplication)(s) +} + +func (s *SparkApplication) driverPodSet() (*kueue.PodSet, error) { + podTemplate, containerIndex, err := sparkPodSpecToPodSetTemplate(s.Spec.Driver.SparkPodSpec, field.NewPath("spec", "driver"), kfsparkcommon.SparkDriverContainerName) + if err != nil { + return nil, err + } + applySparkSpecFieldsToPodTemplate(s.Spec, containerIndex, podTemplate) + + // DriverSpec + // TODO: apply other fields in the spec. + // But, there is no problems because these fields are not interested in PodSet. + if s.Spec.Driver.CoreRequest != nil { + if q, err := resource.ParseQuantity(*s.Spec.Driver.CoreRequest); err != nil { + return nil, fmt.Errorf("spec.driver.coreRequest=%s can't parse: %w", *s.Spec.Driver.CoreRequest, err) + } else { + podTemplate.Spec.Containers[containerIndex].Resources.Requests[corev1.ResourceCPU] = q + } + } + if s.Spec.Driver.PriorityClassName != nil { + podTemplate.Spec.PriorityClassName = *s.Spec.Driver.PriorityClassName + } + + return &kueue.PodSet{ + Name: kfsparkcommon.SparkRoleDriver, + Count: int32(1), + Template: *podTemplate, + TopologyRequest: jobframework.PodSetTopologyRequest(&podTemplate.ObjectMeta, nil, nil, nil), + }, nil +} + +func (s *SparkApplication) executorPodSet() (*kueue.PodSet, error) { + podTemplate, containerIndex, err := sparkPodSpecToPodSetTemplate(s.Spec.Executor.SparkPodSpec, field.NewPath("spec", "executor"), kfsparkcommon.SparkExecutorContainerName) + if err != nil { + return nil, err + } + + applySparkSpecFieldsToPodTemplate(s.Spec, containerIndex, podTemplate) + + // Executor Spec + // TODO: apply other fields in the spec. + // But, there is no problems because these fields are not interested in PodSet. + if s.Spec.Executor.CoreRequest != nil { + if q, err := resource.ParseQuantity(*s.Spec.Executor.CoreRequest); err != nil { + // TODO: Log + } else { + podTemplate.Spec.Containers[containerIndex].Resources.Requests[corev1.ResourceCPU] = q + } + } + if s.Spec.Executor.PriorityClassName != nil { + podTemplate.Spec.PriorityClassName = *s.Spec.Executor.PriorityClassName + } + return &kueue.PodSet{ + Name: kfsparkcommon.SparkRoleExecutor, + Template: *podTemplate, + Count: s.executorCount(), + TopologyRequest: jobframework.PodSetTopologyRequest(&podTemplate.ObjectMeta, ptr.To(kfsparkcommon.LabelSparkExecutorID), nil, nil), + }, nil +} + +func sparkPodSpecToPodSetTemplate(sparkPodSpec kfsparkapi.SparkPodSpec, field *field.Path, targetContainerName string) (*corev1.PodTemplateSpec, int, error) { + podTemplate := &corev1.PodTemplateSpec{} + + if sparkPodSpec.Template != nil { + podTemplate = sparkPodSpec.Template + } + + containerIndex := findContainer(&podTemplate.Spec, targetContainerName) + if containerIndex == -1 { + podTemplate.Spec.Containers = append(podTemplate.Spec.Containers, corev1.Container{ + Name: targetContainerName, + }) + containerIndex = len(podTemplate.Spec.Containers) - 1 + } + + if sparkPodSpec.CoreLimit != nil { + if q, err := resource.ParseQuantity(*sparkPodSpec.CoreLimit); err != nil { + return nil, -1, fmt.Errorf("%s=%s can't parse: %w", field.Child("coreLimit"), *sparkPodSpec.CoreLimit, err) + } else { + if podTemplate.Spec.Containers[containerIndex].Resources.Limits == nil { + podTemplate.Spec.Containers[containerIndex].Resources.Limits = corev1.ResourceList{} + } + podTemplate.Spec.Containers[containerIndex].Resources.Limits[corev1.ResourceCPU] = q + } + } + + if sparkPodSpec.Memory != nil { + if q, err := resource.ParseQuantity(*sparkPodSpec.Memory); err != nil { + return nil, -1, fmt.Errorf("%s=%s can't parse: %w", field.Child("memory"), *sparkPodSpec.Memory, err) + } else { + if podTemplate.Spec.Containers[containerIndex].Resources.Requests == nil { + podTemplate.Spec.Containers[containerIndex].Resources.Requests = corev1.ResourceList{} + } + podTemplate.Spec.Containers[containerIndex].Resources.Requests[corev1.ResourceMemory] = q + } + } + + if sparkPodSpec.GPU != nil { + qStr := fmt.Sprintf("%d", sparkPodSpec.GPU.Quantity) + if q, err := resource.ParseQuantity(qStr); err != nil { + return nil, -1, fmt.Errorf("%s=%s can't parse: %w", field.Child("gpu", "quantity"), qStr, err) + } else { + podTemplate.Spec.Containers[containerIndex].Resources.Limits[corev1.ResourceName(sparkPodSpec.GPU.Name)] = q + } + } + + if sparkPodSpec.Image != nil { + podTemplate.Spec.Containers[containerIndex].Image = *sparkPodSpec.Image + } + + for k, v := range sparkPodSpec.Labels { + if podTemplate.Labels == nil { + podTemplate.Labels = map[string]string{} + } + podTemplate.Labels[k] = v + } + + for k, v := range sparkPodSpec.Annotations { + if podTemplate.Annotations == nil { + podTemplate.Annotations = map[string]string{} + } + podTemplate.Annotations[k] = v + } + + if sparkPodSpec.Affinity != nil { + podTemplate.Spec.Affinity = sparkPodSpec.Affinity + } + + if sparkPodSpec.Tolerations != nil { + podTemplate.Spec.Tolerations = sparkPodSpec.Tolerations + } + + if len(sparkPodSpec.Sidecars) > 0 { + podTemplate.Spec.Containers = append(podTemplate.Spec.Containers, sparkPodSpec.Sidecars...) + } + + if len(sparkPodSpec.InitContainers) > 0 { + podTemplate.Spec.InitContainers = append(podTemplate.Spec.InitContainers, sparkPodSpec.InitContainers...) + } + + if sparkPodSpec.NodeSelector != nil { + podTemplate.Spec.NodeSelector = sparkPodSpec.NodeSelector + } + + // TODO: apply other fields in SparkPodSpec. + // But, there is no problems because these fields are not interested in PodSet. + + return podTemplate, containerIndex, nil +} + +func applySparkSpecFieldsToPodTemplate(spec kfsparkapi.SparkApplicationSpec, targetContainerIndex int, podTemplate *corev1.PodTemplateSpec) { + if spec.Image != nil && podTemplate.Spec.Containers[targetContainerIndex].Image == "" { + podTemplate.Spec.Containers[targetContainerIndex].Image = *spec.Image + } + if spec.NodeSelector != nil && podTemplate.Spec.NodeSelector == nil { + podTemplate.Spec.NodeSelector = spec.NodeSelector + } + + // TODO: apply other fields in SparkApplicationSpec. + // But, there is no problems because these fields are not interested in PodSet. +} + +func (s *SparkApplication) executorCount() int32 { + return ptr.Deref(s.Spec.Executor.Instances, 1) +} + +func findContainer(podSpec *corev1.PodSpec, containerName string) int { + for i, c := range podSpec.Containers { + if c.Name == containerName { + return i + } + } + return -1 +} diff --git a/pkg/controller/jobs/sparkapplication/sparkapplication_controller_test.go b/pkg/controller/jobs/sparkapplication/sparkapplication_controller_test.go new file mode 100644 index 00000000000..f63151e3dab --- /dev/null +++ b/pkg/controller/jobs/sparkapplication/sparkapplication_controller_test.go @@ -0,0 +1,373 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sparkapplication + +import ( + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/tools/record" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + + kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1" + kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" + "sigs.k8s.io/kueue/pkg/constants" + "sigs.k8s.io/kueue/pkg/controller/jobframework" + utiltesting "sigs.k8s.io/kueue/pkg/util/testing" + podtesting "sigs.k8s.io/kueue/pkg/util/testingjobs/pod" + sparkapptesting "sigs.k8s.io/kueue/pkg/util/testingjobs/sparkapplication" + + kfsparkapi "github.com/kubeflow/spark-operator/api/v1beta2" + kfsparkcommon "github.com/kubeflow/spark-operator/pkg/common" +) + +var ( + baseAppWrapper = sparkapptesting.MakeSparkApplication("job", "ns"). + Queue("queue"). + DynamicAllocation(false). + ExecutorInstances(1) + expectedDriverPodWrapper = &podtesting.PodWrapper{Pod: corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + "spec-driver-labels": "spec-driver-labels", + }, + Annotations: map[string]string{ + "spec-driver-annotations": "spec-driver-annotations", + }, + }, + Spec: corev1.PodSpec{ + PriorityClassName: *baseAppWrapper.Spec.Driver.PriorityClassName, + Containers: []corev1.Container{{ + Name: kfsparkcommon.SparkDriverContainerName, + Image: *baseAppWrapper.Spec.Image, + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse(*baseAppWrapper.Spec.Driver.CoreRequest), + corev1.ResourceMemory: resource.MustParse(*baseAppWrapper.Spec.Driver.Memory), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse(*baseAppWrapper.Spec.Driver.CoreLimit), + }, + }, + }}, + Volumes: baseAppWrapper.Spec.Volumes, + }, + }} + expectedExecutorPodWrapper = &podtesting.PodWrapper{Pod: corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + "spec-executor-template-labels": "spec-executor-template-labels", + "spec-executor-labels": "spec-executor-labels", + }, + Annotations: map[string]string{ + "spec-executor-annotations": "spec-executor-annotations", + }, + }, + Spec: corev1.PodSpec{ + Affinity: baseAppWrapper.Spec.Executor.Affinity, + NodeSelector: baseAppWrapper.Spec.Executor.NodeSelector, + Tolerations: baseAppWrapper.Spec.Executor.Tolerations, + InitContainers: baseAppWrapper.Spec.Executor.InitContainers, + PriorityClassName: *baseAppWrapper.Spec.Executor.PriorityClassName, + Containers: append([]corev1.Container{{ + Name: kfsparkcommon.SparkExecutorContainerName, + Image: *baseAppWrapper.Spec.Executor.Image, + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse(*baseAppWrapper.Spec.Executor.CoreRequest), + corev1.ResourceMemory: resource.MustParse(*baseAppWrapper.Spec.Executor.Memory), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse(*baseAppWrapper.Spec.Executor.CoreLimit), + corev1.ResourceName(baseAppWrapper.Spec.Executor.GPU.Name): resource.MustParse(fmt.Sprintf("%d", baseAppWrapper.Spec.Executor.GPU.Quantity)), + }, + }, + }}, baseAppWrapper.Spec.Executor.Sidecars...), + }, + }} +) + +func TestPodSets(t *testing.T) { + testCases := map[string]struct { + sparkApp *kfsparkapi.SparkApplication + wantPodSets []kueue.PodSet + wantErr error + }{ + "normal case": { + sparkApp: baseAppWrapper.Clone().Obj(), + wantPodSets: []kueue.PodSet{{ + Name: kfsparkcommon.SparkRoleDriver, + Count: int32(1), + Template: toPodTemplateSpec(expectedDriverPodWrapper), + }, { + Name: kfsparkcommon.SparkRoleExecutor, + Count: *baseAppWrapper.Spec.Executor.Instances, + Template: toPodTemplateSpec(expectedExecutorPodWrapper.Clone()), + }}, + }, + "with required topology annotation": { + sparkApp: baseAppWrapper.Clone(). + PodAnnotation(kfsparkcommon.SparkRoleDriver, kueuealpha.PodSetRequiredTopologyAnnotation, "cloud.com/block"). + PodAnnotation(kfsparkcommon.SparkRoleExecutor, kueuealpha.PodSetRequiredTopologyAnnotation, "cloud.com/block").Obj(), + wantPodSets: []kueue.PodSet{{ + Name: kfsparkcommon.SparkRoleDriver, + Count: int32(1), + Template: toPodTemplateSpec( + expectedDriverPodWrapper.Clone().Annotation(kueuealpha.PodSetRequiredTopologyAnnotation, "cloud.com/block"), + ), + TopologyRequest: &kueue.PodSetTopologyRequest{ + Required: ptr.To("cloud.com/block"), + }, + }, { + Name: kfsparkcommon.SparkRoleExecutor, + Count: *baseAppWrapper.Spec.Executor.Instances, + Template: toPodTemplateSpec( + expectedExecutorPodWrapper.Clone().Annotation(kueuealpha.PodSetRequiredTopologyAnnotation, "cloud.com/block"), + ), + TopologyRequest: &kueue.PodSetTopologyRequest{ + Required: ptr.To("cloud.com/block"), + PodIndexLabel: ptr.To(kfsparkcommon.LabelSparkExecutorID), + }, + }}, + }, + "with preferred topology annotation": { + sparkApp: baseAppWrapper.Clone(). + PodAnnotation(kfsparkcommon.SparkRoleDriver, kueuealpha.PodSetPreferredTopologyAnnotation, "cloud.com/block"). + PodAnnotation(kfsparkcommon.SparkRoleExecutor, kueuealpha.PodSetPreferredTopologyAnnotation, "cloud.com/block").Obj(), + wantPodSets: []kueue.PodSet{{ + Name: kfsparkcommon.SparkRoleDriver, + Count: int32(1), + Template: toPodTemplateSpec( + expectedDriverPodWrapper.Clone().Annotation(kueuealpha.PodSetPreferredTopologyAnnotation, "cloud.com/block"), + ), + TopologyRequest: &kueue.PodSetTopologyRequest{ + Preferred: ptr.To("cloud.com/block"), + }, + }, { + Name: kfsparkcommon.SparkRoleExecutor, + Count: *baseAppWrapper.Spec.Executor.Instances, + Template: toPodTemplateSpec( + expectedExecutorPodWrapper.Clone().Annotation(kueuealpha.PodSetPreferredTopologyAnnotation, "cloud.com/block"), + ), + TopologyRequest: &kueue.PodSetTopologyRequest{ + Preferred: ptr.To("cloud.com/block"), + PodIndexLabel: ptr.To(kfsparkcommon.LabelSparkExecutorID), + }, + }}, + }, + "invalid - can't parse resource": { + sparkApp: baseAppWrapper.Clone().CoreLimit(kfsparkcommon.SparkRoleDriver, ptr.To("malformat")).Obj(), + wantErr: fmt.Errorf(`spec.driver.coreLimit=malformat can't parse: %w`, resource.ErrFormatWrong), + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + gotPodSets, gotErr := fromObject(tc.sparkApp).PodSets() + if gotErr != nil { + if tc.wantErr == nil { + t.Errorf("PodSets expected error (-want,+got):\n-nil\n+%s", gotErr.Error()) + } + if diff := cmp.Diff(tc.wantErr.Error(), gotErr.Error()); diff != "" { + t.Errorf("PodSets returned error (-want,+got):\n%s", diff) + } + return + } + if diff := cmp.Diff(tc.wantPodSets, gotPodSets); diff != "" { + t.Errorf("pod sets mismatch (-want +got):\n%s", diff) + } + }) + } +} + +var ( + jobCmpOpts = []cmp.Option{ + cmpopts.EquateEmpty(), + cmpopts.IgnoreFields(kfsparkapi.SparkApplication{}, "TypeMeta", "ObjectMeta"), + } + workloadCmpOpts = []cmp.Option{ + cmpopts.EquateEmpty(), + cmpopts.IgnoreFields(kueue.Workload{}, "TypeMeta"), + cmpopts.IgnoreFields(metav1.ObjectMeta{}, "Name", "Labels", "ResourceVersion", "OwnerReferences", "Finalizers"), + cmpopts.IgnoreFields(metav1.Condition{}, "LastTransitionTime"), + cmpopts.IgnoreFields(kueue.PodSet{}, "Template"), + } +) + +func TestReconciler(t *testing.T) { + baseWPCWrapper := utiltesting.MakeWorkloadPriorityClass("test-wpc"). + PriorityValue(100) + driverPCWrapper := utiltesting.MakePriorityClass("driver-priority-class"). + PriorityValue(200) + executorPCWrapper := utiltesting.MakePriorityClass("executor-priority-class").PriorityValue(100) + + testNamespace := &corev1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + Name: "ns", + Labels: map[string]string{ + "kubernetes.io/metadata.name": "ns", + }, + }, + } + + testCases := map[string]struct { + reconcilerOptions []jobframework.Option + sparkApp *kfsparkapi.SparkApplication + priorityClasses []client.Object + wantSparkApp *kfsparkapi.SparkApplication + wantWorkloads []kueue.Workload + wantErr error + }{ + "workload is created with queue and priorityClass": { + sparkApp: baseAppWrapper.Clone().Obj(), + wantSparkApp: baseAppWrapper.Clone().Suspend(true).Obj(), + wantWorkloads: []kueue.Workload{ + *utiltesting.MakeWorkload("job", "ns"). + Queue("queue"). + PriorityClass(driverPCWrapper.Name). + Priority(driverPCWrapper.Value). + PriorityClassSource(constants.PodPriorityClassSource). + PodSets( + *utiltesting.MakePodSet(kfsparkcommon.SparkRoleDriver, 1).Obj(), + *utiltesting.MakePodSet(kfsparkcommon.SparkRoleExecutor, int(*baseAppWrapper.Spec.Executor.Instances)).Obj(), + ).Obj(), + }, + priorityClasses: []client.Object{ + driverPCWrapper.Obj(), executorPCWrapper.Obj(), + }, + }, + "workload is created with queue, priorityClass and workloadPriorityClass": { + sparkApp: baseAppWrapper.Clone().WorkloadPriorityClass(baseWPCWrapper.Name).Obj(), + wantSparkApp: baseAppWrapper.Clone().WorkloadPriorityClass(baseWPCWrapper.Name).Suspend(true).Obj(), + wantWorkloads: []kueue.Workload{ + *utiltesting.MakeWorkload("job", "ns"). + Queue("queue"). + PriorityClass(baseWPCWrapper.Name). + Priority(baseWPCWrapper.Value). + PriorityClassSource(constants.WorkloadPriorityClassSource). + PodSets( + *utiltesting.MakePodSet(kfsparkcommon.SparkRoleDriver, 1).Obj(), + *utiltesting.MakePodSet(kfsparkcommon.SparkRoleExecutor, int(*baseAppWrapper.Spec.Executor.Instances)).Obj(), + ).Obj(), + }, + priorityClasses: []client.Object{ + baseWPCWrapper.Obj(), driverPCWrapper.Obj(), executorPCWrapper.Obj(), + }, + }, + "workload is created with queue, priorityClass and required topology request": { + sparkApp: baseAppWrapper.Clone(). + PodAnnotation(kfsparkcommon.SparkRoleDriver, kueuealpha.PodSetRequiredTopologyAnnotation, "cloud.com/block"). + PodAnnotation(kfsparkcommon.SparkRoleExecutor, kueuealpha.PodSetRequiredTopologyAnnotation, "cloud.com/block").Obj(), + wantSparkApp: baseAppWrapper.Clone(). + Suspend(true). + PodAnnotation(kfsparkcommon.SparkRoleDriver, kueuealpha.PodSetRequiredTopologyAnnotation, "cloud.com/block"). + PodAnnotation(kfsparkcommon.SparkRoleExecutor, kueuealpha.PodSetRequiredTopologyAnnotation, "cloud.com/block").Obj(), + wantWorkloads: []kueue.Workload{ + *utiltesting.MakeWorkload("job", "ns"). + Queue("queue"). + PriorityClass(driverPCWrapper.Name). + Priority(driverPCWrapper.Value). + PriorityClassSource(constants.PodPriorityClassSource). + PodSets( + *utiltesting.MakePodSet(kfsparkcommon.SparkRoleDriver, 1).RequiredTopologyRequest("cloud.com/block").Obj(), + *utiltesting.MakePodSet(kfsparkcommon.SparkRoleExecutor, int(*baseAppWrapper.Spec.Executor.Instances)). + RequiredTopologyRequest("cloud.com/block").PodIndexLabel(ptr.To(kfsparkcommon.LabelSparkExecutorID)).Obj(), + ).Obj(), + }, + priorityClasses: []client.Object{ + driverPCWrapper.Obj(), executorPCWrapper.Obj(), + }, + }, + "workload is created with queue, priorityClass and preferred topology request": { + sparkApp: baseAppWrapper.Clone(). + PodAnnotation(kfsparkcommon.SparkRoleDriver, kueuealpha.PodSetPreferredTopologyAnnotation, "cloud.com/block"). + PodAnnotation(kfsparkcommon.SparkRoleExecutor, kueuealpha.PodSetPreferredTopologyAnnotation, "cloud.com/block").Obj(), + wantSparkApp: baseAppWrapper.Clone(). + Suspend(true). + PodAnnotation(kfsparkcommon.SparkRoleDriver, kueuealpha.PodSetPreferredTopologyAnnotation, "cloud.com/block"). + PodAnnotation(kfsparkcommon.SparkRoleExecutor, kueuealpha.PodSetPreferredTopologyAnnotation, "cloud.com/block").Obj(), + wantWorkloads: []kueue.Workload{ + *utiltesting.MakeWorkload("job", "ns"). + Queue("queue"). + PriorityClass(driverPCWrapper.Name). + Priority(driverPCWrapper.Value). + PriorityClassSource(constants.PodPriorityClassSource). + PodSets( + *utiltesting.MakePodSet(kfsparkcommon.SparkRoleDriver, 1).PreferredTopologyRequest("cloud.com/block").Obj(), + *utiltesting.MakePodSet(kfsparkcommon.SparkRoleExecutor, int(*baseAppWrapper.Spec.Executor.Instances)). + PreferredTopologyRequest("cloud.com/block").PodIndexLabel(ptr.To(kfsparkcommon.LabelSparkExecutorID)).Obj(), + ).Obj(), + }, + priorityClasses: []client.Object{ + driverPCWrapper.Obj(), executorPCWrapper.Obj(), + }, + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + ctx, _ := utiltesting.ContextWithLog(t) + clientBuilder := utiltesting.NewClientBuilder(kfsparkapi.AddToScheme) + if err := SetupIndexes(ctx, utiltesting.AsIndexer(clientBuilder)); err != nil { + t.Fatalf("Could not setup indexes: %v", err) + } + + objs := append(tc.priorityClasses, tc.sparkApp, testNamespace) + kClient := clientBuilder.WithObjects(objs...).Build() + recorder := record.NewBroadcaster().NewRecorder(kClient.Scheme(), corev1.EventSource{Component: "test"}) + + reconciler := NewReconciler(kClient, recorder, tc.reconcilerOptions...) + + jobKey := client.ObjectKeyFromObject(tc.sparkApp) + _, err := reconciler.Reconcile(ctx, reconcile.Request{ + NamespacedName: jobKey, + }) + if diff := cmp.Diff(tc.wantErr, err, cmpopts.EquateErrors()); diff != "" { + t.Errorf("Reconcile returned error (-want,+got):\n%s", diff) + } + + var gotSparkApp kfsparkapi.SparkApplication + if err := kClient.Get(ctx, jobKey, &gotSparkApp); err != nil { + t.Fatalf("Could not get Job after reconcile: %v", err) + } + if diff := cmp.Diff(tc.wantSparkApp, &gotSparkApp, jobCmpOpts...); diff != "" { + t.Errorf("SparkApplication after reconcile (-want,+got):\n%s", diff) + } + var gotWorkloads kueue.WorkloadList + if err := kClient.List(ctx, &gotWorkloads); err != nil { + t.Fatalf("Could not get Workloads after reconcile: %v", err) + } + if diff := cmp.Diff(tc.wantWorkloads, gotWorkloads.Items, workloadCmpOpts...); diff != "" { + t.Errorf("Workloads after reconcile (-want,+got):\n%s", diff) + } + }) + } +} + +func toPodTemplateSpec(pw *podtesting.PodWrapper) corev1.PodTemplateSpec { + return corev1.PodTemplateSpec{ + ObjectMeta: pw.ObjectMeta, + Spec: pw.Spec, + } +} diff --git a/pkg/controller/jobs/sparkapplication/sparkapplication_webhook.go b/pkg/controller/jobs/sparkapplication/sparkapplication_webhook.go new file mode 100644 index 00000000000..b55b635c377 --- /dev/null +++ b/pkg/controller/jobs/sparkapplication/sparkapplication_webhook.go @@ -0,0 +1,137 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sparkapplication + +import ( + "context" + + apivalidation "k8s.io/apimachinery/pkg/api/validation" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/validation/field" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + + "sigs.k8s.io/kueue/pkg/controller/jobframework" + "sigs.k8s.io/kueue/pkg/controller/jobframework/webhook" + "sigs.k8s.io/kueue/pkg/features" + "sigs.k8s.io/kueue/pkg/queue" + + kfsparkapi "github.com/kubeflow/spark-operator/api/v1beta2" +) + +type SparkApplicationWebhook struct { + client client.Client + manageJobsWithoutQueueName bool + managedJobsNamespaceSelector labels.Selector + queues *queue.Manager +} + +// SetupSparkApplicationWebhook configures the webhook for SparkApplication. +func SetupSparkApplicationWebhook(mgr ctrl.Manager, opts ...jobframework.Option) error { + options := jobframework.ProcessOptions(opts...) + wh := &SparkApplicationWebhook{ + client: mgr.GetClient(), + manageJobsWithoutQueueName: options.ManageJobsWithoutQueueName, + managedJobsNamespaceSelector: options.ManagedJobsNamespaceSelector, + queues: options.Queues, + } + obj := &kfsparkapi.SparkApplication{} + return webhook.WebhookManagedBy(mgr). + For(obj). + WithMutationHandler(webhook.WithLosslessDefaulter(mgr.GetScheme(), obj, wh)). + WithValidator(wh). + Complete() +} + +// +kubebuilder:webhook:path=/mutate-sparkoperator-k8s-io-v1beta2-sparkapplication,mutating=true,failurePolicy=fail,sideEffects=None,groups=sparkoperator.k8s.io,resources=sparkapplications,verbs=create,versions=v1beta2,name=msparkapplication.kb.io,admissionReviewVersions=v1 + +var _ admission.CustomDefaulter = &SparkApplicationWebhook{} + +// Default implements webhook.CustomDefaulter so a webhook will be registered for the type +func (wh *SparkApplicationWebhook) Default(ctx context.Context, obj runtime.Object) error { + sparkApp := fromObject(obj) + log := ctrl.LoggerFrom(ctx).WithName("sparkapplication-webhook") + log.V(5).Info("Applying defaults") + jobframework.ApplyDefaultLocalQueue(sparkApp.Object(), wh.queues.DefaultLocalQueueExist) + return jobframework.ApplyDefaultForSuspend(ctx, sparkApp, wh.client, wh.manageJobsWithoutQueueName, wh.managedJobsNamespaceSelector) +} + +// +kubebuilder:webhook:path=/validate-sparkoperator-k8s-io-v1beta2-sparkapplication,mutating=false,failurePolicy=fail,sideEffects=None,groups=sparkoperator.k8s.io,resources=sparkapplications,verbs=create;update,versions=v1beta2,name=vsparkapplication.kb.io,admissionReviewVersions=v1 + +var _ admission.CustomValidator = &SparkApplicationWebhook{} + +// ValidateCreate implements webhook.CustomValidator so a webhook will be registered for the type +func (wh *SparkApplicationWebhook) ValidateCreate(ctx context.Context, obj runtime.Object) (admission.Warnings, error) { + sparkApplication := fromObject(obj) + log := ctrl.LoggerFrom(ctx).WithName("sparkapplication-webhook") + log.Info("Validating create") + return nil, wh.validateCreate(sparkApplication).ToAggregate() +} + +// ValidateUpdate implements webhook.CustomValidator so a webhook will be registered for the type +func (wh *SparkApplicationWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (admission.Warnings, error) { + oldSparkApp := fromObject(oldObj) + newSparkApp := fromObject(newObj) + log := ctrl.LoggerFrom(ctx).WithName("sparkapplication-webhook") + log.Info("Validating update") + return nil, wh.validateUpdate(oldSparkApp, newSparkApp).ToAggregate() +} + +func (wh *SparkApplicationWebhook) validateUpdate(oldSparkApp, newSparkApp *SparkApplication) field.ErrorList { + var allErrs field.ErrorList + allErrs = append(allErrs, jobframework.ValidateJobOnUpdate(oldSparkApp, newSparkApp)...) + if !newSparkApp.IsSuspended() { + allErrs = append(allErrs, apivalidation.ValidateImmutableField( + newSparkApp.Spec, oldSparkApp.Spec, + field.NewPath("spec"), + )...) + } + allErrs = append(allErrs, wh.validateCreate(newSparkApp)...) + return allErrs +} + +func (wh *SparkApplicationWebhook) validateCreate(sparkApplication *SparkApplication) field.ErrorList { + var allErrs field.ErrorList + allErrs = append(allErrs, jobframework.ValidateJobOnCreate(sparkApplication)...) + + if wh.manageJobsWithoutQueueName || jobframework.QueueName(sparkApplication) != "" || features.Enabled(features.LocalQueueDefaulting) { + // We can't support dynamically scaling jobs yet: + // https://github.com/kubernetes-sigs/kueue/issues/77 + dynamicAllocationPath := field.NewPath("spec", "dynamicAllocation") + if sparkApplication.Spec.DynamicAllocation == nil { + allErrs = append(allErrs, field.Required(dynamicAllocationPath, "a kueue managed job should disable dynamicAllocation explicitly")) + } else { + if sparkApplication.Spec.DynamicAllocation.Enabled { + allErrs = append(allErrs, field.Invalid(dynamicAllocationPath.Child("enabled"), sparkApplication.Spec.DynamicAllocation.Enabled, "a kueue managed job should disable dynamicAllocation explicitly")) + } + } + + executorInstancesPath := field.NewPath("spec", "executor", "instances") + if sparkApplication.Spec.Executor.Instances == nil { + allErrs = append(allErrs, field.Required(executorInstancesPath, "a kueue managed job should set the number of executors explicitly")) + } + } + + return allErrs +} + +// ValidateDelete implements webhook.CustomValidator so a webhook will be registered for the type +func (wh *SparkApplicationWebhook) ValidateDelete(_ context.Context, _ runtime.Object) (admission.Warnings, error) { + return nil, nil +} diff --git a/pkg/controller/jobs/sparkapplication/sparkapplication_webhook_test.go b/pkg/controller/jobs/sparkapplication/sparkapplication_webhook_test.go new file mode 100644 index 00000000000..8f8d42ddf0a --- /dev/null +++ b/pkg/controller/jobs/sparkapplication/sparkapplication_webhook_test.go @@ -0,0 +1,289 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sparkapplication + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + + apivalidation "k8s.io/apimachinery/pkg/api/validation" + "k8s.io/apimachinery/pkg/util/validation/field" + + "sigs.k8s.io/kueue/pkg/cache" + "sigs.k8s.io/kueue/pkg/controller/constants" + "sigs.k8s.io/kueue/pkg/features" + "sigs.k8s.io/kueue/pkg/queue" + utiltesting "sigs.k8s.io/kueue/pkg/util/testing" + + kfsparkapi "github.com/kubeflow/spark-operator/api/v1beta2" + sparkapptesting "sigs.k8s.io/kueue/pkg/util/testingjobs/sparkapplication" +) + +func TestDefaults(t *testing.T) { + testCases := map[string]struct { + sparkApp *kfsparkapi.SparkApplication + manageJobsWithoutQueueName bool + localQueueDefaulting bool + defaultLqExist bool + want *kfsparkapi.SparkApplication + }{ + "unmanaged": { + sparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Obj(), + want: sparkapptesting.MakeSparkApplication("job", "ns").Obj(), + }, + "manageJobsWithoutQueueName=true": { + sparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Obj(), + want: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Obj(), + manageJobsWithoutQueueName: true, + }, + "sparkapplication with queue": { + sparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Queue("queue").Obj(), + want: sparkapptesting.MakeSparkApplication("job", "ns").Queue("queue").Suspend(true).Obj(), + }, + "LocalQueueDefaulting enabled, DefaultLocalQueue exists, sparkapplication without queue": { + sparkApp: sparkapptesting.MakeSparkApplication("job", "default").Obj(), + want: sparkapptesting.MakeSparkApplication("job", "default").Queue("default").Suspend(true).Obj(), + localQueueDefaulting: true, + defaultLqExist: true, + }, + "LocalQueueDefaulting enabled, DefaultLocalQueue exists, sparkapplication with queue": { + sparkApp: sparkapptesting.MakeSparkApplication("job", "default").Queue("queue").Obj(), + want: sparkapptesting.MakeSparkApplication("job", "default").Queue("queue").Suspend(true).Obj(), + localQueueDefaulting: true, + defaultLqExist: true, + }, + "LocalQueueDefaulting enabled, DefaultLocalQueue does not exists, sparkapplication without queue": { + sparkApp: sparkapptesting.MakeSparkApplication("job", "default").Obj(), + want: sparkapptesting.MakeSparkApplication("job", "default").Obj(), + localQueueDefaulting: true, + defaultLqExist: false, + }, + "LocalQueueDefaulting enabled, DefaultLocalQueue does not exists, sparkapplication with queue": { + sparkApp: sparkapptesting.MakeSparkApplication("job", "default").Queue("queue").Obj(), + want: sparkapptesting.MakeSparkApplication("job", "default").Queue("queue").Suspend(true).Obj(), + localQueueDefaulting: true, + defaultLqExist: false, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.ManagedJobsNamespaceSelector, false) + features.SetFeatureGateDuringTest(t, features.LocalQueueDefaulting, tc.localQueueDefaulting) + ctx, _ := utiltesting.ContextWithLog(t) + + builder := utiltesting.NewClientBuilder() + cli := builder.Build() + cqCache := cache.New(cli) + queueManager := queue.NewManager(cli, cqCache) + if tc.defaultLqExist { + if err := queueManager.AddLocalQueue(ctx, utiltesting.MakeLocalQueue("default", "default"). + ClusterQueue("cluster-queue").Obj()); err != nil { + t.Fatalf("failed to create default local queue: %s", err) + } + } + + wh := &SparkApplicationWebhook{ + client: cli, + manageJobsWithoutQueueName: tc.manageJobsWithoutQueueName, + queues: queueManager, + } + if err := wh.Default(ctx, tc.sparkApp); err != nil { + t.Errorf("failed to set defaults for sparkoperator.k8s.io/v1beta2/sparkapplication: %s", err) + } + if diff := cmp.Diff(tc.want, tc.sparkApp); len(diff) != 0 { + t.Errorf("Default() mismatch (-want,+got):\n%s", diff) + } + }) + } +} + +func TestValidateCreate(t *testing.T) { + testCases := map[string]struct { + sparkApp *kfsparkapi.SparkApplication + manageJobsWithoutQueueName bool + wantErr error + }{ + "valid - unmanaged": { + sparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Obj(), + }, + "valid - managed": { + sparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj(), + }, + "invalid - without dynamicAllocation": { + sparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Queue("queue").ExecutorInstances(1).Obj(), + wantErr: field.ErrorList{ + field.Required(field.NewPath("spec", "dynamicAllocation"), "a kueue managed job should disable dynamicAllocation explicitly"), + }.ToAggregate(), + }, + "invalid - with dynamicAllocation.enabled=true": { + sparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Queue("queue").DynamicAllocation(true).ExecutorInstances(1).Obj(), + wantErr: field.ErrorList{ + field.Invalid(field.NewPath("spec", "dynamicAllocation", "enabled"), true, "a kueue managed job should disable dynamicAllocation explicitly"), + }.ToAggregate(), + }, + "invalid - without executor.instances": { + sparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Queue("queue").DynamicAllocation(false).Obj(), + wantErr: field.ErrorList{ + field.Required(field.NewPath("spec", "executor", "instances"), "a kueue managed job should set the number of executors explicitly"), + }.ToAggregate(), + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + wh := &SparkApplicationWebhook{} + _, result := wh.ValidateCreate(context.Background(), tc.sparkApp) + if diff := cmp.Diff(tc.wantErr, result); diff != "" { + t.Errorf("ValidateCreate() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestValidateUpdate(t *testing.T) { + testCases := map[string]struct { + oldSparkApp *kfsparkapi.SparkApplication + newSparkApp *kfsparkapi.SparkApplication + wantErr error + }{ + "valid - update unrelated metadata": { + oldSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj(), + newSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Label("key", "value").Obj(), + wantErr: nil, + }, + + // Suspended + "valid - when suspended, nothing changed": { + oldSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj(), + newSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj(), + wantErr: nil, + }, + "valid - when suspended, update Queue": { + oldSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj(), + newSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Queue("queue2").DynamicAllocation(false).ExecutorInstances(1).Obj(), + wantErr: nil, + }, + "valid - when suspend, unset Queue": { + oldSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj(), + newSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).DynamicAllocation(false).ExecutorInstances(1).Obj(), + wantErr: nil, + }, + "valid - when suspended, update executor.instances": { + oldSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj(), + newSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Queue("queue").DynamicAllocation(false).ExecutorInstances(2).Obj(), + wantErr: nil, + }, + "invalid - when suspended, unset executor.instances": { + oldSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj(), + newSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Queue("queue").DynamicAllocation(false).Obj(), + wantErr: field.ErrorList{ + field.Required(field.NewPath("spec", "executor", "instances"), "a kueue managed job should set the number of executors explicitly"), + }.ToAggregate(), + }, + "invalid - when suspended, enable dynamicAllocation": { + oldSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj(), + newSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Queue("queue").DynamicAllocation(true).ExecutorInstances(1).Obj(), + wantErr: field.ErrorList{ + field.Invalid(field.NewPath("spec", "dynamicAllocation", "enabled"), true, "a kueue managed job should disable dynamicAllocation explicitly"), + }.ToAggregate(), + }, + "invalid - when suspended, unset dynamicAllocation": { + oldSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj(), + newSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(true).Queue("queue").ExecutorInstances(1).Obj(), + wantErr: field.ErrorList{ + field.Required(field.NewPath("spec", "dynamicAllocation"), "a kueue managed job should disable dynamicAllocation explicitly"), + }.ToAggregate(), + }, + + // Unsuspended + "valid - when unsuspended, nothing changed": { + oldSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj(), + newSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj(), + wantErr: nil, + }, + "invalid - when unsuspended, update Queue": { + oldSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj(), + newSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue2").DynamicAllocation(false).ExecutorInstances(1).Obj(), + wantErr: field.ErrorList{ + field.Invalid(field.NewPath("metadata", "labels").Key(constants.QueueLabel), "queue2", apivalidation.FieldImmutableErrorMsg), + }.ToAggregate(), + }, + "invalid - when unsuspend, unset Queue": { + oldSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj(), + newSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).DynamicAllocation(false).ExecutorInstances(1).Obj(), + wantErr: field.ErrorList{ + field.Invalid(field.NewPath("metadata", "labels").Key(constants.QueueLabel), "", apivalidation.FieldImmutableErrorMsg), + }.ToAggregate(), + }, + "invalid - when unsuspended, update executor.instances": { + oldSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj(), + newSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue").DynamicAllocation(false).ExecutorInstances(2).Obj(), + wantErr: apivalidation.ValidateImmutableField( + sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue").DynamicAllocation(false).ExecutorInstances(2).Obj().Spec, + sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj().Spec, + field.NewPath("spec"), + ).ToAggregate(), + }, + "invalid - when unsuspended, unset executor.instances": { + oldSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj(), + newSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue").DynamicAllocation(false).Obj(), + wantErr: append( + apivalidation.ValidateImmutableField( + sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue").DynamicAllocation(false).Obj().Spec, + sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj().Spec, + field.NewPath("spec"), + ), + field.Required(field.NewPath("spec", "executor", "instances"), "a kueue managed job should set the number of executors explicitly"), + ).ToAggregate(), + }, + "invalid - when unsuspended, enable dynamicAllocation": { + oldSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj(), + newSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue").DynamicAllocation(true).ExecutorInstances(1).Obj(), + wantErr: append( + apivalidation.ValidateImmutableField( + sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue").DynamicAllocation(true).ExecutorInstances(1).Obj().Spec, + sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj().Spec, + field.NewPath("spec"), + ), + field.Invalid(field.NewPath("spec", "dynamicAllocation", "enabled"), true, "a kueue managed job should disable dynamicAllocation explicitly"), + ).ToAggregate(), + }, + "invalid - when unsuspended, unset dynamicAllocation": { + oldSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj(), + newSparkApp: sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue").ExecutorInstances(1).Obj(), + wantErr: append( + apivalidation.ValidateImmutableField( + sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue").ExecutorInstances(1).Obj().Spec, + sparkapptesting.MakeSparkApplication("job", "ns").Suspend(false).Queue("queue").DynamicAllocation(false).ExecutorInstances(1).Obj().Spec, + field.NewPath("spec"), + ), + field.Required(field.NewPath("spec", "dynamicAllocation"), "a kueue managed job should disable dynamicAllocation explicitly"), + ).ToAggregate(), + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + wh := &SparkApplicationWebhook{} + _, result := wh.ValidateUpdate(context.Background(), tc.oldSparkApp, tc.newSparkApp) + if diff := cmp.Diff(tc.wantErr, result); diff != "" { + t.Errorf("ValidateCreate() mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/util/testingjobs/sparkapplication/wrappers.go b/pkg/util/testingjobs/sparkapplication/wrappers.go new file mode 100644 index 00000000000..7d31fc8af12 --- /dev/null +++ b/pkg/util/testingjobs/sparkapplication/wrappers.go @@ -0,0 +1,201 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sparkapplication + +import ( + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/utils/ptr" + + "sigs.k8s.io/kueue/pkg/controller/constants" + + kfsparkapi "github.com/kubeflow/spark-operator/api/v1beta2" + kfsparkcommon "github.com/kubeflow/spark-operator/pkg/common" +) + +// SparkApplicationWrapper wraps a SparkApplication. +type SparkApplicationWrapper struct{ kfsparkapi.SparkApplication } + +// MakeSparkApplication creates a wrapper for a suspended SparkApplication +func MakeSparkApplication(name, ns string) *SparkApplicationWrapper { + return &SparkApplicationWrapper{kfsparkapi.SparkApplication{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: ns, + Annotations: make(map[string]string, 1), + }, + Spec: kfsparkapi.SparkApplicationSpec{ + Suspend: false, + Image: ptr.To("spark:3.5.3"), + Mode: kfsparkapi.DeployModeCluster, + Type: kfsparkapi.SparkApplicationTypeScala, + MainClass: ptr.To("local:///opt/spark/examples/jars/spark-examples.jar"), + SparkVersion: "3.5.3", + Driver: kfsparkapi.DriverSpec{ + CoreRequest: ptr.To("1"), + PriorityClassName: ptr.To("driver-priority-class"), + SparkPodSpec: kfsparkapi.SparkPodSpec{ + CoreLimit: ptr.To("1"), + Memory: ptr.To("256Mi"), + SchedulerName: ptr.To("test"), + ServiceAccount: ptr.To("test"), + Labels: map[string]string{ + "spec-driver-labels": "spec-driver-labels", + }, + Annotations: map[string]string{ + "spec-driver-annotations": "spec-driver-annotations", + }, + }, + }, + Executor: kfsparkapi.ExecutorSpec{ + Instances: nil, + CoreRequest: ptr.To("1"), + PriorityClassName: ptr.To("executor-priority-class"), + SparkPodSpec: kfsparkapi.SparkPodSpec{ + Template: &corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + "spec-executor-template-labels": "spec-executor-template-labels", + }, + }, + }, + CoreLimit: ptr.To("1"), + Memory: ptr.To("256Mi"), + GPU: &kfsparkapi.GPUSpec{ + Name: "nvidia.com/gpu", + Quantity: 1, + }, + Image: ptr.To("executor-image"), + Labels: map[string]string{ + "spec-executor-labels": "spec-executor-labels", + }, + Annotations: map[string]string{ + "spec-executor-annotations": "spec-executor-annotations", + }, + Affinity: &corev1.Affinity{ + NodeAffinity: &corev1.NodeAffinity{ + PreferredDuringSchedulingIgnoredDuringExecution: []corev1.PreferredSchedulingTerm{{ + Weight: 1, + Preference: corev1.NodeSelectorTerm{ + MatchFields: []corev1.NodeSelectorRequirement{{ + Key: "test", + Operator: corev1.NodeSelectorOpExists, + }}, + }, + }}, + }, + }, + Tolerations: []corev1.Toleration{{ + Key: "test", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoSchedule, + }}, + SchedulerName: ptr.To("test"), + Sidecars: []corev1.Container{{Name: "sidecar", Image: "test"}}, + InitContainers: []corev1.Container{{Name: "initContainer", Image: "test"}}, + NodeSelector: map[string]string{ + "spec-driver-node-selector": "spec-driver-node-selector", + }, + ServiceAccount: ptr.To("test"), + }, + }, + }, + }} +} + +// Obj returns the inner Job. +func (w *SparkApplicationWrapper) Obj() *kfsparkapi.SparkApplication { + return &w.SparkApplication +} + +// Suspend updates the suspend status of the SparkApplication +func (w *SparkApplicationWrapper) Suspend(s bool) *SparkApplicationWrapper { + w.Spec.Suspend = s + return w +} + +// Queue updates the queue name of the SparkApplication +func (w *SparkApplicationWrapper) Queue(queue string) *SparkApplicationWrapper { + if w.Labels == nil { + w.Labels = make(map[string]string) + } + w.Labels[constants.QueueLabel] = queue + return w +} + +// ExecutorInstances updates the number of executor instances of the SparkApplication +func (w *SparkApplicationWrapper) ExecutorInstances(n int32) *SparkApplicationWrapper { + w.Spec.Executor.Instances = &n + return w +} + +// DynamicAllocation explicitly set enabled to dynamic allocation of the SparkApplication +func (w *SparkApplicationWrapper) DynamicAllocation(enabled bool) *SparkApplicationWrapper { + w.Spec.DynamicAllocation = &kfsparkapi.DynamicAllocation{Enabled: enabled} + return w +} + +// Label sets the label key and value +func (w *SparkApplicationWrapper) Label(key, value string) *SparkApplicationWrapper { + if w.Labels == nil { + w.Labels = make(map[string]string) + } + w.Labels[key] = value + return w +} + +// WorkloadPriorityClass updates SparkApplication workloadpriorityclass. +func (w *SparkApplicationWrapper) WorkloadPriorityClass(wpc string) *SparkApplicationWrapper { + if w.Labels == nil { + w.Labels = make(map[string]string) + } + w.Labels[constants.WorkloadPriorityClassLabel] = wpc + return w +} + +// PodAnnotation sets the annotation to driver or executor spec. +func (w *SparkApplicationWrapper) PodAnnotation(role, key, value string) *SparkApplicationWrapper { + switch role { + case kfsparkcommon.SparkRoleDriver: + if w.Spec.Driver.Annotations == nil { + w.Spec.Driver.Annotations = make(map[string]string, 1) + } + w.Spec.Driver.Annotations[key] = value + case kfsparkcommon.SparkRoleExecutor: + if w.Spec.Executor.Annotations == nil { + w.Spec.Executor.Annotations = make(map[string]string, 1) + } + w.Spec.Executor.Annotations[key] = value + } + return w +} + +// CoreLimit sets the driver or executor spec.CoreLimit. +func (w *SparkApplicationWrapper) CoreLimit(role string, quantity *string) *SparkApplicationWrapper { + switch role { + case kfsparkcommon.SparkRoleDriver: + w.Spec.Driver.CoreLimit = quantity + case kfsparkcommon.SparkRoleExecutor: + w.Spec.Executor.CoreLimit = quantity + } + return w +} + +// Clone clones the SparkApplicationWrapper +func (w *SparkApplicationWrapper) Clone() *SparkApplicationWrapper { + return &SparkApplicationWrapper{SparkApplication: *w.DeepCopy()} +}