From 013425d21019994a7753d06cc093777e8eb9c743 Mon Sep 17 00:00:00 2001 From: Yuki Iwai Date: Wed, 28 Aug 2024 07:53:44 +0900 Subject: [PATCH 01/12] KEP-2170: Implement runtime framework interfaces Signed-off-by: Yuki Iwai --- .github/workflows/unittests.yaml | 1 + Makefile | 15 + cmd/training-operator.v2alpha1/main.go | 25 +- go.mod | 66 +-- go.sum | 155 +++--- hack/swagger/go.mod | 28 +- hack/swagger/go.sum | 89 ++- pkg/controller.v2/setup.go | 10 +- pkg/controller.v2/trainjob_controller.go | 16 +- pkg/runtime.v2/core/clustertrainingruntime.go | 62 +++ .../core/clustertrainingruntime_test.go | 145 +++++ pkg/runtime.v2/core/core.go | 39 ++ pkg/runtime.v2/core/registry.go | 34 ++ pkg/runtime.v2/core/trainingruntime.go | 121 +++++ pkg/runtime.v2/core/trainingruntime_test.go | 135 +++++ pkg/runtime.v2/framework/core/framework.go | 126 +++++ .../framework/core/framework_test.go | 514 ++++++++++++++++++ pkg/runtime.v2/framework/interface.go | 57 ++ .../plugins/coscheduling/coscheduling.go | 308 +++++++++++ .../framework/plugins/coscheduling/indexer.go | 56 ++ .../framework/plugins/jobset/builder.go | 83 +++ .../framework/plugins/jobset/jobset.go | 121 +++++ pkg/runtime.v2/framework/plugins/mpi/mpi.go | 60 ++ .../framework/plugins/plainml/plainml.go | 55 ++ pkg/runtime.v2/framework/plugins/registry.go | 42 ++ .../framework/plugins/torch/torch.go | 56 ++ pkg/runtime.v2/indexer/indexer.go | 45 ++ pkg/runtime.v2/interface.go | 33 ++ pkg/runtime.v2/runtime.go | 147 +++++ pkg/runtime.v2/runtime_test.go | 220 ++++++++ pkg/util.v2/testing/client.go | 63 +++ pkg/util.v2/testing/wrapper.go | 475 ++++++++++++++++ .../clustertrainingruntime_webhook.go | 17 +- pkg/webhook.v2/setup.go | 14 +- pkg/webhook.v2/trainingruntime_webhook.go | 17 +- pkg/webhook.v2/trainjob_webhook.go | 17 +- test/integration/framework/framework.go | 12 +- 37 files changed, 3256 insertions(+), 223 deletions(-) create mode 100644 pkg/runtime.v2/core/clustertrainingruntime.go create mode 100644 pkg/runtime.v2/core/clustertrainingruntime_test.go create mode 100644 pkg/runtime.v2/core/core.go create mode 100644 pkg/runtime.v2/core/registry.go create mode 100644 pkg/runtime.v2/core/trainingruntime.go create mode 100644 pkg/runtime.v2/core/trainingruntime_test.go create mode 100644 pkg/runtime.v2/framework/core/framework.go create mode 100644 pkg/runtime.v2/framework/core/framework_test.go create mode 100644 pkg/runtime.v2/framework/interface.go create mode 100644 pkg/runtime.v2/framework/plugins/coscheduling/coscheduling.go create mode 100644 pkg/runtime.v2/framework/plugins/coscheduling/indexer.go create mode 100644 pkg/runtime.v2/framework/plugins/jobset/builder.go create mode 100644 pkg/runtime.v2/framework/plugins/jobset/jobset.go create mode 100644 pkg/runtime.v2/framework/plugins/mpi/mpi.go create mode 100644 pkg/runtime.v2/framework/plugins/plainml/plainml.go create mode 100644 pkg/runtime.v2/framework/plugins/registry.go create mode 100644 pkg/runtime.v2/framework/plugins/torch/torch.go create mode 100644 pkg/runtime.v2/indexer/indexer.go create mode 100644 pkg/runtime.v2/interface.go create mode 100644 pkg/runtime.v2/runtime.go create mode 100644 pkg/runtime.v2/runtime_test.go create mode 100644 pkg/util.v2/testing/client.go create mode 100644 pkg/util.v2/testing/wrapper.go diff --git a/.github/workflows/unittests.yaml b/.github/workflows/unittests.yaml index f5e5ea2b65..22380b5c2a 100644 --- a/.github/workflows/unittests.yaml +++ b/.github/workflows/unittests.yaml @@ -37,6 +37,7 @@ jobs: - name: Run Go test for v2 run: | + make testv2 make test-integrationv2 ENVTEST_K8S_VERSION=${{ matrix.kubernetes-version }} - name: Coveralls report diff --git a/Makefile b/Makefile index 5d8bfc2596..2a65e73622 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,17 @@ else GOBIN=$(shell go env GOBIN) endif +# Setting GREP allows macos users to install GNU grep and use the latter +# instead of the default BSD grep. +ifeq ($(shell command -v ggrep 2>/dev/null),) + GREP ?= $(shell command -v grep) +else + GREP ?= $(shell command -v ggrep) +endif +ifeq ($(shell ${GREP} --version 2>&1 | grep -q GNU; echo $$?),1) + $(error !!! GNU grep is required. If on OS X, use 'brew install grep'.) +endif + # Setting SHELL to bash allows bash commands to be executed by recipes. # This is a requirement for 'setup-envtest.sh' in the test target. # Options are set to exit when a recipe line exits non-zero or a piped command fails. @@ -80,6 +91,10 @@ test: envtest test-integrationv2: envtest KUBEBUILDER_ASSETS="$(shell setup-envtest use $(ENVTEST_K8S_VERSION) -p path)" go test ./test/... -coverprofile cover.out +.PHONY: testv2 +testv2: + go test $(shell go list ./pkg/... | $(GREP) -E '.*\.v2') -coverprofile cover.out + envtest: ifndef HAS_SETUP_ENVTEST go install sigs.k8s.io/controller-runtime/tools/setup-envtest@bf15e44028f908c790721fc8fe67c7bf2d06a611 # v0.17.2 diff --git a/cmd/training-operator.v2alpha1/main.go b/cmd/training-operator.v2alpha1/main.go index 08bb9d4791..27c3007bfe 100644 --- a/cmd/training-operator.v2alpha1/main.go +++ b/cmd/training-operator.v2alpha1/main.go @@ -17,6 +17,7 @@ limitations under the License. package main import ( + "context" "crypto/tls" "errors" "flag" @@ -25,7 +26,7 @@ import ( zaplog "go.uber.org/zap" "go.uber.org/zap/zapcore" - "k8s.io/apimachinery/pkg/runtime" + apiruntime "k8s.io/apimachinery/pkg/runtime" utilruntime "k8s.io/apimachinery/pkg/util/runtime" clientgoscheme "k8s.io/client-go/kubernetes/scheme" ctrl "sigs.k8s.io/controller-runtime" @@ -34,15 +35,17 @@ import ( metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" "sigs.k8s.io/controller-runtime/pkg/webhook" jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" + schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" "github.com/kubeflow/training-operator/pkg/cert" controllerv2 "github.com/kubeflow/training-operator/pkg/controller.v2" + runtimecore "github.com/kubeflow/training-operator/pkg/runtime.v2/core" webhookv2 "github.com/kubeflow/training-operator/pkg/webhook.v2" ) var ( - scheme = runtime.NewScheme() + scheme = apiruntime.NewScheme() setupLog = ctrl.Log.WithName("setup") ) @@ -50,6 +53,7 @@ func init() { utilruntime.Must(clientgoscheme.AddToScheme(scheme)) utilruntime.Must(kubeflowv2.AddToScheme(scheme)) utilruntime.Must(jobsetv1alpha2.AddToScheme(scheme)) + utilruntime.Must(schedulerpluginsv1alpha1.AddToScheme(scheme)) } func main() { @@ -127,27 +131,34 @@ func main() { os.Exit(1) } + ctx := ctrl.SetupSignalHandler() + setupProbeEndpoints(mgr, certsReady) // Set up controllers using goroutines to start the manager quickly. - go setupControllers(mgr, certsReady) + go setupControllers(ctx, mgr, certsReady) setupLog.Info("Starting manager") - if err = mgr.Start(ctrl.SetupSignalHandler()); err != nil { + if err = mgr.Start(ctx); err != nil { setupLog.Error(err, "Could not run manager") os.Exit(1) } } -func setupControllers(mgr ctrl.Manager, certsReady <-chan struct{}) { +func setupControllers(ctx context.Context, mgr ctrl.Manager, certsReady <-chan struct{}) { setupLog.Info("Waiting for certificate generation to complete") <-certsReady setupLog.Info("Certs ready") - if failedCtrlName, err := controllerv2.SetupControllers(mgr); err != nil { + runtimes, err := runtimecore.New(ctx, mgr.GetClient(), mgr.GetFieldIndexer()) + if err != nil { + setupLog.Error(err, "Could not initialize runtimes") + os.Exit(1) + } + if failedCtrlName, err := controllerv2.SetupControllers(mgr, runtimes); err != nil { setupLog.Error(err, "Could not create controller", "controller", failedCtrlName) os.Exit(1) } - if failedWebhook, err := webhookv2.Setup(mgr); err != nil { + if failedWebhook, err := webhookv2.Setup(mgr, runtimes); err != nil { setupLog.Error(err, "Could not create webhook", "webhook", failedWebhook) os.Exit(1) } diff --git a/go.mod b/go.mod index eb2f0afcbc..f55eda8536 100644 --- a/go.mod +++ b/go.mod @@ -3,24 +3,25 @@ module github.com/kubeflow/training-operator go 1.22 require ( - github.com/go-logr/logr v1.4.1 + github.com/go-logr/logr v1.4.2 github.com/google/go-cmp v0.6.0 - github.com/onsi/ginkgo/v2 v2.17.1 - github.com/onsi/gomega v1.32.0 + github.com/onsi/ginkgo/v2 v2.19.0 + github.com/onsi/gomega v1.33.1 github.com/open-policy-agent/cert-controller v0.10.1 github.com/prometheus/client_golang v1.18.0 - github.com/sirupsen/logrus v1.9.0 + github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.9.0 go.uber.org/zap v1.27.0 - k8s.io/api v0.29.3 - k8s.io/apimachinery v0.29.3 - k8s.io/client-go v0.29.3 - k8s.io/code-generator v0.29.3 - k8s.io/klog/v2 v2.110.1 + k8s.io/api v0.29.5 + k8s.io/apimachinery v0.29.5 + k8s.io/client-go v0.29.5 + k8s.io/code-generator v0.29.5 + k8s.io/klog/v2 v2.120.1 k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00 - k8s.io/utils v0.0.0-20230726121419-3b25d923346b + k8s.io/utils v0.0.0-20240502163921-fe8a2dddb1d0 sigs.k8s.io/controller-runtime v0.17.3 sigs.k8s.io/jobset v0.5.2 + sigs.k8s.io/kueue v0.6.3 sigs.k8s.io/scheduler-plugins v0.28.9 sigs.k8s.io/yaml v1.4.0 volcano.sh/apis v1.9.0 @@ -28,24 +29,24 @@ require ( require ( github.com/beorn7/perks v1.0.1 // indirect - github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/emicklei/go-restful/v3 v3.11.0 // indirect - github.com/evanphx/json-patch v5.6.0+incompatible // indirect + github.com/emicklei/go-restful/v3 v3.12.1 // indirect + github.com/evanphx/json-patch v5.9.0+incompatible // indirect github.com/evanphx/json-patch/v5 v5.8.0 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/go-logr/zapr v1.3.0 // indirect - github.com/go-openapi/jsonpointer v0.19.6 // indirect - github.com/go-openapi/jsonreference v0.20.2 // indirect - github.com/go-openapi/swag v0.22.3 // indirect - github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect + github.com/go-openapi/jsonpointer v0.21.0 // indirect + github.com/go-openapi/jsonreference v0.21.0 // indirect + github.com/go-openapi/swag v0.23.0 // indirect + github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/gnostic-models v0.6.8 // indirect github.com/google/gofuzz v1.2.0 // indirect - github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 // indirect - github.com/google/uuid v1.3.1 // indirect + github.com/google/pprof v0.0.0-20240424215950-a892ee059fd6 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/imdario/mergo v0.3.16 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect @@ -56,30 +57,29 @@ require ( github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/prometheus/client_model v0.5.0 // indirect + github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.45.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect github.com/spf13/pflag v1.0.5 // indirect go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect - golang.org/x/mod v0.16.0 // indirect - golang.org/x/net v0.23.0 // indirect - golang.org/x/oauth2 v0.12.0 // indirect - golang.org/x/sys v0.18.0 // indirect - golang.org/x/term v0.18.0 // indirect - golang.org/x/text v0.14.0 // indirect - golang.org/x/time v0.3.0 // indirect - golang.org/x/tools v0.19.0 // indirect + golang.org/x/exp v0.0.0-20240530194437-404ba88c7ed0 // indirect + golang.org/x/mod v0.17.0 // indirect + golang.org/x/net v0.25.0 // indirect + golang.org/x/oauth2 v0.20.0 // indirect + golang.org/x/sys v0.20.0 // indirect + golang.org/x/term v0.20.0 // indirect + golang.org/x/text v0.15.0 // indirect + golang.org/x/time v0.5.0 // indirect + golang.org/x/tools v0.21.0 // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect - google.golang.org/appengine v1.6.7 // indirect - google.golang.org/protobuf v1.33.0 // indirect + google.golang.org/protobuf v1.34.1 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect k8s.io/apiextensions-apiserver v0.29.2 // indirect - k8s.io/component-base v0.29.2 // indirect - k8s.io/gengo v0.0.0-20230829151522-9cce18d56c01 // indirect + k8s.io/component-base v0.29.5 // indirect + k8s.io/gengo v0.0.0-20240404160639-a0386bf69313 // indirect sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect ) diff --git a/go.sum b/go.sum index da8a571436..7e8820b975 100644 --- a/go.sum +++ b/go.sum @@ -1,41 +1,35 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= -github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= -github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g= -github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= -github.com/evanphx/json-patch v5.6.0+incompatible h1:jBYDEEiFBPxA0v50tFdvOzQQTCvpL6mnFh5mB2/l16U= -github.com/evanphx/json-patch v5.6.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= +github.com/emicklei/go-restful/v3 v3.12.1 h1:PJMDIM/ak7btuL8Ex0iYET9hxM3CI2sjZtzpL63nKAU= +github.com/emicklei/go-restful/v3 v3.12.1/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= +github.com/evanphx/json-patch v5.9.0+incompatible h1:fBXyNpNMuTTDdquAq/uisOr2lShz4oaXpDTX2bLe7ls= +github.com/evanphx/json-patch v5.9.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch/v5 v5.8.0 h1:lRj6N9Nci7MvzrXuX6HFzU8XjmhPiXPlsKEy1u0KQro= github.com/evanphx/json-patch/v5 v5.8.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/go-logr/logr v0.2.0/go.mod h1:z6/tIYblkpsD+a4lm/fGIIU9mZ+XfAiaFtq7xTgseGU= -github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= -github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= -github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/zapr v1.3.0 h1:XGdV8XW8zdwFiwOA2Dryh1gj2KRQyOOoNmBy4EplIcQ= github.com/go-logr/zapr v1.3.0/go.mod h1:YKepepNBd1u/oyhd/yQmtjVXmm9uML4IXUgMOwR8/Gg= -github.com/go-openapi/jsonpointer v0.19.6 h1:eCs3fxoIi3Wh6vtgmLTOjdhSpiqphQ+DaPn38N2ZdrE= -github.com/go-openapi/jsonpointer v0.19.6/go.mod h1:osyAmYz/mB/C3I+WsTTSgw1ONzaLJoLCyoi6/zppojs= -github.com/go-openapi/jsonreference v0.20.2 h1:3sVjiK66+uXK/6oQ8xgcRKcFgQ5KXa2KvnJRumpMGbE= -github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En5Ap4rVB5KVcIDZG2k= -github.com/go-openapi/swag v0.22.3 h1:yMBqmnQ0gyZvEb/+KzuWZOXgllrXT4SADYbvDaXHv/g= -github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= -github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= -github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= +github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= +github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= +github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ= +github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4= +github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= +github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= +github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= +github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/gnostic-models v0.6.8 h1:yo/ABAfM5IMRsS1VnXjTBvUb61tFIHozhlYvRgGre9I= @@ -48,11 +42,10 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec= -github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= -github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= -github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/google/pprof v0.0.0-20240424215950-a892ee059fd6 h1:k7nVchz72niMH6YLQNvHSdIE7iqsQxK1P41mySCvssg= +github.com/google/pprof v0.0.0-20240424215950-a892ee059fd6/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4= github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= @@ -62,7 +55,6 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -80,10 +72,10 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/onsi/ginkgo/v2 v2.17.1 h1:V++EzdbhI4ZV4ev0UTIj0PzhzOcReJFyJaLjtSF55M8= -github.com/onsi/ginkgo/v2 v2.17.1/go.mod h1:llBI3WDLL9Z6taip6f33H76YcWtJv+7R3HigUjbIBOs= -github.com/onsi/gomega v1.32.0 h1:JRYU78fJ1LPxlckP6Txi/EYqJvjtMrDC04/MM5XRHPk= -github.com/onsi/gomega v1.32.0/go.mod h1:a4x4gW6Pz2yK1MAmvluYme5lvYTn61afQ2ETw/8n4Lg= +github.com/onsi/ginkgo/v2 v2.19.0 h1:9Cnnf7UHo57Hy3k6/m5k3dRfGTMXGvxhHFvkDTCTpvA= +github.com/onsi/ginkgo/v2 v2.19.0/go.mod h1:rlwLi9PilAFJ8jCg9UE1QP6VBpd6/xj3SRC0d6TU0To= +github.com/onsi/gomega v1.33.1 h1:dsYjIxxSR755MDmKVsaFQTE22ChNBcuuTWgkUDSubOk= +github.com/onsi/gomega v1.33.1/go.mod h1:U4R44UsT+9eLIaYRB2a5qajjtQYn0hauxvRm16AVYg0= github.com/open-policy-agent/cert-controller v0.10.1 h1:RXSYoyn8FdCenWecRP//UV5nbVfmstNpj4kHQFkvPK4= github.com/open-policy-agent/cert-controller v0.10.1/go.mod h1:4uRbBLY5DsPOog+a9pqk3JLxuuhrWsbUedQW65HcLTI= github.com/open-policy-agent/frameworks/constraint v0.0.0-20230822235116-f0b62fe1e4c4 h1:5dum5SLEz+95JDLkMls7Z7IDPjvSq3UhJSFe4f5einQ= @@ -94,27 +86,21 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v1.18.0 h1:HzFfmkOzH5Q8L8G+kSJKUx5dtG87sewO+FoDDqP5Tbk= github.com/prometheus/client_golang v1.18.0/go.mod h1:T+GXkCk5wSJyOqMIzVgvvjFDlkOQntgjkJWKrN5txjA= -github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= -github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= github.com/prometheus/common v0.45.0 h1:2BGz0eBc2hdMDLnO/8n0jeB3oPrt2D08CekT0lneoxM= github.com/prometheus/common v0.45.0/go.mod h1:YJmSTw9BoKxJplESWWxlbyttQR4uaEcGyv9MZjVOJsY= github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= -github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= -github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= -github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= -github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -130,57 +116,54 @@ go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= +golang.org/x/exp v0.0.0-20240530194437-404ba88c7ed0 h1:Mi0bCswbz+9cXmwFAdxoo5GPFMKONUpua6iUdtQS7lk= +golang.org/x/exp v0.0.0-20240530194437-404ba88c7ed0/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= -golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= -golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= -golang.org/x/oauth2 v0.12.0 h1:smVPGxink+n1ZI5pkQa8y6fZT0RW0MgCO5bFpepy4B4= -golang.org/x/oauth2 v0.12.0/go.mod h1:A74bZ3aGXgCY0qaIC9Ahg6Lglin4AMAco8cIv9baba4= +golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/oauth2 v0.20.0 h1:4mQdhULixXKP1rwYBW0vAijoXnkTG0BLCDRzfe1idMo= +golang.org/x/oauth2 v0.20.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= -golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw= +golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= -golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200505023115-26f46d2f7ef8/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= -golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= +golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw= +golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gomodules.xyz/jsonpatch/v2 v2.4.0 h1:Ci3iUJyx9UeRx7CeFN8ARgGbkESwJK+KB9lLcWxY/Zw= gomodules.xyz/jsonpatch/v2 v2.4.0/go.mod h1:AH3dM2RI6uoBZxn3LVrfvJ3E0/9dG4cSrbuBJT4moAY= -google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= -google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= -google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= +google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= @@ -193,35 +176,37 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -k8s.io/api v0.29.3 h1:2ORfZ7+bGC3YJqGpV0KSDDEVf8hdGQ6A03/50vj8pmw= -k8s.io/api v0.29.3/go.mod h1:y2yg2NTyHUUkIoTC+phinTnEa3KFM6RZ3szxt014a80= +k8s.io/api v0.29.5 h1:levS+umUigHCfI3riD36pMY1vQEbrzh4r1ivVWAhHaI= +k8s.io/api v0.29.5/go.mod h1:7b18TtPcJzdjk7w5zWyIHgoAtpGeRvGGASxlS7UZXdQ= k8s.io/apiextensions-apiserver v0.29.2 h1:UK3xB5lOWSnhaCk0RFZ0LUacPZz9RY4wi/yt2Iu+btg= k8s.io/apiextensions-apiserver v0.29.2/go.mod h1:aLfYjpA5p3OwtqNXQFkhJ56TB+spV8Gc4wfMhUA3/b8= -k8s.io/apimachinery v0.29.3 h1:2tbx+5L7RNvqJjn7RIuIKu9XTsIZ9Z5wX2G22XAa5EU= -k8s.io/apimachinery v0.29.3/go.mod h1:hx/S4V2PNW4OMg3WizRrHutyB5la0iCUbZym+W0EQIU= -k8s.io/client-go v0.29.3 h1:R/zaZbEAxqComZ9FHeQwOh3Y1ZUs7FaHKZdQtIc2WZg= -k8s.io/client-go v0.29.3/go.mod h1:tkDisCvgPfiRpxGnOORfkljmS+UrW+WtXAy2fTvXJB0= -k8s.io/code-generator v0.29.3 h1:m7E25/t9R9NvejspO2zBdyu+/Gl0Z5m7dCRc680KS14= -k8s.io/code-generator v0.29.3/go.mod h1:x47ofBhN4gxYFcxeKA1PYXeaPreAGaDN85Y/lNUsPoM= -k8s.io/component-base v0.29.2 h1:lpiLyuvPA9yV1aQwGLENYyK7n/8t6l3nn3zAtFTJYe8= -k8s.io/component-base v0.29.2/go.mod h1:BfB3SLrefbZXiBfbM+2H1dlat21Uewg/5qtKOl8degM= -k8s.io/gengo v0.0.0-20230829151522-9cce18d56c01 h1:pWEwq4Asjm4vjW7vcsmijwBhOr1/shsbSYiWXmNGlks= -k8s.io/gengo v0.0.0-20230829151522-9cce18d56c01/go.mod h1:FiNAH4ZV3gBg2Kwh89tzAEV2be7d5xI0vBa/VySYy3E= +k8s.io/apimachinery v0.29.5 h1:Hofa2BmPfpoT+IyDTlcPdCHSnHtEQMoJYGVoQpRTfv4= +k8s.io/apimachinery v0.29.5/go.mod h1:i3FJVwhvSp/6n8Fl4K97PJEP8C+MM+aoDq4+ZJBf70Y= +k8s.io/client-go v0.29.5 h1:nlASXmPQy190qTteaVP31g3c/wi2kycznkTP7Sv1zPc= +k8s.io/client-go v0.29.5/go.mod h1:aY5CnqUUvXYccJhm47XHoPcRyX6vouHdIBHaKZGTbK4= +k8s.io/code-generator v0.29.5 h1:WqSdBPVV1B3jsPnKtPS39U02zj6Q7+FsjhAj1EPBJec= +k8s.io/code-generator v0.29.5/go.mod h1:7TYnI0dYItL2cKuhhgPSuF3WED9uMdELgbVXFfn/joE= +k8s.io/component-base v0.29.5 h1:Ptj8AzG+p8c2a839XriHwxakDpZH9uvIgYz+o1agjg8= +k8s.io/component-base v0.29.5/go.mod h1:9nBUoPxW/yimISIgAG7sJDrUGJlu7t8HnDafIrOdU8Q= +k8s.io/gengo v0.0.0-20240404160639-a0386bf69313 h1:wBIDZID8ju9pwOiLlV22YYKjFGtiNSWgHf5CnKLRUuM= +k8s.io/gengo v0.0.0-20240404160639-a0386bf69313/go.mod h1:FiNAH4ZV3gBg2Kwh89tzAEV2be7d5xI0vBa/VySYy3E= k8s.io/klog/v2 v2.2.0/go.mod h1:Od+F08eJP+W3HUb4pSrPpgp9DGU4GzlpG/TmITuYh/Y= -k8s.io/klog/v2 v2.110.1 h1:U/Af64HJf7FcwMcXyKm2RPM22WZzyR7OSpYj5tg3cL0= -k8s.io/klog/v2 v2.110.1/go.mod h1:YGtd1984u+GgbuZ7e08/yBuAfKLSO0+uR1Fhi6ExXjo= +k8s.io/klog/v2 v2.120.1 h1:QXU6cPEOIslTGvZaXvFWiP9VKyeet3sawzTOvdXb4Vw= +k8s.io/klog/v2 v2.120.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= k8s.io/kube-aggregator v0.28.1 h1:rvG4llYnQKHjj6YjjoBPEJxfD1uH0DJwkrJTNKGAaCs= k8s.io/kube-aggregator v0.28.1/go.mod h1:JaLizMe+AECSpO2OmrWVsvnG0V3dX1RpW+Wq/QHbu18= k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00 h1:aVUu9fTY98ivBPKR9Y5w/AuzbMm96cd3YHRTU83I780= k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00/go.mod h1:AsvuZPBlUDVuCdzJ87iajxtXuR9oktsTctW/R9wwouA= -k8s.io/utils v0.0.0-20230726121419-3b25d923346b h1:sgn3ZU783SCgtaSJjpcVVlRqd6GSnlTLKgpAAttJvpI= -k8s.io/utils v0.0.0-20230726121419-3b25d923346b/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= +k8s.io/utils v0.0.0-20240502163921-fe8a2dddb1d0 h1:jgGTlFYnhF1PM1Ax/lAlxUPE+KfCIXHaathvJg1C3ak= +k8s.io/utils v0.0.0-20240502163921-fe8a2dddb1d0/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= sigs.k8s.io/controller-runtime v0.17.3 h1:65QmN7r3FWgTxDMz9fvGnO1kbf2nu+acg9p2R9oYYYk= sigs.k8s.io/controller-runtime v0.17.3/go.mod h1:N0jpP5Lo7lMTF9aL56Z/B2oWBJjey6StQM0jRbKQXtY= sigs.k8s.io/jobset v0.5.2 h1:276q5Pi/ErLYj+GQ0ydEXR6tx3LwBhEzHLQv+k8bYF4= sigs.k8s.io/jobset v0.5.2/go.mod h1:Vg99rj/6OoGvy1uvywGEHOcVLCWWJYkJtisKqdWzcFw= sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo= sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0= +sigs.k8s.io/kueue v0.6.3 h1:PmccdKPDFQIaAboyuSG6M0w6hXtxVA51RV+DjCUtBtQ= +sigs.k8s.io/kueue v0.6.3/go.mod h1:rliYfK/K7pJ7CT4ReV1szzciNkAo3sBn5Bmr5Sn6uCY= sigs.k8s.io/scheduler-plugins v0.28.9 h1:1/bXRoXuSUFr1FLqxrzScdyZMl/G1psuDJcDKYxTo+Q= sigs.k8s.io/scheduler-plugins v0.28.9/go.mod h1:32+kIPGT0aTRsEDzKNga7zCbcCHK0dSk5UFCY+gzCLE= sigs.k8s.io/structured-merge-diff/v4 v4.4.1 h1:150L+0vs/8DA78h1u02ooW1/fFq/Lwr+sGiqlzvrtq4= diff --git a/hack/swagger/go.mod b/hack/swagger/go.mod index f45fedaf0b..6635b8079c 100644 --- a/hack/swagger/go.mod +++ b/hack/swagger/go.mod @@ -4,18 +4,18 @@ go 1.22 require ( github.com/kubeflow/training-operator v0.0.0-00010101000000-000000000000 - k8s.io/klog/v2 v2.110.1 + k8s.io/klog/v2 v2.120.1 k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00 ) replace github.com/kubeflow/training-operator => ../../ require ( - github.com/emicklei/go-restful/v3 v3.11.0 // indirect - github.com/go-logr/logr v1.4.1 // indirect - github.com/go-openapi/jsonpointer v0.19.6 // indirect - github.com/go-openapi/jsonreference v0.20.2 // indirect - github.com/go-openapi/swag v0.22.3 // indirect + github.com/emicklei/go-restful/v3 v3.12.1 // indirect + github.com/go-logr/logr v1.4.2 // indirect + github.com/go-openapi/jsonpointer v0.21.0 // indirect + github.com/go-openapi/jsonreference v0.21.0 // indirect + github.com/go-openapi/swag v0.23.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/gnostic-models v0.6.8 // indirect @@ -25,17 +25,17 @@ require ( github.com/mailru/easyjson v0.7.7 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/sirupsen/logrus v1.9.0 // indirect - golang.org/x/net v0.23.0 // indirect - golang.org/x/sys v0.18.0 // indirect - golang.org/x/text v0.14.0 // indirect - google.golang.org/protobuf v1.33.0 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + golang.org/x/net v0.25.0 // indirect + golang.org/x/sys v0.20.0 // indirect + golang.org/x/text v0.15.0 // indirect + google.golang.org/protobuf v1.34.1 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - k8s.io/api v0.29.3 // indirect - k8s.io/apimachinery v0.29.3 // indirect - k8s.io/utils v0.0.0-20230726121419-3b25d923346b // indirect + k8s.io/api v0.29.5 // indirect + k8s.io/apimachinery v0.29.5 // indirect + k8s.io/utils v0.0.0-20240502163921-fe8a2dddb1d0 // indirect sigs.k8s.io/controller-runtime v0.17.3 // indirect sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect diff --git a/hack/swagger/go.sum b/hack/swagger/go.sum index 307010aea6..7980dca80f 100644 --- a/hack/swagger/go.sum +++ b/hack/swagger/go.sum @@ -1,20 +1,19 @@ -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g= -github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= -github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= -github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= -github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= -github.com/go-openapi/jsonpointer v0.19.6 h1:eCs3fxoIi3Wh6vtgmLTOjdhSpiqphQ+DaPn38N2ZdrE= -github.com/go-openapi/jsonpointer v0.19.6/go.mod h1:osyAmYz/mB/C3I+WsTTSgw1ONzaLJoLCyoi6/zppojs= -github.com/go-openapi/jsonreference v0.20.2 h1:3sVjiK66+uXK/6oQ8xgcRKcFgQ5KXa2KvnJRumpMGbE= -github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En5Ap4rVB5KVcIDZG2k= -github.com/go-openapi/swag v0.22.3 h1:yMBqmnQ0gyZvEb/+KzuWZOXgllrXT4SADYbvDaXHv/g= -github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= +github.com/emicklei/go-restful/v3 v3.12.1 h1:PJMDIM/ak7btuL8Ex0iYET9hxM3CI2sjZtzpL63nKAU= +github.com/emicklei/go-restful/v3 v3.12.1/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= +github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= +github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ= +github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4= +github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= +github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= -github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= +github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= +github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= @@ -27,19 +26,16 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec= -github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/google/pprof v0.0.0-20240424215950-a892ee059fd6 h1:k7nVchz72niMH6YLQNvHSdIE7iqsQxK1P41mySCvssg= +github.com/google/pprof v0.0.0-20240424215950-a892ee059fd6/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= @@ -49,26 +45,21 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= -github.com/onsi/ginkgo/v2 v2.17.1 h1:V++EzdbhI4ZV4ev0UTIj0PzhzOcReJFyJaLjtSF55M8= -github.com/onsi/ginkgo/v2 v2.17.1/go.mod h1:llBI3WDLL9Z6taip6f33H76YcWtJv+7R3HigUjbIBOs= -github.com/onsi/gomega v1.32.0 h1:JRYU78fJ1LPxlckP6Txi/EYqJvjtMrDC04/MM5XRHPk= -github.com/onsi/gomega v1.32.0/go.mod h1:a4x4gW6Pz2yK1MAmvluYme5lvYTn61afQ2ETw/8n4Lg= +github.com/onsi/ginkgo/v2 v2.19.0 h1:9Cnnf7UHo57Hy3k6/m5k3dRfGTMXGvxhHFvkDTCTpvA= +github.com/onsi/ginkgo/v2 v2.19.0/go.mod h1:rlwLi9PilAFJ8jCg9UE1QP6VBpd6/xj3SRC0d6TU0To= +github.com/onsi/gomega v1.33.1 h1:dsYjIxxSR755MDmKVsaFQTE22ChNBcuuTWgkUDSubOk= +github.com/onsi/gomega v1.33.1/go.mod h1:U4R44UsT+9eLIaYRB2a5qajjtQYn0hauxvRm16AVYg0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= -github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= -github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= -github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -82,8 +73,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= -golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -91,24 +82,24 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= -golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= +golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw= +golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= -google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= +google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= @@ -120,16 +111,16 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -k8s.io/api v0.29.3 h1:2ORfZ7+bGC3YJqGpV0KSDDEVf8hdGQ6A03/50vj8pmw= -k8s.io/api v0.29.3/go.mod h1:y2yg2NTyHUUkIoTC+phinTnEa3KFM6RZ3szxt014a80= -k8s.io/apimachinery v0.29.3 h1:2tbx+5L7RNvqJjn7RIuIKu9XTsIZ9Z5wX2G22XAa5EU= -k8s.io/apimachinery v0.29.3/go.mod h1:hx/S4V2PNW4OMg3WizRrHutyB5la0iCUbZym+W0EQIU= -k8s.io/klog/v2 v2.110.1 h1:U/Af64HJf7FcwMcXyKm2RPM22WZzyR7OSpYj5tg3cL0= -k8s.io/klog/v2 v2.110.1/go.mod h1:YGtd1984u+GgbuZ7e08/yBuAfKLSO0+uR1Fhi6ExXjo= +k8s.io/api v0.29.5 h1:levS+umUigHCfI3riD36pMY1vQEbrzh4r1ivVWAhHaI= +k8s.io/api v0.29.5/go.mod h1:7b18TtPcJzdjk7w5zWyIHgoAtpGeRvGGASxlS7UZXdQ= +k8s.io/apimachinery v0.29.5 h1:Hofa2BmPfpoT+IyDTlcPdCHSnHtEQMoJYGVoQpRTfv4= +k8s.io/apimachinery v0.29.5/go.mod h1:i3FJVwhvSp/6n8Fl4K97PJEP8C+MM+aoDq4+ZJBf70Y= +k8s.io/klog/v2 v2.120.1 h1:QXU6cPEOIslTGvZaXvFWiP9VKyeet3sawzTOvdXb4Vw= +k8s.io/klog/v2 v2.120.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00 h1:aVUu9fTY98ivBPKR9Y5w/AuzbMm96cd3YHRTU83I780= k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00/go.mod h1:AsvuZPBlUDVuCdzJ87iajxtXuR9oktsTctW/R9wwouA= -k8s.io/utils v0.0.0-20230726121419-3b25d923346b h1:sgn3ZU783SCgtaSJjpcVVlRqd6GSnlTLKgpAAttJvpI= -k8s.io/utils v0.0.0-20230726121419-3b25d923346b/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= +k8s.io/utils v0.0.0-20240502163921-fe8a2dddb1d0 h1:jgGTlFYnhF1PM1Ax/lAlxUPE+KfCIXHaathvJg1C3ak= +k8s.io/utils v0.0.0-20240502163921-fe8a2dddb1d0/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= sigs.k8s.io/controller-runtime v0.17.3 h1:65QmN7r3FWgTxDMz9fvGnO1kbf2nu+acg9p2R9oYYYk= sigs.k8s.io/controller-runtime v0.17.3/go.mod h1:N0jpP5Lo7lMTF9aL56Z/B2oWBJjey6StQM0jRbKQXtY= sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo= diff --git a/pkg/controller.v2/setup.go b/pkg/controller.v2/setup.go index 79e89fa0c5..e2fadd3a96 100644 --- a/pkg/controller.v2/setup.go +++ b/pkg/controller.v2/setup.go @@ -16,13 +16,17 @@ limitations under the License. package controllerv2 -import ctrl "sigs.k8s.io/controller-runtime" +import ( + ctrl "sigs.k8s.io/controller-runtime" -func SetupControllers(mgr ctrl.Manager) (string, error) { + runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" +) + +func SetupControllers(mgr ctrl.Manager, runtimes map[string]runtime.Runtime) (string, error) { if err := NewTrainJobReconciler( mgr.GetClient(), mgr.GetEventRecorderFor("training-operator-trainjob-controller"), - ).SetupWithManager(mgr); err != nil { + ).SetupWithManager(mgr, runtimes); err != nil { return "TrainJob", err } return "", nil diff --git a/pkg/controller.v2/trainjob_controller.go b/pkg/controller.v2/trainjob_controller.go index e12cc3c2d7..ef2f3242ce 100644 --- a/pkg/controller.v2/trainjob_controller.go +++ b/pkg/controller.v2/trainjob_controller.go @@ -26,6 +26,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" ) type TrainJobReconciler struct { @@ -53,8 +54,15 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c return ctrl.Result{}, nil } -func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager) error { - return ctrl.NewControllerManagedBy(mgr). - For(&kubeflowv2.TrainJob{}). - Complete(r) +func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager, runtimes map[string]runtime.Runtime) error { + b := ctrl.NewControllerManagedBy(mgr). + For(&kubeflowv2.TrainJob{}) + for _, run := range runtimes { + for _, registrar := range run.EventHandlerRegistrars() { + if registrar != nil { + b = registrar(b, mgr.GetClient()) + } + } + } + return b.Complete(r) } diff --git a/pkg/runtime.v2/core/clustertrainingruntime.go b/pkg/runtime.v2/core/clustertrainingruntime.go new file mode 100644 index 0000000000..d4908af5f0 --- /dev/null +++ b/pkg/runtime.v2/core/clustertrainingruntime.go @@ -0,0 +1,62 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package core + +import ( + "context" + "errors" + "fmt" + + "k8s.io/apimachinery/pkg/runtime/schema" + "sigs.k8s.io/controller-runtime/pkg/client" + + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" +) + +var ( + errorNotFoundSpecifiedClusterTrainingRuntime = errors.New("not found ClusterTrainingRuntime specified in TrainJob") +) + +type ClusterTrainingRuntime struct { + *TrainingRuntime +} + +var _ runtime.Runtime = (*ClusterTrainingRuntime)(nil) + +var ClusterTrainingRuntimeGroupKind = schema.GroupKind{ + Group: kubeflowv2.GroupVersion.Group, + Kind: "ClusterTrainingRuntime", +}.String() + +func NewClusterTrainingRuntime(context.Context, client.Client, client.FieldIndexer) (runtime.Runtime, error) { + return &ClusterTrainingRuntime{ + TrainingRuntime: trainingRuntimeFactory, + }, nil +} + +func (r *ClusterTrainingRuntime) NewObjects(ctx context.Context, trainJob *kubeflowv2.TrainJob) ([]client.Object, error) { + var clTrainingRuntime kubeflowv2.ClusterTrainingRuntime + if err := r.client.Get(ctx, client.ObjectKey{Name: trainJob.Spec.TrainingRuntimeRef.Name}, &clTrainingRuntime); err != nil { + return nil, fmt.Errorf("%w: %w", errorNotFoundSpecifiedClusterTrainingRuntime, err) + } + return r.buildObjects(ctx, trainJob, clTrainingRuntime.Spec.Template, clTrainingRuntime.Spec.MLPolicy, clTrainingRuntime.Spec.PodGroupPolicy) +} + +func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder { + return nil +} diff --git a/pkg/runtime.v2/core/clustertrainingruntime_test.go b/pkg/runtime.v2/core/clustertrainingruntime_test.go new file mode 100644 index 0000000000..23697b748c --- /dev/null +++ b/pkg/runtime.v2/core/clustertrainingruntime_test.go @@ -0,0 +1,145 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package core + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" + schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" + + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + testingutil "github.com/kubeflow/training-operator/pkg/util.v2/testing" +) + +func TestClusterTrainingRuntimeNewObjects(t *testing.T) { + baseRuntime := testingutil.MakeClusterTrainingRuntimeWrapper(t, "test-runtime"). + Clone() + + cases := map[string]struct { + trainJob *kubeflowv2.TrainJob + clusterTrainingRuntime *kubeflowv2.ClusterTrainingRuntime + wantObjs []client.Object + wantError error + }{ + "succeeded to build JobSet and PodGroup": { + trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). + UID("uid"). + TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind("ClusterTrainingRuntime"), "test-runtime"). + Trainer( + testingutil.MakeTrainJobTrainerWrapper(t). + ContainerImage("test:trainjob"). + Obj(), + ). + Obj(), + clusterTrainingRuntime: baseRuntime.RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper(t, baseRuntime.Spec). + ContainerImage("test:runtime"). + PodGroupPolicySchedulingTimeout(120). + MLPolicyNumNodes(20). + ResourceRequests(0, corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + }). + ResourceRequests(1, corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2"), + }). + Obj(), + ).Obj(), + wantObjs: []client.Object{ + testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job"). + PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job"). + ContainerImage(ptr.To("test:trainjob")). + JobCompletionMode(batchv1.IndexedCompletion). + ResourceRequests(0, corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + }). + ResourceRequests(1, corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2"), + }). + ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid"). + Obj(), + testingutil.MakeSchedulerPluginsPodGroup(t, metav1.NamespaceDefault, "test-job"). + ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid"). + MinMember(40). + SchedulingTimeout(120). + MinResources(corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("60"), + }). + Obj(), + }, + }, + "missing trainingRuntime resource": { + trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). + UID("uid"). + TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind("ClusterTrainingRuntime"), "test-runtime"). + Trainer( + testingutil.MakeTrainJobTrainerWrapper(t). + ContainerImage("test:trainjob"). + Obj(), + ). + Obj(), + wantError: errorNotFoundSpecifiedClusterTrainingRuntime, + }, + } + cmpOpts := []cmp.Option{ + cmpopts.SortSlices(func(a, b client.Object) bool { + return a.GetObjectKind().GroupVersionKind().String() < b.GetObjectKind().GroupVersionKind().String() + }), + cmpopts.EquateEmpty(), + cmpopts.SortMaps(func(a, b string) bool { return a < b }), + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + clientBuilder := testingutil.NewClientBuilder() + if tc.clusterTrainingRuntime != nil { + clientBuilder.WithObjects(tc.clusterTrainingRuntime) + } + + trainingRuntime, err := NewTrainingRuntime(ctx, clientBuilder.Build(), testingutil.AsIndex(clientBuilder)) + if err != nil { + t.Fatal(err) + } + var ok bool + trainingRuntimeFactory, ok = trainingRuntime.(*TrainingRuntime) + if !ok { + t.Fatal("Failed type assertion from Runtime interface to TrainingRuntime") + } + + clTrainingRuntime, err := NewClusterTrainingRuntime(ctx, clientBuilder.Build(), testingutil.AsIndex(clientBuilder)) + if err != nil { + t.Fatal(err) + } + objs, err := clTrainingRuntime.NewObjects(ctx, tc.trainJob) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantObjs, objs, cmpOpts...); len(diff) != 0 { + t.Errorf("Unexpected objects (-want,+got):\n%s", diff) + } + }) + } +} diff --git a/pkg/runtime.v2/core/core.go b/pkg/runtime.v2/core/core.go new file mode 100644 index 0000000000..de37b3b4e2 --- /dev/null +++ b/pkg/runtime.v2/core/core.go @@ -0,0 +1,39 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package core + +import ( + "context" + "fmt" + + "sigs.k8s.io/controller-runtime/pkg/client" + + runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" +) + +func New(ctx context.Context, client client.Client, indexer client.FieldIndexer) (map[string]runtime.Runtime, error) { + registry := NewRuntimeRegistry() + runtimes := make(map[string]runtime.Runtime, len(registry)) + for name, factory := range registry { + r, err := factory(ctx, client, indexer) + if err != nil { + return nil, fmt.Errorf("initializing runtime %q: %w", name, err) + } + runtimes[name] = r + } + return runtimes, nil +} diff --git a/pkg/runtime.v2/core/registry.go b/pkg/runtime.v2/core/registry.go new file mode 100644 index 0000000000..9e912a481d --- /dev/null +++ b/pkg/runtime.v2/core/registry.go @@ -0,0 +1,34 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package core + +import ( + "context" + + "sigs.k8s.io/controller-runtime/pkg/client" + + runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" +) + +type Registry map[string]func(ctx context.Context, client client.Client, indexer client.FieldIndexer) (runtime.Runtime, error) + +func NewRuntimeRegistry() Registry { + return Registry{ + TrainingRuntimeGroupKind: NewTrainingRuntime, + ClusterTrainingRuntimeGroupKind: NewClusterTrainingRuntime, + } +} diff --git a/pkg/runtime.v2/core/trainingruntime.go b/pkg/runtime.v2/core/trainingruntime.go new file mode 100644 index 0000000000..879f95a04f --- /dev/null +++ b/pkg/runtime.v2/core/trainingruntime.go @@ -0,0 +1,121 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package core + +import ( + "context" + "errors" + "fmt" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + apiruntime "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" + jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" + + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" + fwkcore "github.com/kubeflow/training-operator/pkg/runtime.v2/framework/core" + fwkplugins "github.com/kubeflow/training-operator/pkg/runtime.v2/framework/plugins" + idxer "github.com/kubeflow/training-operator/pkg/runtime.v2/indexer" +) + +var ( + errorNotFoundSpecifiedTrainingRuntime = errors.New("not found TrainingRuntime specified in TrainJob") +) + +type TrainingRuntime struct { + framework *fwkcore.Framework + client client.Client + scheme *apiruntime.Scheme +} + +var TrainingRuntimeGroupKind = schema.GroupKind{ + Group: kubeflowv2.GroupVersion.Group, + Kind: "TrainingRuntime", +}.String() + +var _ runtime.Runtime = (*TrainingRuntime)(nil) + +var trainingRuntimeFactory *TrainingRuntime + +func NewTrainingRuntime(ctx context.Context, c client.Client, indexer client.FieldIndexer) (runtime.Runtime, error) { + if err := indexer.IndexField(ctx, &kubeflowv2.TrainJob{}, idxer.TrainJobTrainingRuntimeRefKey, idxer.IndexTrainJobTrainingRuntimes); err != nil { + return nil, fmt.Errorf("setting index on TrainingRuntime and ClusterTrainigRuntime for TrainJob: %w", err) + } + fwk, err := fwkcore.New(ctx, c, fwkplugins.NewRegistry(), indexer) + if err != nil { + return nil, err + } + trainingRuntimeFactory = &TrainingRuntime{ + framework: fwk, + client: c, + scheme: c.Scheme(), + } + return trainingRuntimeFactory, nil +} + +func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *kubeflowv2.TrainJob) ([]client.Object, error) { + var trainingRuntime kubeflowv2.TrainingRuntime + err := r.client.Get(ctx, client.ObjectKey{Namespace: trainJob.Namespace, Name: trainJob.Spec.TrainingRuntimeRef.Name}, &trainingRuntime) + if err != nil { + return nil, fmt.Errorf("%w: %w", errorNotFoundSpecifiedTrainingRuntime, err) + } + return r.buildObjects(ctx, trainJob, trainingRuntime.Spec.Template, trainingRuntime.Spec.MLPolicy, trainingRuntime.Spec.PodGroupPolicy) +} + +func (r *TrainingRuntime) buildObjects(ctx context.Context, trainJob *kubeflowv2.TrainJob, jobSetTemplateSpec kubeflowv2.JobSetTemplateSpec, + mlPolicy *kubeflowv2.MLPolicy, podGroupPolicy *kubeflowv2.PodGroupPolicy) ([]client.Object, error) { + opts := []runtime.InfoOption{ + runtime.WithLabels(jobSetTemplateSpec.Labels), + runtime.WithAnnotations(jobSetTemplateSpec.Annotations), + runtime.WithMLPolicy(mlPolicy), + runtime.WithPodGroupPolicy(podGroupPolicy), + } + for idx, rJob := range jobSetTemplateSpec.Spec.ReplicatedJobs { + if rJob.Replicas == 0 { + jobSetTemplateSpec.Spec.ReplicatedJobs[idx].Replicas = 1 + } + replicas := jobSetTemplateSpec.Spec.ReplicatedJobs[idx].Replicas * ptr.Deref(rJob.Template.Spec.Completions, 1) + opts = append(opts, runtime.WithPodSpecReplicas(rJob.Name, replicas, rJob.Template.Spec.Template.Spec)) + } + info := runtime.NewInfo(&jobsetv1alpha2.JobSet{ + TypeMeta: metav1.TypeMeta{ + APIVersion: jobsetv1alpha2.SchemeGroupVersion.String(), + Kind: "JobSet", + }, + Spec: *jobSetTemplateSpec.Spec.DeepCopy(), + }, opts...) + + if err := r.framework.RunEnforceMLPolicyPlugins(info); err != nil { + return nil, err + } + err := r.framework.RunEnforcePodGroupPolicyPlugins(trainJob, info) + if err != nil { + return nil, err + } + return r.framework.RunComponentBuilderPlugins(ctx, info, trainJob) +} + +func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder { + var builders []runtime.ReconcilerBuilder + for _, ex := range r.framework.WatchExtensionPlugins() { + builders = append(builders, ex.ReconcilerBuilders()...) + } + return builders +} diff --git a/pkg/runtime.v2/core/trainingruntime_test.go b/pkg/runtime.v2/core/trainingruntime_test.go new file mode 100644 index 0000000000..244fa88128 --- /dev/null +++ b/pkg/runtime.v2/core/trainingruntime_test.go @@ -0,0 +1,135 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package core + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" + schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" + + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + testingutil "github.com/kubeflow/training-operator/pkg/util.v2/testing" +) + +func TestTrainingRuntimeNewObjects(t *testing.T) { + baseRuntime := testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test-runtime"). + Clone() + + cases := map[string]struct { + trainJob *kubeflowv2.TrainJob + trainingRuntime *kubeflowv2.TrainingRuntime + wantObjs []client.Object + wantError error + }{ + "succeeded to build JobSet and PodGroup": { + trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). + UID("uid"). + TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind("TrainingRuntime"), "test-runtime"). + Trainer( + testingutil.MakeTrainJobTrainerWrapper(t). + ContainerImage("test:trainjob"). + Obj(), + ). + Obj(), + trainingRuntime: baseRuntime.RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper(t, baseRuntime.Spec). + ContainerImage("test:runtime"). + PodGroupPolicySchedulingTimeout(120). + MLPolicyNumNodes(20). + ResourceRequests(0, corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + }). + ResourceRequests(1, corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2"), + }). + Obj(), + ).Obj(), + wantObjs: []client.Object{ + testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job"). + PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job"). + ContainerImage(ptr.To("test:trainjob")). + JobCompletionMode(batchv1.IndexedCompletion). + ResourceRequests(0, corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + }). + ResourceRequests(1, corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2"), + }). + ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid"). + Obj(), + testingutil.MakeSchedulerPluginsPodGroup(t, metav1.NamespaceDefault, "test-job"). + ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid"). + MinMember(40). + SchedulingTimeout(120). + MinResources(corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("60"), + }). + Obj(), + }, + }, + "missing trainingRuntime resource": { + trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). + UID("uid"). + TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind("TrainingRuntime"), "test-runtime"). + Trainer( + testingutil.MakeTrainJobTrainerWrapper(t). + ContainerImage("test:trainjob"). + Obj(), + ). + Obj(), + wantError: errorNotFoundSpecifiedTrainingRuntime, + }, + } + cmpOpts := []cmp.Option{ + cmpopts.SortSlices(func(a, b client.Object) bool { + return a.GetObjectKind().GroupVersionKind().String() < b.GetObjectKind().GroupVersionKind().String() + }), + cmpopts.EquateEmpty(), + cmpopts.SortMaps(func(a, b string) bool { return a < b }), + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + clientBuilder := testingutil.NewClientBuilder() + if tc.trainingRuntime != nil { + clientBuilder.WithObjects(tc.trainingRuntime) + } + + trainingRuntime, err := NewTrainingRuntime(ctx, clientBuilder.Build(), testingutil.AsIndex(clientBuilder)) + if err != nil { + t.Fatal(err) + } + objs, err := trainingRuntime.NewObjects(ctx, tc.trainJob) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantObjs, objs, cmpOpts...); len(diff) != 0 { + t.Errorf("Unexpected objects (-want,+got):\n%s", diff) + } + }) + } +} diff --git a/pkg/runtime.v2/framework/core/framework.go b/pkg/runtime.v2/framework/core/framework.go new file mode 100644 index 0000000000..cb1b23d42e --- /dev/null +++ b/pkg/runtime.v2/framework/core/framework.go @@ -0,0 +1,126 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package core + +import ( + "context" + + "k8s.io/apimachinery/pkg/util/validation/field" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" + "github.com/kubeflow/training-operator/pkg/runtime.v2/framework" + fwkplugins "github.com/kubeflow/training-operator/pkg/runtime.v2/framework/plugins" +) + +type Framework struct { + registry fwkplugins.Registry + plugins map[string]framework.Plugin + enforceMLPlugins []framework.EnforceMLPolicyPlugin + enforcePodGroupPolicyPlugins []framework.EnforcePodGroupPolicyPlugin + customValidationPlugins []framework.CustomValidationPlugin + watchExtensionPlugins []framework.WatchExtensionPlugin + componentBuilderPlugins []framework.ComponentBuilderPlugin +} + +func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer client.FieldIndexer) (*Framework, error) { + f := &Framework{ + registry: r, + } + plugins := make(map[string]framework.Plugin, len(r)) + + for name, factory := range r { + plugin, err := factory(ctx, c, indexer) + if err != nil { + return nil, err + } + plugins[name] = plugin + if p, ok := plugin.(framework.EnforceMLPolicyPlugin); ok { + f.enforceMLPlugins = append(f.enforceMLPlugins, p) + } + if p, ok := plugin.(framework.EnforcePodGroupPolicyPlugin); ok { + f.enforcePodGroupPolicyPlugins = append(f.enforcePodGroupPolicyPlugins, p) + } + if p, ok := plugin.(framework.CustomValidationPlugin); ok { + f.customValidationPlugins = append(f.customValidationPlugins, p) + } + if p, ok := plugin.(framework.WatchExtensionPlugin); ok { + f.watchExtensionPlugins = append(f.watchExtensionPlugins, p) + } + if p, ok := plugin.(framework.ComponentBuilderPlugin); ok { + f.componentBuilderPlugins = append(f.componentBuilderPlugins, p) + } + } + f.plugins = plugins + return f, nil +} + +func (f *Framework) RunEnforceMLPolicyPlugins(info *runtime.Info) error { + for _, plugin := range f.enforceMLPlugins { + if err := plugin.EnforceMLPolicy(info); err != nil { + return err + } + } + return nil +} + +func (f *Framework) RunEnforcePodGroupPolicyPlugins(trainJob *kubeflowv2.TrainJob, info *runtime.Info) error { + for _, plugin := range f.enforcePodGroupPolicyPlugins { + if err := plugin.EnforcePodGroupPolicy(trainJob, info); err != nil { + return err + } + } + return nil +} + +func (f *Framework) RunCustomValidationPlugins(oldObj, newObj client.Object) (admission.Warnings, error) { + var aggregatedWarnings admission.Warnings + var aggregatedErrors field.ErrorList + for _, plugin := range f.customValidationPlugins { + warnings, errs := plugin.Validate(oldObj, newObj) + if len(warnings) != 0 { + aggregatedWarnings = append(aggregatedWarnings, warnings...) + } + if errs != nil { + aggregatedErrors = append(aggregatedErrors, errs...) + } + } + if len(aggregatedErrors) == 0 { + return aggregatedWarnings, nil + } + return aggregatedWarnings, aggregatedErrors.ToAggregate() +} + +func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, info *runtime.Info, trainJob *kubeflowv2.TrainJob) ([]client.Object, error) { + var objs []client.Object + for _, plugin := range f.componentBuilderPlugins { + obj, err := plugin.Build(ctx, info, trainJob) + if err != nil { + return nil, err + } + if obj != nil { + objs = append(objs, obj) + } + } + return objs, nil +} + +func (f *Framework) WatchExtensionPlugins() []framework.WatchExtensionPlugin { + return f.watchExtensionPlugins +} diff --git a/pkg/runtime.v2/framework/core/framework_test.go b/pkg/runtime.v2/framework/core/framework_test.go new file mode 100644 index 0000000000..141d7995fb --- /dev/null +++ b/pkg/runtime.v2/framework/core/framework_test.go @@ -0,0 +1,514 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package core + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/meta" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + apiruntime "k8s.io/apimachinery/pkg/runtime" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" + + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" + "github.com/kubeflow/training-operator/pkg/runtime.v2/framework" + fwkplugins "github.com/kubeflow/training-operator/pkg/runtime.v2/framework/plugins" + "github.com/kubeflow/training-operator/pkg/runtime.v2/framework/plugins/coscheduling" + "github.com/kubeflow/training-operator/pkg/runtime.v2/framework/plugins/jobset" + "github.com/kubeflow/training-operator/pkg/runtime.v2/framework/plugins/mpi" + "github.com/kubeflow/training-operator/pkg/runtime.v2/framework/plugins/plainml" + "github.com/kubeflow/training-operator/pkg/runtime.v2/framework/plugins/torch" + testingutil "github.com/kubeflow/training-operator/pkg/util.v2/testing" +) + +// TODO: We should introduce mock plugins and use plugins in this framework testing. +// After we migrate the actual plugins to mock one for testing data, +// we can delegate the actual plugin testing to each plugin directories, and implement detailed unit testing. + +func TestNew(t *testing.T) { + cases := map[string]struct { + registry fwkplugins.Registry + emptyCoSchedulingIndexerTrainingRuntimeContainerRuntimeClassKey bool + emptyCoSchedulingIndexerClusterTrainingRuntimeContainerRuntimeClassKey bool + wantFramework *Framework + wantError error + }{ + "positive case": { + registry: fwkplugins.NewRegistry(), + wantFramework: &Framework{ + registry: fwkplugins.NewRegistry(), + plugins: map[string]framework.Plugin{ + coscheduling.Name: &coscheduling.CoScheduling{}, + mpi.Name: &mpi.MPI{}, + plainml.Name: &plainml.PlainML{}, + torch.Name: &torch.Torch{}, + jobset.Name: &jobset.JobSet{}, + }, + enforceMLPlugins: []framework.EnforceMLPolicyPlugin{ + &mpi.MPI{}, + &plainml.PlainML{}, + &torch.Torch{}, + }, + enforcePodGroupPolicyPlugins: []framework.EnforcePodGroupPolicyPlugin{ + &coscheduling.CoScheduling{}, + }, + customValidationPlugins: []framework.CustomValidationPlugin{ + &mpi.MPI{}, + &torch.Torch{}, + }, + watchExtensionPlugins: []framework.WatchExtensionPlugin{ + &coscheduling.CoScheduling{}, + &jobset.JobSet{}, + }, + componentBuilderPlugins: []framework.ComponentBuilderPlugin{ + &coscheduling.CoScheduling{}, + &jobset.JobSet{}, + }, + }, + }, + "indexer key for trainingRuntime and runtimeClass is an empty": { + registry: fwkplugins.Registry{ + coscheduling.Name: coscheduling.New, + }, + emptyCoSchedulingIndexerTrainingRuntimeContainerRuntimeClassKey: true, + wantError: coscheduling.ErrorCanNotSetupTrainingRuntimeRuntimeClassIndexer, + }, + "indexer key for clusterTrainingRuntime and runtimeClass is an empty": { + registry: fwkplugins.Registry{ + coscheduling.Name: coscheduling.New, + }, + emptyCoSchedulingIndexerClusterTrainingRuntimeContainerRuntimeClassKey: true, + wantError: coscheduling.ErrorCanNotSetupClusterTrainingRuntimeRuntimeClassIndexer, + }, + } + cmpOpts := []cmp.Option{ + cmp.AllowUnexported(Framework{}), + cmpopts.IgnoreUnexported(coscheduling.CoScheduling{}, mpi.MPI{}, plainml.PlainML{}, torch.Torch{}, jobset.JobSet{}), + cmpopts.IgnoreFields(coscheduling.CoScheduling{}, "client"), + cmpopts.IgnoreFields(jobset.JobSet{}, "client"), + cmpopts.IgnoreTypes(apiruntime.Scheme{}, meta.DefaultRESTMapper{}, fwkplugins.Registry{}), + cmpopts.SortMaps(func(a, b string) bool { return a < b }), + cmpopts.SortSlices(func(a, b framework.Plugin) bool { return a.Name() < b.Name() }), + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + if tc.emptyCoSchedulingIndexerTrainingRuntimeContainerRuntimeClassKey { + originTrainingRuntimeRuntimeKey := coscheduling.TrainingRuntimeContainerRuntimeClassKey + coscheduling.TrainingRuntimeContainerRuntimeClassKey = "" + t.Cleanup(func() { + coscheduling.TrainingRuntimeContainerRuntimeClassKey = originTrainingRuntimeRuntimeKey + }) + } + if tc.emptyCoSchedulingIndexerClusterTrainingRuntimeContainerRuntimeClassKey { + originClusterTrainingRuntimeKey := coscheduling.ClusterTrainingRuntimeContainerRuntimeClassKey + coscheduling.ClusterTrainingRuntimeContainerRuntimeClassKey = "" + t.Cleanup(func() { + coscheduling.ClusterTrainingRuntimeContainerRuntimeClassKey = originClusterTrainingRuntimeKey + }) + } + clientBuilder := testingutil.NewClientBuilder() + fwk, err := New(ctx, clientBuilder.Build(), tc.registry, testingutil.AsIndex(clientBuilder)) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected errors (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantFramework, fwk, cmpOpts...); len(diff) != 0 { + t.Errorf("Unexpected framework (-want,+got):\n%s", diff) + } + }) + } +} + +func TestRunEnforceMLPolicyPlugins(t *testing.T) { + cases := map[string]struct { + registry fwkplugins.Registry + runtimeInfo *runtime.Info + wantRuntimeInfo *runtime.Info + wantError error + }{ + "plainml MLPolicy is applied to runtime.Info": { + registry: fwkplugins.NewRegistry(), + runtimeInfo: &runtime.Info{ + Policy: runtime.Policy{ + MLPolicy: &kubeflowv2.MLPolicy{ + NumNodes: ptr.To[int32](100), + }, + }, + TotalRequests: map[string]runtime.TotalResourceRequest{ + "Coordinator": {Replicas: 1}, + "Worker": {Replicas: 10}, + }, + }, + wantRuntimeInfo: &runtime.Info{ + Policy: runtime.Policy{ + MLPolicy: &kubeflowv2.MLPolicy{ + NumNodes: ptr.To[int32](100), + }, + }, + TotalRequests: map[string]runtime.TotalResourceRequest{ + "Coordinator": {Replicas: 100}, + "Worker": {Replicas: 100}, + }, + }, + }, + "registry is empty": { + runtimeInfo: &runtime.Info{ + Policy: runtime.Policy{ + MLPolicy: &kubeflowv2.MLPolicy{ + NumNodes: ptr.To[int32](100), + }, + }, + TotalRequests: map[string]runtime.TotalResourceRequest{ + "Coordinator": {Replicas: 1}, + "Worker": {Replicas: 10}, + }, + }, + wantRuntimeInfo: &runtime.Info{ + Policy: runtime.Policy{ + MLPolicy: &kubeflowv2.MLPolicy{ + NumNodes: ptr.To[int32](100), + }, + }, + TotalRequests: map[string]runtime.TotalResourceRequest{ + "Coordinator": {Replicas: 1}, + "Worker": {Replicas: 10}, + }, + }, + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + clientBuilder := testingutil.NewClientBuilder() + + fwk, err := New(ctx, clientBuilder.Build(), tc.registry, testingutil.AsIndex(clientBuilder)) + if err != nil { + t.Fatal(err) + } + err = fwk.RunEnforceMLPolicyPlugins(tc.runtimeInfo) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error (-want,+got): %s", diff) + } + if diff := cmp.Diff(tc.wantRuntimeInfo, tc.runtimeInfo, cmpopts.EquateEmpty()); len(diff) != 0 { + t.Errorf("Unexpected runtime.Info (-want,+got): %s", diff) + } + }) + } +} + +func TestRunEnforcePodGroupPolicyPlugins(t *testing.T) { + cases := map[string]struct { + trainJob *kubeflowv2.TrainJob + registry fwkplugins.Registry + runtimeInfo *runtime.Info + wantRuntimeInfo *runtime.Info + wantError error + }{ + "coscheduling plugin is applied to runtime.Info": { + trainJob: &kubeflowv2.TrainJob{ObjectMeta: metav1.ObjectMeta{Name: "test-job", Namespace: metav1.NamespaceDefault}}, + registry: fwkplugins.NewRegistry(), + runtimeInfo: &runtime.Info{ + PodLabels: make(map[string]string), + Policy: runtime.Policy{ + PodGroupPolicy: &kubeflowv2.PodGroupPolicy{}, + }, + }, + wantRuntimeInfo: &runtime.Info{ + PodLabels: map[string]string{ + schedulerpluginsv1alpha1.PodGroupLabel: "test-job", + }, + Policy: runtime.Policy{ + PodGroupPolicy: &kubeflowv2.PodGroupPolicy{}, + }, + }, + }, + "an empty registry": { + trainJob: &kubeflowv2.TrainJob{ObjectMeta: metav1.ObjectMeta{Name: "test-job", Namespace: metav1.NamespaceDefault}}, + runtimeInfo: &runtime.Info{ + Policy: runtime.Policy{ + PodGroupPolicy: &kubeflowv2.PodGroupPolicy{}, + }, + }, + wantRuntimeInfo: &runtime.Info{ + Policy: runtime.Policy{ + PodGroupPolicy: &kubeflowv2.PodGroupPolicy{}, + }, + }, + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + clientBuilder := testingutil.NewClientBuilder() + + fwk, err := New(ctx, clientBuilder.Build(), tc.registry, testingutil.AsIndex(clientBuilder)) + if err != nil { + t.Fatal(err) + } + err = fwk.RunEnforcePodGroupPolicyPlugins(tc.trainJob, tc.runtimeInfo) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error (-want,+got): %s", diff) + } + if diff := cmp.Diff(tc.wantRuntimeInfo, tc.runtimeInfo); len(diff) != 0 { + t.Errorf("Unexpected runtime.Info (-want,+got): %s", diff) + } + }) + } +} + +func TestRunCustomValidationPlugins(t *testing.T) { + cases := map[string]struct { + trainJob *kubeflowv2.TrainJob + registry fwkplugins.Registry + oldObj client.Object + newObj client.Object + wantWarnings admission.Warnings + wantError error + }{ + // Need to implement more detail testing after we implement custom validator in any plugins. + "there are not any custom validations": { + trainJob: &kubeflowv2.TrainJob{ObjectMeta: metav1.ObjectMeta{Name: "test-job", Namespace: metav1.NamespaceDefault}}, + registry: fwkplugins.NewRegistry(), + oldObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(), + newObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(), + }, + "an empty registry": { + trainJob: &kubeflowv2.TrainJob{ObjectMeta: metav1.ObjectMeta{Name: "test-job", Namespace: metav1.NamespaceDefault}}, + oldObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(), + newObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(), + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + clientBuildr := testingutil.NewClientBuilder() + + fwk, err := New(ctx, clientBuildr.Build(), tc.registry, testingutil.AsIndex(clientBuildr)) + if err != nil { + t.Fatal(err) + } + warnings, err := fwk.RunCustomValidationPlugins(tc.oldObj, tc.newObj) + if diff := cmp.Diff(tc.wantWarnings, warnings, cmpopts.SortSlices(func(a, b string) bool { return a < b })); len(diff) != 0 { + t.Errorf("Unexpected warninigs (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error (-want,+got):\n%s", diff) + } + }) + } +} + +func TestRunComponentBuilderPlugins(t *testing.T) { + jobSetBase := testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job"). + ResourceRequests(0, corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2"), + corev1.ResourceMemory: resource.MustParse("4Gi"), + }). + ResourceRequests(1, corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + corev1.ResourceMemory: resource.MustParse("2Gi"), + }). + Clone() + jobSetWithPropagatedTrainJobParams := jobSetBase. + JobCompletionMode(batchv1.IndexedCompletion). + ContainerImage(ptr.To("foo:bar")). + ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid"). + Clone() + + cases := map[string]struct { + runtimeInfo *runtime.Info + trainJob *kubeflowv2.TrainJob + registry fwkplugins.Registry + wantError error + wantRuntimeInfo *runtime.Info + wantObjs []client.Object + }{ + "coscheduling and jobset are performed": { + trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). + UID("uid"). + Trainer( + testingutil.MakeTrainJobTrainerWrapper(t). + ContainerImage("foo:bar"). + Obj(), + ). + Obj(), + runtimeInfo: &runtime.Info{ + Obj: jobSetBase. + Obj(), + Policy: runtime.Policy{ + MLPolicy: &kubeflowv2.MLPolicy{ + NumNodes: ptr.To[int32](10), + }, + PodGroupPolicy: &kubeflowv2.PodGroupPolicy{ + PodGroupPolicySource: kubeflowv2.PodGroupPolicySource{ + Coscheduling: &kubeflowv2.CoschedulingPodGroupPolicySource{ + ScheduleTimeoutSeconds: ptr.To[int32](300), + }, + }, + }, + }, + TotalRequests: map[string]runtime.TotalResourceRequest{ + "Coordinator": { + Replicas: 1, + PodRequests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2"), + corev1.ResourceMemory: resource.MustParse("4Gi"), + }, + }, + "Worker": { + Replicas: 1, + PodRequests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + corev1.ResourceMemory: resource.MustParse("2Gi"), + }, + }, + }, + }, + registry: fwkplugins.NewRegistry(), + wantObjs: []client.Object{ + testingutil.MakeSchedulerPluginsPodGroup(t, metav1.NamespaceDefault, "test-job"). + SchedulingTimeout(300). + MinMember(20). + MinResources(corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("30"), + corev1.ResourceMemory: resource.MustParse("60Gi"), + }). + ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid"). + Obj(), + jobSetWithPropagatedTrainJobParams. + Obj(), + }, + wantRuntimeInfo: &runtime.Info{ + Obj: jobSetWithPropagatedTrainJobParams. + Obj(), + Policy: runtime.Policy{ + MLPolicy: &kubeflowv2.MLPolicy{ + NumNodes: ptr.To[int32](10), + }, + PodGroupPolicy: &kubeflowv2.PodGroupPolicy{ + PodGroupPolicySource: kubeflowv2.PodGroupPolicySource{ + Coscheduling: &kubeflowv2.CoschedulingPodGroupPolicySource{ + ScheduleTimeoutSeconds: ptr.To[int32](300), + }, + }, + }, + }, + TotalRequests: map[string]runtime.TotalResourceRequest{ + "Coordinator": { + Replicas: 10, + PodRequests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2"), + corev1.ResourceMemory: resource.MustParse("4Gi"), + }, + }, + "Worker": { + Replicas: 10, + PodRequests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + corev1.ResourceMemory: resource.MustParse("2Gi"), + }, + }, + }, + }, + }, + "an empty registry": {}, + } + cmpOpts := []cmp.Option{ + cmpopts.SortSlices(func(a, b client.Object) bool { + return a.GetObjectKind().GroupVersionKind().String() < b.GetObjectKind().GroupVersionKind().String() + }), + cmpopts.EquateEmpty(), + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + clientBuilder := testingutil.NewClientBuilder() + + fwk, err := New(ctx, clientBuilder.Build(), tc.registry, testingutil.AsIndex(clientBuilder)) + if err != nil { + t.Fatal(err) + } + if err = fwk.RunEnforceMLPolicyPlugins(tc.runtimeInfo); err != nil { + t.Fatal(err) + } + objs, err := fwk.RunComponentBuilderPlugins(ctx, tc.runtimeInfo, tc.trainJob) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected errors (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantRuntimeInfo, tc.runtimeInfo); len(diff) != 0 { + t.Errorf("Unexpected runtime.Info (-want,+got)\n%s", diff) + } + if diff := cmp.Diff(tc.wantObjs, objs, cmpOpts...); len(diff) != 0 { + t.Errorf("Unexpected objects (-want,+got):\n%s", diff) + } + }) + } +} + +func TestRunExtensionPlugins(t *testing.T) { + cases := map[string]struct { + registry fwkplugins.Registry + wantPlugins []framework.WatchExtensionPlugin + }{ + "coscheding and jobset are performed": { + registry: fwkplugins.NewRegistry(), + wantPlugins: []framework.WatchExtensionPlugin{ + &coscheduling.CoScheduling{}, + &jobset.JobSet{}, + }, + }, + "an empty registry": { + wantPlugins: nil, + }, + } + cmpOpts := []cmp.Option{ + cmpopts.SortSlices(func(a, b framework.Plugin) bool { return a.Name() < b.Name() }), + cmpopts.IgnoreUnexported(coscheduling.CoScheduling{}, jobset.JobSet{}), + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + clientBuilder := testingutil.NewClientBuilder() + + fwk, err := New(ctx, clientBuilder.Build(), tc.registry, testingutil.AsIndex(clientBuilder)) + if err != nil { + t.Fatal(err) + } + plugins := fwk.WatchExtensionPlugins() + if diff := cmp.Diff(tc.wantPlugins, plugins, cmpOpts...); len(diff) != 0 { + t.Errorf("Unexpected plugins (-want,+got):\n%s", diff) + } + }) + } +} diff --git a/pkg/runtime.v2/framework/interface.go b/pkg/runtime.v2/framework/interface.go new file mode 100644 index 0000000000..886c1ab39c --- /dev/null +++ b/pkg/runtime.v2/framework/interface.go @@ -0,0 +1,57 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package framework + +import ( + "context" + + "k8s.io/apimachinery/pkg/util/validation/field" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" +) + +type Plugin interface { + Name() string +} + +type WatchExtensionPlugin interface { + Plugin + ReconcilerBuilders() []runtime.ReconcilerBuilder +} + +type EnforcePodGroupPolicyPlugin interface { + Plugin + EnforcePodGroupPolicy(trainJob *kubeflowv2.TrainJob, info *runtime.Info) error +} + +type EnforceMLPolicyPlugin interface { + Plugin + EnforceMLPolicy(info *runtime.Info) error +} + +type CustomValidationPlugin interface { + Plugin + Validate(oldObj, newObj client.Object) (admission.Warnings, field.ErrorList) +} + +type ComponentBuilderPlugin interface { + Plugin + Build(ctx context.Context, info *runtime.Info, trainJob *kubeflowv2.TrainJob) (client.Object, error) +} diff --git a/pkg/runtime.v2/framework/plugins/coscheduling/coscheduling.go b/pkg/runtime.v2/framework/plugins/coscheduling/coscheduling.go new file mode 100644 index 0000000000..721a756dae --- /dev/null +++ b/pkg/runtime.v2/framework/plugins/coscheduling/coscheduling.go @@ -0,0 +1,308 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package coscheduling + +import ( + "context" + "errors" + "fmt" + "maps" + + corev1 "k8s.io/api/core/v1" + nodev1 "k8s.io/api/node/v1" + "k8s.io/apimachinery/pkg/api/equality" + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/api/meta" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + apiruntime "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/util/workqueue" + "k8s.io/klog/v2" + "k8s.io/utils/ptr" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/builder" + "sigs.k8s.io/controller-runtime/pkg/client" + ctrlutil "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + "sigs.k8s.io/controller-runtime/pkg/event" + "sigs.k8s.io/controller-runtime/pkg/handler" + schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" + + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" + "github.com/kubeflow/training-operator/pkg/runtime.v2/framework" + runtimeindexer "github.com/kubeflow/training-operator/pkg/runtime.v2/indexer" +) + +type CoScheduling struct { + client client.Client + restMapper meta.RESTMapper + scheme *apiruntime.Scheme +} + +var _ framework.EnforcePodGroupPolicyPlugin = (*CoScheduling)(nil) +var _ framework.WatchExtensionPlugin = (*CoScheduling)(nil) +var _ framework.ComponentBuilderPlugin = (*CoScheduling)(nil) + +var ( + ErrorCanNotSetupTrainingRuntimeRuntimeClassIndexer = errors.New("setting index on runtimeClass for TrainingRuntime") + ErrorCanNotSetupClusterTrainingRuntimeRuntimeClassIndexer = errors.New("setting index on runtimeClass for ClusterTrainingRuntime") +) + +const Name = "CoScheduling" + +func New(ctx context.Context, c client.Client, indexer client.FieldIndexer) (framework.Plugin, error) { + if err := indexer.IndexField(ctx, &kubeflowv2.TrainingRuntime{}, TrainingRuntimeContainerRuntimeClassKey, + IndexTrainingRuntimeContainerRuntimeClass); err != nil { + return nil, fmt.Errorf("%w: %w", ErrorCanNotSetupTrainingRuntimeRuntimeClassIndexer, err) + } + if err := indexer.IndexField(ctx, &kubeflowv2.ClusterTrainingRuntime{}, ClusterTrainingRuntimeContainerRuntimeClassKey, + IndexClusterTrainingRuntimeContainerRuntimeClass); err != nil { + return nil, fmt.Errorf("%w: %w", ErrorCanNotSetupClusterTrainingRuntimeRuntimeClassIndexer, err) + } + return &CoScheduling{ + client: c, + restMapper: c.RESTMapper(), + scheme: c.Scheme(), + }, nil +} + +func (c *CoScheduling) Name() string { + return Name +} + +func (c *CoScheduling) EnforcePodGroupPolicy(trainJob *kubeflowv2.TrainJob, info *runtime.Info) error { + if info == nil || info.PodGroupPolicy == nil || trainJob == nil { + return nil + } + if info.PodLabels == nil { + info.PodLabels = make(map[string]string, 1) + } + info.PodLabels[schedulerpluginsv1alpha1.PodGroupLabel] = trainJob.Name + return nil +} + +func (c *CoScheduling) Build(ctx context.Context, info *runtime.Info, trainJob *kubeflowv2.TrainJob) (client.Object, error) { + if info == nil || info.PodGroupPolicy == nil || info.PodGroupPolicy.Coscheduling == nil || trainJob == nil { + return nil, nil + } + + var totalMembers int32 + totalResources := make(corev1.ResourceList) + for _, resourceRequests := range info.TotalRequests { + totalMembers += resourceRequests.Replicas + for resName, quantity := range resourceRequests.PodRequests { + quantity.Mul(int64(resourceRequests.Replicas)) + current := totalResources[resName] + current.Add(quantity) + totalResources[resName] = current + } + } + newPG := &schedulerpluginsv1alpha1.PodGroup{ + TypeMeta: metav1.TypeMeta{ + APIVersion: schedulerpluginsv1alpha1.SchemeGroupVersion.String(), + Kind: "PodGroup", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: trainJob.Name, + Namespace: trainJob.Namespace, + }, + Spec: schedulerpluginsv1alpha1.PodGroupSpec{ + ScheduleTimeoutSeconds: info.PodGroupPolicy.Coscheduling.ScheduleTimeoutSeconds, + MinMember: totalMembers, + MinResources: totalResources, + }, + } + if err := ctrlutil.SetControllerReference(trainJob, newPG, c.scheme); err != nil { + return nil, err + } + oldPG := &schedulerpluginsv1alpha1.PodGroup{} + if err := c.client.Get(ctx, client.ObjectKeyFromObject(newPG), oldPG); err != nil { + if !apierrors.IsNotFound(err) { + return nil, err + } + oldPG = nil + } + if needsCreateOrUpdate(oldPG, newPG, ptr.Deref(trainJob.Spec.Suspend, false)) { + return newPG, nil + } + return nil, nil +} + +func needsCreateOrUpdate(old, new *schedulerpluginsv1alpha1.PodGroup, suspended bool) bool { + return old == nil || + suspended && (!equality.Semantic.DeepEqual(old.Spec, new.Spec) || !maps.Equal(old.Labels, new.Labels) || !maps.Equal(old.Annotations, new.Annotations)) +} + +type PodGroupRuntimeClassHandler struct { + client client.Client +} + +var _ handler.EventHandler = (*PodGroupRuntimeClassHandler)(nil) + +func (h *PodGroupRuntimeClassHandler) Create(ctx context.Context, e event.CreateEvent, q workqueue.RateLimitingInterface) { + containerRuntimeClass, ok := e.Object.(*nodev1.RuntimeClass) + if !ok { + return + } + log := ctrl.LoggerFrom(ctx).WithValues("runtimeClass", klog.KObj(containerRuntimeClass)) + if err := h.queueSuspendedTrainJob(ctx, containerRuntimeClass, q); err != nil { + log.Error(err, "could not queue suspended TrainJob to reconcile queue") + } +} + +func (h *PodGroupRuntimeClassHandler) Update(ctx context.Context, e event.UpdateEvent, q workqueue.RateLimitingInterface) { + _, ok := e.ObjectOld.(*nodev1.RuntimeClass) + if !ok { + return + } + newContainerRuntimeClass, ok := e.ObjectNew.(*nodev1.RuntimeClass) + if !ok { + return + } + log := ctrl.LoggerFrom(ctx).WithValues("runtimeClass", klog.KObj(newContainerRuntimeClass)) + if err := h.queueSuspendedTrainJob(ctx, newContainerRuntimeClass, q); err != nil { + log.Error(err, "could not queue suspended TrainJob to reconcile queue") + } +} + +func (h *PodGroupRuntimeClassHandler) Delete(ctx context.Context, e event.DeleteEvent, q workqueue.RateLimitingInterface) { + containerRuntimeClass, ok := e.Object.(*nodev1.RuntimeClass) + if !ok { + return + } + log := ctrl.LoggerFrom(ctx).WithValues("runtimeClass", klog.KObj(containerRuntimeClass)) + if err := h.queueSuspendedTrainJob(ctx, containerRuntimeClass, q); err != nil { + log.Error(err, "could not queue suspended TrainJob to reconcile queue") + } +} + +func (h *PodGroupRuntimeClassHandler) Generic(context.Context, event.GenericEvent, workqueue.RateLimitingInterface) { +} + +func (h *PodGroupRuntimeClassHandler) queueSuspendedTrainJob(ctx context.Context, runtimeClass *nodev1.RuntimeClass, q workqueue.RateLimitingInterface) error { + var trainingRuntimes kubeflowv2.TrainingRuntimeList + if err := h.client.List(ctx, &trainingRuntimes, client.MatchingFields{TrainingRuntimeContainerRuntimeClassKey: runtimeClass.Name}); err != nil { + return err + } + var clusterTrainingRuntimes kubeflowv2.ClusterTrainingRuntimeList + if err := h.client.List(ctx, &clusterTrainingRuntimes, client.MatchingFields{ClusterTrainingRuntimeContainerRuntimeClassKey: runtimeClass.Name}); err != nil { + return err + } + + var runtimeNames []string + for _, trainingRuntime := range trainingRuntimes.Items { + runtimeNames = append(runtimeNames, trainingRuntime.Name) + } + for _, clusterTrainingRuntime := range clusterTrainingRuntimes.Items { + runtimeNames = append(runtimeNames, clusterTrainingRuntime.Name) + } + for _, runtimeName := range runtimeNames { + var trainJobs kubeflowv2.TrainJobList + if err := h.client.List(ctx, &trainJobs, client.MatchingFields{runtimeindexer.TrainJobTrainingRuntimeRefKey: runtimeName}); err != nil { + return err + } + for _, trainJob := range trainJobs.Items { + if ptr.Deref(trainJob.Spec.Suspend, false) { + q.Add(client.ObjectKeyFromObject(&trainJob)) + } + } + } + return nil +} + +type PodGroupLimitRangeHandler struct { + client client.Client +} + +var _ handler.EventHandler = (*PodGroupLimitRangeHandler)(nil) + +func (h *PodGroupLimitRangeHandler) Create(ctx context.Context, e event.CreateEvent, q workqueue.RateLimitingInterface) { + limitRange, ok := e.Object.(*corev1.LimitRange) + if !ok { + return + } + log := ctrl.LoggerFrom(ctx).WithValues("limitRange", klog.KObj(limitRange)) + if err := h.queueSuspendedTrainJob(ctx, limitRange.Namespace, q); err != nil { + log.Error(err, "could not queue suspended TrainJob to reconcile queue") + } +} + +func (h *PodGroupLimitRangeHandler) Update(ctx context.Context, e event.UpdateEvent, q workqueue.RateLimitingInterface) { + _, ok := e.ObjectOld.(*corev1.LimitRange) + if !ok { + return + } + newLimitRange, ok := e.ObjectNew.(*corev1.LimitRange) + if !ok { + return + } + log := ctrl.LoggerFrom(ctx).WithValues("limitRange", klog.KObj(newLimitRange)) + if err := h.queueSuspendedTrainJob(ctx, newLimitRange.Namespace, q); err != nil { + log.Error(err, "could not queue suspended TrainJob to reconcile queue") + } +} + +func (h *PodGroupLimitRangeHandler) Delete(ctx context.Context, e event.DeleteEvent, q workqueue.RateLimitingInterface) { + limitRange, ok := e.Object.(*corev1.LimitRange) + if !ok { + return + } + log := ctrl.LoggerFrom(ctx).WithValues("limitRange", klog.KObj(limitRange)) + if err := h.queueSuspendedTrainJob(ctx, limitRange.Namespace, q); err != nil { + log.Error(err, "could not queue suspended TrainJob to reconcile queue") + } +} + +func (h *PodGroupLimitRangeHandler) Generic(context.Context, event.GenericEvent, workqueue.RateLimitingInterface) { +} + +func (h *PodGroupLimitRangeHandler) queueSuspendedTrainJob(ctx context.Context, ns string, q workqueue.RateLimitingInterface) error { + var trainJobs kubeflowv2.TrainJobList + if err := h.client.List(ctx, &trainJobs, client.InNamespace(ns)); err != nil { + return err + } + for _, trainJob := range trainJobs.Items { + if ptr.Deref(trainJob.Spec.Suspend, false) { + q.Add(client.ObjectKeyFromObject(&trainJob)) + } + } + return nil +} + +func (c *CoScheduling) ReconcilerBuilders() []runtime.ReconcilerBuilder { + if _, err := c.restMapper.RESTMapping( + schema.GroupKind{Group: schedulerpluginsv1alpha1.SchemeGroupVersion.Group, Kind: "PodGroup"}, + schedulerpluginsv1alpha1.SchemeGroupVersion.Version, + ); err != nil { + return nil + } + return []runtime.ReconcilerBuilder{ + func(b *builder.Builder, c client.Client) *builder.Builder { + return b.Owns(&schedulerpluginsv1alpha1.PodGroup{}) + }, + func(b *builder.Builder, c client.Client) *builder.Builder { + return b.Watches(&corev1.LimitRange{}, &PodGroupLimitRangeHandler{ + client: c, + }) + }, + func(b *builder.Builder, c client.Client) *builder.Builder { + return b.Watches(&nodev1.RuntimeClass{}, &PodGroupRuntimeClassHandler{ + client: c, + }) + }, + } +} diff --git a/pkg/runtime.v2/framework/plugins/coscheduling/indexer.go b/pkg/runtime.v2/framework/plugins/coscheduling/indexer.go new file mode 100644 index 0000000000..04a415ae3e --- /dev/null +++ b/pkg/runtime.v2/framework/plugins/coscheduling/indexer.go @@ -0,0 +1,56 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package coscheduling + +import ( + "sigs.k8s.io/controller-runtime/pkg/client" + + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" +) + +var ( + TrainingRuntimeContainerRuntimeClassKey = ".trainingRuntimeSpec.jobSetTemplateSpec.replicatedJobs.podTemplateSpec.runtimeClassName" + ClusterTrainingRuntimeContainerRuntimeClassKey = ".clusterTrainingRuntimeSpec.jobSetTemplateSpec.replicatedJobs.podTemplateSpec.runtimeClassName" +) + +func IndexTrainingRuntimeContainerRuntimeClass(obj client.Object) []string { + runtime, ok := obj.(*kubeflowv2.TrainingRuntime) + if !ok { + return nil + } + var runtimeClasses []string + for _, rJob := range runtime.Spec.Template.Spec.ReplicatedJobs { + if rJob.Template.Spec.Template.Spec.RuntimeClassName != nil { + runtimeClasses = append(runtimeClasses, *rJob.Template.Spec.Template.Spec.RuntimeClassName) + } + } + return runtimeClasses +} + +func IndexClusterTrainingRuntimeContainerRuntimeClass(obj client.Object) []string { + clRuntime, ok := obj.(*kubeflowv2.ClusterTrainingRuntime) + if !ok { + return nil + } + var runtimeClasses []string + for _, rJob := range clRuntime.Spec.Template.Spec.ReplicatedJobs { + if rJob.Template.Spec.Template.Spec.RuntimeClassName != nil { + runtimeClasses = append(runtimeClasses, *rJob.Template.Spec.Template.Spec.RuntimeClassName) + } + } + return runtimeClasses +} diff --git a/pkg/runtime.v2/framework/plugins/jobset/builder.go b/pkg/runtime.v2/framework/plugins/jobset/builder.go new file mode 100644 index 0000000000..ed336edfc9 --- /dev/null +++ b/pkg/runtime.v2/framework/plugins/jobset/builder.go @@ -0,0 +1,83 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package jobset + +import ( + "maps" + + batchv1 "k8s.io/api/batch/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client" + jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" + + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" +) + +type Builder struct { + *jobsetv1alpha2.JobSet +} + +func NewBuilder(objectKey client.ObjectKey, jobSetTemplateSpec kubeflowv2.JobSetTemplateSpec) *Builder { + return &Builder{ + JobSet: &jobsetv1alpha2.JobSet{ + TypeMeta: metav1.TypeMeta{ + APIVersion: jobsetv1alpha2.SchemeGroupVersion.String(), + Kind: "JobSet", + }, + ObjectMeta: metav1.ObjectMeta{ + Namespace: objectKey.Namespace, + Name: objectKey.Name, + Labels: maps.Clone(jobSetTemplateSpec.Labels), + Annotations: maps.Clone(jobSetTemplateSpec.Annotations), + }, + Spec: *jobSetTemplateSpec.Spec.DeepCopy(), + }, + } +} + +func (b *Builder) ContainerImage(image *string) *Builder { + if image == nil || *image == "" { + return b + } + for i, rJob := range b.Spec.ReplicatedJobs { + for j := range rJob.Template.Spec.Template.Spec.Containers { + b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Image = *image + } + } + return b +} + +func (b *Builder) JobCompletionMode(mode batchv1.CompletionMode) *Builder { + for i := range b.Spec.ReplicatedJobs { + b.Spec.ReplicatedJobs[i].Template.Spec.CompletionMode = &mode + } + return b +} + +// TODO: Supporting merge labels would be great. +func (b *Builder) PodLabels(labels map[string]string) *Builder { + for i := range b.Spec.ReplicatedJobs { + b.Spec.ReplicatedJobs[i].Template.Spec.Template.Labels = labels + } + return b +} + +// TODO: Need to support all TrainJob fields. + +func (b *Builder) Build() *jobsetv1alpha2.JobSet { + return b.JobSet +} diff --git a/pkg/runtime.v2/framework/plugins/jobset/jobset.go b/pkg/runtime.v2/framework/plugins/jobset/jobset.go new file mode 100644 index 0000000000..7d53abba1f --- /dev/null +++ b/pkg/runtime.v2/framework/plugins/jobset/jobset.go @@ -0,0 +1,121 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package jobset + +import ( + "context" + "fmt" + "maps" + + batchv1 "k8s.io/api/batch/v1" + "k8s.io/apimachinery/pkg/api/equality" + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/api/meta" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + apiruntime "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/builder" + "sigs.k8s.io/controller-runtime/pkg/client" + ctrlutil "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" + + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" + "github.com/kubeflow/training-operator/pkg/runtime.v2/framework" +) + +type JobSet struct { + client client.Client + restMapper meta.RESTMapper + scheme *apiruntime.Scheme +} + +var _ framework.WatchExtensionPlugin = (*JobSet)(nil) +var _ framework.ComponentBuilderPlugin = (*JobSet)(nil) + +const Name = "JobSet" + +func New(_ context.Context, c client.Client, _ client.FieldIndexer) (framework.Plugin, error) { + return &JobSet{ + client: c, + restMapper: c.RESTMapper(), + scheme: c.Scheme(), + }, nil +} + +func (j *JobSet) Name() string { + return Name +} + +func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *kubeflowv2.TrainJob) (client.Object, error) { + if info == nil || info.Obj == nil || trainJob == nil { + return nil, fmt.Errorf("runtime info or object is missing") + } + raw, ok := info.Obj.(*jobsetv1alpha2.JobSet) + if !ok { + return nil, nil + } + jobSetBuilder := NewBuilder(client.ObjectKeyFromObject(trainJob), kubeflowv2.JobSetTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: info.Labels, + Annotations: info.Annotations, + }, + Spec: raw.Spec, + }) + jobSet := jobSetBuilder. + ContainerImage(trainJob.Spec.Trainer.Image). + JobCompletionMode(batchv1.IndexedCompletion). + PodLabels(info.PodLabels). + Build() + if err := ctrlutil.SetControllerReference(trainJob, jobSet, j.scheme); err != nil { + return nil, err + } + oldJobSet := &jobsetv1alpha2.JobSet{} + if err := j.client.Get(ctx, client.ObjectKeyFromObject(jobSet), oldJobSet); err != nil { + if !apierrors.IsNotFound(err) { + return nil, err + } + oldJobSet = nil + } + if err := info.Update(jobSet); err != nil { + return nil, err + } + if needsCreateOrUpdate(oldJobSet, jobSet, ptr.Deref(trainJob.Spec.Suspend, false)) { + return jobSet, nil + } + return nil, nil +} + +func needsCreateOrUpdate(old, new *jobsetv1alpha2.JobSet, suspended bool) bool { + return old == nil || + suspended && (!equality.Semantic.DeepEqual(old.Spec, new.Spec) || !maps.Equal(old.Labels, new.Labels) || !maps.Equal(old.Annotations, new.Annotations)) +} + +func (j *JobSet) ReconcilerBuilders() []runtime.ReconcilerBuilder { + if _, err := j.restMapper.RESTMapping( + schema.GroupKind{Group: jobsetv1alpha2.GroupVersion.Group, Kind: "JobSet"}, + jobsetv1alpha2.SchemeGroupVersion.Version, + ); err != nil { + return nil + } + return []runtime.ReconcilerBuilder{ + func(b *builder.Builder, c client.Client) *builder.Builder { + return b.Owns(&jobsetv1alpha2.JobSet{}) + }, + } +} diff --git a/pkg/runtime.v2/framework/plugins/mpi/mpi.go b/pkg/runtime.v2/framework/plugins/mpi/mpi.go new file mode 100644 index 0000000000..b85a265195 --- /dev/null +++ b/pkg/runtime.v2/framework/plugins/mpi/mpi.go @@ -0,0 +1,60 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mpi + +import ( + "context" + + "k8s.io/apimachinery/pkg/util/validation/field" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + + runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" + "github.com/kubeflow/training-operator/pkg/runtime.v2/framework" +) + +type MPI struct { + client client.Client +} + +var _ framework.EnforceMLPolicyPlugin = (*MPI)(nil) +var _ framework.CustomValidationPlugin = (*MPI)(nil) + +const Name = "MPI" + +func New(_ context.Context, client client.Client, _ client.FieldIndexer) (framework.Plugin, error) { + return &MPI{ + client: client, + }, nil +} + +func (m *MPI) Name() string { + return Name +} + +func (m *MPI) EnforceMLPolicy(info *runtime.Info) error { + if info == nil || info.MLPolicy == nil || info.MLPolicy.MPI == nil { + return nil + } + // TODO: Need to implement main logic. + return nil +} + +// TODO: Need to implement validations for MPIJob. +func (m *MPI) Validate(oldObj, newObj client.Object) (admission.Warnings, field.ErrorList) { + return nil, nil +} diff --git a/pkg/runtime.v2/framework/plugins/plainml/plainml.go b/pkg/runtime.v2/framework/plugins/plainml/plainml.go new file mode 100644 index 0000000000..3320fc7231 --- /dev/null +++ b/pkg/runtime.v2/framework/plugins/plainml/plainml.go @@ -0,0 +1,55 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plainml + +import ( + "context" + + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" + + runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" + "github.com/kubeflow/training-operator/pkg/runtime.v2/framework" +) + +var _ framework.EnforceMLPolicyPlugin = (*PlainML)(nil) + +type PlainML struct{} + +const Name = "PlainML" + +func New(context.Context, client.Client, client.FieldIndexer) (framework.Plugin, error) { + return &PlainML{}, nil +} + +func (p *PlainML) Name() string { + return Name +} + +func (p *PlainML) EnforceMLPolicy(info *runtime.Info) error { + if info == nil || info.MLPolicy == nil || info.MLPolicy.Torch != nil || info.MLPolicy.MPI != nil { + return nil + } + numNodes := ptr.Deref(info.MLPolicy.NumNodes, 1) + for rName := range info.TotalRequests { + info.TotalRequests[rName] = runtime.TotalResourceRequest{ + Replicas: numNodes, + PodRequests: info.TotalRequests[rName].PodRequests, + } + } + return nil +} diff --git a/pkg/runtime.v2/framework/plugins/registry.go b/pkg/runtime.v2/framework/plugins/registry.go new file mode 100644 index 0000000000..37cc663dac --- /dev/null +++ b/pkg/runtime.v2/framework/plugins/registry.go @@ -0,0 +1,42 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plugins + +import ( + "context" + + "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/kubeflow/training-operator/pkg/runtime.v2/framework" + "github.com/kubeflow/training-operator/pkg/runtime.v2/framework/plugins/coscheduling" + "github.com/kubeflow/training-operator/pkg/runtime.v2/framework/plugins/jobset" + "github.com/kubeflow/training-operator/pkg/runtime.v2/framework/plugins/mpi" + "github.com/kubeflow/training-operator/pkg/runtime.v2/framework/plugins/plainml" + "github.com/kubeflow/training-operator/pkg/runtime.v2/framework/plugins/torch" +) + +type Registry map[string]func(ctx context.Context, client client.Client, indexer client.FieldIndexer) (framework.Plugin, error) + +func NewRegistry() Registry { + return Registry{ + coscheduling.Name: coscheduling.New, + mpi.Name: mpi.New, + plainml.Name: plainml.New, + torch.Name: torch.New, + jobset.Name: jobset.New, + } +} diff --git a/pkg/runtime.v2/framework/plugins/torch/torch.go b/pkg/runtime.v2/framework/plugins/torch/torch.go new file mode 100644 index 0000000000..1a0306be98 --- /dev/null +++ b/pkg/runtime.v2/framework/plugins/torch/torch.go @@ -0,0 +1,56 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package torch + +import ( + "context" + + "k8s.io/apimachinery/pkg/util/validation/field" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + + runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" + "github.com/kubeflow/training-operator/pkg/runtime.v2/framework" +) + +type Torch struct{} + +var _ framework.EnforceMLPolicyPlugin = (*Torch)(nil) +var _ framework.CustomValidationPlugin = (*Torch)(nil) + +const Name = "Torch" + +func New(context.Context, client.Client, client.FieldIndexer) (framework.Plugin, error) { + return &Torch{}, nil +} + +func (t *Torch) Name() string { + return Name +} + +func (t *Torch) EnforceMLPolicy(info *runtime.Info) error { + if info == nil || info.MLPolicy == nil || info.MLPolicy.Torch == nil { + return nil + } + // TODO: Need to implement main logic. + return nil +} + +// TODO: Need to implement validateions for TorchJob. +func (t *Torch) Validate(oldObj, newObj client.Object) (admission.Warnings, field.ErrorList) { + return nil, nil +} diff --git a/pkg/runtime.v2/indexer/indexer.go b/pkg/runtime.v2/indexer/indexer.go new file mode 100644 index 0000000000..9ba2c057f7 --- /dev/null +++ b/pkg/runtime.v2/indexer/indexer.go @@ -0,0 +1,45 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package indexer + +import ( + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" + + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" +) + +const ( + TrainJobTrainingRuntimeRefKey = ".spec.trainingRuntimeRef" +) + +func IndexTrainJobTrainingRuntimes(obj client.Object) []string { + trainJob, ok := obj.(*kubeflowv2.TrainJob) + if !ok { + return nil + } + runtimeRefGroupKind := schema.GroupKind{ + Group: ptr.Deref(trainJob.Spec.TrainingRuntimeRef.APIGroup, ""), + Kind: ptr.Deref(trainJob.Spec.TrainingRuntimeRef.Kind, ""), + } + if runtimeRefGroupKind.Group == kubeflowv2.GroupVersion.Group && + (runtimeRefGroupKind.Kind == "TrainingRuntime" || runtimeRefGroupKind.Kind == "ClusterTrainingRuntime") { + return []string{trainJob.Spec.TrainingRuntimeRef.Name} + } + return nil +} diff --git a/pkg/runtime.v2/interface.go b/pkg/runtime.v2/interface.go new file mode 100644 index 0000000000..d7b84e3f46 --- /dev/null +++ b/pkg/runtime.v2/interface.go @@ -0,0 +1,33 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package runtimev2 + +import ( + "context" + + "sigs.k8s.io/controller-runtime/pkg/builder" + "sigs.k8s.io/controller-runtime/pkg/client" + + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" +) + +type ReconcilerBuilder func(*builder.Builder, client.Client) *builder.Builder + +type Runtime interface { + NewObjects(ctx context.Context, trainJob *kubeflowv2.TrainJob) ([]client.Object, error) + EventHandlerRegistrars() []ReconcilerBuilder +} diff --git a/pkg/runtime.v2/runtime.go b/pkg/runtime.v2/runtime.go new file mode 100644 index 0000000000..eadec3a523 --- /dev/null +++ b/pkg/runtime.v2/runtime.go @@ -0,0 +1,147 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package runtimev2 + +import ( + "errors" + "maps" + + corev1 "k8s.io/api/core/v1" + "sigs.k8s.io/controller-runtime/pkg/client" + kueuelr "sigs.k8s.io/kueue/pkg/util/limitrange" + + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" +) + +var ( + errorDifferentGVK = errors.New("the GroupVersionKinds are different between old and new objects") + errorObjectsAreNil = errors.New("old or new objects are nil") +) + +type Info struct { + Obj client.Object + Labels map[string]string + PodLabels map[string]string + Annotations map[string]string + PodAnnotations map[string]string + Policy + TotalRequests map[string]TotalResourceRequest +} + +type Policy struct { + MLPolicy *kubeflowv2.MLPolicy + PodGroupPolicy *kubeflowv2.PodGroupPolicy +} + +type TotalResourceRequest struct { + Replicas int32 + PodRequests corev1.ResourceList +} + +type InfoOptions struct { + podSpecReplicas []podSpecReplica + Policy + labels map[string]string + annotations map[string]string +} + +type InfoOption func(options *InfoOptions) + +var defaultOptions = InfoOptions{} + +type podSpecReplica struct { + replicas int32 + name string + podSpec corev1.PodSpec +} + +func WithPodSpecReplicas(replicaName string, replicas int32, podSpec corev1.PodSpec) InfoOption { + return func(o *InfoOptions) { + o.podSpecReplicas = append(o.podSpecReplicas, podSpecReplica{ + name: replicaName, + replicas: replicas, + podSpec: podSpec, + }) + } +} + +func WithLabels(labels map[string]string) InfoOption { + return func(o *InfoOptions) { + o.labels = maps.Clone(labels) + } +} + +func WithAnnotations(annotations map[string]string) InfoOption { + return func(o *InfoOptions) { + o.annotations = maps.Clone(annotations) + } +} + +func WithPodGroupPolicy(pgPolicy *kubeflowv2.PodGroupPolicy) InfoOption { + return func(o *InfoOptions) { + o.PodGroupPolicy = pgPolicy + } +} + +func WithMLPolicy(mlPolicy *kubeflowv2.MLPolicy) InfoOption { + return func(o *InfoOptions) { + o.MLPolicy = mlPolicy + } +} + +func NewInfo(obj client.Object, opts ...InfoOption) *Info { + options := defaultOptions + for _, opt := range opts { + opt(&options) + } + var copyObj client.Object + if obj != nil { + copyObj = obj.DeepCopyObject().(client.Object) + } + info := &Info{ + Obj: copyObj, + Labels: make(map[string]string), + Annotations: make(map[string]string), + TotalRequests: make(map[string]TotalResourceRequest, len(options.podSpecReplicas)), + } + for _, spec := range options.podSpecReplicas { + info.TotalRequests[spec.name] = TotalResourceRequest{ + Replicas: spec.replicas, + // TODO: Need to address LimitRange and RuntimeClass. + PodRequests: kueuelr.TotalRequests(&spec.podSpec), + } + } + if options.labels != nil { + info.Labels = options.labels + } + if options.annotations != nil { + info.Annotations = options.annotations + } + info.Policy = options.Policy + return info +} + +func (i *Info) Update(obj client.Object) error { + if obj == nil || i.Obj == nil { + return errorObjectsAreNil + } + if i.Obj.GetObjectKind().GroupVersionKind() != obj.GetObjectKind().GroupVersionKind() { + return errorDifferentGVK + } + i.Obj = obj.DeepCopyObject().(client.Object) + return nil +} diff --git a/pkg/runtime.v2/runtime_test.go b/pkg/runtime.v2/runtime_test.go new file mode 100644 index 0000000000..bcb6b5efd9 --- /dev/null +++ b/pkg/runtime.v2/runtime_test.go @@ -0,0 +1,220 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package runtimev2 + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" + + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + testingutil "github.com/kubeflow/training-operator/pkg/util.v2/testing" +) + +func TestNewInfo(t *testing.T) { + jobSetBase := testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job"). + Clone() + + cases := map[string]struct { + obj client.Object + infoOpts []InfoOption + wantInfo *Info + }{ + "all arguments are specified": { + obj: jobSetBase.Obj(), + infoOpts: []InfoOption{ + WithLabels(map[string]string{ + "labelKey": "labelValue", + }), + WithAnnotations(map[string]string{ + "annotationKey": "annotationValue", + }), + WithPodGroupPolicy(&kubeflowv2.PodGroupPolicy{ + PodGroupPolicySource: kubeflowv2.PodGroupPolicySource{ + Coscheduling: &kubeflowv2.CoschedulingPodGroupPolicySource{ + ScheduleTimeoutSeconds: ptr.To[int32](300), + }, + }, + }), + WithMLPolicy(&kubeflowv2.MLPolicy{ + NumNodes: ptr.To[int32](100), + MLPolicySource: kubeflowv2.MLPolicySource{ + Torch: &kubeflowv2.TorchMLPolicySource{ + NumProcPerNode: ptr.To("8"), + }, + }, + }), + WithPodSpecReplicas("Leader", 1, corev1.PodSpec{ + InitContainers: []corev1.Container{{ + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("5"), + }, + }, + RestartPolicy: ptr.To(corev1.ContainerRestartPolicyAlways), + }}, + Containers: []corev1.Container{{ + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("10"), + }, + }, + }}, + }), + WithPodSpecReplicas("Worker", 10, corev1.PodSpec{ + InitContainers: []corev1.Container{{ + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("15"), + }, + }, + RestartPolicy: ptr.To(corev1.ContainerRestartPolicyAlways), + }}, + Containers: []corev1.Container{{ + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("25"), + }, + }, + }}, + }), + }, + wantInfo: &Info{ + Obj: jobSetBase.Obj(), + Labels: map[string]string{ + "labelKey": "labelValue", + }, + Annotations: map[string]string{ + "annotationKey": "annotationValue", + }, + Policy: Policy{ + MLPolicy: &kubeflowv2.MLPolicy{ + NumNodes: ptr.To[int32](100), + MLPolicySource: kubeflowv2.MLPolicySource{ + Torch: &kubeflowv2.TorchMLPolicySource{ + NumProcPerNode: ptr.To("8"), + }, + }, + }, + PodGroupPolicy: &kubeflowv2.PodGroupPolicy{ + PodGroupPolicySource: kubeflowv2.PodGroupPolicySource{ + Coscheduling: &kubeflowv2.CoschedulingPodGroupPolicySource{ + ScheduleTimeoutSeconds: ptr.To[int32](300), + }, + }, + }, + }, + TotalRequests: map[string]TotalResourceRequest{ + "Leader": { + Replicas: 1, + PodRequests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("15"), + }, + }, + "Worker": { + Replicas: 10, + PodRequests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("40"), + }, + }, + }, + }, + }, + "all arguments are not specified": { + wantInfo: &Info{}, + }, + } + cmpOpts := []cmp.Option{ + cmpopts.SortMaps(func(a, b string) bool { return a < b }), + cmpopts.EquateEmpty(), + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + info := NewInfo(tc.obj, tc.infoOpts...) + if diff := cmp.Diff(tc.wantInfo, info, cmpOpts...); len(diff) != 0 { + t.Errorf("Unexpected runtime.Info (-want,+got):\n%s", diff) + } + }) + } +} + +func TestUpdate(t *testing.T) { + jobSetBase := testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job"). + Clone() + + cases := map[string]struct { + info *Info + obj client.Object + wantInfo *Info + wantError error + }{ + "gvk is different between old and new objects": { + info: &Info{ + Obj: jobSetBase.Obj(), + }, + obj: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). + Obj(), + wantInfo: &Info{ + Obj: jobSetBase.Obj(), + }, + wantError: errorDifferentGVK, + }, + "old object is nil": { + info: &Info{}, + obj: jobSetBase.Obj(), + wantInfo: &Info{}, + wantError: errorObjectsAreNil, + }, + "new object is nil": { + info: &Info{ + Obj: jobSetBase.Obj(), + }, + wantInfo: &Info{ + Obj: jobSetBase.Obj(), + }, + wantError: errorObjectsAreNil, + }, + "update object with the appropriate parameter": { + info: &Info{ + Obj: jobSetBase.Obj(), + }, + obj: jobSetBase.ContainerImage(ptr.To("test:latest")).Obj(), + wantInfo: &Info{ + Obj: jobSetBase.ContainerImage(ptr.To("test:latest")).Obj(), + }, + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + if tc.info != nil { + err := tc.info.Update(tc.obj) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error (-want,+got):\n%s", diff) + } + } + if diff := cmp.Diff(tc.wantInfo, tc.info); len(diff) != 0 { + t.Errorf("Unexpected runtime.Info (-want,+got):\n%s", diff) + } + }) + } +} diff --git a/pkg/util.v2/testing/client.go b/pkg/util.v2/testing/client.go new file mode 100644 index 0000000000..3f1209779c --- /dev/null +++ b/pkg/util.v2/testing/client.go @@ -0,0 +1,63 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package testing + +import ( + "context" + "fmt" + + "k8s.io/apimachinery/pkg/runtime" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" + schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" + + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" +) + +func NewClientBuilder(addToSchemes ...func(s *runtime.Scheme) error) *fake.ClientBuilder { + scm := runtime.NewScheme() + utilruntime.Must(clientgoscheme.AddToScheme(scm)) + utilruntime.Must(kubeflowv2.AddToScheme(scm)) + utilruntime.Must(jobsetv1alpha2.AddToScheme(scm)) + utilruntime.Must(schedulerpluginsv1alpha1.AddToScheme(scm)) + for i := range addToSchemes { + utilruntime.Must(addToSchemes[i](scm)) + } + return fake.NewClientBuilder(). + WithScheme(scm) +} + +type builderIndexer struct { + *fake.ClientBuilder +} + +var _ client.FieldIndexer = (*builderIndexer)(nil) + +func (b *builderIndexer) IndexField(_ context.Context, obj client.Object, field string, extractValue client.IndexerFunc) error { + if obj == nil || field == "" || extractValue == nil { + return fmt.Errorf("error from test indexer") + } + b.ClientBuilder = b.ClientBuilder.WithIndex(obj, field, extractValue) + return nil +} + +func AsIndex(builder *fake.ClientBuilder) client.FieldIndexer { + return &builderIndexer{ClientBuilder: builder} +} diff --git a/pkg/util.v2/testing/wrapper.go b/pkg/util.v2/testing/wrapper.go new file mode 100644 index 0000000000..8cadbb7e00 --- /dev/null +++ b/pkg/util.v2/testing/wrapper.go @@ -0,0 +1,475 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package testing + +import ( + "testing" + + batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/types" + "k8s.io/utils/ptr" + jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" + schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" + + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" +) + +type JobSetWrapper struct { + jobsetv1alpha2.JobSet +} + +func MakeJobSetWrapper(t *testing.T, namespace, name string) *JobSetWrapper { + t.Helper() + return &JobSetWrapper{ + JobSet: jobsetv1alpha2.JobSet{ + TypeMeta: metav1.TypeMeta{ + APIVersion: jobsetv1alpha2.SchemeGroupVersion.String(), + Kind: "JobSet", + }, + ObjectMeta: metav1.ObjectMeta{ + Namespace: namespace, + Name: name, + }, + Spec: jobsetv1alpha2.JobSetSpec{ + ReplicatedJobs: []jobsetv1alpha2.ReplicatedJob{ + { + Name: "Coordinator", + Replicas: 1, + Template: batchv1.JobTemplateSpec{ + Spec: batchv1.JobSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "trainer", + }, + }, + }, + }, + }, + }, + }, + { + Name: "Worker", + Replicas: 1, + Template: batchv1.JobTemplateSpec{ + Spec: batchv1.JobSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "trainer", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } +} + +func (j *JobSetWrapper) Completions(idx int, completions int32) *JobSetWrapper { + if len(j.Spec.ReplicatedJobs) < idx { + return j + } + j.Spec.ReplicatedJobs[idx].Template.Spec.Completions = &completions + return j +} + +func (j *JobSetWrapper) Parallelism(idx int, parallelism int32) *JobSetWrapper { + if len(j.Spec.ReplicatedJobs) < idx { + return j + } + j.Spec.ReplicatedJobs[idx].Template.Spec.Parallelism = ¶llelism + return j +} + +func (j *JobSetWrapper) ResourceRequests(idx int, res corev1.ResourceList) *JobSetWrapper { + if len(j.Spec.ReplicatedJobs) < idx { + return j + } + j.Spec.ReplicatedJobs[idx].Template.Spec.Template.Spec.Containers[0].Resources.Requests = res + return j +} + +func (j *JobSetWrapper) JobCompletionMode(mode batchv1.CompletionMode) *JobSetWrapper { + for i := range j.Spec.ReplicatedJobs { + j.Spec.ReplicatedJobs[i].Template.Spec.CompletionMode = &mode + } + return j +} + +func (j *JobSetWrapper) ContainerImage(image *string) *JobSetWrapper { + if image == nil || *image == "" { + return j + } + for i, rJob := range j.Spec.ReplicatedJobs { + for k := range rJob.Template.Spec.Template.Spec.Containers { + j.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[k].Image = *image + } + } + return j +} + +func (j *JobSetWrapper) ControllerReference(gvk schema.GroupVersionKind, name, uid string) *JobSetWrapper { + j.OwnerReferences = append(j.OwnerReferences, metav1.OwnerReference{ + APIVersion: gvk.GroupVersion().String(), + Kind: gvk.Kind, + Name: name, + UID: types.UID(uid), + Controller: ptr.To(true), + BlockOwnerDeletion: ptr.To(true), + }) + return j +} + +func (j *JobSetWrapper) PodLabel(key, value string) *JobSetWrapper { + for i, rJob := range j.Spec.ReplicatedJobs { + if rJob.Template.Spec.Template.Labels == nil { + j.Spec.ReplicatedJobs[i].Template.Spec.Template.Labels = make(map[string]string, 1) + } + j.Spec.ReplicatedJobs[i].Template.Spec.Template.Labels[key] = value + } + return j +} + +func (j *JobSetWrapper) Clone() *JobSetWrapper { + return &JobSetWrapper{ + JobSet: *j.JobSet.DeepCopy(), + } +} + +func (j *JobSetWrapper) Obj() *jobsetv1alpha2.JobSet { + return &j.JobSet +} + +type TrainJobWrapper struct { + kubeflowv2.TrainJob +} + +func MakeTrainJobWrapper(t *testing.T, namespace, name string) *TrainJobWrapper { + t.Helper() + return &TrainJobWrapper{ + TrainJob: kubeflowv2.TrainJob{ + TypeMeta: metav1.TypeMeta{ + APIVersion: kubeflowv2.SchemeGroupVersion.Version, + Kind: "TrainJob", + }, + ObjectMeta: metav1.ObjectMeta{ + Namespace: namespace, + Name: name, + }, + Spec: kubeflowv2.TrainJobSpec{}, + }, + } +} + +func (t *TrainJobWrapper) UID(uid string) *TrainJobWrapper { + t.ObjectMeta.UID = types.UID(uid) + return t +} + +type TrainJobTrainerWrapper struct { + kubeflowv2.Trainer +} + +func MakeTrainJobTrainerWrapper(t *testing.T) *TrainJobTrainerWrapper { + t.Helper() + return &TrainJobTrainerWrapper{ + Trainer: kubeflowv2.Trainer{}, + } +} + +func (t *TrainJobTrainerWrapper) ContainerImage(img string) *TrainJobTrainerWrapper { + t.Image = &img + return t +} + +func (t *TrainJobTrainerWrapper) Obj() *kubeflowv2.Trainer { + return &t.Trainer +} + +func (t *TrainJobWrapper) Trainer(trainer *kubeflowv2.Trainer) *TrainJobWrapper { + t.Spec.Trainer = trainer + return t +} + +func (t *TrainJobWrapper) TrainingRuntimeRef(gvk schema.GroupVersionKind, name string) *TrainJobWrapper { + t.Spec.TrainingRuntimeRef = kubeflowv2.TrainingRuntimeRef{ + APIGroup: &gvk.Group, + Kind: &gvk.Kind, + Name: name, + } + return t +} + +func (t *TrainJobWrapper) Obj() *kubeflowv2.TrainJob { + return &t.TrainJob +} + +type TrainingRuntimeWrapper struct { + kubeflowv2.TrainingRuntime +} + +func MakeTrainingRuntimeWrapper(t *testing.T, namespace, name string) *TrainingRuntimeWrapper { + t.Helper() + return &TrainingRuntimeWrapper{ + TrainingRuntime: kubeflowv2.TrainingRuntime{ + TypeMeta: metav1.TypeMeta{ + APIVersion: kubeflowv2.SchemeGroupVersion.String(), + Kind: "TrainingRuntime", + }, + ObjectMeta: metav1.ObjectMeta{ + Namespace: namespace, + Name: name, + }, + Spec: kubeflowv2.TrainingRuntimeSpec{ + Template: kubeflowv2.JobSetTemplateSpec{ + Spec: jobsetv1alpha2.JobSetSpec{ + ReplicatedJobs: []jobsetv1alpha2.ReplicatedJob{ + { + Name: "Coordinator", + Template: batchv1.JobTemplateSpec{ + Spec: batchv1.JobSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "trainer", + }}, + }, + }, + }, + }, + }, + { + Name: "Worker", + Template: batchv1.JobTemplateSpec{ + Spec: batchv1.JobSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "trainer", + }}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } +} + +func (r *TrainingRuntimeWrapper) RuntimeSpec(spec kubeflowv2.TrainingRuntimeSpec) *TrainingRuntimeWrapper { + r.Spec = spec + return r +} + +func (r *TrainingRuntimeWrapper) Clone() *TrainingRuntimeWrapper { + return &TrainingRuntimeWrapper{ + TrainingRuntime: *r.TrainingRuntime.DeepCopy(), + } +} + +func (r *TrainingRuntimeWrapper) Obj() *kubeflowv2.TrainingRuntime { + return &r.TrainingRuntime +} + +type ClusterTrainingRuntimeWrapper struct { + kubeflowv2.ClusterTrainingRuntime +} + +func MakeClusterTrainingRuntimeWrapper(t *testing.T, name string) *ClusterTrainingRuntimeWrapper { + t.Helper() + return &ClusterTrainingRuntimeWrapper{ + ClusterTrainingRuntime: kubeflowv2.ClusterTrainingRuntime{ + TypeMeta: metav1.TypeMeta{ + APIVersion: kubeflowv2.SchemeGroupVersion.String(), + Kind: "ClusterTrainingRuntime", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: name, + }, + Spec: kubeflowv2.TrainingRuntimeSpec{ + Template: kubeflowv2.JobSetTemplateSpec{ + Spec: jobsetv1alpha2.JobSetSpec{ + ReplicatedJobs: []jobsetv1alpha2.ReplicatedJob{ + { + Name: "Coordinator", + Template: batchv1.JobTemplateSpec{ + Spec: batchv1.JobSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "trainer", + }}, + }, + }, + }, + }, + }, + { + Name: "Worker", + Template: batchv1.JobTemplateSpec{ + Spec: batchv1.JobSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "trainer", + }}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } +} + +func (r *ClusterTrainingRuntimeWrapper) RuntimeSpec(spec kubeflowv2.TrainingRuntimeSpec) *ClusterTrainingRuntimeWrapper { + r.Spec = spec + return r +} + +func (r *ClusterTrainingRuntimeWrapper) Clone() *ClusterTrainingRuntimeWrapper { + return &ClusterTrainingRuntimeWrapper{ + ClusterTrainingRuntime: *r.ClusterTrainingRuntime.DeepCopy(), + } +} + +func (r *ClusterTrainingRuntimeWrapper) Obj() *kubeflowv2.ClusterTrainingRuntime { + return &r.ClusterTrainingRuntime +} + +type TrainingRuntimeSpecWrapper struct { + kubeflowv2.TrainingRuntimeSpec +} + +func MakeTrainingRuntimeSpecWrapper(t *testing.T, spec kubeflowv2.TrainingRuntimeSpec) *TrainingRuntimeSpecWrapper { + t.Helper() + return &TrainingRuntimeSpecWrapper{ + TrainingRuntimeSpec: spec, + } +} + +func (s *TrainingRuntimeSpecWrapper) ContainerImage(image string) *TrainingRuntimeSpecWrapper { + for i, rJob := range s.Template.Spec.ReplicatedJobs { + for j := range rJob.Template.Spec.Template.Spec.Containers { + s.Template.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Image = image + } + } + return s +} + +func (s *TrainingRuntimeSpecWrapper) ResourceRequests(idx int, res corev1.ResourceList) *TrainingRuntimeSpecWrapper { + if len(s.Template.Spec.ReplicatedJobs) < idx { + return s + } + s.Template.Spec.ReplicatedJobs[idx].Template.Spec.Template.Spec.Containers[0].Resources.Requests = res + return s +} + +func (s *TrainingRuntimeSpecWrapper) PodGroupPolicySchedulingTimeout(timeout int32) *TrainingRuntimeSpecWrapper { + if s.PodGroupPolicy == nil || s.PodGroupPolicy.Coscheduling == nil { + s.PodGroupPolicy = &kubeflowv2.PodGroupPolicy{ + PodGroupPolicySource: kubeflowv2.PodGroupPolicySource{ + Coscheduling: &kubeflowv2.CoschedulingPodGroupPolicySource{ + ScheduleTimeoutSeconds: &timeout, + }, + }, + } + } + s.PodGroupPolicy.Coscheduling.ScheduleTimeoutSeconds = &timeout + return s +} + +func (s *TrainingRuntimeSpecWrapper) MLPolicyNumNodes(numNodes int32) *TrainingRuntimeSpecWrapper { + if s.MLPolicy == nil { + s.MLPolicy = &kubeflowv2.MLPolicy{} + } + s.MLPolicy.NumNodes = &numNodes + return s +} + +func (s *TrainingRuntimeSpecWrapper) Obj() kubeflowv2.TrainingRuntimeSpec { + return s.TrainingRuntimeSpec +} + +type SchedulerPluginsPodGroupWrapper struct { + schedulerpluginsv1alpha1.PodGroup +} + +func MakeSchedulerPluginsPodGroup(t *testing.T, namespace, name string) *SchedulerPluginsPodGroupWrapper { + t.Helper() + return &SchedulerPluginsPodGroupWrapper{ + PodGroup: schedulerpluginsv1alpha1.PodGroup{ + TypeMeta: metav1.TypeMeta{ + APIVersion: schedulerpluginsv1alpha1.SchemeGroupVersion.String(), + Kind: "PodGroup", + }, + ObjectMeta: metav1.ObjectMeta{ + Namespace: namespace, + Name: name, + }, + }, + } +} + +func (p *SchedulerPluginsPodGroupWrapper) SchedulingTimeout(timeout int32) *SchedulerPluginsPodGroupWrapper { + p.PodGroup.Spec.ScheduleTimeoutSeconds = &timeout + return p +} + +func (p *SchedulerPluginsPodGroupWrapper) MinMember(members int32) *SchedulerPluginsPodGroupWrapper { + p.PodGroup.Spec.MinMember = members + return p +} + +func (p *SchedulerPluginsPodGroupWrapper) MinResources(resources corev1.ResourceList) *SchedulerPluginsPodGroupWrapper { + p.PodGroup.Spec.MinResources = resources + return p +} + +func (p *SchedulerPluginsPodGroupWrapper) ControllerReference(gvk schema.GroupVersionKind, name, uid string) *SchedulerPluginsPodGroupWrapper { + p.OwnerReferences = append(p.OwnerReferences, metav1.OwnerReference{ + APIVersion: gvk.GroupVersion().String(), + Kind: gvk.Kind, + Name: name, + UID: types.UID(uid), + Controller: ptr.To(true), + BlockOwnerDeletion: ptr.To(true), + }) + return p +} + +func (p *SchedulerPluginsPodGroupWrapper) Obj() *schedulerpluginsv1alpha1.PodGroup { + return &p.PodGroup +} diff --git a/pkg/webhook.v2/clustertrainingruntime_webhook.go b/pkg/webhook.v2/clustertrainingruntime_webhook.go index 0ce728f654..a679fe66dd 100644 --- a/pkg/webhook.v2/clustertrainingruntime_webhook.go +++ b/pkg/webhook.v2/clustertrainingruntime_webhook.go @@ -19,20 +19,23 @@ package webhookv2 import ( "context" - "k8s.io/apimachinery/pkg/runtime" + apiruntime "k8s.io/apimachinery/pkg/runtime" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/webhook" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" ) -type ClusterTrainingRuntimeWebhook struct{} +type ClusterTrainingRuntimeWebhook struct { + runtimes map[string]runtime.Runtime +} -func setupWebhookForClusterTrainingRuntime(mgr ctrl.Manager) error { +func setupWebhookForClusterTrainingRuntime(mgr ctrl.Manager, run map[string]runtime.Runtime) error { return ctrl.NewWebhookManagedBy(mgr). For(&kubeflowv2.ClusterTrainingRuntime{}). - WithValidator(&ClusterTrainingRuntimeWebhook{}). + WithValidator(&ClusterTrainingRuntimeWebhook{runtimes: run}). Complete() } @@ -40,14 +43,14 @@ func setupWebhookForClusterTrainingRuntime(mgr ctrl.Manager) error { var _ webhook.CustomValidator = (*ClusterTrainingRuntimeWebhook)(nil) -func (w *ClusterTrainingRuntimeWebhook) ValidateCreate(context.Context, runtime.Object) (admission.Warnings, error) { +func (w *ClusterTrainingRuntimeWebhook) ValidateCreate(context.Context, apiruntime.Object) (admission.Warnings, error) { return nil, nil } -func (w *ClusterTrainingRuntimeWebhook) ValidateUpdate(context.Context, runtime.Object, runtime.Object) (admission.Warnings, error) { +func (w *ClusterTrainingRuntimeWebhook) ValidateUpdate(context.Context, apiruntime.Object, apiruntime.Object) (admission.Warnings, error) { return nil, nil } -func (w *ClusterTrainingRuntimeWebhook) ValidateDelete(context.Context, runtime.Object) (admission.Warnings, error) { +func (w *ClusterTrainingRuntimeWebhook) ValidateDelete(context.Context, apiruntime.Object) (admission.Warnings, error) { return nil, nil } diff --git a/pkg/webhook.v2/setup.go b/pkg/webhook.v2/setup.go index f7c9436ab7..a9109f2ffd 100644 --- a/pkg/webhook.v2/setup.go +++ b/pkg/webhook.v2/setup.go @@ -16,16 +16,20 @@ limitations under the License. package webhookv2 -import ctrl "sigs.k8s.io/controller-runtime" +import ( + ctrl "sigs.k8s.io/controller-runtime" -func Setup(mgr ctrl.Manager) (string, error) { - if err := setupWebhookForClusterTrainingRuntime(mgr); err != nil { + runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" +) + +func Setup(mgr ctrl.Manager, runtimes map[string]runtime.Runtime) (string, error) { + if err := setupWebhookForClusterTrainingRuntime(mgr, runtimes); err != nil { return "ClusterTrainingRuntime", err } - if err := setupWebhookForTrainingRuntime(mgr); err != nil { + if err := setupWebhookForTrainingRuntime(mgr, runtimes); err != nil { return "TrainingRuntime", err } - if err := setupWebhookForTrainJob(mgr); err != nil { + if err := setupWebhookForTrainJob(mgr, runtimes); err != nil { return "TrainJob", err } return "", nil diff --git a/pkg/webhook.v2/trainingruntime_webhook.go b/pkg/webhook.v2/trainingruntime_webhook.go index a9fa897dbc..5597e5238f 100644 --- a/pkg/webhook.v2/trainingruntime_webhook.go +++ b/pkg/webhook.v2/trainingruntime_webhook.go @@ -19,20 +19,23 @@ package webhookv2 import ( "context" - "k8s.io/apimachinery/pkg/runtime" + apiruntime "k8s.io/apimachinery/pkg/runtime" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/webhook" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" ) -type TrainingRuntimeWebhook struct{} +type TrainingRuntimeWebhook struct { + runtimes map[string]runtime.Runtime +} -func setupWebhookForTrainingRuntime(mgr ctrl.Manager) error { +func setupWebhookForTrainingRuntime(mgr ctrl.Manager, run map[string]runtime.Runtime) error { return ctrl.NewWebhookManagedBy(mgr). For(&kubeflowv2.TrainingRuntime{}). - WithValidator(&TrainingRuntimeWebhook{}). + WithValidator(&TrainingRuntimeWebhook{runtimes: run}). Complete() } @@ -40,14 +43,14 @@ func setupWebhookForTrainingRuntime(mgr ctrl.Manager) error { var _ webhook.CustomValidator = (*TrainingRuntimeWebhook)(nil) -func (w *TrainingRuntimeWebhook) ValidateCreate(context.Context, runtime.Object) (admission.Warnings, error) { +func (w *TrainingRuntimeWebhook) ValidateCreate(context.Context, apiruntime.Object) (admission.Warnings, error) { return nil, nil } -func (w *TrainingRuntimeWebhook) ValidateUpdate(context.Context, runtime.Object, runtime.Object) (admission.Warnings, error) { +func (w *TrainingRuntimeWebhook) ValidateUpdate(context.Context, apiruntime.Object, apiruntime.Object) (admission.Warnings, error) { return nil, nil } -func (w *TrainingRuntimeWebhook) ValidateDelete(context.Context, runtime.Object) (admission.Warnings, error) { +func (w *TrainingRuntimeWebhook) ValidateDelete(context.Context, apiruntime.Object) (admission.Warnings, error) { return nil, nil } diff --git a/pkg/webhook.v2/trainjob_webhook.go b/pkg/webhook.v2/trainjob_webhook.go index 231e124f3d..cf75400c82 100644 --- a/pkg/webhook.v2/trainjob_webhook.go +++ b/pkg/webhook.v2/trainjob_webhook.go @@ -19,20 +19,23 @@ package webhookv2 import ( "context" - "k8s.io/apimachinery/pkg/runtime" + apiruntime "k8s.io/apimachinery/pkg/runtime" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/webhook" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" ) -type TrainJobWebhook struct{} +type TrainJobWebhook struct { + runtimes map[string]runtime.Runtime +} -func setupWebhookForTrainJob(mgr ctrl.Manager) error { +func setupWebhookForTrainJob(mgr ctrl.Manager, run map[string]runtime.Runtime) error { return ctrl.NewWebhookManagedBy(mgr). For(&kubeflowv2.TrainJob{}). - WithValidator(&TrainJobWebhook{}). + WithValidator(&TrainJobWebhook{runtimes: run}). Complete() } @@ -40,14 +43,14 @@ func setupWebhookForTrainJob(mgr ctrl.Manager) error { var _ webhook.CustomValidator = (*TrainJobWebhook)(nil) -func (w *TrainJobWebhook) ValidateCreate(context.Context, runtime.Object) (admission.Warnings, error) { +func (w *TrainJobWebhook) ValidateCreate(context.Context, apiruntime.Object) (admission.Warnings, error) { return nil, nil } -func (w *TrainJobWebhook) ValidateUpdate(context.Context, runtime.Object, runtime.Object) (admission.Warnings, error) { +func (w *TrainJobWebhook) ValidateUpdate(context.Context, apiruntime.Object, apiruntime.Object) (admission.Warnings, error) { return nil, nil } -func (w *TrainJobWebhook) ValidateDelete(context.Context, runtime.Object) (admission.Warnings, error) { +func (w *TrainJobWebhook) ValidateDelete(context.Context, apiruntime.Object) (admission.Warnings, error) { return nil, nil } diff --git a/test/integration/framework/framework.go b/test/integration/framework/framework.go index 83d85c7e4a..0a3a7fb774 100644 --- a/test/integration/framework/framework.go +++ b/test/integration/framework/framework.go @@ -39,6 +39,7 @@ import ( kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" controllerv2 "github.com/kubeflow/training-operator/pkg/controller.v2" + runtimecore "github.com/kubeflow/training-operator/pkg/runtime.v2/core" webhookv2 "github.com/kubeflow/training-operator/pkg/webhook.v2" ) @@ -89,10 +90,17 @@ func (f *Framework) RunManager(cfg *rest.Config) (context.Context, client.Client }) gomega.ExpectWithOffset(1, err).NotTo(gomega.HaveOccurred(), "failed to create manager") - failedCtrlName, err := controllerv2.SetupControllers(mgr) + runtimes, err := runtimecore.New(ctx, mgr.GetClient(), mgr.GetFieldIndexer()) + gomega.ExpectWithOffset(1, err).NotTo(gomega.HaveOccurred()) + gomega.ExpectWithOffset(1, runtimes).NotTo(gomega.BeNil()) + + failedCtrlName, err := controllerv2.SetupControllers(mgr, runtimes) gomega.ExpectWithOffset(1, err).NotTo(gomega.HaveOccurred(), "controller", failedCtrlName) - failedWebhookName, err := webhookv2.Setup(mgr) + gomega.ExpectWithOffset(1, failedCtrlName).To(gomega.BeEmpty()) + + failedWebhookName, err := webhookv2.Setup(mgr, runtimes) gomega.ExpectWithOffset(1, err).NotTo(gomega.HaveOccurred(), "webhook", failedWebhookName) + gomega.ExpectWithOffset(1, failedWebhookName).To(gomega.BeEmpty()) go func() { defer ginkgo.GinkgoRecover() From d64cb3481c3cfa5d6841942034ec4802100b213a Mon Sep 17 00:00:00 2001 From: Yuki Iwai Date: Wed, 11 Sep 2024 03:03:48 +0900 Subject: [PATCH 02/12] Remove grep dependency Signed-off-by: Yuki Iwai --- Makefile | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/Makefile b/Makefile index 2a65e73622..44a011c093 100644 --- a/Makefile +++ b/Makefile @@ -10,17 +10,6 @@ else GOBIN=$(shell go env GOBIN) endif -# Setting GREP allows macos users to install GNU grep and use the latter -# instead of the default BSD grep. -ifeq ($(shell command -v ggrep 2>/dev/null),) - GREP ?= $(shell command -v grep) -else - GREP ?= $(shell command -v ggrep) -endif -ifeq ($(shell ${GREP} --version 2>&1 | grep -q GNU; echo $$?),1) - $(error !!! GNU grep is required. If on OS X, use 'brew install grep'.) -endif - # Setting SHELL to bash allows bash commands to be executed by recipes. # This is a requirement for 'setup-envtest.sh' in the test target. # Options are set to exit when a recipe line exits non-zero or a piped command fails. @@ -93,7 +82,7 @@ test-integrationv2: envtest .PHONY: testv2 testv2: - go test $(shell go list ./pkg/... | $(GREP) -E '.*\.v2') -coverprofile cover.out + go test ./pkg/controller.v2/... ./pkg/runtime.v2/... ./pkg/webhook.v2/... ./pkg/util.v2/... -coverprofile cover.out envtest: ifndef HAS_SETUP_ENVTEST From a0cae9ab7982f0a55e853cfedace3b799eed3ac1 Mon Sep 17 00:00:00 2001 From: Yuki Iwai Date: Fri, 11 Oct 2024 03:51:43 +0900 Subject: [PATCH 03/12] KEP-2170: Implement ValidateObjects interface to the runtime framework Signed-off-by: Yuki Iwai --- pkg/runtime.v2/core/clustertrainingruntime.go | 15 +++++++++++++++ pkg/runtime.v2/core/trainingruntime.go | 15 +++++++++++++++ pkg/runtime.v2/framework/core/framework.go | 4 ++-- pkg/runtime.v2/framework/core/framework_test.go | 9 +++++---- pkg/runtime.v2/interface.go | 3 +++ 5 files changed, 40 insertions(+), 6 deletions(-) diff --git a/pkg/runtime.v2/core/clustertrainingruntime.go b/pkg/runtime.v2/core/clustertrainingruntime.go index d4908af5f0..ecbd36b255 100644 --- a/pkg/runtime.v2/core/clustertrainingruntime.go +++ b/pkg/runtime.v2/core/clustertrainingruntime.go @@ -22,7 +22,9 @@ import ( "fmt" "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" @@ -60,3 +62,16 @@ func (r *ClusterTrainingRuntime) NewObjects(ctx context.Context, trainJob *kubef func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder { return nil } + +func (r *ClusterTrainingRuntime) ValidateObjects(ctx context.Context, old, new *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) { + if err := r.client.Get(ctx, client.ObjectKey{ + Namespace: old.Namespace, + Name: old.Spec.TrainingRuntimeRef.Name, + }, &kubeflowv2.ClusterTrainingRuntime{}); err != nil { + return nil, field.ErrorList{ + field.Invalid(field.NewPath("spec", "trainingRuntimeRef"), old.Spec.TrainingRuntimeRef, + fmt.Sprintf("%v: specified clusterTrainingRuntime must be created before the TrainJob is created", err)), + } + } + return r.framework.RunCustomValidationPlugins(old, new) +} diff --git a/pkg/runtime.v2/core/trainingruntime.go b/pkg/runtime.v2/core/trainingruntime.go index 879f95a04f..32659f4c0b 100644 --- a/pkg/runtime.v2/core/trainingruntime.go +++ b/pkg/runtime.v2/core/trainingruntime.go @@ -24,8 +24,10 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" apiruntime "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/util/validation/field" "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" @@ -119,3 +121,16 @@ func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder { } return builders } + +func (r *TrainingRuntime) ValidateObjects(ctx context.Context, old, new *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) { + if err := r.client.Get(ctx, client.ObjectKey{ + Namespace: old.Namespace, + Name: old.Spec.TrainingRuntimeRef.Name, + }, &kubeflowv2.TrainingRuntime{}); err != nil { + return nil, field.ErrorList{ + field.Invalid(field.NewPath("spec", "trainingRuntimeRef"), old.Spec.TrainingRuntimeRef, + fmt.Sprintf("%v: specified trainingRuntime must be created before the TrainJob is created", err)), + } + } + return r.framework.RunCustomValidationPlugins(old, new) +} diff --git a/pkg/runtime.v2/framework/core/framework.go b/pkg/runtime.v2/framework/core/framework.go index cb1b23d42e..ae9f007493 100644 --- a/pkg/runtime.v2/framework/core/framework.go +++ b/pkg/runtime.v2/framework/core/framework.go @@ -89,7 +89,7 @@ func (f *Framework) RunEnforcePodGroupPolicyPlugins(trainJob *kubeflowv2.TrainJo return nil } -func (f *Framework) RunCustomValidationPlugins(oldObj, newObj client.Object) (admission.Warnings, error) { +func (f *Framework) RunCustomValidationPlugins(oldObj, newObj client.Object) (admission.Warnings, field.ErrorList) { var aggregatedWarnings admission.Warnings var aggregatedErrors field.ErrorList for _, plugin := range f.customValidationPlugins { @@ -104,7 +104,7 @@ func (f *Framework) RunCustomValidationPlugins(oldObj, newObj client.Object) (ad if len(aggregatedErrors) == 0 { return aggregatedWarnings, nil } - return aggregatedWarnings, aggregatedErrors.ToAggregate() + return aggregatedWarnings, aggregatedErrors } func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, info *runtime.Info, trainJob *kubeflowv2.TrainJob) ([]client.Object, error) { diff --git a/pkg/runtime.v2/framework/core/framework_test.go b/pkg/runtime.v2/framework/core/framework_test.go index 141d7995fb..4f5085781f 100644 --- a/pkg/runtime.v2/framework/core/framework_test.go +++ b/pkg/runtime.v2/framework/core/framework_test.go @@ -28,6 +28,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" apiruntime "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/validation/field" "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" @@ -291,7 +292,7 @@ func TestRunCustomValidationPlugins(t *testing.T) { oldObj client.Object newObj client.Object wantWarnings admission.Warnings - wantError error + wantError field.ErrorList }{ // Need to implement more detail testing after we implement custom validator in any plugins. "there are not any custom validations": { @@ -300,7 +301,7 @@ func TestRunCustomValidationPlugins(t *testing.T) { oldObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(), newObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(), }, - "an empty registry": { + "an empty registry": { trainJob: &kubeflowv2.TrainJob{ObjectMeta: metav1.ObjectMeta{Name: "test-job", Namespace: metav1.NamespaceDefault}}, oldObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(), newObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(), @@ -316,11 +317,11 @@ func TestRunCustomValidationPlugins(t *testing.T) { if err != nil { t.Fatal(err) } - warnings, err := fwk.RunCustomValidationPlugins(tc.oldObj, tc.newObj) + warnings, errs := fwk.RunCustomValidationPlugins(tc.oldObj, tc.newObj) if diff := cmp.Diff(tc.wantWarnings, warnings, cmpopts.SortSlices(func(a, b string) bool { return a < b })); len(diff) != 0 { t.Errorf("Unexpected warninigs (-want,+got):\n%s", diff) } - if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + if diff := cmp.Diff(tc.wantError, errs, cmpopts.IgnoreFields(field.Error{}, "Detail", "BadValue")); len(diff) != 0 { t.Errorf("Unexpected error (-want,+got):\n%s", diff) } }) diff --git a/pkg/runtime.v2/interface.go b/pkg/runtime.v2/interface.go index d7b84e3f46..8c735ad4f1 100644 --- a/pkg/runtime.v2/interface.go +++ b/pkg/runtime.v2/interface.go @@ -19,8 +19,10 @@ package runtimev2 import ( "context" + "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" ) @@ -30,4 +32,5 @@ type ReconcilerBuilder func(*builder.Builder, client.Client) *builder.Builder type Runtime interface { NewObjects(ctx context.Context, trainJob *kubeflowv2.TrainJob) ([]client.Object, error) EventHandlerRegistrars() []ReconcilerBuilder + ValidateObjects(ctx context.Context, old, new *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) } From a527f868312d252a82faa56b7f14a3273e6ce7e4 Mon Sep 17 00:00:00 2001 From: Yuki Iwai Date: Fri, 11 Oct 2024 04:12:54 +0900 Subject: [PATCH 04/12] KEP-2170: Expose the TrainingRuntime and ClusterTrainingRuntime Kind Signed-off-by: Yuki Iwai --- pkg/apis/kubeflow.org/v2alpha1/trainingruntime_types.go | 7 +++++++ pkg/runtime.v2/core/clustertrainingruntime.go | 2 +- pkg/runtime.v2/core/clustertrainingruntime_test.go | 4 ++-- pkg/runtime.v2/core/trainingruntime.go | 2 +- pkg/runtime.v2/core/trainingruntime_test.go | 4 ++-- pkg/runtime.v2/indexer/indexer.go | 2 +- pkg/util.v2/testing/wrapper.go | 4 ++-- pkg/webhook.v2/setup.go | 5 +++-- 8 files changed, 19 insertions(+), 11 deletions(-) diff --git a/pkg/apis/kubeflow.org/v2alpha1/trainingruntime_types.go b/pkg/apis/kubeflow.org/v2alpha1/trainingruntime_types.go index b63363be7b..318d22be0d 100644 --- a/pkg/apis/kubeflow.org/v2alpha1/trainingruntime_types.go +++ b/pkg/apis/kubeflow.org/v2alpha1/trainingruntime_types.go @@ -22,6 +22,13 @@ import ( jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" ) +const ( + // TrainingRuntimeKind is the GroupVersionKind Kind name for the TrainingRuntime. + TrainingRuntimeKind string = "TrainingRuntime" + // ClusterTrainingRuntimeKind is the GroupVersionKind Kind name for the ClusterTrainingRuntime. + ClusterTrainingRuntimeKind string = "ClusterTrainingRuntime" +) + // +genclient // +genclient:nonNamespaced // +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object diff --git a/pkg/runtime.v2/core/clustertrainingruntime.go b/pkg/runtime.v2/core/clustertrainingruntime.go index ecbd36b255..de819363b5 100644 --- a/pkg/runtime.v2/core/clustertrainingruntime.go +++ b/pkg/runtime.v2/core/clustertrainingruntime.go @@ -42,7 +42,7 @@ var _ runtime.Runtime = (*ClusterTrainingRuntime)(nil) var ClusterTrainingRuntimeGroupKind = schema.GroupKind{ Group: kubeflowv2.GroupVersion.Group, - Kind: "ClusterTrainingRuntime", + Kind: kubeflowv2.ClusterTrainingRuntimeKind, }.String() func NewClusterTrainingRuntime(context.Context, client.Client, client.FieldIndexer) (runtime.Runtime, error) { diff --git a/pkg/runtime.v2/core/clustertrainingruntime_test.go b/pkg/runtime.v2/core/clustertrainingruntime_test.go index 23697b748c..1831d2ec62 100644 --- a/pkg/runtime.v2/core/clustertrainingruntime_test.go +++ b/pkg/runtime.v2/core/clustertrainingruntime_test.go @@ -47,7 +47,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { "succeeded to build JobSet and PodGroup": { trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). UID("uid"). - TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind("ClusterTrainingRuntime"), "test-runtime"). + TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "test-runtime"). Trainer( testingutil.MakeTrainJobTrainerWrapper(t). ContainerImage("test:trainjob"). @@ -93,7 +93,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { "missing trainingRuntime resource": { trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). UID("uid"). - TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind("ClusterTrainingRuntime"), "test-runtime"). + TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "test-runtime"). Trainer( testingutil.MakeTrainJobTrainerWrapper(t). ContainerImage("test:trainjob"). diff --git a/pkg/runtime.v2/core/trainingruntime.go b/pkg/runtime.v2/core/trainingruntime.go index 32659f4c0b..179b7c2835 100644 --- a/pkg/runtime.v2/core/trainingruntime.go +++ b/pkg/runtime.v2/core/trainingruntime.go @@ -49,7 +49,7 @@ type TrainingRuntime struct { var TrainingRuntimeGroupKind = schema.GroupKind{ Group: kubeflowv2.GroupVersion.Group, - Kind: "TrainingRuntime", + Kind: kubeflowv2.TrainingRuntimeKind, }.String() var _ runtime.Runtime = (*TrainingRuntime)(nil) diff --git a/pkg/runtime.v2/core/trainingruntime_test.go b/pkg/runtime.v2/core/trainingruntime_test.go index 244fa88128..6a08415fbf 100644 --- a/pkg/runtime.v2/core/trainingruntime_test.go +++ b/pkg/runtime.v2/core/trainingruntime_test.go @@ -47,7 +47,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { "succeeded to build JobSet and PodGroup": { trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). UID("uid"). - TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind("TrainingRuntime"), "test-runtime"). + TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "test-runtime"). Trainer( testingutil.MakeTrainJobTrainerWrapper(t). ContainerImage("test:trainjob"). @@ -93,7 +93,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { "missing trainingRuntime resource": { trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). UID("uid"). - TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind("TrainingRuntime"), "test-runtime"). + TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "test-runtime"). Trainer( testingutil.MakeTrainJobTrainerWrapper(t). ContainerImage("test:trainjob"). diff --git a/pkg/runtime.v2/indexer/indexer.go b/pkg/runtime.v2/indexer/indexer.go index 9ba2c057f7..1aac8a4132 100644 --- a/pkg/runtime.v2/indexer/indexer.go +++ b/pkg/runtime.v2/indexer/indexer.go @@ -38,7 +38,7 @@ func IndexTrainJobTrainingRuntimes(obj client.Object) []string { Kind: ptr.Deref(trainJob.Spec.TrainingRuntimeRef.Kind, ""), } if runtimeRefGroupKind.Group == kubeflowv2.GroupVersion.Group && - (runtimeRefGroupKind.Kind == "TrainingRuntime" || runtimeRefGroupKind.Kind == "ClusterTrainingRuntime") { + (runtimeRefGroupKind.Kind == kubeflowv2.TrainingRuntimeKind || runtimeRefGroupKind.Kind == kubeflowv2.ClusterTrainingRuntimeKind) { return []string{trainJob.Spec.TrainingRuntimeRef.Name} } return nil diff --git a/pkg/util.v2/testing/wrapper.go b/pkg/util.v2/testing/wrapper.go index 8cadbb7e00..c7b7d06061 100644 --- a/pkg/util.v2/testing/wrapper.go +++ b/pkg/util.v2/testing/wrapper.go @@ -238,7 +238,7 @@ func MakeTrainingRuntimeWrapper(t *testing.T, namespace, name string) *TrainingR TrainingRuntime: kubeflowv2.TrainingRuntime{ TypeMeta: metav1.TypeMeta{ APIVersion: kubeflowv2.SchemeGroupVersion.String(), - Kind: "TrainingRuntime", + Kind: kubeflowv2.TrainingRuntimeKind, }, ObjectMeta: metav1.ObjectMeta{ Namespace: namespace, @@ -309,7 +309,7 @@ func MakeClusterTrainingRuntimeWrapper(t *testing.T, name string) *ClusterTraini ClusterTrainingRuntime: kubeflowv2.ClusterTrainingRuntime{ TypeMeta: metav1.TypeMeta{ APIVersion: kubeflowv2.SchemeGroupVersion.String(), - Kind: "ClusterTrainingRuntime", + Kind: kubeflowv2.ClusterTrainingRuntimeKind, }, ObjectMeta: metav1.ObjectMeta{ Name: name, diff --git a/pkg/webhook.v2/setup.go b/pkg/webhook.v2/setup.go index a9109f2ffd..6e7c7f290e 100644 --- a/pkg/webhook.v2/setup.go +++ b/pkg/webhook.v2/setup.go @@ -19,15 +19,16 @@ package webhookv2 import ( ctrl "sigs.k8s.io/controller-runtime" + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" ) func Setup(mgr ctrl.Manager, runtimes map[string]runtime.Runtime) (string, error) { if err := setupWebhookForClusterTrainingRuntime(mgr, runtimes); err != nil { - return "ClusterTrainingRuntime", err + return kubeflowv2.ClusterTrainingRuntimeKind, err } if err := setupWebhookForTrainingRuntime(mgr, runtimes); err != nil { - return "TrainingRuntime", err + return kubeflowv2.TrainingRuntimeKind, err } if err := setupWebhookForTrainJob(mgr, runtimes); err != nil { return "TrainJob", err From 2aaae2b41e81160a9295f28f185c2ef51bf892ec Mon Sep 17 00:00:00 2001 From: Yuki Iwai Date: Fri, 11 Oct 2024 04:15:20 +0900 Subject: [PATCH 05/12] KEP-2170: Remove unneeded scheme field from the internal TrainingRuntime Signed-off-by: Yuki Iwai --- pkg/runtime.v2/core/trainingruntime.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/pkg/runtime.v2/core/trainingruntime.go b/pkg/runtime.v2/core/trainingruntime.go index 179b7c2835..331380c312 100644 --- a/pkg/runtime.v2/core/trainingruntime.go +++ b/pkg/runtime.v2/core/trainingruntime.go @@ -22,7 +22,6 @@ import ( "fmt" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - apiruntime "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/validation/field" "k8s.io/utils/ptr" @@ -44,7 +43,6 @@ var ( type TrainingRuntime struct { framework *fwkcore.Framework client client.Client - scheme *apiruntime.Scheme } var TrainingRuntimeGroupKind = schema.GroupKind{ @@ -67,7 +65,6 @@ func NewTrainingRuntime(ctx context.Context, c client.Client, indexer client.Fie trainingRuntimeFactory = &TrainingRuntime{ framework: fwk, client: c, - scheme: c.Scheme(), } return trainingRuntimeFactory, nil } From 1c88c85a2df5cbdb2af96262b364668d392a7264 Mon Sep 17 00:00:00 2001 From: Yuki Iwai Date: Wed, 16 Oct 2024 22:25:45 +0900 Subject: [PATCH 06/12] Rephrase the error message Signed-off-by: Yuki Iwai --- pkg/runtime.v2/core/trainingruntime.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/runtime.v2/core/trainingruntime.go b/pkg/runtime.v2/core/trainingruntime.go index 331380c312..4267530964 100644 --- a/pkg/runtime.v2/core/trainingruntime.go +++ b/pkg/runtime.v2/core/trainingruntime.go @@ -37,7 +37,7 @@ import ( ) var ( - errorNotFoundSpecifiedTrainingRuntime = errors.New("not found TrainingRuntime specified in TrainJob") + errorNotFoundSpecifiedTrainingRuntime = errors.New("TrainingRuntime specified in TrainJob is not found") ) type TrainingRuntime struct { From 1c94ede05828ce1f7c17be0440e76956ae3187fe Mon Sep 17 00:00:00 2001 From: Yuki Iwai Date: Wed, 16 Oct 2024 22:50:09 +0900 Subject: [PATCH 07/12] Distinguish TrainingRuntime and ClusterTrainingRuntime when creating indexes for the TrainJobs Signed-off-by: Yuki Iwai --- pkg/runtime.v2/core/trainingruntime.go | 7 +++-- .../plugins/coscheduling/coscheduling.go | 30 ++++++++++++------- pkg/runtime.v2/indexer/indexer.go | 24 ++++++++++----- 3 files changed, 40 insertions(+), 21 deletions(-) diff --git a/pkg/runtime.v2/core/trainingruntime.go b/pkg/runtime.v2/core/trainingruntime.go index 4267530964..00c4152b8f 100644 --- a/pkg/runtime.v2/core/trainingruntime.go +++ b/pkg/runtime.v2/core/trainingruntime.go @@ -55,8 +55,11 @@ var _ runtime.Runtime = (*TrainingRuntime)(nil) var trainingRuntimeFactory *TrainingRuntime func NewTrainingRuntime(ctx context.Context, c client.Client, indexer client.FieldIndexer) (runtime.Runtime, error) { - if err := indexer.IndexField(ctx, &kubeflowv2.TrainJob{}, idxer.TrainJobTrainingRuntimeRefKey, idxer.IndexTrainJobTrainingRuntimes); err != nil { - return nil, fmt.Errorf("setting index on TrainingRuntime and ClusterTrainigRuntime for TrainJob: %w", err) + if err := indexer.IndexField(ctx, &kubeflowv2.TrainJob{}, idxer.TrainJobTrainingRuntimeRefKey, idxer.IndexTrainJobTrainingRuntime); err != nil { + return nil, fmt.Errorf("setting index on TrainingRuntime for TrainJob: %w", err) + } + if err := indexer.IndexField(ctx, &kubeflowv2.TrainJob{}, idxer.TrainJobClusterTrainingRuntimeRefKey, idxer.IndexTrainJobClusterTrainingRuntime); err != nil { + return nil, fmt.Errorf("setting index on ClusterTrainingRuntime for TrainJob: %w", err) } fwk, err := fwkcore.New(ctx, c, fwkplugins.NewRegistry(), indexer) if err != nil { diff --git a/pkg/runtime.v2/framework/plugins/coscheduling/coscheduling.go b/pkg/runtime.v2/framework/plugins/coscheduling/coscheduling.go index 721a756dae..85f08a8366 100644 --- a/pkg/runtime.v2/framework/plugins/coscheduling/coscheduling.go +++ b/pkg/runtime.v2/framework/plugins/coscheduling/coscheduling.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "maps" + "slices" corev1 "k8s.io/api/core/v1" nodev1 "k8s.io/api/node/v1" @@ -203,22 +204,29 @@ func (h *PodGroupRuntimeClassHandler) queueSuspendedTrainJob(ctx context.Context return err } - var runtimeNames []string + var trainJobs []kubeflowv2.TrainJob for _, trainingRuntime := range trainingRuntimes.Items { - runtimeNames = append(runtimeNames, trainingRuntime.Name) + var trainJobsWithTrainingRuntime kubeflowv2.TrainJobList + err := h.client.List(ctx, &trainJobsWithTrainingRuntime, client.MatchingFields{runtimeindexer.TrainJobTrainingRuntimeRefKey: trainingRuntime.Name}) + if err != nil { + return err + } + trainJobs = append(trainJobs, trainJobsWithTrainingRuntime.Items...) } for _, clusterTrainingRuntime := range clusterTrainingRuntimes.Items { - runtimeNames = append(runtimeNames, clusterTrainingRuntime.Name) - } - for _, runtimeName := range runtimeNames { - var trainJobs kubeflowv2.TrainJobList - if err := h.client.List(ctx, &trainJobs, client.MatchingFields{runtimeindexer.TrainJobTrainingRuntimeRefKey: runtimeName}); err != nil { + var trainJobsWithClTrainingRuntime kubeflowv2.TrainJobList + err := h.client.List(ctx, &trainJobsWithClTrainingRuntime, client.MatchingFields{runtimeindexer.TrainJobClusterTrainingRuntimeRefKey: clusterTrainingRuntime.Name}) + if err != nil { return err } - for _, trainJob := range trainJobs.Items { - if ptr.Deref(trainJob.Spec.Suspend, false) { - q.Add(client.ObjectKeyFromObject(&trainJob)) - } + trainJobs = append(trainJobs, trainJobsWithClTrainingRuntime.Items...) + } + trainJobs = slices.CompactFunc(trainJobs, func(a, b kubeflowv2.TrainJob) bool { + return a.Name == b.Name + }) + for _, trainJob := range trainJobs { + if ptr.Deref(trainJob.Spec.Suspend, false) { + q.Add(client.ObjectKeyFromObject(&trainJob)) } } return nil diff --git a/pkg/runtime.v2/indexer/indexer.go b/pkg/runtime.v2/indexer/indexer.go index 1aac8a4132..dacbfcd050 100644 --- a/pkg/runtime.v2/indexer/indexer.go +++ b/pkg/runtime.v2/indexer/indexer.go @@ -17,7 +17,6 @@ limitations under the License. package indexer import ( - "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" @@ -25,20 +24,29 @@ import ( ) const ( - TrainJobTrainingRuntimeRefKey = ".spec.trainingRuntimeRef" + TrainJobTrainingRuntimeRefKey = ".spec.trainingRuntimeRef.kind=trainingRuntime" + TrainJobClusterTrainingRuntimeRefKey = ".spec.trainingRuntimeRef.kind=clusterTrainingRuntime" ) -func IndexTrainJobTrainingRuntimes(obj client.Object) []string { +func IndexTrainJobTrainingRuntime(obj client.Object) []string { trainJob, ok := obj.(*kubeflowv2.TrainJob) if !ok { return nil } - runtimeRefGroupKind := schema.GroupKind{ - Group: ptr.Deref(trainJob.Spec.TrainingRuntimeRef.APIGroup, ""), - Kind: ptr.Deref(trainJob.Spec.TrainingRuntimeRef.Kind, ""), + if ptr.Deref(trainJob.Spec.TrainingRuntimeRef.APIGroup, "") == kubeflowv2.GroupVersion.Group && + ptr.Deref(trainJob.Spec.TrainingRuntimeRef.Kind, "") == kubeflowv2.TrainingRuntimeKind { + return []string{trainJob.Spec.TrainingRuntimeRef.Name} + } + return nil +} + +func IndexTrainJobClusterTrainingRuntime(obj client.Object) []string { + trainJob, ok := obj.(*kubeflowv2.TrainJob) + if !ok { + return nil } - if runtimeRefGroupKind.Group == kubeflowv2.GroupVersion.Group && - (runtimeRefGroupKind.Kind == kubeflowv2.TrainingRuntimeKind || runtimeRefGroupKind.Kind == kubeflowv2.ClusterTrainingRuntimeKind) { + if ptr.Deref(trainJob.Spec.TrainingRuntimeRef.APIGroup, "") == kubeflowv2.GroupVersion.Group && + ptr.Deref(trainJob.Spec.TrainingRuntimeRef.Kind, "") == kubeflowv2.ClusterTrainingRuntimeKind { return []string{trainJob.Spec.TrainingRuntimeRef.Name} } return nil From 338cf6efc355676ce87344c9586b82473dbb5a4b Mon Sep 17 00:00:00 2001 From: Yuki Iwai Date: Thu, 17 Oct 2024 02:58:42 +0900 Subject: [PATCH 08/12] Propagate the TrainJob labels and annotations to the JobSet Signed-off-by: Yuki Iwai --- pkg/runtime.v2/core/trainingruntime.go | 25 +++++++++-- pkg/runtime.v2/core/trainingruntime_test.go | 33 ++++++++------ pkg/util.v2/testing/wrapper.go | 48 +++++++++++++++++++++ 3 files changed, 89 insertions(+), 17 deletions(-) diff --git a/pkg/runtime.v2/core/trainingruntime.go b/pkg/runtime.v2/core/trainingruntime.go index 00c4152b8f..41dea87105 100644 --- a/pkg/runtime.v2/core/trainingruntime.go +++ b/pkg/runtime.v2/core/trainingruntime.go @@ -81,11 +81,28 @@ func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *kubeflowv2.T return r.buildObjects(ctx, trainJob, trainingRuntime.Spec.Template, trainingRuntime.Spec.MLPolicy, trainingRuntime.Spec.PodGroupPolicy) } -func (r *TrainingRuntime) buildObjects(ctx context.Context, trainJob *kubeflowv2.TrainJob, jobSetTemplateSpec kubeflowv2.JobSetTemplateSpec, - mlPolicy *kubeflowv2.MLPolicy, podGroupPolicy *kubeflowv2.PodGroupPolicy) ([]client.Object, error) { +func (r *TrainingRuntime) buildObjects( + ctx context.Context, trainJob *kubeflowv2.TrainJob, jobSetTemplateSpec kubeflowv2.JobSetTemplateSpec, mlPolicy *kubeflowv2.MLPolicy, podGroupPolicy *kubeflowv2.PodGroupPolicy, +) ([]client.Object, error) { + propagationLabels := jobSetTemplateSpec.Labels + if propagationLabels == nil && trainJob.Spec.Labels != nil { + propagationLabels = make(map[string]string, len(trainJob.Spec.Labels)) + } + for k, v := range trainJob.Spec.Labels { + // The JobSetTemplateSpec labels are overridden by the TrainJob Labels (.spec.labels). + propagationLabels[k] = v + } + propagationAnnotations := jobSetTemplateSpec.Annotations + if propagationAnnotations == nil && trainJob.Spec.Annotations != nil { + propagationAnnotations = make(map[string]string, len(trainJob.Spec.Annotations)) + } + for k, v := range trainJob.Spec.Annotations { + // The JobSetTemplateSpec annotations are overridden by the TrainJob Annotations (.spec.annotations). + propagationAnnotations[k] = v + } opts := []runtime.InfoOption{ - runtime.WithLabels(jobSetTemplateSpec.Labels), - runtime.WithAnnotations(jobSetTemplateSpec.Annotations), + runtime.WithLabels(propagationLabels), + runtime.WithAnnotations(propagationAnnotations), runtime.WithMLPolicy(mlPolicy), runtime.WithPodGroupPolicy(podGroupPolicy), } diff --git a/pkg/runtime.v2/core/trainingruntime_test.go b/pkg/runtime.v2/core/trainingruntime_test.go index 6a08415fbf..ca31d8f133 100644 --- a/pkg/runtime.v2/core/trainingruntime_test.go +++ b/pkg/runtime.v2/core/trainingruntime_test.go @@ -48,27 +48,34 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). UID("uid"). TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "test-runtime"). + SpecLabel("conflictLabel", "override"). + SpecAnnotation("conflictAnnotation", "override"). Trainer( testingutil.MakeTrainJobTrainerWrapper(t). ContainerImage("test:trainjob"). Obj(), ). Obj(), - trainingRuntime: baseRuntime.RuntimeSpec( - testingutil.MakeTrainingRuntimeSpecWrapper(t, baseRuntime.Spec). - ContainerImage("test:runtime"). - PodGroupPolicySchedulingTimeout(120). - MLPolicyNumNodes(20). - ResourceRequests(0, corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("1"), - }). - ResourceRequests(1, corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("2"), - }). - Obj(), - ).Obj(), + trainingRuntime: baseRuntime. + Label("conflictLabel", "overridden"). + Annotation("conflictAnnotation", "overridden"). + RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper(t, baseRuntime.Spec). + ContainerImage("test:runtime"). + PodGroupPolicySchedulingTimeout(120). + MLPolicyNumNodes(20). + ResourceRequests(0, corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + }). + ResourceRequests(1, corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2"), + }). + Obj(), + ).Obj(), wantObjs: []client.Object{ testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job"). + Label("conflictLabel", "override"). + Annotation("conflictAnnotation", "override"). PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job"). ContainerImage(ptr.To("test:trainjob")). JobCompletionMode(batchv1.IndexedCompletion). diff --git a/pkg/util.v2/testing/wrapper.go b/pkg/util.v2/testing/wrapper.go index c7b7d06061..10441c040b 100644 --- a/pkg/util.v2/testing/wrapper.go +++ b/pkg/util.v2/testing/wrapper.go @@ -154,6 +154,22 @@ func (j *JobSetWrapper) PodLabel(key, value string) *JobSetWrapper { return j } +func (j *JobSetWrapper) Label(key, value string) *JobSetWrapper { + if j.ObjectMeta.Labels == nil { + j.ObjectMeta.Labels = make(map[string]string, 1) + } + j.ObjectMeta.Labels[key] = value + return j +} + +func (j *JobSetWrapper) Annotation(key, value string) *JobSetWrapper { + if j.ObjectMeta.Annotations == nil { + j.ObjectMeta.Annotations = make(map[string]string, 1) + } + j.ObjectMeta.Annotations[key] = value + return j +} + func (j *JobSetWrapper) Clone() *JobSetWrapper { return &JobSetWrapper{ JobSet: *j.JobSet.DeepCopy(), @@ -190,6 +206,22 @@ func (t *TrainJobWrapper) UID(uid string) *TrainJobWrapper { return t } +func (t *TrainJobWrapper) SpecLabel(key, value string) *TrainJobWrapper { + if t.Spec.Labels == nil { + t.Spec.Labels = make(map[string]string, 1) + } + t.Spec.Labels[key] = value + return t +} + +func (t *TrainJobWrapper) SpecAnnotation(key, value string) *TrainJobWrapper { + if t.Spec.Annotations == nil { + t.Spec.Annotations = make(map[string]string, 1) + } + t.Spec.Annotations[key] = value + return t +} + type TrainJobTrainerWrapper struct { kubeflowv2.Trainer } @@ -289,6 +321,22 @@ func (r *TrainingRuntimeWrapper) RuntimeSpec(spec kubeflowv2.TrainingRuntimeSpec return r } +func (r *TrainingRuntimeWrapper) Label(key, value string) *TrainingRuntimeWrapper { + if r.ObjectMeta.Labels == nil { + r.ObjectMeta.Labels = make(map[string]string, 1) + } + r.ObjectMeta.Labels[key] = value + return r +} + +func (r *TrainingRuntimeWrapper) Annotation(key, value string) *TrainingRuntimeWrapper { + if r.ObjectMeta.Annotations == nil { + r.ObjectMeta.Annotations = make(map[string]string, 1) + } + r.ObjectMeta.Annotations[key] = value + return r +} + func (r *TrainingRuntimeWrapper) Clone() *TrainingRuntimeWrapper { return &TrainingRuntimeWrapper{ TrainingRuntime: *r.TrainingRuntime.DeepCopy(), From 498ff28660d2af385913c9e511b4aeed805e05b2 Mon Sep 17 00:00:00 2001 From: Yuki Iwai Date: Thu, 17 Oct 2024 03:01:13 +0900 Subject: [PATCH 09/12] Remove PodAnnotations from the runtime info Signed-off-by: Yuki Iwai --- pkg/runtime.v2/runtime.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/pkg/runtime.v2/runtime.go b/pkg/runtime.v2/runtime.go index eadec3a523..de816f5dbe 100644 --- a/pkg/runtime.v2/runtime.go +++ b/pkg/runtime.v2/runtime.go @@ -33,11 +33,10 @@ var ( ) type Info struct { - Obj client.Object - Labels map[string]string - PodLabels map[string]string - Annotations map[string]string - PodAnnotations map[string]string + Obj client.Object + Labels map[string]string + PodLabels map[string]string + Annotations map[string]string Policy TotalRequests map[string]TotalResourceRequest } From ba81beb7f1ac828a95ecac70d42b471dc68f80b8 Mon Sep 17 00:00:00 2001 From: Yuki Iwai Date: Thu, 17 Oct 2024 04:35:20 +0900 Subject: [PATCH 10/12] Implement TrainingRuntime ReplicatedJob validation Signed-off-by: Yuki Iwai --- .../core/clustertrainingruntime_test.go | 16 ++--- pkg/runtime.v2/core/trainingruntime.go | 3 - pkg/runtime.v2/core/trainingruntime_test.go | 16 ++--- pkg/runtime.v2/framework/core/framework.go | 5 +- .../framework/core/framework_test.go | 23 ++++--- pkg/runtime.v2/framework/interface.go | 2 +- pkg/runtime.v2/framework/plugins/mpi/mpi.go | 3 +- .../framework/plugins/torch/torch.go | 3 +- pkg/runtime.v2/runtime_test.go | 6 +- pkg/util.v2/testing/wrapper.go | 49 ++++++++------- .../clustertrainingruntime_webhook.go | 8 ++- pkg/webhook.v2/trainingruntime_webhook.go | 24 +++++++- .../trainingruntime_webhook_test.go | 60 +++++++++++++++++++ .../webhook.v2/clustertrainingruntime_test.go | 25 ++++++++ .../webhook.v2/trainingruntime_test.go | 26 ++++++++ 15 files changed, 203 insertions(+), 66 deletions(-) create mode 100644 pkg/webhook.v2/trainingruntime_webhook_test.go diff --git a/pkg/runtime.v2/core/clustertrainingruntime_test.go b/pkg/runtime.v2/core/clustertrainingruntime_test.go index 1831d2ec62..5665c10fe5 100644 --- a/pkg/runtime.v2/core/clustertrainingruntime_test.go +++ b/pkg/runtime.v2/core/clustertrainingruntime_test.go @@ -35,7 +35,7 @@ import ( ) func TestClusterTrainingRuntimeNewObjects(t *testing.T) { - baseRuntime := testingutil.MakeClusterTrainingRuntimeWrapper(t, "test-runtime"). + baseRuntime := testingutil.MakeClusterTrainingRuntimeWrapper("test-runtime"). Clone() cases := map[string]struct { @@ -45,17 +45,17 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { wantError error }{ "succeeded to build JobSet and PodGroup": { - trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). UID("uid"). TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "test-runtime"). Trainer( - testingutil.MakeTrainJobTrainerWrapper(t). + testingutil.MakeTrainJobTrainerWrapper(). ContainerImage("test:trainjob"). Obj(), ). Obj(), clusterTrainingRuntime: baseRuntime.RuntimeSpec( - testingutil.MakeTrainingRuntimeSpecWrapper(t, baseRuntime.Spec). + testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec). ContainerImage("test:runtime"). PodGroupPolicySchedulingTimeout(120). MLPolicyNumNodes(20). @@ -68,7 +68,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { Obj(), ).Obj(), wantObjs: []client.Object{ - testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job"). + testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job"). ContainerImage(ptr.To("test:trainjob")). JobCompletionMode(batchv1.IndexedCompletion). @@ -80,7 +80,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { }). ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid"). Obj(), - testingutil.MakeSchedulerPluginsPodGroup(t, metav1.NamespaceDefault, "test-job"). + testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job"). ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid"). MinMember(40). SchedulingTimeout(120). @@ -91,11 +91,11 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { }, }, "missing trainingRuntime resource": { - trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). UID("uid"). TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "test-runtime"). Trainer( - testingutil.MakeTrainJobTrainerWrapper(t). + testingutil.MakeTrainJobTrainerWrapper(). ContainerImage("test:trainjob"). Obj(), ). diff --git a/pkg/runtime.v2/core/trainingruntime.go b/pkg/runtime.v2/core/trainingruntime.go index 41dea87105..1597bbf0a8 100644 --- a/pkg/runtime.v2/core/trainingruntime.go +++ b/pkg/runtime.v2/core/trainingruntime.go @@ -107,9 +107,6 @@ func (r *TrainingRuntime) buildObjects( runtime.WithPodGroupPolicy(podGroupPolicy), } for idx, rJob := range jobSetTemplateSpec.Spec.ReplicatedJobs { - if rJob.Replicas == 0 { - jobSetTemplateSpec.Spec.ReplicatedJobs[idx].Replicas = 1 - } replicas := jobSetTemplateSpec.Spec.ReplicatedJobs[idx].Replicas * ptr.Deref(rJob.Template.Spec.Completions, 1) opts = append(opts, runtime.WithPodSpecReplicas(rJob.Name, replicas, rJob.Template.Spec.Template.Spec)) } diff --git a/pkg/runtime.v2/core/trainingruntime_test.go b/pkg/runtime.v2/core/trainingruntime_test.go index ca31d8f133..a3bd63efa6 100644 --- a/pkg/runtime.v2/core/trainingruntime_test.go +++ b/pkg/runtime.v2/core/trainingruntime_test.go @@ -35,7 +35,7 @@ import ( ) func TestTrainingRuntimeNewObjects(t *testing.T) { - baseRuntime := testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test-runtime"). + baseRuntime := testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime"). Clone() cases := map[string]struct { @@ -45,13 +45,13 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { wantError error }{ "succeeded to build JobSet and PodGroup": { - trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). UID("uid"). TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "test-runtime"). SpecLabel("conflictLabel", "override"). SpecAnnotation("conflictAnnotation", "override"). Trainer( - testingutil.MakeTrainJobTrainerWrapper(t). + testingutil.MakeTrainJobTrainerWrapper(). ContainerImage("test:trainjob"). Obj(), ). @@ -60,7 +60,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { Label("conflictLabel", "overridden"). Annotation("conflictAnnotation", "overridden"). RuntimeSpec( - testingutil.MakeTrainingRuntimeSpecWrapper(t, baseRuntime.Spec). + testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec). ContainerImage("test:runtime"). PodGroupPolicySchedulingTimeout(120). MLPolicyNumNodes(20). @@ -73,7 +73,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { Obj(), ).Obj(), wantObjs: []client.Object{ - testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job"). + testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). Label("conflictLabel", "override"). Annotation("conflictAnnotation", "override"). PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job"). @@ -87,7 +87,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { }). ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid"). Obj(), - testingutil.MakeSchedulerPluginsPodGroup(t, metav1.NamespaceDefault, "test-job"). + testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job"). ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid"). MinMember(40). SchedulingTimeout(120). @@ -98,11 +98,11 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { }, }, "missing trainingRuntime resource": { - trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). UID("uid"). TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "test-runtime"). Trainer( - testingutil.MakeTrainJobTrainerWrapper(t). + testingutil.MakeTrainJobTrainerWrapper(). ContainerImage("test:trainjob"). Obj(), ). diff --git a/pkg/runtime.v2/framework/core/framework.go b/pkg/runtime.v2/framework/core/framework.go index ae9f007493..8997afe467 100644 --- a/pkg/runtime.v2/framework/core/framework.go +++ b/pkg/runtime.v2/framework/core/framework.go @@ -89,7 +89,7 @@ func (f *Framework) RunEnforcePodGroupPolicyPlugins(trainJob *kubeflowv2.TrainJo return nil } -func (f *Framework) RunCustomValidationPlugins(oldObj, newObj client.Object) (admission.Warnings, field.ErrorList) { +func (f *Framework) RunCustomValidationPlugins(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) { var aggregatedWarnings admission.Warnings var aggregatedErrors field.ErrorList for _, plugin := range f.customValidationPlugins { @@ -101,9 +101,6 @@ func (f *Framework) RunCustomValidationPlugins(oldObj, newObj client.Object) (ad aggregatedErrors = append(aggregatedErrors, errs...) } } - if len(aggregatedErrors) == 0 { - return aggregatedWarnings, nil - } return aggregatedWarnings, aggregatedErrors } diff --git a/pkg/runtime.v2/framework/core/framework_test.go b/pkg/runtime.v2/framework/core/framework_test.go index 4f5085781f..c3b630d923 100644 --- a/pkg/runtime.v2/framework/core/framework_test.go +++ b/pkg/runtime.v2/framework/core/framework_test.go @@ -287,24 +287,21 @@ func TestRunEnforcePodGroupPolicyPlugins(t *testing.T) { func TestRunCustomValidationPlugins(t *testing.T) { cases := map[string]struct { - trainJob *kubeflowv2.TrainJob registry fwkplugins.Registry - oldObj client.Object - newObj client.Object + oldObj *kubeflowv2.TrainJob + newObj *kubeflowv2.TrainJob wantWarnings admission.Warnings wantError field.ErrorList }{ // Need to implement more detail testing after we implement custom validator in any plugins. "there are not any custom validations": { - trainJob: &kubeflowv2.TrainJob{ObjectMeta: metav1.ObjectMeta{Name: "test-job", Namespace: metav1.NamespaceDefault}}, registry: fwkplugins.NewRegistry(), - oldObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(), - newObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(), + oldObj: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test").Obj(), + newObj: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test").Obj(), }, "an empty registry": { - trainJob: &kubeflowv2.TrainJob{ObjectMeta: metav1.ObjectMeta{Name: "test-job", Namespace: metav1.NamespaceDefault}}, - oldObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(), - newObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(), + oldObj: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test").Obj(), + newObj: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test").Obj(), }, } for name, tc := range cases { @@ -329,7 +326,7 @@ func TestRunCustomValidationPlugins(t *testing.T) { } func TestRunComponentBuilderPlugins(t *testing.T) { - jobSetBase := testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job"). + jobSetBase := testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). ResourceRequests(0, corev1.ResourceList{ corev1.ResourceCPU: resource.MustParse("2"), corev1.ResourceMemory: resource.MustParse("4Gi"), @@ -354,10 +351,10 @@ func TestRunComponentBuilderPlugins(t *testing.T) { wantObjs []client.Object }{ "coscheduling and jobset are performed": { - trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). UID("uid"). Trainer( - testingutil.MakeTrainJobTrainerWrapper(t). + testingutil.MakeTrainJobTrainerWrapper(). ContainerImage("foo:bar"). Obj(), ). @@ -396,7 +393,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) { }, registry: fwkplugins.NewRegistry(), wantObjs: []client.Object{ - testingutil.MakeSchedulerPluginsPodGroup(t, metav1.NamespaceDefault, "test-job"). + testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job"). SchedulingTimeout(300). MinMember(20). MinResources(corev1.ResourceList{ diff --git a/pkg/runtime.v2/framework/interface.go b/pkg/runtime.v2/framework/interface.go index 886c1ab39c..00d613ec0a 100644 --- a/pkg/runtime.v2/framework/interface.go +++ b/pkg/runtime.v2/framework/interface.go @@ -48,7 +48,7 @@ type EnforceMLPolicyPlugin interface { type CustomValidationPlugin interface { Plugin - Validate(oldObj, newObj client.Object) (admission.Warnings, field.ErrorList) + Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) } type ComponentBuilderPlugin interface { diff --git a/pkg/runtime.v2/framework/plugins/mpi/mpi.go b/pkg/runtime.v2/framework/plugins/mpi/mpi.go index b85a265195..9e79f6c5a1 100644 --- a/pkg/runtime.v2/framework/plugins/mpi/mpi.go +++ b/pkg/runtime.v2/framework/plugins/mpi/mpi.go @@ -23,6 +23,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" "github.com/kubeflow/training-operator/pkg/runtime.v2/framework" ) @@ -55,6 +56,6 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info) error { } // TODO: Need to implement validations for MPIJob. -func (m *MPI) Validate(oldObj, newObj client.Object) (admission.Warnings, field.ErrorList) { +func (m *MPI) Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) { return nil, nil } diff --git a/pkg/runtime.v2/framework/plugins/torch/torch.go b/pkg/runtime.v2/framework/plugins/torch/torch.go index 1a0306be98..b9b7f10cb9 100644 --- a/pkg/runtime.v2/framework/plugins/torch/torch.go +++ b/pkg/runtime.v2/framework/plugins/torch/torch.go @@ -23,6 +23,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" "github.com/kubeflow/training-operator/pkg/runtime.v2/framework" ) @@ -51,6 +52,6 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info) error { } // TODO: Need to implement validateions for TorchJob. -func (t *Torch) Validate(oldObj, newObj client.Object) (admission.Warnings, field.ErrorList) { +func (t *Torch) Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) { return nil, nil } diff --git a/pkg/runtime.v2/runtime_test.go b/pkg/runtime.v2/runtime_test.go index bcb6b5efd9..95e366ac22 100644 --- a/pkg/runtime.v2/runtime_test.go +++ b/pkg/runtime.v2/runtime_test.go @@ -32,7 +32,7 @@ import ( ) func TestNewInfo(t *testing.T) { - jobSetBase := testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job"). + jobSetBase := testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). Clone() cases := map[string]struct { @@ -159,7 +159,7 @@ func TestNewInfo(t *testing.T) { } func TestUpdate(t *testing.T) { - jobSetBase := testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job"). + jobSetBase := testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). Clone() cases := map[string]struct { @@ -172,7 +172,7 @@ func TestUpdate(t *testing.T) { info: &Info{ Obj: jobSetBase.Obj(), }, - obj: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). + obj: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). Obj(), wantInfo: &Info{ Obj: jobSetBase.Obj(), diff --git a/pkg/util.v2/testing/wrapper.go b/pkg/util.v2/testing/wrapper.go index 10441c040b..df3f8cbfce 100644 --- a/pkg/util.v2/testing/wrapper.go +++ b/pkg/util.v2/testing/wrapper.go @@ -17,8 +17,6 @@ limitations under the License. package testing import ( - "testing" - batchv1 "k8s.io/api/batch/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -35,8 +33,7 @@ type JobSetWrapper struct { jobsetv1alpha2.JobSet } -func MakeJobSetWrapper(t *testing.T, namespace, name string) *JobSetWrapper { - t.Helper() +func MakeJobSetWrapper(namespace, name string) *JobSetWrapper { return &JobSetWrapper{ JobSet: jobsetv1alpha2.JobSet{ TypeMeta: metav1.TypeMeta{ @@ -170,6 +167,13 @@ func (j *JobSetWrapper) Annotation(key, value string) *JobSetWrapper { return j } +func (j *JobSetWrapper) Replicas(replicas int32) *JobSetWrapper { + for idx := range j.Spec.ReplicatedJobs { + j.Spec.ReplicatedJobs[idx].Replicas = replicas + } + return j +} + func (j *JobSetWrapper) Clone() *JobSetWrapper { return &JobSetWrapper{ JobSet: *j.JobSet.DeepCopy(), @@ -184,8 +188,7 @@ type TrainJobWrapper struct { kubeflowv2.TrainJob } -func MakeTrainJobWrapper(t *testing.T, namespace, name string) *TrainJobWrapper { - t.Helper() +func MakeTrainJobWrapper(namespace, name string) *TrainJobWrapper { return &TrainJobWrapper{ TrainJob: kubeflowv2.TrainJob{ TypeMeta: metav1.TypeMeta{ @@ -226,8 +229,7 @@ type TrainJobTrainerWrapper struct { kubeflowv2.Trainer } -func MakeTrainJobTrainerWrapper(t *testing.T) *TrainJobTrainerWrapper { - t.Helper() +func MakeTrainJobTrainerWrapper() *TrainJobTrainerWrapper { return &TrainJobTrainerWrapper{ Trainer: kubeflowv2.Trainer{}, } @@ -264,8 +266,7 @@ type TrainingRuntimeWrapper struct { kubeflowv2.TrainingRuntime } -func MakeTrainingRuntimeWrapper(t *testing.T, namespace, name string) *TrainingRuntimeWrapper { - t.Helper() +func MakeTrainingRuntimeWrapper(namespace, name string) *TrainingRuntimeWrapper { return &TrainingRuntimeWrapper{ TrainingRuntime: kubeflowv2.TrainingRuntime{ TypeMeta: metav1.TypeMeta{ @@ -281,7 +282,8 @@ func MakeTrainingRuntimeWrapper(t *testing.T, namespace, name string) *TrainingR Spec: jobsetv1alpha2.JobSetSpec{ ReplicatedJobs: []jobsetv1alpha2.ReplicatedJob{ { - Name: "Coordinator", + Name: "Coordinator", + Replicas: 1, Template: batchv1.JobTemplateSpec{ Spec: batchv1.JobSpec{ Template: corev1.PodTemplateSpec{ @@ -295,7 +297,8 @@ func MakeTrainingRuntimeWrapper(t *testing.T, namespace, name string) *TrainingR }, }, { - Name: "Worker", + Name: "Worker", + Replicas: 1, Template: batchv1.JobTemplateSpec{ Spec: batchv1.JobSpec{ Template: corev1.PodTemplateSpec{ @@ -351,8 +354,7 @@ type ClusterTrainingRuntimeWrapper struct { kubeflowv2.ClusterTrainingRuntime } -func MakeClusterTrainingRuntimeWrapper(t *testing.T, name string) *ClusterTrainingRuntimeWrapper { - t.Helper() +func MakeClusterTrainingRuntimeWrapper(name string) *ClusterTrainingRuntimeWrapper { return &ClusterTrainingRuntimeWrapper{ ClusterTrainingRuntime: kubeflowv2.ClusterTrainingRuntime{ TypeMeta: metav1.TypeMeta{ @@ -367,7 +369,8 @@ func MakeClusterTrainingRuntimeWrapper(t *testing.T, name string) *ClusterTraini Spec: jobsetv1alpha2.JobSetSpec{ ReplicatedJobs: []jobsetv1alpha2.ReplicatedJob{ { - Name: "Coordinator", + Name: "Coordinator", + Replicas: 1, Template: batchv1.JobTemplateSpec{ Spec: batchv1.JobSpec{ Template: corev1.PodTemplateSpec{ @@ -381,7 +384,8 @@ func MakeClusterTrainingRuntimeWrapper(t *testing.T, name string) *ClusterTraini }, }, { - Name: "Worker", + Name: "Worker", + Replicas: 1, Template: batchv1.JobTemplateSpec{ Spec: batchv1.JobSpec{ Template: corev1.PodTemplateSpec{ @@ -421,13 +425,19 @@ type TrainingRuntimeSpecWrapper struct { kubeflowv2.TrainingRuntimeSpec } -func MakeTrainingRuntimeSpecWrapper(t *testing.T, spec kubeflowv2.TrainingRuntimeSpec) *TrainingRuntimeSpecWrapper { - t.Helper() +func MakeTrainingRuntimeSpecWrapper(spec kubeflowv2.TrainingRuntimeSpec) *TrainingRuntimeSpecWrapper { return &TrainingRuntimeSpecWrapper{ TrainingRuntimeSpec: spec, } } +func (s *TrainingRuntimeSpecWrapper) Replicas(replicas int32) *TrainingRuntimeSpecWrapper { + for idx := range s.Template.Spec.ReplicatedJobs { + s.Template.Spec.ReplicatedJobs[idx].Replicas = replicas + } + return s +} + func (s *TrainingRuntimeSpecWrapper) ContainerImage(image string) *TrainingRuntimeSpecWrapper { for i, rJob := range s.Template.Spec.ReplicatedJobs { for j := range rJob.Template.Spec.Template.Spec.Containers { @@ -475,8 +485,7 @@ type SchedulerPluginsPodGroupWrapper struct { schedulerpluginsv1alpha1.PodGroup } -func MakeSchedulerPluginsPodGroup(t *testing.T, namespace, name string) *SchedulerPluginsPodGroupWrapper { - t.Helper() +func MakeSchedulerPluginsPodGroup(namespace, name string) *SchedulerPluginsPodGroupWrapper { return &SchedulerPluginsPodGroupWrapper{ PodGroup: schedulerpluginsv1alpha1.PodGroup{ TypeMeta: metav1.TypeMeta{ diff --git a/pkg/webhook.v2/clustertrainingruntime_webhook.go b/pkg/webhook.v2/clustertrainingruntime_webhook.go index a679fe66dd..c98d71a15b 100644 --- a/pkg/webhook.v2/clustertrainingruntime_webhook.go +++ b/pkg/webhook.v2/clustertrainingruntime_webhook.go @@ -20,6 +20,7 @@ import ( "context" apiruntime "k8s.io/apimachinery/pkg/runtime" + "k8s.io/klog/v2" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/webhook" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" @@ -43,8 +44,11 @@ func setupWebhookForClusterTrainingRuntime(mgr ctrl.Manager, run map[string]runt var _ webhook.CustomValidator = (*ClusterTrainingRuntimeWebhook)(nil) -func (w *ClusterTrainingRuntimeWebhook) ValidateCreate(context.Context, apiruntime.Object) (admission.Warnings, error) { - return nil, nil +func (w *ClusterTrainingRuntimeWebhook) ValidateCreate(ctx context.Context, obj apiruntime.Object) (admission.Warnings, error) { + clTrainingRuntime := obj.(*kubeflowv2.ClusterTrainingRuntime) + log := ctrl.LoggerFrom(ctx).WithName("clustertrainingruntime-webhook") + log.V(5).Info("Validating create", "clusterTrainingRuntime", klog.KObj(clTrainingRuntime)) + return nil, validateReplicatedJobs(clTrainingRuntime.Spec.Template.Spec.ReplicatedJobs).ToAggregate() } func (w *ClusterTrainingRuntimeWebhook) ValidateUpdate(context.Context, apiruntime.Object, apiruntime.Object) (admission.Warnings, error) { diff --git a/pkg/webhook.v2/trainingruntime_webhook.go b/pkg/webhook.v2/trainingruntime_webhook.go index 5597e5238f..fa6a7186db 100644 --- a/pkg/webhook.v2/trainingruntime_webhook.go +++ b/pkg/webhook.v2/trainingruntime_webhook.go @@ -20,9 +20,12 @@ import ( "context" apiruntime "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/klog/v2" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/webhook" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" @@ -43,8 +46,25 @@ func setupWebhookForTrainingRuntime(mgr ctrl.Manager, run map[string]runtime.Run var _ webhook.CustomValidator = (*TrainingRuntimeWebhook)(nil) -func (w *TrainingRuntimeWebhook) ValidateCreate(context.Context, apiruntime.Object) (admission.Warnings, error) { - return nil, nil +func (w *TrainingRuntimeWebhook) ValidateCreate(ctx context.Context, obj apiruntime.Object) (admission.Warnings, error) { + trainingRuntime := obj.(*kubeflowv2.TrainingRuntime) + log := ctrl.LoggerFrom(ctx).WithName("trainingruntime-webhook") + log.V(5).Info("Validating create", "trainingRuntime", klog.KObj(trainingRuntime)) + return nil, validateReplicatedJobs(trainingRuntime.Spec.Template.Spec.ReplicatedJobs).ToAggregate() +} + +func validateReplicatedJobs(rJobs []jobsetv1alpha2.ReplicatedJob) field.ErrorList { + rJobsPath := field.NewPath("spec"). + Child("template"). + Child("spec"). + Child("replicatedJobs") + var allErrs field.ErrorList + for idx, rJob := range rJobs { + if rJob.Replicas != 1 { + allErrs = append(allErrs, field.Invalid(rJobsPath.Index(idx).Child("replicas"), rJob.Replicas, "always must be 1")) + } + } + return allErrs } func (w *TrainingRuntimeWebhook) ValidateUpdate(context.Context, apiruntime.Object, apiruntime.Object) (admission.Warnings, error) { diff --git a/pkg/webhook.v2/trainingruntime_webhook_test.go b/pkg/webhook.v2/trainingruntime_webhook_test.go new file mode 100644 index 0000000000..fbbf9a6a35 --- /dev/null +++ b/pkg/webhook.v2/trainingruntime_webhook_test.go @@ -0,0 +1,60 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package webhookv2 + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "k8s.io/apimachinery/pkg/util/validation/field" + jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" + + testingutil "github.com/kubeflow/training-operator/pkg/util.v2/testing" +) + +func TestValidateReplicatedJobs(t *testing.T) { + cases := map[string]struct { + rJobs []jobsetv1alpha2.ReplicatedJob + wantError field.ErrorList + }{ + "valid replicatedJobs": { + rJobs: testingutil.MakeJobSetWrapper("ns", "valid"). + Replicas(1). + Obj().Spec.ReplicatedJobs, + }, + "invalid replicas": { + rJobs: testingutil.MakeJobSetWrapper("ns", "valid"). + Replicas(2). + Obj().Spec.ReplicatedJobs, + wantError: field.ErrorList{ + field.Invalid(field.NewPath("spec").Child("template").Child("spec").Child("replicatedJobs").Index(0).Child("replicas"), + "2", ""), + field.Invalid(field.NewPath("spec").Child("template").Child("spec").Child("replicatedJobs").Index(1).Child("replicas"), + "2", ""), + }, + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + gotErr := validateReplicatedJobs(tc.rJobs) + if diff := cmp.Diff(tc.wantError, gotErr, cmpopts.IgnoreFields(field.Error{}, "Detail", "BadValue")); len(diff) != 0 { + t.Errorf("validateReplicateJobs() mismatch (-want,+got):\n%s", diff) + } + }) + } +} diff --git a/test/integration/webhook.v2/clustertrainingruntime_test.go b/test/integration/webhook.v2/clustertrainingruntime_test.go index 9ba9be69af..a2519c8ff8 100644 --- a/test/integration/webhook.v2/clustertrainingruntime_test.go +++ b/test/integration/webhook.v2/clustertrainingruntime_test.go @@ -22,9 +22,13 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + testingutil "github.com/kubeflow/training-operator/pkg/util.v2/testing" "github.com/kubeflow/training-operator/test/integration/framework" ) +const clTrainingRuntimeName = "test-clustertrainingruntime" + var _ = ginkgo.Describe("ClusterTrainingRuntime Webhook", ginkgo.Ordered, func() { var ns *corev1.Namespace @@ -49,4 +53,25 @@ var _ = ginkgo.Describe("ClusterTrainingRuntime Webhook", ginkgo.Ordered, func() } gomega.Expect(k8sClient.Create(ctx, ns)).To(gomega.Succeed()) }) + + ginkgo.AfterEach(func() { + gomega.Expect(k8sClient.DeleteAllOf(ctx, &kubeflowv2.ClusterTrainingRuntime{})).To(gomega.Succeed()) + }) + + ginkgo.When("Creating ClusterTrainingRuntime", func() { + ginkgo.DescribeTable("", func(runtime func() *kubeflowv2.ClusterTrainingRuntime) { + gomega.Expect(k8sClient.Create(ctx, runtime())).Should(gomega.Succeed()) + }, + ginkgo.Entry("Should succeed to create ClusterTrainingRuntime", + func() *kubeflowv2.ClusterTrainingRuntime { + baseRuntime := testingutil.MakeClusterTrainingRuntimeWrapper(clTrainingRuntimeName) + return baseRuntime. + RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec). + Replicas(1). + Obj()). + Obj() + }), + ) + }) }) diff --git a/test/integration/webhook.v2/trainingruntime_test.go b/test/integration/webhook.v2/trainingruntime_test.go index dc2add14ce..7599e04759 100644 --- a/test/integration/webhook.v2/trainingruntime_test.go +++ b/test/integration/webhook.v2/trainingruntime_test.go @@ -21,10 +21,15 @@ import ( "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client" + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + testingutil "github.com/kubeflow/training-operator/pkg/util.v2/testing" "github.com/kubeflow/training-operator/test/integration/framework" ) +const trainingRuntimeName = "test-trainingruntime" + var _ = ginkgo.Describe("TrainingRuntime Webhook", ginkgo.Ordered, func() { var ns *corev1.Namespace @@ -49,4 +54,25 @@ var _ = ginkgo.Describe("TrainingRuntime Webhook", ginkgo.Ordered, func() { } gomega.Expect(k8sClient.Create(ctx, ns)).To(gomega.Succeed()) }) + + ginkgo.AfterEach(func() { + gomega.Expect(k8sClient.DeleteAllOf(ctx, &kubeflowv2.TrainingRuntime{}, client.InNamespace(ns.Name))).To(gomega.Succeed()) + }) + + ginkgo.When("Creating TrainingRuntime", func() { + ginkgo.DescribeTable("Validate TrainingRuntime on creation", func(runtime func() *kubeflowv2.TrainingRuntime) { + gomega.Expect(k8sClient.Create(ctx, runtime())).Should(gomega.Succeed()) + }, + ginkgo.Entry("Should succeed to create TrainingRuntime", + func() *kubeflowv2.TrainingRuntime { + baseRuntime := testingutil.MakeTrainingRuntimeWrapper(ns.Name, trainingRuntimeName).Clone() + return baseRuntime. + RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec). + Replicas(1). + Obj()). + Obj() + }), + ) + }) }) From 2f616e5a79afd3d0c39d4e447d37ee74b5f30ea4 Mon Sep 17 00:00:00 2001 From: Yuki Iwai Date: Thu, 17 Oct 2024 04:43:01 +0900 Subject: [PATCH 11/12] Add TODO comments Signed-off-by: Yuki Iwai --- pkg/runtime.v2/framework/plugins/jobset/jobset.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pkg/runtime.v2/framework/plugins/jobset/jobset.go b/pkg/runtime.v2/framework/plugins/jobset/jobset.go index 7d53abba1f..82eca0ef7f 100644 --- a/pkg/runtime.v2/framework/plugins/jobset/jobset.go +++ b/pkg/runtime.v2/framework/plugins/jobset/jobset.go @@ -21,6 +21,7 @@ import ( "fmt" "maps" + "github.com/go-logr/logr" batchv1 "k8s.io/api/batch/v1" "k8s.io/apimachinery/pkg/api/equality" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -29,6 +30,7 @@ import ( apiruntime "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/utils/ptr" + ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/client" ctrlutil "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" @@ -43,6 +45,7 @@ type JobSet struct { client client.Client restMapper meta.RESTMapper scheme *apiruntime.Scheme + logger logr.Logger } var _ framework.WatchExtensionPlugin = (*JobSet)(nil) @@ -50,11 +53,12 @@ var _ framework.ComponentBuilderPlugin = (*JobSet)(nil) const Name = "JobSet" -func New(_ context.Context, c client.Client, _ client.FieldIndexer) (framework.Plugin, error) { +func New(ctx context.Context, c client.Client, _ client.FieldIndexer) (framework.Plugin, error) { return &JobSet{ client: c, restMapper: c.RESTMapper(), scheme: c.Scheme(), + logger: ctrl.LoggerFrom(ctx).WithValues("pluginName", "JobSet"), }, nil } @@ -77,6 +81,7 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *kubefl }, Spec: raw.Spec, }) + // TODO (tenzen-y): We should support all field propagation in builder. jobSet := jobSetBuilder. ContainerImage(trainJob.Spec.Trainer.Image). JobCompletionMode(batchv1.IndexedCompletion). @@ -111,7 +116,8 @@ func (j *JobSet) ReconcilerBuilders() []runtime.ReconcilerBuilder { schema.GroupKind{Group: jobsetv1alpha2.GroupVersion.Group, Kind: "JobSet"}, jobsetv1alpha2.SchemeGroupVersion.Version, ); err != nil { - return nil + // TODO (tenzen-y): After we provide the Configuration API, we should return errors based on the enabled plugins. + j.logger.Error(err, "JobSet CRDs must be installed in advance") } return []runtime.ReconcilerBuilder{ func(b *builder.Builder, c client.Client) *builder.Builder { From 0c376d3425c291f43da8c300e0742aa117390936 Mon Sep 17 00:00:00 2001 From: Yuki Iwai Date: Thu, 17 Oct 2024 04:43:52 +0900 Subject: [PATCH 12/12] Replace queueSuspendedTrainJob with queueSuspendedTrainJobs Signed-off-by: Yuki Iwai --- .../framework/plugins/coscheduling/coscheduling.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/runtime.v2/framework/plugins/coscheduling/coscheduling.go b/pkg/runtime.v2/framework/plugins/coscheduling/coscheduling.go index 85f08a8366..36b7d6813d 100644 --- a/pkg/runtime.v2/framework/plugins/coscheduling/coscheduling.go +++ b/pkg/runtime.v2/framework/plugins/coscheduling/coscheduling.go @@ -160,7 +160,7 @@ func (h *PodGroupRuntimeClassHandler) Create(ctx context.Context, e event.Create return } log := ctrl.LoggerFrom(ctx).WithValues("runtimeClass", klog.KObj(containerRuntimeClass)) - if err := h.queueSuspendedTrainJob(ctx, containerRuntimeClass, q); err != nil { + if err := h.queueSuspendedTrainJobs(ctx, containerRuntimeClass, q); err != nil { log.Error(err, "could not queue suspended TrainJob to reconcile queue") } } @@ -175,7 +175,7 @@ func (h *PodGroupRuntimeClassHandler) Update(ctx context.Context, e event.Update return } log := ctrl.LoggerFrom(ctx).WithValues("runtimeClass", klog.KObj(newContainerRuntimeClass)) - if err := h.queueSuspendedTrainJob(ctx, newContainerRuntimeClass, q); err != nil { + if err := h.queueSuspendedTrainJobs(ctx, newContainerRuntimeClass, q); err != nil { log.Error(err, "could not queue suspended TrainJob to reconcile queue") } } @@ -186,7 +186,7 @@ func (h *PodGroupRuntimeClassHandler) Delete(ctx context.Context, e event.Delete return } log := ctrl.LoggerFrom(ctx).WithValues("runtimeClass", klog.KObj(containerRuntimeClass)) - if err := h.queueSuspendedTrainJob(ctx, containerRuntimeClass, q); err != nil { + if err := h.queueSuspendedTrainJobs(ctx, containerRuntimeClass, q); err != nil { log.Error(err, "could not queue suspended TrainJob to reconcile queue") } } @@ -194,7 +194,7 @@ func (h *PodGroupRuntimeClassHandler) Delete(ctx context.Context, e event.Delete func (h *PodGroupRuntimeClassHandler) Generic(context.Context, event.GenericEvent, workqueue.RateLimitingInterface) { } -func (h *PodGroupRuntimeClassHandler) queueSuspendedTrainJob(ctx context.Context, runtimeClass *nodev1.RuntimeClass, q workqueue.RateLimitingInterface) error { +func (h *PodGroupRuntimeClassHandler) queueSuspendedTrainJobs(ctx context.Context, runtimeClass *nodev1.RuntimeClass, q workqueue.RateLimitingInterface) error { var trainingRuntimes kubeflowv2.TrainingRuntimeList if err := h.client.List(ctx, &trainingRuntimes, client.MatchingFields{TrainingRuntimeContainerRuntimeClassKey: runtimeClass.Name}); err != nil { return err