From 20136ef4ce87efe2790e2f863441702c12a3ca9b Mon Sep 17 00:00:00 2001 From: Akshay Chitneni Date: Fri, 4 Oct 2024 08:02:40 -0700 Subject: [PATCH] Adding v2 trainjob validation webhook fixing runtime Signed-off-by: Akshay Chitneni --- pkg/controller.v2/trainjob_controller.go | 16 +-- pkg/runtime.v2/core/clustertrainingruntime.go | 13 +- pkg/runtime.v2/core/trainingruntime.go | 36 +++-- pkg/runtime.v2/framework/core/framework.go | 5 +- .../framework/core/framework_test.go | 4 +- pkg/runtime.v2/framework/interface.go | 2 +- .../framework/plugins/jobset/jobset.go | 115 +++++++++++++++ pkg/runtime.v2/framework/plugins/mpi/mpi.go | 16 ++- .../framework/plugins/torch/torch.go | 21 ++- pkg/util.v2/runtime/runtime.go | 17 +++ pkg/util.v2/testing/wrapper.go | 4 + pkg/webhook.v2/setup.go | 2 +- pkg/webhook.v2/trainjob_webhook.go | 37 +++-- .../controller.v2/trainjob_controller_test.go | 61 +++++--- test/integration/framework/framework.go | 12 +- .../webhook.v2/clustertrainingruntime_test.go | 2 +- .../webhook.v2/trainingruntime_test.go | 2 +- test/integration/webhook.v2/trainjob_test.go | 132 +++++++++++++++++- 18 files changed, 421 insertions(+), 76 deletions(-) create mode 100644 pkg/util.v2/runtime/runtime.go diff --git a/pkg/controller.v2/trainjob_controller.go b/pkg/controller.v2/trainjob_controller.go index 95a34048e0..4bf42ade19 100644 --- a/pkg/controller.v2/trainjob_controller.go +++ b/pkg/controller.v2/trainjob_controller.go @@ -18,14 +18,13 @@ package controllerv2 import ( "context" - "errors" "fmt" + runtimeUtils "github.com/kubeflow/training-operator/pkg/util.v2/runtime" "github.com/go-logr/logr" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/tools/record" "k8s.io/klog/v2" - "k8s.io/utils/ptr" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/apiutil" @@ -34,8 +33,6 @@ import ( jobruntimes "github.com/kubeflow/training-operator/pkg/runtime.v2" ) -var errorUnsupportedRuntime = errors.New("the specified runtime is not supported") - type TrainJobReconciler struct { log logr.Logger client client.Client @@ -73,10 +70,10 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *kubeflowv2.TrainJob) error { log := ctrl.LoggerFrom(ctx) - runtimeRefGK := runtimeRefToGroupKind(trainJob.Spec.RuntimeRef).String() + runtimeRefGK := runtimeUtils.RuntimeRefToGroupKind(trainJob.Spec.RuntimeRef).String() runtime, ok := r.runtimes[runtimeRefGK] if !ok { - return fmt.Errorf("%w: %s", errorUnsupportedRuntime, runtimeRefGK) + return fmt.Errorf("%w: %s", runtimeUtils.ErrorUnsupportedRuntime, runtimeRefGK) } objs, err := runtime.NewObjects(ctx, trainJob) if err != nil { @@ -117,13 +114,6 @@ func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *k return nil } -func runtimeRefToGroupKind(runtimeRef kubeflowv2.RuntimeRef) schema.GroupKind { - return schema.GroupKind{ - Group: ptr.Deref(runtimeRef.APIGroup, ""), - Kind: ptr.Deref(runtimeRef.Kind, ""), - } -} - func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager) error { b := ctrl.NewControllerManagedBy(mgr). For(&kubeflowv2.TrainJob{}) diff --git a/pkg/runtime.v2/core/clustertrainingruntime.go b/pkg/runtime.v2/core/clustertrainingruntime.go index 35c35fe0c9..23f8063b69 100644 --- a/pkg/runtime.v2/core/clustertrainingruntime.go +++ b/pkg/runtime.v2/core/clustertrainingruntime.go @@ -64,14 +64,17 @@ func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBu } func (r *ClusterTrainingRuntime) ValidateObjects(ctx context.Context, old, new *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) { + clusterTrainingRuntime := &kubeflowv2.ClusterTrainingRuntime{} if err := r.client.Get(ctx, client.ObjectKey{ - Namespace: old.Namespace, - Name: old.Spec.RuntimeRef.Name, - }, &kubeflowv2.ClusterTrainingRuntime{}); err != nil { + Namespace: new.Namespace, + Name: new.Spec.RuntimeRef.Name, + }, clusterTrainingRuntime); err != nil { return nil, field.ErrorList{ - field.Invalid(field.NewPath("spec", "RuntimeRef"), old.Spec.RuntimeRef, + field.Invalid(field.NewPath("spec", "RuntimeRef"), new.Spec.RuntimeRef, fmt.Sprintf("%v: specified clusterTrainingRuntime must be created before the TrainJob is created", err)), } } - return r.framework.RunCustomValidationPlugins(old, new) + info := r.getRuntimeInfo(ctx, new, clusterTrainingRuntime.Spec.Template, clusterTrainingRuntime.Spec.MLPolicy, + clusterTrainingRuntime.Spec.PodGroupPolicy) + return r.framework.RunCustomValidationPlugins(old, new, info) } diff --git a/pkg/runtime.v2/core/trainingruntime.go b/pkg/runtime.v2/core/trainingruntime.go index 621d4eb533..460ad9aecb 100644 --- a/pkg/runtime.v2/core/trainingruntime.go +++ b/pkg/runtime.v2/core/trainingruntime.go @@ -84,6 +84,21 @@ func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *kubeflowv2.T func (r *TrainingRuntime) buildObjects( ctx context.Context, trainJob *kubeflowv2.TrainJob, jobSetTemplateSpec kubeflowv2.JobSetTemplateSpec, mlPolicy *kubeflowv2.MLPolicy, podGroupPolicy *kubeflowv2.PodGroupPolicy, ) ([]client.Object, error) { + + info := r.getRuntimeInfo(ctx, trainJob, jobSetTemplateSpec, mlPolicy, podGroupPolicy) + if err := r.framework.RunEnforceMLPolicyPlugins(info); err != nil { + return nil, err + } + err := r.framework.RunEnforcePodGroupPolicyPlugins(trainJob, info) + if err != nil { + return nil, err + } + return r.framework.RunComponentBuilderPlugins(ctx, info, trainJob) +} + +func (r *TrainingRuntime) getRuntimeInfo( + ctx context.Context, trainJob *kubeflowv2.TrainJob, jobSetTemplateSpec kubeflowv2.JobSetTemplateSpec, mlPolicy *kubeflowv2.MLPolicy, podGroupPolicy *kubeflowv2.PodGroupPolicy) *runtime.Info { + propagationLabels := jobSetTemplateSpec.Labels if propagationLabels == nil && trainJob.Spec.Labels != nil { propagationLabels = make(map[string]string, len(trainJob.Spec.Labels)) @@ -118,14 +133,7 @@ func (r *TrainingRuntime) buildObjects( Spec: *jobSetTemplateSpec.Spec.DeepCopy(), }, opts...) - if err := r.framework.RunEnforceMLPolicyPlugins(info); err != nil { - return nil, err - } - err := r.framework.RunEnforcePodGroupPolicyPlugins(trainJob, info) - if err != nil { - return nil, err - } - return r.framework.RunComponentBuilderPlugins(ctx, info, trainJob) + return info } func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder { @@ -137,14 +145,16 @@ func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder { } func (r *TrainingRuntime) ValidateObjects(ctx context.Context, old, new *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) { + trainingRuntime := &kubeflowv2.TrainingRuntime{} if err := r.client.Get(ctx, client.ObjectKey{ - Namespace: old.Namespace, - Name: old.Spec.RuntimeRef.Name, - }, &kubeflowv2.TrainingRuntime{}); err != nil { + Namespace: new.Namespace, + Name: new.Spec.RuntimeRef.Name, + }, trainingRuntime); err != nil { return nil, field.ErrorList{ - field.Invalid(field.NewPath("spec", "runtimeRef"), old.Spec.RuntimeRef, + field.Invalid(field.NewPath("spec", "runtimeRef"), new.Spec.RuntimeRef, fmt.Sprintf("%v: specified trainingRuntime must be created before the TrainJob is created", err)), } } - return r.framework.RunCustomValidationPlugins(old, new) + info := r.getRuntimeInfo(ctx, new, trainingRuntime.Spec.Template, trainingRuntime.Spec.MLPolicy, trainingRuntime.Spec.PodGroupPolicy) + return r.framework.RunCustomValidationPlugins(old, new, info) } diff --git a/pkg/runtime.v2/framework/core/framework.go b/pkg/runtime.v2/framework/core/framework.go index 8997afe467..d09a20f7f9 100644 --- a/pkg/runtime.v2/framework/core/framework.go +++ b/pkg/runtime.v2/framework/core/framework.go @@ -89,11 +89,12 @@ func (f *Framework) RunEnforcePodGroupPolicyPlugins(trainJob *kubeflowv2.TrainJo return nil } -func (f *Framework) RunCustomValidationPlugins(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) { +func (f *Framework) RunCustomValidationPlugins(oldObj, newObj *kubeflowv2.TrainJob, + runtimeInfo *runtime.Info) (admission.Warnings, field.ErrorList) { var aggregatedWarnings admission.Warnings var aggregatedErrors field.ErrorList for _, plugin := range f.customValidationPlugins { - warnings, errs := plugin.Validate(oldObj, newObj) + warnings, errs := plugin.Validate(oldObj, newObj, runtimeInfo) if len(warnings) != 0 { aggregatedWarnings = append(aggregatedWarnings, warnings...) } diff --git a/pkg/runtime.v2/framework/core/framework_test.go b/pkg/runtime.v2/framework/core/framework_test.go index 0a1edb266f..194b79b6ca 100644 --- a/pkg/runtime.v2/framework/core/framework_test.go +++ b/pkg/runtime.v2/framework/core/framework_test.go @@ -80,6 +80,7 @@ func TestNew(t *testing.T) { customValidationPlugins: []framework.CustomValidationPlugin{ &mpi.MPI{}, &torch.Torch{}, + &jobset.JobSet{}, }, watchExtensionPlugins: []framework.WatchExtensionPlugin{ &coscheduling.CoScheduling{}, @@ -314,7 +315,8 @@ func TestRunCustomValidationPlugins(t *testing.T) { if err != nil { t.Fatal(err) } - warnings, errs := fwk.RunCustomValidationPlugins(tc.oldObj, tc.newObj) + runtimeInfo := runtime.NewInfo(testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test").Obj()) + warnings, errs := fwk.RunCustomValidationPlugins(tc.oldObj, tc.newObj, runtimeInfo) if diff := cmp.Diff(tc.wantWarnings, warnings, cmpopts.SortSlices(func(a, b string) bool { return a < b })); len(diff) != 0 { t.Errorf("Unexpected warninigs (-want,+got):\n%s", diff) } diff --git a/pkg/runtime.v2/framework/interface.go b/pkg/runtime.v2/framework/interface.go index 00d613ec0a..57279bc21b 100644 --- a/pkg/runtime.v2/framework/interface.go +++ b/pkg/runtime.v2/framework/interface.go @@ -48,7 +48,7 @@ type EnforceMLPolicyPlugin interface { type CustomValidationPlugin interface { Plugin - Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) + Validate(oldObj, newObj *kubeflowv2.TrainJob, runtimeInfo *runtime.Info) (admission.Warnings, field.ErrorList) } type ComponentBuilderPlugin interface { diff --git a/pkg/runtime.v2/framework/plugins/jobset/jobset.go b/pkg/runtime.v2/framework/plugins/jobset/jobset.go index 9ff369c61c..e7dbed8a7a 100644 --- a/pkg/runtime.v2/framework/plugins/jobset/jobset.go +++ b/pkg/runtime.v2/framework/plugins/jobset/jobset.go @@ -19,7 +19,9 @@ package jobset import ( "context" "fmt" + "k8s.io/apimachinery/pkg/util/validation/field" "maps" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" "github.com/go-logr/logr" batchv1 "k8s.io/api/batch/v1" @@ -50,6 +52,7 @@ type JobSet struct { var _ framework.WatchExtensionPlugin = (*JobSet)(nil) var _ framework.ComponentBuilderPlugin = (*JobSet)(nil) +var _ framework.CustomValidationPlugin = (*JobSet)(nil) const Name = "JobSet" @@ -140,3 +143,115 @@ func (j *JobSet) ReconcilerBuilders() []runtime.ReconcilerBuilder { }, } } + +func (j *JobSet) Validate(oldObj, newObj *kubeflowv2.TrainJob, runtimeInfo *runtime.Info) (admission.Warnings, field.ErrorList) { + + var allErrs field.ErrorList + specPath := field.NewPath("spec") + + jobSet, ok := runtimeInfo.Obj.(*jobsetv1alpha2.JobSet) + if !ok { + return nil, nil + } + + if newObj.Spec.ModelConfig != nil { + // validate `model-initializer` container in the `Initializer` Job + if newObj.Spec.ModelConfig.Input != nil { + modelConfigInputPath := specPath.Child("modelConfig").Child("input") + if len(jobSet.Spec.ReplicatedJobs) == 0 { + allErrs = append(allErrs, field.Invalid(modelConfigInputPath, newObj.Spec.ModelConfig.Input, "trainingRuntime should have replicated jobs configured with model config input set")) + } else { + initializerJobFound := false + modelInitializerContainerFound := false + for _, job := range jobSet.Spec.ReplicatedJobs { + if job.Name == "Initializer" { + initializerJobFound = true + for _, container := range job.Template.Spec.Template.Spec.Containers { + if container.Name == "model-initializer" { + modelInitializerContainerFound = true + } + } + } + } + if !initializerJobFound { + allErrs = append(allErrs, field.Invalid(modelConfigInputPath, newObj.Spec.ModelConfig.Input, "trainingRuntime should have replicated job configured with name - Initializer")) + } else if !modelInitializerContainerFound { + allErrs = append(allErrs, field.Invalid(modelConfigInputPath, newObj.Spec.ModelConfig.Input, "trainingRuntime with replicated job initializer should have container with name - model-initializer")) + } + } + } + + // validate `model-exporter` container in the `Exporter` Job + if newObj.Spec.ModelConfig.Output != nil { + modelConfigOutputPath := specPath.Child("modelConfig").Child("output") + if len(jobSet.Spec.ReplicatedJobs) == 0 { + allErrs = append(allErrs, field.Invalid(modelConfigOutputPath, newObj.Spec.ModelConfig.Output, "trainingRuntime should have replicated jobs configured with model config output set")) + } else { + exporterJobFound := false + modelExporterContainerFound := false + for _, job := range jobSet.Spec.ReplicatedJobs { + if job.Name == "Exporter" { + exporterJobFound = true + for _, container := range job.Template.Spec.Template.Spec.Containers { + if container.Name == "model-exporter" { + modelExporterContainerFound = true + } + } + } + } + if !exporterJobFound { + allErrs = append(allErrs, field.Invalid(modelConfigOutputPath, newObj.Spec.ModelConfig.Output, "trainingRuntime should have replicated job configured with name - Exporter")) + } else if !modelExporterContainerFound { + allErrs = append(allErrs, field.Invalid(modelConfigOutputPath, newObj.Spec.ModelConfig.Output, "trainingRuntime with replicated job Exporter should have container with name - model-exporter")) + } + } + } + } + + if len(newObj.Spec.PodSpecOverrides) != 0 { + podSpecOverridesPath := specPath.Child("podSpecOverrides") + jobsMap := map[string]bool{} + for _, job := range jobSet.Spec.ReplicatedJobs { + jobsMap[job.Name] = true + } + // validate if jobOverrides are valid + for idx, override := range newObj.Spec.PodSpecOverrides { + for _, job := range override.TargetJobs { + if _, found := jobsMap[job.Name]; !found { + allErrs = append(allErrs, field.Invalid(podSpecOverridesPath, newObj.Spec.PodSpecOverrides, fmt.Sprintf("job: %s, configured in the podOverride should be present in the referenced training runtime", job))) + } + } + if len(override.Containers) != 0 { + // validate if containerOverrides are valid + containerMap := map[string]bool{} + for _, job := range jobSet.Spec.ReplicatedJobs { + for _, container := range job.Template.Spec.Template.Spec.Containers { + containerMap[container.Name] = true + } + } + containerOverridePath := podSpecOverridesPath.Index(idx) + for _, container := range override.Containers { + if _, found := containerMap[container.Name]; !found { + allErrs = append(allErrs, field.Invalid(containerOverridePath, override.Containers, fmt.Sprintf("container: %s, configured in the containerOverride should be present in the referenced training runtime", container.Name))) + } + } + } + if len(override.InitContainers) != 0 { + // validate if initContainerOverrides are valid + initContainerMap := map[string]bool{} + for _, job := range jobSet.Spec.ReplicatedJobs { + for _, initContainer := range job.Template.Spec.Template.Spec.InitContainers { + initContainerMap[initContainer.Name] = true + } + } + initContainerOverridePath := podSpecOverridesPath.Index(idx) + for _, container := range override.Containers { + if _, found := initContainerMap[container.Name]; !found { + allErrs = append(allErrs, field.Invalid(initContainerOverridePath, override.InitContainers, fmt.Sprintf("initContainer: %s, configured in the initContainerOverride should be present in the referenced training runtime", container.Name))) + } + } + } + } + } + return nil, allErrs +} diff --git a/pkg/runtime.v2/framework/plugins/mpi/mpi.go b/pkg/runtime.v2/framework/plugins/mpi/mpi.go index 9e79f6c5a1..6fc12de0c6 100644 --- a/pkg/runtime.v2/framework/plugins/mpi/mpi.go +++ b/pkg/runtime.v2/framework/plugins/mpi/mpi.go @@ -18,6 +18,7 @@ package mpi import ( "context" + "strconv" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/client" @@ -55,7 +56,16 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info) error { return nil } -// TODO: Need to implement validations for MPIJob. -func (m *MPI) Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) { - return nil, nil +func (m *MPI) Validate(oldJobObj, newJobObj *kubeflowv2.TrainJob, runtimeInfo *runtime.Info) (admission.Warnings, field.ErrorList) { + var allErrs field.ErrorList + specPath := field.NewPath("spec") + if newJobObj.Spec.Trainer != nil { + numProcPerNodePath := specPath.Child("trainer").Child("numProcPerNode") + if runtimeInfo.MLPolicy.MPI != nil { + if _, err := strconv.Atoi(*newJobObj.Spec.Trainer.NumProcPerNode); err != nil { + allErrs = append(allErrs, field.Invalid(numProcPerNodePath, newJobObj.Spec.Trainer.NumProcPerNode, "should have an int value")) + } + } + } + return nil, allErrs } diff --git a/pkg/runtime.v2/framework/plugins/torch/torch.go b/pkg/runtime.v2/framework/plugins/torch/torch.go index b9b7f10cb9..7b3c48c0fa 100644 --- a/pkg/runtime.v2/framework/plugins/torch/torch.go +++ b/pkg/runtime.v2/framework/plugins/torch/torch.go @@ -18,6 +18,8 @@ package torch import ( "context" + "k8s.io/utils/strings/slices" + "strconv" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/client" @@ -51,7 +53,20 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info) error { return nil } -// TODO: Need to implement validateions for TorchJob. -func (t *Torch) Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) { - return nil, nil +func (t *Torch) Validate(oldObj, newObj *kubeflowv2.TrainJob, runtimeInfo *runtime.Info) (admission.Warnings, field.ErrorList) { + var allErrs field.ErrorList + specPath := field.NewPath("spec") + + if newObj.Spec.Trainer != nil { + numProcPerNodePath := specPath.Child("trainer").Child("numProcPerNode") + if runtimeInfo.MLPolicy.Torch != nil { + allowedStringValList := []string{"auto", "cpu", "gpu"} + if !slices.Contains(allowedStringValList, *newObj.Spec.Trainer.NumProcPerNode) { + if _, err := strconv.Atoi(*newObj.Spec.Trainer.NumProcPerNode); err != nil { + allErrs = append(allErrs, field.Invalid(numProcPerNodePath, newObj.Spec.Trainer.NumProcPerNode, "should have an int value or auto/cpu/gpu")) + } + } + } + } + return nil, allErrs } diff --git a/pkg/util.v2/runtime/runtime.go b/pkg/util.v2/runtime/runtime.go new file mode 100644 index 0000000000..c355f81714 --- /dev/null +++ b/pkg/util.v2/runtime/runtime.go @@ -0,0 +1,17 @@ +package runtime + +import ( + "errors" + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/utils/ptr" +) + +var ErrorUnsupportedRuntime = errors.New("the specified runtime is not supported") + +func RuntimeRefToGroupKind(runtimeRef kubeflowv2.RuntimeRef) schema.GroupKind { + return schema.GroupKind{ + Group: ptr.Deref(runtimeRef.APIGroup, ""), + Kind: ptr.Deref(runtimeRef.Kind, ""), + } +} diff --git a/pkg/util.v2/testing/wrapper.go b/pkg/util.v2/testing/wrapper.go index 4ffad88742..4bb07a3c8e 100644 --- a/pkg/util.v2/testing/wrapper.go +++ b/pkg/util.v2/testing/wrapper.go @@ -258,6 +258,10 @@ func (t *TrainJobWrapper) ManagedBy(m string) *TrainJobWrapper { t.Spec.ManagedBy = &m return t } +func (t *TrainJobWrapper) ModelConfig(config *kubeflowv2.ModelConfig) *TrainJobWrapper { + t.Spec.ModelConfig = config + return t +} func (t *TrainJobWrapper) Obj() *kubeflowv2.TrainJob { return &t.TrainJob diff --git a/pkg/webhook.v2/setup.go b/pkg/webhook.v2/setup.go index 6e7c7f290e..682cf7be30 100644 --- a/pkg/webhook.v2/setup.go +++ b/pkg/webhook.v2/setup.go @@ -31,7 +31,7 @@ func Setup(mgr ctrl.Manager, runtimes map[string]runtime.Runtime) (string, error return kubeflowv2.TrainingRuntimeKind, err } if err := setupWebhookForTrainJob(mgr, runtimes); err != nil { - return "TrainJob", err + return kubeflowv2.TrainJobKind, err } return "", nil } diff --git a/pkg/webhook.v2/trainjob_webhook.go b/pkg/webhook.v2/trainjob_webhook.go index cf75400c82..fe728d850d 100644 --- a/pkg/webhook.v2/trainjob_webhook.go +++ b/pkg/webhook.v2/trainjob_webhook.go @@ -18,21 +18,23 @@ package webhookv2 import ( "context" - + "fmt" apiruntime "k8s.io/apimachinery/pkg/runtime" + "k8s.io/klog/v2" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/webhook" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" - runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" + jobRuntime "github.com/kubeflow/training-operator/pkg/runtime.v2" + runtimeUtils "github.com/kubeflow/training-operator/pkg/util.v2/runtime" ) type TrainJobWebhook struct { - runtimes map[string]runtime.Runtime + runtimes map[string]jobRuntime.Runtime } -func setupWebhookForTrainJob(mgr ctrl.Manager, run map[string]runtime.Runtime) error { +func setupWebhookForTrainJob(mgr ctrl.Manager, run map[string]jobRuntime.Runtime) error { return ctrl.NewWebhookManagedBy(mgr). For(&kubeflowv2.TrainJob{}). WithValidator(&TrainJobWebhook{runtimes: run}). @@ -43,12 +45,31 @@ func setupWebhookForTrainJob(mgr ctrl.Manager, run map[string]runtime.Runtime) e var _ webhook.CustomValidator = (*TrainJobWebhook)(nil) -func (w *TrainJobWebhook) ValidateCreate(context.Context, apiruntime.Object) (admission.Warnings, error) { - return nil, nil +func (w *TrainJobWebhook) ValidateCreate(ctx context.Context, obj apiruntime.Object) (admission.Warnings, error) { + trainJob := obj.(*kubeflowv2.TrainJob) + log := ctrl.LoggerFrom(ctx).WithName("trainJob-webhook") + log.V(5).Info("Validating create", "TrainJob", klog.KObj(trainJob)) + runtimeRefGK := runtimeUtils.RuntimeRefToGroupKind(trainJob.Spec.RuntimeRef).String() + runtime, ok := w.runtimes[runtimeRefGK] + if !ok { + return nil, fmt.Errorf("%w: %s", runtimeUtils.ErrorUnsupportedRuntime, runtimeRefGK) + } + warnings, errorList := runtime.ValidateObjects(ctx, nil, trainJob) + return warnings, errorList.ToAggregate() } -func (w *TrainJobWebhook) ValidateUpdate(context.Context, apiruntime.Object, apiruntime.Object) (admission.Warnings, error) { - return nil, nil +func (w *TrainJobWebhook) ValidateUpdate(ctx context.Context, oldObj apiruntime.Object, newObj apiruntime.Object) (admission.Warnings, error) { + oldTrainJob := oldObj.(*kubeflowv2.TrainJob) + newTrainJob := newObj.(*kubeflowv2.TrainJob) + log := ctrl.LoggerFrom(ctx).WithName("trainJob-webhook") + log.V(5).Info("Validating update", "TrainJob", klog.KObj(newTrainJob)) + runtimeRefGK := runtimeUtils.RuntimeRefToGroupKind(newTrainJob.Spec.RuntimeRef).String() + runtime, ok := w.runtimes[runtimeRefGK] + if !ok { + return nil, fmt.Errorf("%w: %s", runtimeUtils.ErrorUnsupportedRuntime, runtimeRefGK) + } + warnings, errorList := runtime.ValidateObjects(ctx, oldTrainJob, newTrainJob) + return warnings, errorList.ToAggregate() } func (w *TrainJobWebhook) ValidateDelete(context.Context, apiruntime.Object) (admission.Warnings, error) { diff --git a/test/integration/controller.v2/trainjob_controller_test.go b/test/integration/controller.v2/trainjob_controller_test.go index 098ae39c39..fa229dfa98 100644 --- a/test/integration/controller.v2/trainjob_controller_test.go +++ b/test/integration/controller.v2/trainjob_controller_test.go @@ -41,7 +41,7 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { ginkgo.BeforeAll(func() { fwk = &framework.Framework{} cfg = fwk.Init() - ctx, k8sClient = fwk.RunManager(cfg) + ctx, k8sClient = fwk.RunManager(cfg, true) }) ginkgo.AfterAll(func() { fwk.Teardown() @@ -246,11 +246,11 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { var _ = ginkgo.Describe("TrainJob marker validations and defaulting", ginkgo.Ordered, func() { var ns *corev1.Namespace - + runtimeName := "training-runtime" ginkgo.BeforeAll(func() { fwk = &framework.Framework{} cfg = fwk.Init() - ctx, k8sClient = fwk.RunManager(cfg) + ctx, k8sClient = fwk.RunManager(cfg, false) }) ginkgo.AfterAll(func() { fwk.Teardown() @@ -267,8 +267,37 @@ var _ = ginkgo.Describe("TrainJob marker validations and defaulting", ginkgo.Ord }, } gomega.Expect(k8sClient.Create(ctx, ns)).To(gomega.Succeed()) + + baseRuntimeWrapper := testingutil.MakeTrainingRuntimeWrapper(ns.Name, runtimeName) + baseClusterRuntimeWrapper := testingutil.MakeClusterTrainingRuntimeWrapper(runtimeName) + trainingRuntime := baseRuntimeWrapper.RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntimeWrapper.Spec). + Replicas(1). + Obj()).Obj() + clusterTrainingRuntime := baseClusterRuntimeWrapper.RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntimeWrapper.Spec). + Replicas(1). + Obj()).Obj() + gomega.Expect(k8sClient.Create(ctx, trainingRuntime)).To(gomega.Succeed()) + gomega.Expect(k8sClient.Create(ctx, clusterTrainingRuntime)).To(gomega.Succeed()) + gomega.Eventually(func() error { + err := k8sClient.Get(ctx, client.ObjectKeyFromObject(trainingRuntime), trainingRuntime) + if err != nil { + return err + } + return nil + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + gomega.Eventually(func() error { + err := k8sClient.Get(ctx, client.ObjectKeyFromObject(clusterTrainingRuntime), clusterTrainingRuntime) + if err != nil { + return err + } + return nil + }, util.Timeout, util.Interval).Should(gomega.Succeed()) }) ginkgo.AfterEach(func() { + gomega.Expect(k8sClient.DeleteAllOf(ctx, &kubeflowv2.TrainingRuntime{}, client.InNamespace(ns.Name))).Should(gomega.Succeed()) + gomega.Expect(k8sClient.DeleteAllOf(ctx, &kubeflowv2.ClusterTrainingRuntime{})).Should(gomega.Succeed()) gomega.Expect(k8sClient.DeleteAllOf(ctx, &kubeflowv2.TrainJob{}, client.InNamespace(ns.Name))).Should(gomega.Succeed()) }) @@ -280,7 +309,7 @@ var _ = ginkgo.Describe("TrainJob marker validations and defaulting", ginkgo.Ord func() *kubeflowv2.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "managed-by-trainjob-controller"). ManagedBy("kubeflow.org/trainjob-controller"). - RuntimeRef(kubeflowv2.GroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "testing"). + RuntimeRef(kubeflowv2.GroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), runtimeName). Obj() }, gomega.Succeed()), @@ -288,7 +317,7 @@ var _ = ginkgo.Describe("TrainJob marker validations and defaulting", ginkgo.Ord func() *kubeflowv2.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "managed-by-trainjob-controller"). ManagedBy("kueue.x-k8s.io/multikueue"). - RuntimeRef(kubeflowv2.GroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "testing"). + RuntimeRef(kubeflowv2.GroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), runtimeName). Obj() }, gomega.Succeed()), @@ -296,7 +325,7 @@ var _ = ginkgo.Describe("TrainJob marker validations and defaulting", ginkgo.Ord func() *kubeflowv2.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "invalid-managed-by"). ManagedBy("invalid"). - RuntimeRef(kubeflowv2.GroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "testing"). + RuntimeRef(kubeflowv2.GroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), runtimeName). Obj() }, testingutil.BeInvalidError()), @@ -310,53 +339,53 @@ var _ = ginkgo.Describe("TrainJob marker validations and defaulting", ginkgo.Ord func() *kubeflowv2.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "null-suspend"). ManagedBy("kueue.x-k8s.io/multikueue"). - RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "testing"). + RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), runtimeName). Obj() }, func() *kubeflowv2.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "null-suspend"). ManagedBy("kueue.x-k8s.io/multikueue"). - RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "testing"). + RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), runtimeName). Suspend(false). Obj() }), ginkgo.Entry("Should succeed to default managedBy=kubeflow.org/trainjob-controller", func() *kubeflowv2.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "null-managed-by"). - RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "testing"). + RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), runtimeName). Suspend(true). Obj() }, func() *kubeflowv2.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "null-managed-by"). ManagedBy("kubeflow.org/trainjob-controller"). - RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "testing"). + RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), runtimeName). Suspend(true). Obj() }), ginkgo.Entry("Should succeed to default runtimeRef.apiGroup", func() *kubeflowv2.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "empty-api-group"). - RuntimeRef(schema.GroupVersionKind{Group: "", Version: "", Kind: kubeflowv2.TrainingRuntimeKind}, "testing"). + RuntimeRef(schema.GroupVersionKind{Group: "", Version: "", Kind: kubeflowv2.TrainingRuntimeKind}, runtimeName). Obj() }, func() *kubeflowv2.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "empty-api-group"). ManagedBy("kubeflow.org/trainjob-controller"). - RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "testing"). + RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), runtimeName). Suspend(false). Obj() }), ginkgo.Entry("Should succeed to default runtimeRef.kind", func() *kubeflowv2.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "empty-kind"). - RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(""), "testing"). + RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(""), runtimeName). Obj() }, func() *kubeflowv2.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "empty-kind"). ManagedBy("kubeflow.org/trainjob-controller"). - RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "testing"). + RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), runtimeName). Suspend(false). Obj() }), @@ -373,7 +402,7 @@ var _ = ginkgo.Describe("TrainJob marker validations and defaulting", ginkgo.Ord func() *kubeflowv2.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "valid-managed-by"). ManagedBy("kubeflow.org/trainjob-controller"). - RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "testing"). + RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), runtimeName). Obj() }, func(job *kubeflowv2.TrainJob) *kubeflowv2.TrainJob { @@ -384,7 +413,7 @@ var _ = ginkgo.Describe("TrainJob marker validations and defaulting", ginkgo.Ord ginkgo.Entry("Should fail to update runtimeRef", func() *kubeflowv2.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "valid-runtimeref"). - RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainJobKind), "testing"). + RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), runtimeName). Obj() }, func(job *kubeflowv2.TrainJob) *kubeflowv2.TrainJob { diff --git a/test/integration/framework/framework.go b/test/integration/framework/framework.go index a86c433d7e..8f09876e99 100644 --- a/test/integration/framework/framework.go +++ b/test/integration/framework/framework.go @@ -20,6 +20,7 @@ import ( "context" "crypto/tls" "fmt" + controllerv2 "github.com/kubeflow/training-operator/pkg/controller.v2" "net" "path/filepath" "time" @@ -40,7 +41,6 @@ import ( schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" - controllerv2 "github.com/kubeflow/training-operator/pkg/controller.v2" runtimecore "github.com/kubeflow/training-operator/pkg/runtime.v2/core" webhookv2 "github.com/kubeflow/training-operator/pkg/webhook.v2" ) @@ -70,7 +70,7 @@ func (f *Framework) Init() *rest.Config { return cfg } -func (f *Framework) RunManager(cfg *rest.Config) (context.Context, client.Client) { +func (f *Framework) RunManager(cfg *rest.Config, startControllers bool) (context.Context, client.Client) { webhookInstallOpts := &f.testEnv.WebhookInstallOptions gomega.ExpectWithOffset(1, kubeflowv2.AddToScheme(scheme.Scheme)).NotTo(gomega.HaveOccurred()) gomega.ExpectWithOffset(1, jobsetv1alpha2.AddToScheme(scheme.Scheme)).NotTo(gomega.HaveOccurred()) @@ -100,9 +100,11 @@ func (f *Framework) RunManager(cfg *rest.Config) (context.Context, client.Client gomega.ExpectWithOffset(1, err).NotTo(gomega.HaveOccurred()) gomega.ExpectWithOffset(1, runtimes).NotTo(gomega.BeNil()) - failedCtrlName, err := controllerv2.SetupControllers(mgr, runtimes) - gomega.ExpectWithOffset(1, err).NotTo(gomega.HaveOccurred(), "controller", failedCtrlName) - gomega.ExpectWithOffset(1, failedCtrlName).To(gomega.BeEmpty()) + if startControllers { + failedCtrlName, err := controllerv2.SetupControllers(mgr, runtimes) + gomega.ExpectWithOffset(1, err).NotTo(gomega.HaveOccurred(), "controller", failedCtrlName) + gomega.ExpectWithOffset(1, failedCtrlName).To(gomega.BeEmpty()) + } failedWebhookName, err := webhookv2.Setup(mgr, runtimes) gomega.ExpectWithOffset(1, err).NotTo(gomega.HaveOccurred(), "webhook", failedWebhookName) diff --git a/test/integration/webhook.v2/clustertrainingruntime_test.go b/test/integration/webhook.v2/clustertrainingruntime_test.go index a2519c8ff8..d0c740a43a 100644 --- a/test/integration/webhook.v2/clustertrainingruntime_test.go +++ b/test/integration/webhook.v2/clustertrainingruntime_test.go @@ -35,7 +35,7 @@ var _ = ginkgo.Describe("ClusterTrainingRuntime Webhook", ginkgo.Ordered, func() ginkgo.BeforeAll(func() { fwk = &framework.Framework{} cfg = fwk.Init() - ctx, k8sClient = fwk.RunManager(cfg) + ctx, k8sClient = fwk.RunManager(cfg, false) }) ginkgo.AfterAll(func() { fwk.Teardown() diff --git a/test/integration/webhook.v2/trainingruntime_test.go b/test/integration/webhook.v2/trainingruntime_test.go index 7599e04759..d647987fb2 100644 --- a/test/integration/webhook.v2/trainingruntime_test.go +++ b/test/integration/webhook.v2/trainingruntime_test.go @@ -36,7 +36,7 @@ var _ = ginkgo.Describe("TrainingRuntime Webhook", ginkgo.Ordered, func() { ginkgo.BeforeAll(func() { fwk = &framework.Framework{} cfg = fwk.Init() - ctx, k8sClient = fwk.RunManager(cfg) + ctx, k8sClient = fwk.RunManager(cfg, false) }) ginkgo.AfterAll(func() { fwk.Teardown() diff --git a/test/integration/webhook.v2/trainjob_test.go b/test/integration/webhook.v2/trainjob_test.go index a8578f007b..38c25ffc17 100644 --- a/test/integration/webhook.v2/trainjob_test.go +++ b/test/integration/webhook.v2/trainjob_test.go @@ -17,21 +17,29 @@ limitations under the License. package webhookv2 import ( + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + testingutil "github.com/kubeflow/training-operator/pkg/util.v2/testing" + "github.com/kubeflow/training-operator/test/integration/framework" + "github.com/kubeflow/training-operator/test/util" "github.com/onsi/ginkgo/v2" "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "github.com/kubeflow/training-operator/test/integration/framework" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" ) var _ = ginkgo.Describe("TrainJob Webhook", ginkgo.Ordered, func() { var ns *corev1.Namespace + var trainingRuntime *kubeflowv2.TrainingRuntime + var clusterTrainingRuntime *kubeflowv2.ClusterTrainingRuntime + runtimeName := "training-runtime" + jobName := "train-job" ginkgo.BeforeAll(func() { fwk = &framework.Framework{} cfg = fwk.Init() - ctx, k8sClient = fwk.RunManager(cfg) + ctx, k8sClient = fwk.RunManager(cfg, false) }) ginkgo.AfterAll(func() { fwk.Teardown() @@ -48,5 +56,123 @@ var _ = ginkgo.Describe("TrainJob Webhook", ginkgo.Ordered, func() { }, } gomega.Expect(k8sClient.Create(ctx, ns)).To(gomega.Succeed()) + + baseRuntimeWrapper := testingutil.MakeTrainingRuntimeWrapper(ns.Name, runtimeName) + baseClusterRuntimeWrapper := testingutil.MakeClusterTrainingRuntimeWrapper(runtimeName) + trainingRuntime = baseRuntimeWrapper.RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntimeWrapper.Spec). + Replicas(1). + Obj()).Obj() + clusterTrainingRuntime = baseClusterRuntimeWrapper.RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntimeWrapper.Spec). + Replicas(1). + Obj()).Obj() + gomega.Expect(k8sClient.Create(ctx, trainingRuntime)).To(gomega.Succeed()) + gomega.Expect(k8sClient.Create(ctx, clusterTrainingRuntime)).To(gomega.Succeed()) + gomega.Eventually(func() error { + err := k8sClient.Get(ctx, client.ObjectKeyFromObject(trainingRuntime), trainingRuntime) + if err != nil { + return err + } + return nil + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + gomega.Eventually(func() error { + err := k8sClient.Get(ctx, client.ObjectKeyFromObject(clusterTrainingRuntime), clusterTrainingRuntime) + if err != nil { + return err + } + return nil + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.AfterEach(func() { + gomega.Expect(k8sClient.DeleteAllOf(ctx, &kubeflowv2.TrainingRuntime{}, client.InNamespace(ns.Name))).To(gomega.Succeed()) + gomega.Expect(k8sClient.DeleteAllOf(ctx, &kubeflowv2.ClusterTrainingRuntime{})).To(gomega.Succeed()) + gomega.Expect(k8sClient.DeleteAllOf(ctx, &kubeflowv2.TrainJob{}, client.InNamespace(ns.Name))).To(gomega.Succeed()) + }) + + ginkgo.When("Creating TrainJob", func() { + ginkgo.DescribeTable("Validate TrainJob on creation", func(trainJob func() *kubeflowv2.TrainJob, errorMatcher gomega.OmegaMatcher) { + gomega.Expect(k8sClient.Create(ctx, trainJob())).Should(errorMatcher) + }, + ginkgo.Entry("Should succeed in creating trainJob with namespace scoped trainingRuntime", + func() *kubeflowv2.TrainJob { + return testingutil.MakeTrainJobWrapper(ns.Name, jobName). + RuntimeRef(kubeflowv2.GroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), runtimeName). + Obj() + }, + gomega.Succeed()), + ginkgo.Entry("Should fail in creating trainJob referencing trainingRuntime not present in the namespace", + func() *kubeflowv2.TrainJob { + return testingutil.MakeTrainJobWrapper(ns.Name, jobName). + RuntimeRef(kubeflowv2.GroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "invalid"). + Obj() + }, + testingutil.BeForbiddenError()), + ginkgo.Entry("Should succeed in creating trainJob with namespace scoped trainingRuntime", + func() *kubeflowv2.TrainJob { + return testingutil.MakeTrainJobWrapper(ns.Name, jobName). + RuntimeRef(kubeflowv2.GroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), runtimeName). + Obj() + }, + gomega.Succeed()), + ginkgo.Entry("Should fail in creating trainJob with pre-trained model config when referencing a trainingRuntime without an initializer", + func() *kubeflowv2.TrainJob { + return testingutil.MakeTrainJobWrapper(ns.Name, jobName). + RuntimeRef(kubeflowv2.GroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), runtimeName). + ModelConfig(&kubeflowv2.ModelConfig{Input: &kubeflowv2.InputModel{}}). + Obj() + }, + testingutil.BeForbiddenError()), + ginkgo.Entry("Should fail in creating trainJob with output model config when referencing a trainingRuntime without an exporter", + func() *kubeflowv2.TrainJob { + return testingutil.MakeTrainJobWrapper(ns.Name, jobName). + RuntimeRef(kubeflowv2.GroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), runtimeName). + ModelConfig(&kubeflowv2.ModelConfig{Output: &kubeflowv2.OutputModel{}}). + Obj() + }, + testingutil.BeForbiddenError()), + ginkgo.Entry("Should fail in creating trainJob with podSpecOverrides when referencing a trainingRuntime doesnt have the job specified in the override", + func() *kubeflowv2.TrainJob { + trainJob := testingutil.MakeTrainJobWrapper(ns.Name, jobName). + RuntimeRef(kubeflowv2.GroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), runtimeName). + Obj() + trainJob.Spec.PodSpecOverrides = []kubeflowv2.PodSpecOverride{ + {TargetJobs: []kubeflowv2.PodSpecOverrideTargetJob{{Name: "valid"}, {Name: "invalid"}}}, + } + return trainJob + }, + testingutil.BeForbiddenError()), + ginkgo.Entry("Should fail in creating trainJob with invalid trainer config for mpi runtime", + func() *kubeflowv2.TrainJob { + trainingRuntime.Spec.MLPolicy = &kubeflowv2.MLPolicy{MLPolicySource: kubeflowv2.MLPolicySource{MPI: &kubeflowv2.MPIMLPolicySource{}}} + gomega.Expect(k8sClient.Update(ctx, trainingRuntime)).To(gomega.Succeed()) + return testingutil.MakeTrainJobWrapper(ns.Name, jobName). + RuntimeRef(kubeflowv2.GroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), runtimeName). + Trainer(&kubeflowv2.Trainer{NumProcPerNode: ptr.To("invalid")}). + Obj() + }, + testingutil.BeForbiddenError()), + ginkgo.Entry("Should fail in creating trainJob with invalid trainer config for torch runtime", + func() *kubeflowv2.TrainJob { + trainingRuntime.Spec.MLPolicy = &kubeflowv2.MLPolicy{MLPolicySource: kubeflowv2.MLPolicySource{Torch: &kubeflowv2.TorchMLPolicySource{}}} + gomega.Expect(k8sClient.Update(ctx, trainingRuntime)).To(gomega.Succeed()) + return testingutil.MakeTrainJobWrapper(ns.Name, jobName). + RuntimeRef(kubeflowv2.GroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), runtimeName). + Trainer(&kubeflowv2.Trainer{NumProcPerNode: ptr.To("invalid")}). + Obj() + }, + testingutil.BeForbiddenError()), + ginkgo.Entry("Should succeed in creating trainJob with valid trainer config for torch runtime", + func() *kubeflowv2.TrainJob { + trainingRuntime.Spec.MLPolicy = &kubeflowv2.MLPolicy{MLPolicySource: kubeflowv2.MLPolicySource{Torch: &kubeflowv2.TorchMLPolicySource{}}} + gomega.Expect(k8sClient.Update(ctx, trainingRuntime)).To(gomega.Succeed()) + return testingutil.MakeTrainJobWrapper(ns.Name, jobName). + RuntimeRef(kubeflowv2.GroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), runtimeName). + Trainer(&kubeflowv2.Trainer{NumProcPerNode: ptr.To("auto")}). + Obj() + }, + gomega.Succeed()), + ) }) })