Skip to content

Commit

Permalink
KEP-2170: Implement JobSet, PlainML, and Torch Plugins (#2308)
Browse files Browse the repository at this point in the history
* KEP-2170: Implement JobSet and PlainML Plugins

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Fix nil pointer exception for Trainer

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Fix unit tests in runtime package

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Fix unit tests

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Fix integration tests

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Fix lint

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Implement Torch Plugin

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Use list for the Info envs

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Fix golang ci

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Fix Torch plugin

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Use K8s sets
Update error return
Use ptr.Deref() for nil values

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Use client.Object for Build() call

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Remove DeepCopy

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Remove MLPolicy and PodGroupPolicy from the Info object

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Inline error

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Remove SDK jar file

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Add integration test for Torch plugin

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Add TODO to calculate PodGroup values in unit tests

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Revert the change to add original Runtime Policies to Info

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Create const for the DefaultJobReplicas

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Check if PodLabels is empty

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

---------

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>
  • Loading branch information
andreyvelich authored Oct 31, 2024
1 parent 3f7ec16 commit 7c5ea70
Show file tree
Hide file tree
Showing 21 changed files with 1,036 additions and 574 deletions.
4 changes: 2 additions & 2 deletions pkg/apis/kubeflow.org/v2alpha1/trainingruntime_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import (
)

const (
// TrainingRuntimeKind is the GroupVersionKind Kind name for the TrainingRuntime.
// TrainingRuntimeKind is the Kind name for the TrainingRuntime.
TrainingRuntimeKind string = "TrainingRuntime"
// ClusterTrainingRuntimeKind is the GroupVersionKind Kind name for the ClusterTrainingRuntime.
// ClusterTrainingRuntimeKind is the Kind name for the ClusterTrainingRuntime.
ClusterTrainingRuntimeKind string = "ClusterTrainingRuntime"
)

Expand Down
1 change: 1 addition & 0 deletions pkg/apis/kubeflow.org/v2alpha1/trainjob_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
)

const (
// TrainJobKind is the Kind name for the TrainJob.
TrainJobKind string = "TrainJob"
)

Expand Down
59 changes: 59 additions & 0 deletions pkg/constants/constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package constants

import (
"fmt"

batchv1 "k8s.io/api/batch/v1"
)

const (

// DefaultJobReplicas is the default value for the ReplicatedJob replicas.
DefaultJobReplicas = 1

// JobSetKind is the Kind name for the JobSet.
JobSetKind string = "JobSet"

// JobTrainerNode is the Job name for the trainer node.
JobTrainerNode string = "trainer-node"

// ContainerTrainer is the container name for the trainer.
ContainerTrainer string = "trainer"

// ContainerTrainerPort is the default port for the trainer nodes communication.
ContainerTrainerPort int32 = 29500

// JobInitializer is the Job name for the initializer.
JobInitializer string = "initializer"

// ContainerModelInitializer is the container name for the model initializer.
ContainerModelInitializer string = "model-initializer"

// ContainerDatasetInitializer is the container name for the dataset initializer.
ContainerDatasetInitializer string = "dataset-initializer"

// PodGroupKind is the Kind name for the PodGroup.
PodGroupKind string = "PodGroup"

// Distributed envs for torchrun.
// Ref: https://github.com/pytorch/pytorch/blob/3a0d0885171376ed610c8175a19ba40411fc6f3f/torch/distributed/argparse_util.py#L45
// TorchEnvNumNodes is the env name for the number of training nodes.
TorchEnvNumNodes string = "PET_NNODES"

// TorchEnvNumProcPerNode is the env name for the number of procs per node (e.g. number of GPUs per Pod).
TorchEnvNumProcPerNode string = "PET_NPROC_PER_NODE"

// TorchEnvNodeRank is the env name for the node RANK
TorchEnvNodeRank string = "PET_NODE_RANK"

// TorchEnvMasterAddr is the env name for the master node address.
TorchEnvMasterAddr string = "PET_MASTER_ADDR"

// TorchEnvMasterPort is the env name for the master node port.
TorchEnvMasterPort string = "PET_MASTER_PORT"
)

var (
// JobCompletionIndexFieldPath is the field path for the Job completion index annotation.
JobCompletionIndexFieldPath string = fmt.Sprintf("metadata.annotations['%s']", batchv1.JobCompletionIndexAnnotation)
)
55 changes: 22 additions & 33 deletions pkg/runtime.v2/core/clustertrainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@ import (

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"
schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"

Expand All @@ -35,60 +33,52 @@ import (
)

func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
baseRuntime := testingutil.MakeClusterTrainingRuntimeWrapper("test-runtime").
Clone()

resRequests := corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
}

cases := map[string]struct {
trainJob *kubeflowv2.TrainJob
clusterTrainingRuntime *kubeflowv2.ClusterTrainingRuntime
wantObjs []client.Object
wantError error
}{
"succeeded to build JobSet and PodGroup": {
"succeeded to build PodGroup and JobSet with NumNodes from the Runtime and container from the Trainer.": {
clusterTrainingRuntime: testingutil.MakeClusterTrainingRuntimeWrapper("test-runtime").RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeClusterTrainingRuntimeWrapper("test-runtime").Spec).
NumNodes(100).
ContainerTrainer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
ContainerDatasetModelInitializer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
PodGroupPolicyCoschedulingSchedulingTimeout(120).
Obj(),
).Obj(),
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
Suspend(true).
UID("uid").
RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "test-runtime").
Trainer(
testingutil.MakeTrainJobTrainerWrapper().
ContainerImage("test:trainjob").
ContainerTrainer("test:trainjob", []string{"trainjob"}, []string{"trainjob"}, resRequests).
Obj(),
).
Obj(),
clusterTrainingRuntime: baseRuntime.RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec).
ContainerImage("test:runtime").
PodGroupPolicyCoschedulingSchedulingTimeout(120).
MLPolicyNumNodes(20).
ResourceRequests(0, corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
}).
ResourceRequests(1, corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2"),
}).
Obj(),
).Obj(),
wantObjs: []client.Object{
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
NumNodes(100).
ContainerTrainer("test:trainjob", []string{"trainjob"}, []string{"trainjob"}, resRequests).
ContainerDatasetModelInitializer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
Suspend(true).
PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job").
ContainerImage(ptr.To("test:trainjob")).
JobCompletionMode(batchv1.IndexedCompletion).
ResourceRequests(0, corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
}).
ResourceRequests(1, corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2"),
}).
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid").
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainJobKind), "test-job", "uid").
Obj(),
testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job").
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid").
MinMember(40).
SchedulingTimeout(120).
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainJobKind), "test-job", "uid").
MinMember(101). // 101 replicas = 100 Trainer nodes + 1 Initializer.
MinResources(corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("60"),
corev1.ResourceCPU: resource.MustParse("101"), // Every replica has 1 CPU = 101 CPUs in total.
}).
SchedulingTimeout(120).
Obj(),
},
},
Expand All @@ -98,7 +88,6 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "test-runtime").
Trainer(
testingutil.MakeTrainJobTrainerWrapper().
ContainerImage("test:trainjob").
Obj(),
).
Obj(),
Expand Down
31 changes: 15 additions & 16 deletions pkg/runtime.v2/core/trainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@ import (
"errors"
"fmt"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
Expand Down Expand Up @@ -106,26 +104,27 @@ func (r *TrainingRuntime) buildObjects(
runtime.WithMLPolicy(mlPolicy),
runtime.WithPodGroupPolicy(podGroupPolicy),
}
for idx, rJob := range jobSetTemplateSpec.Spec.ReplicatedJobs {
replicas := jobSetTemplateSpec.Spec.ReplicatedJobs[idx].Replicas * ptr.Deref(rJob.Template.Spec.Completions, 1)
opts = append(opts, runtime.WithPodSpecReplicas(rJob.Name, replicas, rJob.Template.Spec.Template.Spec))

for _, rJob := range jobSetTemplateSpec.Spec.ReplicatedJobs {
// By default every ReplicatedJob has only 1 replica.
opts = append(opts, runtime.WithPodSpecReplicas(rJob.Name, 1, rJob.Template.Spec.Template.Spec))
}
info := runtime.NewInfo(&jobsetv1alpha2.JobSet{
TypeMeta: metav1.TypeMeta{
APIVersion: jobsetv1alpha2.SchemeGroupVersion.String(),
Kind: "JobSet",
},
Spec: *jobSetTemplateSpec.Spec.DeepCopy(),
}, opts...)

if err := r.framework.RunEnforceMLPolicyPlugins(info); err != nil {
info := runtime.NewInfo(opts...)

if err := r.framework.RunEnforceMLPolicyPlugins(info, trainJob); err != nil {
return nil, err
}
err := r.framework.RunEnforcePodGroupPolicyPlugins(trainJob, info)
if err != nil {

if err := r.framework.RunEnforcePodGroupPolicyPlugins(info, trainJob); err != nil {
return nil, err
}
return r.framework.RunComponentBuilderPlugins(ctx, info, trainJob)

jobSetTemplate := jobsetv1alpha2.JobSet{
Spec: jobSetTemplateSpec.Spec,
}

return r.framework.RunComponentBuilderPlugins(ctx, jobSetTemplate.DeepCopy(), info, trainJob)
}

func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
Expand Down
Loading

0 comments on commit 7c5ea70

Please sign in to comment.