Skip to content

Commit

Permalink
Adding v2 trainjob validation webhook
Browse files Browse the repository at this point in the history
fixing runtime

Signed-off-by: Akshay Chitneni <achitneni@apple.com>
  • Loading branch information
Akshay Chitneni committed Oct 25, 2024
1 parent 9ed4112 commit 20136ef
Show file tree
Hide file tree
Showing 18 changed files with 421 additions and 76 deletions.
16 changes: 3 additions & 13 deletions pkg/controller.v2/trainjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ package controllerv2

import (
"context"
"errors"
"fmt"
runtimeUtils "github.com/kubeflow/training-operator/pkg/util.v2/runtime"

"github.com/go-logr/logr"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/client-go/tools/record"
"k8s.io/klog/v2"
"k8s.io/utils/ptr"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/apiutil"
Expand All @@ -34,8 +33,6 @@ import (
jobruntimes "github.com/kubeflow/training-operator/pkg/runtime.v2"
)

var errorUnsupportedRuntime = errors.New("the specified runtime is not supported")

type TrainJobReconciler struct {
log logr.Logger
client client.Client
Expand Down Expand Up @@ -73,10 +70,10 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *kubeflowv2.TrainJob) error {
log := ctrl.LoggerFrom(ctx)

runtimeRefGK := runtimeRefToGroupKind(trainJob.Spec.RuntimeRef).String()
runtimeRefGK := runtimeUtils.RuntimeRefToGroupKind(trainJob.Spec.RuntimeRef).String()
runtime, ok := r.runtimes[runtimeRefGK]
if !ok {
return fmt.Errorf("%w: %s", errorUnsupportedRuntime, runtimeRefGK)
return fmt.Errorf("%w: %s", runtimeUtils.ErrorUnsupportedRuntime, runtimeRefGK)
}
objs, err := runtime.NewObjects(ctx, trainJob)
if err != nil {
Expand Down Expand Up @@ -117,13 +114,6 @@ func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *k
return nil
}

func runtimeRefToGroupKind(runtimeRef kubeflowv2.RuntimeRef) schema.GroupKind {
return schema.GroupKind{
Group: ptr.Deref(runtimeRef.APIGroup, ""),
Kind: ptr.Deref(runtimeRef.Kind, ""),
}
}

func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager) error {
b := ctrl.NewControllerManagedBy(mgr).
For(&kubeflowv2.TrainJob{})
Expand Down
13 changes: 8 additions & 5 deletions pkg/runtime.v2/core/clustertrainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,17 @@ func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBu
}

func (r *ClusterTrainingRuntime) ValidateObjects(ctx context.Context, old, new *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
clusterTrainingRuntime := &kubeflowv2.ClusterTrainingRuntime{}
if err := r.client.Get(ctx, client.ObjectKey{
Namespace: old.Namespace,
Name: old.Spec.RuntimeRef.Name,
}, &kubeflowv2.ClusterTrainingRuntime{}); err != nil {
Namespace: new.Namespace,
Name: new.Spec.RuntimeRef.Name,
}, clusterTrainingRuntime); err != nil {
return nil, field.ErrorList{
field.Invalid(field.NewPath("spec", "RuntimeRef"), old.Spec.RuntimeRef,
field.Invalid(field.NewPath("spec", "RuntimeRef"), new.Spec.RuntimeRef,
fmt.Sprintf("%v: specified clusterTrainingRuntime must be created before the TrainJob is created", err)),
}
}
return r.framework.RunCustomValidationPlugins(old, new)
info := r.getRuntimeInfo(ctx, new, clusterTrainingRuntime.Spec.Template, clusterTrainingRuntime.Spec.MLPolicy,
clusterTrainingRuntime.Spec.PodGroupPolicy)
return r.framework.RunCustomValidationPlugins(old, new, info)
}
36 changes: 23 additions & 13 deletions pkg/runtime.v2/core/trainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,21 @@ func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *kubeflowv2.T
func (r *TrainingRuntime) buildObjects(
ctx context.Context, trainJob *kubeflowv2.TrainJob, jobSetTemplateSpec kubeflowv2.JobSetTemplateSpec, mlPolicy *kubeflowv2.MLPolicy, podGroupPolicy *kubeflowv2.PodGroupPolicy,
) ([]client.Object, error) {

info := r.getRuntimeInfo(ctx, trainJob, jobSetTemplateSpec, mlPolicy, podGroupPolicy)
if err := r.framework.RunEnforceMLPolicyPlugins(info); err != nil {
return nil, err
}
err := r.framework.RunEnforcePodGroupPolicyPlugins(trainJob, info)
if err != nil {
return nil, err
}
return r.framework.RunComponentBuilderPlugins(ctx, info, trainJob)
}

func (r *TrainingRuntime) getRuntimeInfo(
ctx context.Context, trainJob *kubeflowv2.TrainJob, jobSetTemplateSpec kubeflowv2.JobSetTemplateSpec, mlPolicy *kubeflowv2.MLPolicy, podGroupPolicy *kubeflowv2.PodGroupPolicy) *runtime.Info {

propagationLabels := jobSetTemplateSpec.Labels
if propagationLabels == nil && trainJob.Spec.Labels != nil {
propagationLabels = make(map[string]string, len(trainJob.Spec.Labels))
Expand Down Expand Up @@ -118,14 +133,7 @@ func (r *TrainingRuntime) buildObjects(
Spec: *jobSetTemplateSpec.Spec.DeepCopy(),
}, opts...)

if err := r.framework.RunEnforceMLPolicyPlugins(info); err != nil {
return nil, err
}
err := r.framework.RunEnforcePodGroupPolicyPlugins(trainJob, info)
if err != nil {
return nil, err
}
return r.framework.RunComponentBuilderPlugins(ctx, info, trainJob)
return info
}

func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
Expand All @@ -137,14 +145,16 @@ func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
}

func (r *TrainingRuntime) ValidateObjects(ctx context.Context, old, new *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
trainingRuntime := &kubeflowv2.TrainingRuntime{}
if err := r.client.Get(ctx, client.ObjectKey{
Namespace: old.Namespace,
Name: old.Spec.RuntimeRef.Name,
}, &kubeflowv2.TrainingRuntime{}); err != nil {
Namespace: new.Namespace,
Name: new.Spec.RuntimeRef.Name,
}, trainingRuntime); err != nil {
return nil, field.ErrorList{
field.Invalid(field.NewPath("spec", "runtimeRef"), old.Spec.RuntimeRef,
field.Invalid(field.NewPath("spec", "runtimeRef"), new.Spec.RuntimeRef,
fmt.Sprintf("%v: specified trainingRuntime must be created before the TrainJob is created", err)),
}
}
return r.framework.RunCustomValidationPlugins(old, new)
info := r.getRuntimeInfo(ctx, new, trainingRuntime.Spec.Template, trainingRuntime.Spec.MLPolicy, trainingRuntime.Spec.PodGroupPolicy)
return r.framework.RunCustomValidationPlugins(old, new, info)
}
5 changes: 3 additions & 2 deletions pkg/runtime.v2/framework/core/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,12 @@ func (f *Framework) RunEnforcePodGroupPolicyPlugins(trainJob *kubeflowv2.TrainJo
return nil
}

func (f *Framework) RunCustomValidationPlugins(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
func (f *Framework) RunCustomValidationPlugins(oldObj, newObj *kubeflowv2.TrainJob,
runtimeInfo *runtime.Info) (admission.Warnings, field.ErrorList) {
var aggregatedWarnings admission.Warnings
var aggregatedErrors field.ErrorList
for _, plugin := range f.customValidationPlugins {
warnings, errs := plugin.Validate(oldObj, newObj)
warnings, errs := plugin.Validate(oldObj, newObj, runtimeInfo)
if len(warnings) != 0 {
aggregatedWarnings = append(aggregatedWarnings, warnings...)
}
Expand Down
4 changes: 3 additions & 1 deletion pkg/runtime.v2/framework/core/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ func TestNew(t *testing.T) {
customValidationPlugins: []framework.CustomValidationPlugin{
&mpi.MPI{},
&torch.Torch{},
&jobset.JobSet{},
},
watchExtensionPlugins: []framework.WatchExtensionPlugin{
&coscheduling.CoScheduling{},
Expand Down Expand Up @@ -314,7 +315,8 @@ func TestRunCustomValidationPlugins(t *testing.T) {
if err != nil {
t.Fatal(err)
}
warnings, errs := fwk.RunCustomValidationPlugins(tc.oldObj, tc.newObj)
runtimeInfo := runtime.NewInfo(testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test").Obj())
warnings, errs := fwk.RunCustomValidationPlugins(tc.oldObj, tc.newObj, runtimeInfo)
if diff := cmp.Diff(tc.wantWarnings, warnings, cmpopts.SortSlices(func(a, b string) bool { return a < b })); len(diff) != 0 {
t.Errorf("Unexpected warninigs (-want,+got):\n%s", diff)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime.v2/framework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ type EnforceMLPolicyPlugin interface {

type CustomValidationPlugin interface {
Plugin
Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList)
Validate(oldObj, newObj *kubeflowv2.TrainJob, runtimeInfo *runtime.Info) (admission.Warnings, field.ErrorList)
}

type ComponentBuilderPlugin interface {
Expand Down
115 changes: 115 additions & 0 deletions pkg/runtime.v2/framework/plugins/jobset/jobset.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package jobset
import (
"context"
"fmt"
"k8s.io/apimachinery/pkg/util/validation/field"
"maps"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"

"github.com/go-logr/logr"
batchv1 "k8s.io/api/batch/v1"
Expand Down Expand Up @@ -50,6 +52,7 @@ type JobSet struct {

var _ framework.WatchExtensionPlugin = (*JobSet)(nil)
var _ framework.ComponentBuilderPlugin = (*JobSet)(nil)
var _ framework.CustomValidationPlugin = (*JobSet)(nil)

const Name = "JobSet"

Expand Down Expand Up @@ -140,3 +143,115 @@ func (j *JobSet) ReconcilerBuilders() []runtime.ReconcilerBuilder {
},
}
}

func (j *JobSet) Validate(oldObj, newObj *kubeflowv2.TrainJob, runtimeInfo *runtime.Info) (admission.Warnings, field.ErrorList) {

var allErrs field.ErrorList
specPath := field.NewPath("spec")

jobSet, ok := runtimeInfo.Obj.(*jobsetv1alpha2.JobSet)
if !ok {
return nil, nil
}

if newObj.Spec.ModelConfig != nil {
// validate `model-initializer` container in the `Initializer` Job
if newObj.Spec.ModelConfig.Input != nil {
modelConfigInputPath := specPath.Child("modelConfig").Child("input")
if len(jobSet.Spec.ReplicatedJobs) == 0 {
allErrs = append(allErrs, field.Invalid(modelConfigInputPath, newObj.Spec.ModelConfig.Input, "trainingRuntime should have replicated jobs configured with model config input set"))
} else {
initializerJobFound := false
modelInitializerContainerFound := false
for _, job := range jobSet.Spec.ReplicatedJobs {
if job.Name == "Initializer" {
initializerJobFound = true
for _, container := range job.Template.Spec.Template.Spec.Containers {
if container.Name == "model-initializer" {
modelInitializerContainerFound = true
}
}
}
}
if !initializerJobFound {
allErrs = append(allErrs, field.Invalid(modelConfigInputPath, newObj.Spec.ModelConfig.Input, "trainingRuntime should have replicated job configured with name - Initializer"))
} else if !modelInitializerContainerFound {
allErrs = append(allErrs, field.Invalid(modelConfigInputPath, newObj.Spec.ModelConfig.Input, "trainingRuntime with replicated job initializer should have container with name - model-initializer"))
}
}
}

// validate `model-exporter` container in the `Exporter` Job
if newObj.Spec.ModelConfig.Output != nil {
modelConfigOutputPath := specPath.Child("modelConfig").Child("output")
if len(jobSet.Spec.ReplicatedJobs) == 0 {
allErrs = append(allErrs, field.Invalid(modelConfigOutputPath, newObj.Spec.ModelConfig.Output, "trainingRuntime should have replicated jobs configured with model config output set"))
} else {
exporterJobFound := false
modelExporterContainerFound := false
for _, job := range jobSet.Spec.ReplicatedJobs {
if job.Name == "Exporter" {
exporterJobFound = true
for _, container := range job.Template.Spec.Template.Spec.Containers {
if container.Name == "model-exporter" {
modelExporterContainerFound = true
}
}
}
}
if !exporterJobFound {
allErrs = append(allErrs, field.Invalid(modelConfigOutputPath, newObj.Spec.ModelConfig.Output, "trainingRuntime should have replicated job configured with name - Exporter"))
} else if !modelExporterContainerFound {
allErrs = append(allErrs, field.Invalid(modelConfigOutputPath, newObj.Spec.ModelConfig.Output, "trainingRuntime with replicated job Exporter should have container with name - model-exporter"))
}
}
}
}

if len(newObj.Spec.PodSpecOverrides) != 0 {
podSpecOverridesPath := specPath.Child("podSpecOverrides")
jobsMap := map[string]bool{}
for _, job := range jobSet.Spec.ReplicatedJobs {
jobsMap[job.Name] = true
}
// validate if jobOverrides are valid
for idx, override := range newObj.Spec.PodSpecOverrides {
for _, job := range override.TargetJobs {
if _, found := jobsMap[job.Name]; !found {
allErrs = append(allErrs, field.Invalid(podSpecOverridesPath, newObj.Spec.PodSpecOverrides, fmt.Sprintf("job: %s, configured in the podOverride should be present in the referenced training runtime", job)))
}
}
if len(override.Containers) != 0 {
// validate if containerOverrides are valid
containerMap := map[string]bool{}
for _, job := range jobSet.Spec.ReplicatedJobs {
for _, container := range job.Template.Spec.Template.Spec.Containers {
containerMap[container.Name] = true
}
}
containerOverridePath := podSpecOverridesPath.Index(idx)
for _, container := range override.Containers {
if _, found := containerMap[container.Name]; !found {
allErrs = append(allErrs, field.Invalid(containerOverridePath, override.Containers, fmt.Sprintf("container: %s, configured in the containerOverride should be present in the referenced training runtime", container.Name)))
}
}
}
if len(override.InitContainers) != 0 {
// validate if initContainerOverrides are valid
initContainerMap := map[string]bool{}
for _, job := range jobSet.Spec.ReplicatedJobs {
for _, initContainer := range job.Template.Spec.Template.Spec.InitContainers {
initContainerMap[initContainer.Name] = true
}
}
initContainerOverridePath := podSpecOverridesPath.Index(idx)
for _, container := range override.Containers {
if _, found := initContainerMap[container.Name]; !found {
allErrs = append(allErrs, field.Invalid(initContainerOverridePath, override.InitContainers, fmt.Sprintf("initContainer: %s, configured in the initContainerOverride should be present in the referenced training runtime", container.Name)))
}
}
}
}
}
return nil, allErrs
}
16 changes: 13 additions & 3 deletions pkg/runtime.v2/framework/plugins/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package mpi

import (
"context"
"strconv"

"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -55,7 +56,16 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info) error {
return nil
}

// TODO: Need to implement validations for MPIJob.
func (m *MPI) Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
return nil, nil
func (m *MPI) Validate(oldJobObj, newJobObj *kubeflowv2.TrainJob, runtimeInfo *runtime.Info) (admission.Warnings, field.ErrorList) {
var allErrs field.ErrorList
specPath := field.NewPath("spec")
if newJobObj.Spec.Trainer != nil {
numProcPerNodePath := specPath.Child("trainer").Child("numProcPerNode")
if runtimeInfo.MLPolicy.MPI != nil {
if _, err := strconv.Atoi(*newJobObj.Spec.Trainer.NumProcPerNode); err != nil {
allErrs = append(allErrs, field.Invalid(numProcPerNodePath, newJobObj.Spec.Trainer.NumProcPerNode, "should have an int value"))
}
}
}
return nil, allErrs
}
21 changes: 18 additions & 3 deletions pkg/runtime.v2/framework/plugins/torch/torch.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package torch

import (
"context"
"k8s.io/utils/strings/slices"
"strconv"

"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -51,7 +53,20 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info) error {
return nil
}

// TODO: Need to implement validateions for TorchJob.
func (t *Torch) Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
return nil, nil
func (t *Torch) Validate(oldObj, newObj *kubeflowv2.TrainJob, runtimeInfo *runtime.Info) (admission.Warnings, field.ErrorList) {
var allErrs field.ErrorList
specPath := field.NewPath("spec")

if newObj.Spec.Trainer != nil {
numProcPerNodePath := specPath.Child("trainer").Child("numProcPerNode")
if runtimeInfo.MLPolicy.Torch != nil {
allowedStringValList := []string{"auto", "cpu", "gpu"}
if !slices.Contains(allowedStringValList, *newObj.Spec.Trainer.NumProcPerNode) {
if _, err := strconv.Atoi(*newObj.Spec.Trainer.NumProcPerNode); err != nil {
allErrs = append(allErrs, field.Invalid(numProcPerNodePath, newObj.Spec.Trainer.NumProcPerNode, "should have an int value or auto/cpu/gpu"))
}
}
}
}
return nil, allErrs
}
17 changes: 17 additions & 0 deletions pkg/util.v2/runtime/runtime.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package runtime

import (
"errors"
kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/utils/ptr"
)

var ErrorUnsupportedRuntime = errors.New("the specified runtime is not supported")

func RuntimeRefToGroupKind(runtimeRef kubeflowv2.RuntimeRef) schema.GroupKind {
return schema.GroupKind{
Group: ptr.Deref(runtimeRef.APIGroup, ""),
Kind: ptr.Deref(runtimeRef.Kind, ""),
}
}
Loading

0 comments on commit 20136ef

Please sign in to comment.