Skip to content

Commit

Permalink
KEP-2170: Add TrainJob and TrainingRuntime APIs (#2223)
Browse files Browse the repository at this point in the history
* KEP-2170: Add TrainJob and TrainingRuntime APIs

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Fix TrainJobList

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Register APIs with scheme

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Add SchemeGroupVersion

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Fix TrainingRuntimeSpec omitempty

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Generate manifests only for v1

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Fix pointers for APIs

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Run code-gen

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Use pointer for MPIImplementation

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Update the JobSetTemplate API

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Rename PodGroupPolicy and MLPolicy APIs

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Update comments

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

---------

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>
  • Loading branch information
andreyvelich authored Aug 27, 2024
1 parent 25b8a5b commit 181191e
Show file tree
Hide file tree
Showing 7 changed files with 1,326 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ help: ## Display this help.
##@ Development

manifests: controller-gen ## Generate WebhookConfiguration, ClusterRole and CustomResourceDefinition objects.
$(CONTROLLER_GEN) $(CRD_OPTIONS) rbac:roleName=training-operator webhook paths="./pkg/..." \
$(CONTROLLER_GEN) $(CRD_OPTIONS) rbac:roleName=training-operator webhook paths="./pkg/apis/kubeflow.org/v1/..." \
output:crd:artifacts:config=manifests/base/crds \
output:rbac:artifacts:config=manifests/base/rbac \
output:webhook:artifacts:config=manifests/base/webhook
Expand Down
17 changes: 9 additions & 8 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ go 1.22
require (
github.com/go-logr/logr v1.4.1
github.com/google/go-cmp v0.6.0
github.com/onsi/ginkgo/v2 v2.14.0
github.com/onsi/gomega v1.30.0
github.com/onsi/ginkgo/v2 v2.17.1
github.com/onsi/gomega v1.32.0
github.com/open-policy-agent/cert-controller v0.10.1
github.com/prometheus/client_golang v1.18.0
github.com/sirupsen/logrus v1.9.0
Expand All @@ -19,7 +19,8 @@ require (
k8s.io/klog/v2 v2.110.1
k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00
k8s.io/utils v0.0.0-20230726121419-3b25d923346b
sigs.k8s.io/controller-runtime v0.17.2
sigs.k8s.io/controller-runtime v0.17.3
sigs.k8s.io/jobset v0.5.2
sigs.k8s.io/scheduler-plugins v0.28.9
sigs.k8s.io/yaml v1.4.0
volcano.sh/apis v1.9.0
Expand All @@ -44,8 +45,8 @@ require (
github.com/google/gnostic-models v0.6.8 // indirect
github.com/google/gofuzz v1.2.0 // indirect
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/imdario/mergo v0.3.13 // indirect
github.com/google/uuid v1.3.1 // indirect
github.com/imdario/mergo v0.3.16 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
Expand All @@ -61,7 +62,7 @@ require (
github.com/spf13/pflag v1.0.5 // indirect
go.uber.org/atomic v1.11.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 // indirect
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
golang.org/x/mod v0.16.0 // indirect
golang.org/x/net v0.23.0 // indirect
golang.org/x/oauth2 v0.12.0 // indirect
Expand All @@ -76,8 +77,8 @@ require (
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/apiextensions-apiserver v0.29.0 // indirect
k8s.io/component-base v0.29.0 // indirect
k8s.io/apiextensions-apiserver v0.29.2 // indirect
k8s.io/component-base v0.29.2 // indirect
k8s.io/gengo v0.0.0-20230829151522-9cce18d56c01 // indirect
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect
sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect
Expand Down
35 changes: 18 additions & 17 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec=
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/imdario/mergo v0.3.13 h1:lFzP57bqS/wsqKssCGmtLAb8A0wKjLGrve2q3PPVcBk=
github.com/imdario/mergo v0.3.13/go.mod h1:4lJ1jqUDcsbIECGy0RUJAXNIhg+6ocWgb1ALK2O4oXg=
github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4=
github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
Expand All @@ -80,10 +80,10 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/onsi/ginkgo/v2 v2.14.0 h1:vSmGj2Z5YPb9JwCWT6z6ihcUvDhuXLc3sJiqd3jMKAY=
github.com/onsi/ginkgo/v2 v2.14.0/go.mod h1:JkUdW7JkN0V6rFvsHcJ478egV3XH9NxpD27Hal/PhZw=
github.com/onsi/gomega v1.30.0 h1:hvMK7xYz4D3HapigLTeGdId/NcfQx1VHMJc60ew99+8=
github.com/onsi/gomega v1.30.0/go.mod h1:9sxs+SwGrKI0+PWe4Fxa9tFQQBG5xSsSbMXOI8PPpoQ=
github.com/onsi/ginkgo/v2 v2.17.1 h1:V++EzdbhI4ZV4ev0UTIj0PzhzOcReJFyJaLjtSF55M8=
github.com/onsi/ginkgo/v2 v2.17.1/go.mod h1:llBI3WDLL9Z6taip6f33H76YcWtJv+7R3HigUjbIBOs=
github.com/onsi/gomega v1.32.0 h1:JRYU78fJ1LPxlckP6Txi/EYqJvjtMrDC04/MM5XRHPk=
github.com/onsi/gomega v1.32.0/go.mod h1:a4x4gW6Pz2yK1MAmvluYme5lvYTn61afQ2ETw/8n4Lg=
github.com/open-policy-agent/cert-controller v0.10.1 h1:RXSYoyn8FdCenWecRP//UV5nbVfmstNpj4kHQFkvPK4=
github.com/open-policy-agent/cert-controller v0.10.1/go.mod h1:4uRbBLY5DsPOog+a9pqk3JLxuuhrWsbUedQW65HcLTI=
github.com/open-policy-agent/frameworks/constraint v0.0.0-20230822235116-f0b62fe1e4c4 h1:5dum5SLEz+95JDLkMls7Z7IDPjvSq3UhJSFe4f5einQ=
Expand Down Expand Up @@ -130,8 +130,8 @@ go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 h1:tnebWN09GYg9OLPss1KXj8txwZc6X6uMr6VFdcGNbHw=
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic=
Expand Down Expand Up @@ -191,21 +191,20 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
k8s.io/api v0.29.3 h1:2ORfZ7+bGC3YJqGpV0KSDDEVf8hdGQ6A03/50vj8pmw=
k8s.io/api v0.29.3/go.mod h1:y2yg2NTyHUUkIoTC+phinTnEa3KFM6RZ3szxt014a80=
k8s.io/apiextensions-apiserver v0.29.0 h1:0VuspFG7Hj+SxyF/Z/2T0uFbI5gb5LRgEyUVE3Q4lV0=
k8s.io/apiextensions-apiserver v0.29.0/go.mod h1:TKmpy3bTS0mr9pylH0nOt/QzQRrW7/h7yLdRForMZwc=
k8s.io/apiextensions-apiserver v0.29.2 h1:UK3xB5lOWSnhaCk0RFZ0LUacPZz9RY4wi/yt2Iu+btg=
k8s.io/apiextensions-apiserver v0.29.2/go.mod h1:aLfYjpA5p3OwtqNXQFkhJ56TB+spV8Gc4wfMhUA3/b8=
k8s.io/apimachinery v0.29.3 h1:2tbx+5L7RNvqJjn7RIuIKu9XTsIZ9Z5wX2G22XAa5EU=
k8s.io/apimachinery v0.29.3/go.mod h1:hx/S4V2PNW4OMg3WizRrHutyB5la0iCUbZym+W0EQIU=
k8s.io/client-go v0.29.3 h1:R/zaZbEAxqComZ9FHeQwOh3Y1ZUs7FaHKZdQtIc2WZg=
k8s.io/client-go v0.29.3/go.mod h1:tkDisCvgPfiRpxGnOORfkljmS+UrW+WtXAy2fTvXJB0=
k8s.io/code-generator v0.29.3 h1:m7E25/t9R9NvejspO2zBdyu+/Gl0Z5m7dCRc680KS14=
k8s.io/code-generator v0.29.3/go.mod h1:x47ofBhN4gxYFcxeKA1PYXeaPreAGaDN85Y/lNUsPoM=
k8s.io/component-base v0.29.0 h1:T7rjd5wvLnPBV1vC4zWd/iWRbV8Mdxs+nGaoaFzGw3s=
k8s.io/component-base v0.29.0/go.mod h1:sADonFTQ9Zc9yFLghpDpmNXEdHyQmFIGbiuZbqAXQ1M=
k8s.io/component-base v0.29.2 h1:lpiLyuvPA9yV1aQwGLENYyK7n/8t6l3nn3zAtFTJYe8=
k8s.io/component-base v0.29.2/go.mod h1:BfB3SLrefbZXiBfbM+2H1dlat21Uewg/5qtKOl8degM=
k8s.io/gengo v0.0.0-20230829151522-9cce18d56c01 h1:pWEwq4Asjm4vjW7vcsmijwBhOr1/shsbSYiWXmNGlks=
k8s.io/gengo v0.0.0-20230829151522-9cce18d56c01/go.mod h1:FiNAH4ZV3gBg2Kwh89tzAEV2be7d5xI0vBa/VySYy3E=
k8s.io/klog/v2 v2.2.0/go.mod h1:Od+F08eJP+W3HUb4pSrPpgp9DGU4GzlpG/TmITuYh/Y=
Expand All @@ -217,8 +216,10 @@ k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00 h1:aVUu9fTY98ivBPKR9Y5w/A
k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00/go.mod h1:AsvuZPBlUDVuCdzJ87iajxtXuR9oktsTctW/R9wwouA=
k8s.io/utils v0.0.0-20230726121419-3b25d923346b h1:sgn3ZU783SCgtaSJjpcVVlRqd6GSnlTLKgpAAttJvpI=
k8s.io/utils v0.0.0-20230726121419-3b25d923346b/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0=
sigs.k8s.io/controller-runtime v0.17.2 h1:FwHwD1CTUemg0pW2otk7/U5/i5m2ymzvOXdbeGOUvw0=
sigs.k8s.io/controller-runtime v0.17.2/go.mod h1:+MngTvIQQQhfXtwfdGw/UOQ/aIaqsYywfCINOtwMO/s=
sigs.k8s.io/controller-runtime v0.17.3 h1:65QmN7r3FWgTxDMz9fvGnO1kbf2nu+acg9p2R9oYYYk=
sigs.k8s.io/controller-runtime v0.17.3/go.mod h1:N0jpP5Lo7lMTF9aL56Z/B2oWBJjey6StQM0jRbKQXtY=
sigs.k8s.io/jobset v0.5.2 h1:276q5Pi/ErLYj+GQ0ydEXR6tx3LwBhEzHLQv+k8bYF4=
sigs.k8s.io/jobset v0.5.2/go.mod h1:Vg99rj/6OoGvy1uvywGEHOcVLCWWJYkJtisKqdWzcFw=
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo=
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0=
sigs.k8s.io/scheduler-plugins v0.28.9 h1:1/bXRoXuSUFr1FLqxrzScdyZMl/G1psuDJcDKYxTo+Q=
Expand Down
39 changes: 39 additions & 0 deletions pkg/apis/kubeflow.org/v2alpha1/groupversion_info.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
Copyright 2024 The Kubeflow Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

// Package v2alpha1 contains API Schema definitions for the kubeflow.org v2alpha1 API group
// +kubebuilder:object:generate=true
// +groupName=kubeflow.org
package v2alpha1

import (
"k8s.io/apimachinery/pkg/runtime/schema"
"sigs.k8s.io/controller-runtime/pkg/scheme"
)

var (
// GroupVersion is group version used to register these objects.
GroupVersion = schema.GroupVersion{Group: "kubeflow.org", Version: "v2alpha1"}

// SchemeBuilder is used to add go types to the GroupVersionKind scheme.
SchemeBuilder = &scheme.Builder{GroupVersion: GroupVersion}

// SchemeGroupVersion is alias to GroupVersion for client-go libraries.
SchemeGroupVersion = GroupVersion

// AddToScheme adds the types in this group-version to the given scheme.
AddToScheme = SchemeBuilder.AddToScheme
)
195 changes: 195 additions & 0 deletions pkg/apis/kubeflow.org/v2alpha1/trainingruntime_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,198 @@ limitations under the License.
*/

package v2alpha1

import (
autoscalingv2 "k8s.io/api/autoscaling/v2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
)

// +kubebuilder:object:root=true

// ClusterTrainingRuntime represents a training runtime which can be referenced as part of
// `trainingRuntimeRef` API in TrainJob. This resource is a cluster-scoped and can be referenced
// by TrainJob that created in *any* namespace.
type ClusterTrainingRuntime struct {
metav1.TypeMeta `json:",inline"`

// Standard object's metadata.
metav1.ObjectMeta `json:"metadata,omitempty"`

// Specification of the desired ClusterTrainingRuntime.
Spec TrainingRuntimeSpec `json:"spec,omitempty"`
}

// +kubebuilder:object:root=true

// ClusterTrainingRuntimeList is a collection of cluster training runtimes.
type ClusterTrainingRuntimeList struct {
metav1.TypeMeta `json:",inline"`

// Standard list metadata.
metav1.ListMeta `json:"metadata,omitempty"`

// List of ClusterTrainingRuntimes.
Items []ClusterTrainingRuntime `json:"items"`
}

// +kubebuilder:object:root=true

// TrainingRuntime represents a training runtime which can be referenced as part of
// `trainingRuntimeRef` API in TrainJob. This resource is a namespaced-scoped and can be referenced
// by TrainJob that created in the *same* namespace as the TrainingRuntime.
type TrainingRuntime struct {
metav1.TypeMeta `json:",inline"`

// Standard object's metadata.
metav1.ObjectMeta `json:"metadata,omitempty"`

// Specification of the desired TrainingRuntime.
Spec TrainingRuntimeSpec `json:"spec,omitempty"`
}

// +kubebuilder:object:root=true

// TrainingRuntimeList is a collection of training runtimes.
type TrainingRuntimeList struct {
metav1.TypeMeta `json:",inline"`

// Standard list metadata.
metav1.ListMeta `json:"metadata,omitempty"`

// List of TrainingRuntimes.
Items []TrainingRuntime `json:"items"`
}

// TrainingRuntimeSpec represents a specification of the desired training runtime.
type TrainingRuntimeSpec struct {
// Configuration for the model training with ML-specific parameters.
MLPolicy *MLPolicy `json:"mlPolicy,omitempty"`

// Configuration for the PodGroup to enable gang-scheduling via supported plugins.
PodGroupPolicy *PodGroupPolicy `json:"podGroupPolicy,omitempty"`

// JobSet template which will be used by TrainJob.
Template JobSetTemplateSpec `json:"template"`
}

// JobSetTemplateSpec represents a template of the desired JobSet.
type JobSetTemplateSpec struct {
// Metadata for custom JobSet's labels and annotations.
// JobSet name and namespace is equal to the TrainJob's name and namespace.
metav1.ObjectMeta `json:"metadata,omitempty"`

// Specification of the desired JobSet which will be created from TrainJob.
Spec jobsetv1alpha2.JobSetSpec `json:"spec,omitempty"`
}

// PodGroupPolicy represents a PodGroup configuration for gang-scheduling.
type PodGroupPolicy struct {
// Configuration for gang-scheduling using various plugins.
PodGroupPolicySource `json:",inline"`
}

// PodGroupPolicySource represents supported plugins for gang-scheduling.
// Only one of its members may be specified.
type PodGroupPolicySource struct {
// Coscheduling plugin from the Kubernetes scheduler-plugins for gang-scheduling.
Coscheduling *CoschedulingPodGroupPolicySource `json:"coscheduling,omitempty"`

// TODO (andreyvelich): Add support for Volcano gang-scheduler.
}

// CoschedulingPodGroupPolicySource represents configuration for coscheduling plugin.
// The number of min members in the PodGroupSpec is always equal to the number of nodes.
type CoschedulingPodGroupPolicySource struct {
// Time threshold to schedule PodGroup for gang-scheduling.
// If the scheduling timeout is equal to 0, the default value is used.
// Defaults to 60 seconds.
ScheduleTimeoutSeconds *int32 `json:"scheduleTimeoutSeconds,omitempty"`
}

// MLPolicy represents configuration for the model trining with ML-specific parameters.
type MLPolicy struct {
// Number of training nodes.
// Defaults to 1.
NumNodes *int32 `json:"numNodes,omitempty"`

// Configuration for the runtime-specific parameters, such as Torch or MPI.
// Only one of its members may be specified.
MLPolicySource `json:",inline"`
}

// MLPolicySource represents the runtime-specific configuration for various technologies.
// One of the following specs can be set.
type MLPolicySource struct {
// Configuration for the PyTorch runtime.
Torch *TorchMLPolicySource `json:"torch,omitempty"`

// Configuration for the MPI Runtime.
MPI *MPIMLPolicySource `json:"mpi,omitempty"`
}

// TorchMLPolicySource represents a PyTorch runtime configuration.
type TorchMLPolicySource struct {
// Number of processes per node.
// This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI.
// Supported values: `auto`, `cpu`, `gpu`, or int value.
// TODO (andreyvelich): Add kubebuilder validation.
// Defaults to `auto`.
NumProcPerNode *string `json:"numProcPerNode,omitempty"`

// Elastic policy for the PyTorch training.
ElasticPolicy *TorchElasticPolicy `json:"elasticPolicy,omitempty"`
}

// TorchElasticPolicy represents a configuration for the PyTorch elastic training.
// If this policy is set, the `.spec.numNodes` parameter must be omitted, since min and max node
// is used to configure the `torchrun` CLI argument: `--nnodes=minNodes:maxNodes`.
// Only `c10d` backend is supported for the Rendezvous communication.
type TorchElasticPolicy struct {
// How many times the training job can be restarted.
// This value is inserted into the `--max-restarts` argument of the `torchrun` CLI and
// the `.spec.failurePolicy.maxRestarts` parameter of the training Job.
MaxRestarts *int32 `json:"maxRestarts,omitempty"`

// Lower limit for the number of nodes to which training job can scale down.
MinNodes *int32 `json:"minNodes,omitempty"`

// Upper limit for the number of nodes to which training job can scale up.
MaxNodes *int32 `json:"maxNodes,omitempty"`

// Specification which are used to calculate the desired number of nodes. See the individual
// metric source types for more information about how each type of metric must respond.
// The HPA will be created to perform auto-scaling.
Metrics []autoscalingv2.MetricSpec `json:"metrics,omitempty"`
}

// MPIMLPolicySource represents a MPI runtime configuration.
type MPIMLPolicySource struct {
// Number of processes per node.
// This value is equal to the number of slots for each node in the hostfile.
NumProcPerNode *int32 `json:"numProcPerNode,omitempty"`

// Implementation name for the MPI to create the appropriate hostfile.
// Defaults to OpenMPI.
MPIImplementation *MPIImplementation `json:"mpiImplementation,omitempty"`

// Directory where SSH keys are mounted.
SSHAuthMountPath *string `json:"SSHAuthMountPath,omitempty"`

// Whether to run training process on the launcher Job.
// Defaults to false.
RunLauncherAsNode *bool `json:"runLauncherAsNode,omitempty"`
}

// MPIImplementation represents one of the supported MPI implementations.
type MPIImplementation string

const (
MPIImplementationOpenMPI MPIImplementation = "OpenMPI"
MPIImplementationIntel MPIImplementation = "Intel"
MPIImplementationMPICH MPIImplementation = "MPICH"
)

func init() {
SchemeBuilder.Register(&ClusterTrainingRuntime{}, &ClusterTrainingRuntimeList{}, &TrainingRuntime{}, &TrainingRuntimeList{})
}
Loading

0 comments on commit 181191e

Please sign in to comment.