From e02439b429e82dcc9aeb9bfefc32d0b07eb93fe4 Mon Sep 17 00:00:00 2001 From: Yuki Iwai Date: Tue, 26 Mar 2024 21:35:01 +0900 Subject: [PATCH] Implement webhook validations for the PyTorchJob Signed-off-by: Yuki Iwai --- Makefile | 5 +- cmd/training-operator.v1/main.go | 91 ++++-- go.mod | 6 +- go.sum | 10 +- manifests/base/deployment.yaml | 12 + .../base/internalcert/kustomization.yaml | 2 + manifests/base/internalcert/secret.yaml | 4 + manifests/base/kustomization.yaml | 2 + manifests/base/rbac/role.yaml | 18 ++ manifests/base/webhook/kustomization.yaml | 11 + manifests/base/webhook/kustomizeconfig.yaml | 18 ++ manifests/base/webhook/manifests.yaml | 26 ++ manifests/base/webhook/service.yaml | 9 + .../kubeflow.org/v1/pytorch_validation.go | 92 ------ .../v1/pytorch_validation_test.go | 222 --------------- pkg/cert/cert.go | 80 ++++++ .../pytorch/pytorchjob_controller.go | 7 - .../pytorchjob_controller_suite_test.go | 24 ++ pkg/webhooks/pytorch/pytorchjob_webhook.go | 146 ++++++++++ .../pytorch/pytorchjob_webhook_test.go | 269 ++++++++++++++++++ pkg/webhooks/webhooks.go | 41 +++ 21 files changed, 750 insertions(+), 345 deletions(-) create mode 100644 manifests/base/internalcert/kustomization.yaml create mode 100644 manifests/base/internalcert/secret.yaml create mode 100644 manifests/base/webhook/kustomization.yaml create mode 100644 manifests/base/webhook/kustomizeconfig.yaml create mode 100644 manifests/base/webhook/manifests.yaml create mode 100644 manifests/base/webhook/service.yaml delete mode 100644 pkg/apis/kubeflow.org/v1/pytorch_validation.go delete mode 100644 pkg/apis/kubeflow.org/v1/pytorch_validation_test.go create mode 100644 pkg/cert/cert.go create mode 100644 pkg/webhooks/pytorch/pytorchjob_webhook.go create mode 100644 pkg/webhooks/pytorch/pytorchjob_webhook_test.go create mode 100644 pkg/webhooks/webhooks.go diff --git a/Makefile b/Makefile index 90dd48e34e..f13e374d9d 100644 --- a/Makefile +++ b/Makefile @@ -37,7 +37,10 @@ help: ## Display this help. ##@ Development manifests: controller-gen ## Generate WebhookConfiguration, ClusterRole and CustomResourceDefinition objects. - $(CONTROLLER_GEN) $(CRD_OPTIONS) rbac:roleName=training-operator webhook paths="./pkg/..." output:crd:artifacts:config=manifests/base/crds output:rbac:artifacts:config=manifests/base/rbac + $(CONTROLLER_GEN) $(CRD_OPTIONS) rbac:roleName=training-operator webhook paths="./pkg/..." \ + output:crd:artifacts:config=manifests/base/crds \ + output:rbac:artifacts:config=manifests/base/rbac \ + output:webhook:artifacts:config=manifests/base/webhook generate: controller-gen ## Generate apidoc, sdk and code containing DeepCopy, DeepCopyInto, and DeepCopyObject method implementations. $(CONTROLLER_GEN) object:headerFile="hack/boilerplate/boilerplate.go.txt" paths="./pkg/apis/..." diff --git a/cmd/training-operator.v1/main.go b/cmd/training-operator.v1/main.go index 9acd3f52ba..897d4a018e 100644 --- a/cmd/training-operator.v1/main.go +++ b/cmd/training-operator.v1/main.go @@ -19,6 +19,7 @@ package main import ( "errors" "flag" + "net/http" "os" "strings" @@ -40,9 +41,11 @@ import ( volcanoclient "volcano.sh/apis/pkg/client/clientset/versioned" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + "github.com/kubeflow/training-operator/pkg/cert" "github.com/kubeflow/training-operator/pkg/config" controllerv1 "github.com/kubeflow/training-operator/pkg/controller.v1" "github.com/kubeflow/training-operator/pkg/controller.v1/common" + "github.com/kubeflow/training-operator/pkg/webhooks" //+kubebuilder:scaffold:imports ) @@ -72,8 +75,11 @@ func main() { var enabledSchemes controllerv1.EnabledSchemes var gangSchedulerName string var namespace string - var webhookServerPort int var controllerThreads int + var webhookServerPort int + var webhookServiceName string + var webhookSecretName string + flag.StringVar(&metricsAddr, "metrics-bind-address", ":8080", "The address the metric endpoint binds to.") flag.StringVar(&probeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.") flag.BoolVar(&enableLeaderElection, "leader-elect", false, @@ -86,7 +92,6 @@ func main() { " Note: If you set another scheduler name, the training-operator assumes it's the scheduler-plugins.") flag.StringVar(&namespace, "namespace", os.Getenv(EnvKubeflowNamespace), "The namespace to monitor kubeflow jobs. If unset, it monitors all namespaces cluster-wide."+ "If set, it only monitors kubeflow jobs in the given namespace.") - flag.IntVar(&webhookServerPort, "webhook-server-port", 9443, "Endpoint port for the webhook server.") flag.IntVar(&controllerThreads, "controller-threads", 1, "Number of worker threads used by the controller.") // PyTorch related flags @@ -101,6 +106,11 @@ func main() { flag.StringVar(&config.Config.MPIKubectlDeliveryImage, "mpi-kubectl-delivery-image", config.MPIKubectlDeliveryImageDefault, "The image for mpi launcher init container") + // Cert generation flags + flag.IntVar(&webhookServerPort, "webhook-server-port", 9443, "Endpoint port for the webhook server.") + flag.StringVar(&webhookServiceName, "webhook-service-name", "training-operator-webhook-service", "Name of the Service used as part of the DNSName") + flag.StringVar(&webhookSecretName, "webhook-secret-name", "training-operator-webhook-server-secret", "Name of the Secret to store CA and server certs") + opts := zap.Options{ Development: true, StacktraceLevel: zapcore.DPanicLevel, @@ -124,9 +134,9 @@ func main() { Metrics: metricsserver.Options{ BindAddress: metricsAddr, }, - WebhookServer: &webhook.DefaultServer{Options: webhook.Options{ + WebhookServer: webhook.NewServer(webhook.Options{ Port: webhookServerPort, - }}, + }), HealthProbeBindAddress: probeAddr, LeaderElection: enableLeaderElection, LeaderElectionID: leaderElectionID, @@ -137,20 +147,23 @@ func main() { os.Exit(1) } - // Set up controllers using goroutines to start the manager quickly. - go setupControllers(mgr, enabledSchemes, gangSchedulerName, controllerThreads) - - //+kubebuilder:scaffold:builder - - if err := mgr.AddHealthzCheck("healthz", healthz.Ping); err != nil { - setupLog.Error(err, "unable to set up health check") - os.Exit(1) + certsReady := make(chan struct{}) + defer close(certsReady) + certGenerationConfig := cert.Config{ + WebhookSecretName: webhookSecretName, + WebhookServiceName: webhookServiceName, } - if err := mgr.AddReadyzCheck("readyz", healthz.Ping); err != nil { - setupLog.Error(err, "unable to set up ready check") + if err = cert.ManageCerts(mgr, certGenerationConfig, certsReady); err != nil { + setupLog.Error(err, "Unable to set up cert rotation") os.Exit(1) } + setupProbeEndpoints(mgr, certsReady) + // Set up controllers using goroutines to start the manager quickly. + go setupControllers(mgr, enabledSchemes, gangSchedulerName, controllerThreads, certsReady) + + //+kubebuilder:scaffold:builder + setupLog.Info("starting manager") if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil { setupLog.Error(err, "problem running manager") @@ -158,9 +171,12 @@ func main() { } } -func setupControllers(mgr ctrl.Manager, enabledSchemes controllerv1.EnabledSchemes, gangSchedulerName string, controllerThreads int) { - setupLog.Info("registering controllers...") +func setupControllers(mgr ctrl.Manager, enabledSchemes controllerv1.EnabledSchemes, gangSchedulerName string, controllerThreads int, certsReady <-chan struct{}) { + setupLog.Info("Waiting for certificate generation to complete") + <-certsReady + setupLog.Info("Certs ready") + setupLog.Info("registering controllers...") // Prepare GangSchedulingSetupFunc gangSchedulingSetupFunc := common.GenNonGangSchedulerSetupFunc() if strings.EqualFold(gangSchedulerName, string(common.GangSchedulerVolcano)) { @@ -182,15 +198,52 @@ func setupControllers(mgr ctrl.Manager, enabledSchemes controllerv1.EnabledSchem } errMsg := "failed to set up controllers" for _, s := range enabledSchemes { - setupFunc, supported := controllerv1.SupportedSchemeReconciler[s] - if !supported { + setupReconcilerFunc, supportedReconciler := controllerv1.SupportedSchemeReconciler[s] + if !supportedReconciler { setupLog.Error(errors.New(errMsg), "scheme is not supported", "scheme", s) os.Exit(1) } - if err := setupFunc(mgr, gangSchedulingSetupFunc, controllerThreads); err != nil { + if err := setupReconcilerFunc(mgr, gangSchedulingSetupFunc, controllerThreads); err != nil { setupLog.Error(errors.New(errMsg), "unable to create controller", "scheme", s) os.Exit(1) } + setupWebhookFunc, supportedWebhook := webhooks.SupportedSchemeWebhook[s] + if !supportedWebhook { + setupLog.Error(errors.New(errMsg), "scheme is not supported", "scheme", s) + os.Exit(1) + } + if err := setupWebhookFunc(mgr); err != nil { + setupLog.Error(errors.New(errMsg), "unable to start webhook server", "scheme", s) + os.Exit(1) + } + } +} + +func setupProbeEndpoints(mgr ctrl.Manager, certsReady <-chan struct{}) { + defer setupLog.Info("Probe endpoints are configured on healthz and readyz") + + if err := mgr.AddHealthzCheck("healthz", healthz.Ping); err != nil { + setupLog.Error(err, "unable to set up health check") + os.Exit(1) + } + + // Wait for the webhook server to be listening before advertising the + // training-operator replica as ready. This allows users to wait with sending the first + // requests, requiring webhooks, until the training-operator deployment is available, so + // that the early requests are not rejected during the traininig-operator's startup. + // We wrap the call to GetWebhookServer in a closure to delay calling + // the function, otherwise a not fully-initialized webhook server (without + // ready certs) fails the start of the manager. + if err := mgr.AddReadyzCheck("readyz", func(req *http.Request) error { + select { + case <-certsReady: + return mgr.GetWebhookServer().StartedChecker()(req) + default: + return errors.New("certificates are not ready") + } + }); err != nil { + setupLog.Error(err, "unable to set up ready check") + os.Exit(1) } } diff --git a/go.mod b/go.mod index 09397ea053..1dad73fb8c 100644 --- a/go.mod +++ b/go.mod @@ -7,10 +7,12 @@ require ( github.com/google/go-cmp v0.5.9 github.com/onsi/ginkgo/v2 v2.11.0 github.com/onsi/gomega v1.27.10 + github.com/open-policy-agent/cert-controller v0.10.1 github.com/prometheus/client_golang v1.16.0 github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.8.2 go.uber.org/zap v1.25.0 + golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e k8s.io/api v0.28.2 k8s.io/apimachinery v0.28.2 k8s.io/client-go v0.28.2 @@ -58,8 +60,8 @@ require ( github.com/prometheus/common v0.44.0 // indirect github.com/prometheus/procfs v0.10.1 // indirect github.com/spf13/pflag v1.0.5 // indirect + go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e // indirect golang.org/x/mod v0.10.0 // indirect golang.org/x/net v0.17.0 // indirect golang.org/x/oauth2 v0.8.0 // indirect @@ -74,7 +76,7 @@ require ( gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - k8s.io/apiextensions-apiserver v0.28.0 // indirect + k8s.io/apiextensions-apiserver v0.28.1 // indirect k8s.io/component-base v0.28.1 // indirect k8s.io/gengo v0.0.0-20220902162205-c0856e24416d // indirect sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect diff --git a/go.sum b/go.sum index 5ae70e5844..3b0556b253 100644 --- a/go.sum +++ b/go.sum @@ -89,6 +89,9 @@ github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU github.com/onsi/ginkgo/v2 v2.11.0/go.mod h1:ZhrRA5XmEE3x3rhlzamx/JJvujdZoJ2uvgI7kR0iZvM= github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI= github.com/onsi/gomega v1.27.10/go.mod h1:RsS8tutOdbdgzbPtzzATp12yT7kM5I5aElG3evPbQ0M= +github.com/open-policy-agent/cert-controller v0.10.1 h1:RXSYoyn8FdCenWecRP//UV5nbVfmstNpj4kHQFkvPK4= +github.com/open-policy-agent/cert-controller v0.10.1/go.mod h1:4uRbBLY5DsPOog+a9pqk3JLxuuhrWsbUedQW65HcLTI= +github.com/open-policy-agent/frameworks/constraint v0.0.0-20230822235116-f0b62fe1e4c4 h1:5dum5SLEz+95JDLkMls7Z7IDPjvSq3UhJSFe4f5einQ= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -122,6 +125,8 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= @@ -215,8 +220,8 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= k8s.io/api v0.28.2 h1:9mpl5mOb6vXZvqbQmankOfPIGiudghwCoLl1EYfUZbw= k8s.io/api v0.28.2/go.mod h1:RVnJBsjU8tcMq7C3iaRSGMeaKt2TWEUXcpIt/90fjEg= -k8s.io/apiextensions-apiserver v0.28.0 h1:CszgmBL8CizEnj4sj7/PtLGey6Na3YgWyGCPONv7E9E= -k8s.io/apiextensions-apiserver v0.28.0/go.mod h1:uRdYiwIuu0SyqJKriKmqEN2jThIJPhVmOWETm8ud1VE= +k8s.io/apiextensions-apiserver v0.28.1 h1:l2ThkBRjrWpw4f24uq0Da2HaEgqJZ7pcgiEUTKSmQZw= +k8s.io/apiextensions-apiserver v0.28.1/go.mod h1:sVvrI+P4vxh2YBBcm8n2ThjNyzU4BQGilCQ/JAY5kGs= k8s.io/apimachinery v0.28.2 h1:KCOJLrc6gu+wV1BYgwik4AF4vXOlVJPdiqn0yAWWwXQ= k8s.io/apimachinery v0.28.2/go.mod h1:RdzF87y/ngqk9H4z3EL2Rppv5jj95vGS/HaFXrLDApU= k8s.io/client-go v0.28.2 h1:DNoYI1vGq0slMBN/SWKMZMw0Rq+0EQW6/AK4v9+3VeY= @@ -230,6 +235,7 @@ k8s.io/gengo v0.0.0-20220902162205-c0856e24416d/go.mod h1:FiNAH4ZV3gBg2Kwh89tzAE k8s.io/klog/v2 v2.2.0/go.mod h1:Od+F08eJP+W3HUb4pSrPpgp9DGU4GzlpG/TmITuYh/Y= k8s.io/klog/v2 v2.100.1 h1:7WCHKK6K8fNhTqfBhISHQ97KrnJNFZMcQvKp7gP/tmg= k8s.io/klog/v2 v2.100.1/go.mod h1:y1WjHnz7Dj687irZUWR/WLkLc5N1YHtjLdmgWjndZn0= +k8s.io/kube-aggregator v0.28.1 h1:rvG4llYnQKHjj6YjjoBPEJxfD1uH0DJwkrJTNKGAaCs= k8s.io/kube-openapi v0.0.0-20230717233707-2695361300d9 h1:LyMgNKD2P8Wn1iAwQU5OhxCKlKJy0sHc+PcDwFB24dQ= k8s.io/kube-openapi v0.0.0-20230717233707-2695361300d9/go.mod h1:wZK2AVp1uHCp4VamDVgBP2COHZjqD1T68Rf0CM3YjSM= k8s.io/utils v0.0.0-20230406110748-d93618cff8a2 h1:qY1Ad8PODbnymg2pRbkyMT/ylpTrCM8P2RJ0yroCyIk= diff --git a/manifests/base/deployment.yaml b/manifests/base/deployment.yaml index de4dc8740e..2537c9b156 100644 --- a/manifests/base/deployment.yaml +++ b/manifests/base/deployment.yaml @@ -23,6 +23,9 @@ spec: name: training-operator ports: - containerPort: 8080 + - containerPort: 9443 + name: webhook-server + protocol: TCP env: - name: MY_POD_NAMESPACE valueFrom: @@ -34,6 +37,10 @@ spec: fieldPath: metadata.name securityContext: allowPrivilegeEscalation: false + volumeMounts: + - mountPath: /tmp/k8s-webhook-server/serving-certs + name: cert + readOnly: true livenessProbe: httpGet: path: /healthz @@ -50,3 +57,8 @@ spec: timeoutSeconds: 3 serviceAccountName: training-operator terminationGracePeriodSeconds: 10 + volumes: + - name: cert + secret: + defaultMode: 420 + secretName: training-operator-webhook-server-secret diff --git a/manifests/base/internalcert/kustomization.yaml b/manifests/base/internalcert/kustomization.yaml new file mode 100644 index 0000000000..97a9721bd8 --- /dev/null +++ b/manifests/base/internalcert/kustomization.yaml @@ -0,0 +1,2 @@ +resources: + - secret.yaml diff --git a/manifests/base/internalcert/secret.yaml b/manifests/base/internalcert/secret.yaml new file mode 100644 index 0000000000..68843736af --- /dev/null +++ b/manifests/base/internalcert/secret.yaml @@ -0,0 +1,4 @@ +apiVersion: v1 +kind: Secret +metadata: + name: training-operator-webhook-server-secret diff --git a/manifests/base/kustomization.yaml b/manifests/base/kustomization.yaml index 1308bb6da2..b6a710ec10 100644 --- a/manifests/base/kustomization.yaml +++ b/manifests/base/kustomization.yaml @@ -5,5 +5,7 @@ resources: - ./rbac/cluster-role-binding.yaml - ./rbac/role.yaml - ./rbac/service-account.yaml + - ./internalcert + - ./webhook - service.yaml - deployment.yaml diff --git a/manifests/base/rbac/role.yaml b/manifests/base/rbac/role.yaml index 4c77d2fae6..60cdbf2aa5 100644 --- a/manifests/base/rbac/role.yaml +++ b/manifests/base/rbac/role.yaml @@ -43,6 +43,15 @@ rules: - pods/exec verbs: - create +- apiGroups: + - "" + resources: + - secrets + verbs: + - get + - list + - update + - watch - apiGroups: - "" resources: @@ -62,6 +71,15 @@ rules: - get - list - watch +- apiGroups: + - admissionregistration.k8s.io + resources: + - validatingwebhookconfigurations + verbs: + - get + - list + - update + - watch - apiGroups: - autoscaling resources: diff --git a/manifests/base/webhook/kustomization.yaml b/manifests/base/webhook/kustomization.yaml new file mode 100644 index 0000000000..e2f0e5d625 --- /dev/null +++ b/manifests/base/webhook/kustomization.yaml @@ -0,0 +1,11 @@ +apiVersion: kustomize.config.k8s.io/v1beta1 +kind: Kustomization +resources: + - manifests.yaml + - service.yaml +namePrefix: training-operator- +commonLabels: + control-plane: kubeflow-training-operator + +configurations: + - kustomizeconfig.yaml diff --git a/manifests/base/webhook/kustomizeconfig.yaml b/manifests/base/webhook/kustomizeconfig.yaml new file mode 100644 index 0000000000..9190272198 --- /dev/null +++ b/manifests/base/webhook/kustomizeconfig.yaml @@ -0,0 +1,18 @@ +# the following config is for teaching kustomize where to look at when substituting vars. +# It requires kustomize v2.1.0 or newer to work properly. +nameReference: + - kind: Service + version: v1 + fieldSpecs: + - kind: ValidatingWebhookConfiguration + group: admissionregistration.k8s.io + path: webhooks/clientConfig/service/name + +namespace: + - kind: ValidatingWebhookConfiguration + group: admissionregistration.k8s.io + path: webhooks/clientConfig/service/namespace + create: true + +varReference: + - path: metadata/annotations diff --git a/manifests/base/webhook/manifests.yaml b/manifests/base/webhook/manifests.yaml new file mode 100644 index 0000000000..45dc562128 --- /dev/null +++ b/manifests/base/webhook/manifests.yaml @@ -0,0 +1,26 @@ +--- +apiVersion: admissionregistration.k8s.io/v1 +kind: ValidatingWebhookConfiguration +metadata: + name: validating-webhook-configuration +webhooks: +- admissionReviewVersions: + - v1 + clientConfig: + service: + name: webhook-service + namespace: system + path: /validate-kubeflow-org-v1-pytorchjob + failurePolicy: Fail + name: vpytorchjob.kb.io + rules: + - apiGroups: + - kubeflow.org + apiVersions: + - v1 + operations: + - CREATE + - UPDATE + resources: + - pytorchjobs + sideEffects: None diff --git a/manifests/base/webhook/service.yaml b/manifests/base/webhook/service.yaml new file mode 100644 index 0000000000..75035f5d79 --- /dev/null +++ b/manifests/base/webhook/service.yaml @@ -0,0 +1,9 @@ +apiVersion: v1 +kind: Service +metadata: + name: webhook-service +spec: + ports: + - port: 443 + protocol: TCP + targetPort: 9443 diff --git a/pkg/apis/kubeflow.org/v1/pytorch_validation.go b/pkg/apis/kubeflow.org/v1/pytorch_validation.go deleted file mode 100644 index 752b8196df..0000000000 --- a/pkg/apis/kubeflow.org/v1/pytorch_validation.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2018 The Kubeflow Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package v1 - -import ( - "fmt" - - apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" -) - -func ValidateV1PyTorchJob(pytorchJob *PyTorchJob) error { - if errors := apimachineryvalidation.NameIsDNS1035Label(pytorchJob.ObjectMeta.Name, false); errors != nil { - return fmt.Errorf("PyTorchJob name is invalid: %v", errors) - } - if err := validatePyTorchReplicaSpecs(pytorchJob.Spec.PyTorchReplicaSpecs); err != nil { - return err - } - if err := validateNprocPerNode(pytorchJob); err != nil { - return err - } - return nil -} - -func validateNprocPerNode(pytorchJob *PyTorchJob) error { - if pytorchJob.Spec.NprocPerNode != nil && pytorchJob.Spec.ElasticPolicy != nil && pytorchJob.Spec.ElasticPolicy.NProcPerNode != nil { - return fmt.Errorf(".spec.elasticPolicy.nProcPerNode is deprecated, use .spec.nprocPerNode instead") - } - return nil -} - -func validatePyTorchReplicaSpecs(specs map[ReplicaType]*ReplicaSpec) error { - if specs == nil { - return fmt.Errorf("PyTorchJobSpec is not valid") - } - for rType, value := range specs { - if value == nil || len(value.Template.Spec.Containers) == 0 { - return fmt.Errorf("PyTorchJobSpec is not valid: containers definition expected in %v", rType) - } - // Make sure the replica type is valid. - validReplicaTypes := []ReplicaType{PyTorchJobReplicaTypeMaster, PyTorchJobReplicaTypeWorker} - - isValidReplicaType := false - for _, t := range validReplicaTypes { - if t == rType { - isValidReplicaType = true - break - } - } - - if !isValidReplicaType { - return fmt.Errorf("PyTorchReplicaType is %v but must be one of %v", rType, validReplicaTypes) - } - - //Make sure the image is defined in the container - defaultContainerPresent := false - for _, container := range value.Template.Spec.Containers { - if container.Image == "" { - msg := fmt.Sprintf("PyTorchJobSpec is not valid: Image is undefined in the container of %v", rType) - return fmt.Errorf(msg) - } - if container.Name == PyTorchJobDefaultContainerName { - defaultContainerPresent = true - } - } - //Make sure there has at least one container named "pytorch" - if !defaultContainerPresent { - msg := fmt.Sprintf("PyTorchJobSpec is not valid: There is no container named %s in %v", PyTorchJobDefaultContainerName, rType) - return fmt.Errorf(msg) - } - if rType == PyTorchJobReplicaTypeMaster { - if value.Replicas != nil && int(*value.Replicas) != 1 { - return fmt.Errorf("PyTorchJobSpec is not valid: There must be only 1 master replica") - } - } - - } - - return nil - -} diff --git a/pkg/apis/kubeflow.org/v1/pytorch_validation_test.go b/pkg/apis/kubeflow.org/v1/pytorch_validation_test.go deleted file mode 100644 index 0b46da3742..0000000000 --- a/pkg/apis/kubeflow.org/v1/pytorch_validation_test.go +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright 2018 The Kubeflow Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package v1 - -import ( - "testing" - - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/utils/pointer" -) - -func TestValidateV1PyTorchJob(t *testing.T) { - validPyTorchReplicaSpecs := map[ReplicaType]*ReplicaSpec{ - PyTorchJobReplicaTypeMaster: { - Replicas: pointer.Int32(1), - RestartPolicy: RestartPolicyOnFailure, - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{{ - Name: "pytorch", - Image: "docker.io/kubeflowkatib/pytorch-mnist:v1beta1-45c5727", - ImagePullPolicy: corev1.PullAlways, - Command: []string{ - "python3", - "/opt/pytorch-mnist/mnist.py", - "--epochs=1", - }, - }}, - }, - }, - }, - PyTorchJobReplicaTypeWorker: { - Replicas: pointer.Int32(1), - RestartPolicy: RestartPolicyOnFailure, - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{{ - Name: "pytorch", - Image: "docker.io/kubeflowkatib/pytorch-mnist:v1beta1-45c5727", - ImagePullPolicy: corev1.PullAlways, - Command: []string{ - "python3", - "/opt/pytorch-mnist/mnist.py", - "--epochs=1", - }, - }}, - }, - }, - }, - } - - testCases := map[string]struct { - pytorchJob *PyTorchJob - wantErr bool - }{ - "valid PyTorchJob": { - pytorchJob: &PyTorchJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: PyTorchJobSpec{ - PyTorchReplicaSpecs: validPyTorchReplicaSpecs, - }, - }, - wantErr: false, - }, - "pytorchJob name does not meet DNS1035": { - pytorchJob: &PyTorchJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "0-test", - }, - Spec: PyTorchJobSpec{ - PyTorchReplicaSpecs: validPyTorchReplicaSpecs, - }, - }, - wantErr: true, - }, - "no containers": { - pytorchJob: &PyTorchJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: PyTorchJobSpec{ - PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - PyTorchJobReplicaTypeWorker: { - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{}, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - "image is empty": { - pytorchJob: &PyTorchJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: PyTorchJobSpec{ - PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - PyTorchJobReplicaTypeWorker: { - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "pytorch", - Image: "", - }, - }, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - "pytorchJob default container name doesn't present": { - pytorchJob: &PyTorchJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: PyTorchJobSpec{ - PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - PyTorchJobReplicaTypeWorker: { - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "", - Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0", - }, - }, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - "the number of replicas in masterReplica is other than 1": { - pytorchJob: &PyTorchJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: PyTorchJobSpec{ - PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - PyTorchJobReplicaTypeMaster: { - Replicas: pointer.Int32(2), - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "pytorch", - Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0", - }, - }, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - "Spec.NprocPerNode and Spec.ElasticPolicy.NProcPerNode are set": { - pytorchJob: &PyTorchJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: PyTorchJobSpec{ - NprocPerNode: pointer.String("1"), - ElasticPolicy: &ElasticPolicy{ - NProcPerNode: pointer.Int32(1), - }, - PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - PyTorchJobReplicaTypeMaster: { - Replicas: pointer.Int32(2), - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "pytorch", - Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0", - }, - }, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - } - - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - got := ValidateV1PyTorchJob(tc.pytorchJob) - if (got != nil) != tc.wantErr { - t.Fatalf("ValidateV1PyTorchJob() error = %v, wantErr %v", got, tc.wantErr) - } - }) - } -} diff --git a/pkg/cert/cert.go b/pkg/cert/cert.go new file mode 100644 index 0000000000..ce5003e0d4 --- /dev/null +++ b/pkg/cert/cert.go @@ -0,0 +1,80 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cert + +import ( + "fmt" + "os" + "strings" + + cert "github.com/open-policy-agent/cert-controller/pkg/rotator" + "k8s.io/apimachinery/pkg/types" + ctrl "sigs.k8s.io/controller-runtime" +) + +const ( + certDir = "/tmp/k8s-webhook-server/serving-certs" + vwcName = "training-operator-validating-webhook-configuration" + caName = "training-operator-ca" + caOrganization = "training-operator" + defaultOperatorNamespace = "kubeflow" +) + +type Config struct { + WebhookServiceName string + WebhookSecretName string +} + +// +kubebuilder:rbac:groups="",resources=secrets,verbs=get;list;watch;update +// +kubebuilder:rbac:groups="admissionregistration.k8s.io",resources=validatingwebhookconfigurations,verbs=get;list;watch;update + +// ManageCerts creates all certs for webhooks. +func ManageCerts(mgr ctrl.Manager, cfg Config, setupFinished chan struct{}) error { + var ( + ns = getOperatorNamespace() + // DNSName is ..svc + dnsName = fmt.Sprintf("%s.%s.svc", cfg.WebhookServiceName, ns) + ) + + return cert.AddRotator(mgr, &cert.CertRotator{ + SecretKey: types.NamespacedName{ + Namespace: ns, + Name: cfg.WebhookSecretName, + }, + CertDir: certDir, + CAName: caName, + CAOrganization: caOrganization, + DNSName: dnsName, + IsReady: setupFinished, + Webhooks: []cert.WebhookInfo{{ + Type: cert.Validating, + Name: vwcName, + }}, + // When training-operator is running in the leader election mode, + // we expect webhook server will run in primary and secondary instance + RequireLeaderElection: false, + }) +} + +func getOperatorNamespace() string { + if data, err := os.ReadFile("/var/run/secrets/kubernetes.io/serviceaccount/namespace"); err == nil { + if ns := strings.TrimSpace(string(data)); len(ns) > 0 { + return ns + } + } + return defaultOperatorNamespace +} diff --git a/pkg/controller.v1/pytorch/pytorchjob_controller.go b/pkg/controller.v1/pytorch/pytorchjob_controller.go index 151530c138..7025c24396 100644 --- a/pkg/controller.v1/pytorch/pytorchjob_controller.go +++ b/pkg/controller.v1/pytorch/pytorchjob_controller.go @@ -132,13 +132,6 @@ func (r *PyTorchJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) return ctrl.Result{}, client.IgnoreNotFound(err) } - if err = kubeflowv1.ValidateV1PyTorchJob(pytorchjob); err != nil { - logger.Error(err, "PyTorchJob failed validation") - r.Recorder.Eventf(pytorchjob, corev1.EventTypeWarning, commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobFailedValidationReason), - "PyTorchJob failed validation because %s", err) - return ctrl.Result{}, err - } - // Check if reconciliation is needed jobKey, err := common.KeyFunc(pytorchjob) if err != nil { diff --git a/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go b/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go index 0f8cf9744e..7abed7fa77 100644 --- a/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go +++ b/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go @@ -16,12 +16,17 @@ package pytorch import ( "context" + "crypto/tls" + "fmt" + "net" "path/filepath" "testing" + "time" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/kubeflow/training-operator/pkg/config" "github.com/kubeflow/training-operator/pkg/controller.v1/common" + pytorchwebhook "github.com/kubeflow/training-operator/pkg/webhooks/pytorch" . "github.com/onsi/ginkgo/v2" "github.com/onsi/gomega" @@ -33,6 +38,7 @@ import ( logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" + "sigs.k8s.io/controller-runtime/pkg/webhook" "volcano.sh/apis/pkg/apis/scheduling/v1beta1" //+kubebuilder:scaffold:imports ) @@ -62,6 +68,9 @@ var _ = BeforeSuite(func() { testEnv = &envtest.Environment{ CRDDirectoryPaths: []string{filepath.Join("..", "..", "..", "manifests", "base", "crds")}, ErrorIfCRDPathMissing: true, + WebhookInstallOptions: envtest.WebhookInstallOptions{ + Paths: []string{filepath.Join("..", "..", "..", "manifests", "base", "webhook")}, + }, } cfg, err := testEnv.Start() @@ -87,6 +96,12 @@ var _ = BeforeSuite(func() { Metrics: metricsserver.Options{ BindAddress: "0", }, + WebhookServer: webhook.NewServer( + webhook.Options{ + Host: testEnv.WebhookInstallOptions.LocalServingHost, + Port: testEnv.WebhookInstallOptions.LocalServingPort, + CertDir: testEnv.WebhookInstallOptions.LocalServingCertDir, + }), }) Expect(err).NotTo(gomega.HaveOccurred()) @@ -94,12 +109,21 @@ var _ = BeforeSuite(func() { r := NewReconciler(mgr, gangSchedulingSetupFunc) Expect(r.SetupWithManager(mgr, 1)).NotTo(gomega.HaveOccurred()) + Expect(pytorchwebhook.SetupWebhook(mgr)).NotTo(gomega.HaveOccurred()) go func() { defer GinkgoRecover() err = mgr.Start(testCtx) Expect(err).ToNot(HaveOccurred(), "failed to run manager") }() + + dialer := &net.Dialer{Timeout: time.Second} + addrPort := fmt.Sprintf("%s:%d", testEnv.WebhookInstallOptions.LocalServingHost, testEnv.WebhookInstallOptions.LocalServingPort) + gomega.Eventually(func(g gomega.Gomega) { + conn, err := tls.DialWithDialer(dialer, "tcp", addrPort, &tls.Config{InsecureSkipVerify: true}) + g.Expect(err).NotTo(gomega.HaveOccurred()) + g.Expect(conn.Close()).NotTo(gomega.HaveOccurred()) + }).Should(gomega.Succeed()) }) var _ = AfterSuite(func() { diff --git a/pkg/webhooks/pytorch/pytorchjob_webhook.go b/pkg/webhooks/pytorch/pytorchjob_webhook.go new file mode 100644 index 0000000000..5c7452c1fe --- /dev/null +++ b/pkg/webhooks/pytorch/pytorchjob_webhook.go @@ -0,0 +1,146 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pytorch + +import ( + "context" + "fmt" + "strings" + + "golang.org/x/exp/slices" + + apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/klog/v2" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/webhook" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + + trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" +) + +var ( + specPath = field.NewPath("spec") + pytorchReplicaSpecPath = specPath.Child("pytorchReplicaSpecs") +) + +type Webhook struct{} + +func SetupWebhook(mgr ctrl.Manager) error { + return ctrl.NewWebhookManagedBy(mgr). + For(&trainingoperator.PyTorchJob{}). + WithValidator(&Webhook{}). + Complete() +} + +// +kubebuilder:webhook:path=/validate-kubeflow-org-v1-pytorchjob,mutating=false,failurePolicy=fail,sideEffects=None,groups=kubeflow.org,resources=pytorchjobs,verbs=create;update,versions=v1,name=vpytorchjob.kb.io,admissionReviewVersions=v1 + +var _ webhook.CustomValidator = &Webhook{} + +func (w *Webhook) ValidateCreate(ctx context.Context, obj runtime.Object) (admission.Warnings, error) { + job := obj.(*trainingoperator.PyTorchJob) + log := ctrl.LoggerFrom(ctx).WithName("pytorchjob-webhook") + log.V(5).Info("Validating create", "pytorchJob", klog.KObj(job)) + warnings, errs := validatePyTorchJob(job) + return warnings, errs.ToAggregate() +} + +func (w *Webhook) ValidateUpdate(ctx context.Context, _ runtime.Object, newObj runtime.Object) (admission.Warnings, error) { + job := newObj.(*trainingoperator.PyTorchJob) + log := ctrl.LoggerFrom(ctx).WithName("pytorchjob-webhook") + log.V(5).Info("Validating update", "pytorchJob", klog.KObj(job)) + warnings, errs := validatePyTorchJob(job) + return warnings, errs.ToAggregate() +} + +func (w *Webhook) ValidateDelete(context.Context, runtime.Object) (admission.Warnings, error) { + return nil, nil +} + +func validatePyTorchJob(job *trainingoperator.PyTorchJob) (admission.Warnings, field.ErrorList) { + var allErrs field.ErrorList + var warnings admission.Warnings + + if errors := apimachineryvalidation.NameIsDNS1035Label(job.ObjectMeta.Name, false); len(errors) != 0 { + allErrs = append(allErrs, field.Invalid(field.NewPath("metadata").Child("name"), job.Name, fmt.Sprintf("should match: %v", strings.Join(errors, ",")))) + } + ws, err := validateSpec(job.Spec) + warnings = append(warnings, ws...) + allErrs = append(allErrs, err...) + return warnings, allErrs +} + +func validateSpec(spec trainingoperator.PyTorchJobSpec) (admission.Warnings, field.ErrorList) { + var allErrs field.ErrorList + var warnings admission.Warnings + + if spec.NprocPerNode != nil && spec.ElasticPolicy != nil && spec.ElasticPolicy.NProcPerNode != nil { + elasticNProcPerNodePath := specPath.Child("elasticPolicy").Child("nProcPerNode") + nprocPerNodePath := specPath.Child("nprocPerNode") + allErrs = append(allErrs, field.Forbidden(elasticNProcPerNodePath, fmt.Sprintf("must not be used with %s", nprocPerNodePath))) + warnings = append(warnings, fmt.Sprintf("%s is deprecated, use %s instead", elasticNProcPerNodePath.String(), nprocPerNodePath.String())) + } + allErrs = append(allErrs, validatePyTorchReplicaSpecs(spec.PyTorchReplicaSpecs, *specPath)...) + return warnings, allErrs +} + +func validatePyTorchReplicaSpecs(rSpecs map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec, specPath field.Path) field.ErrorList { + var allErrs field.ErrorList + + if rSpecs == nil { + allErrs = append(allErrs, field.Required(pytorchReplicaSpecPath, "must be required")) + } + for rType, rSpec := range rSpecs { + rolePath := pytorchReplicaSpecPath.Key(string(rType)) + containersPath := rolePath.Child("template").Child("spec").Child("containers") + + // Make sure the replica type is valid. + validRoleTypes := []string{ + string(trainingoperator.PyTorchJobReplicaTypeMaster), + string(trainingoperator.PyTorchJobReplicaTypeWorker), + } + if !slices.Contains(validRoleTypes, string(rType)) { + allErrs = append(allErrs, field.NotSupported(rolePath, rType, validRoleTypes)) + } + + if rSpec == nil || len(rSpec.Template.Spec.Containers) == 0 { + allErrs = append(allErrs, field.Required(containersPath, "must be specified")) + } + + // Make sure the image is defined in the container + defaultContainerPresent := false + for idx, container := range rSpec.Template.Spec.Containers { + if container.Image == "" { + allErrs = append(allErrs, field.Required(containersPath.Index(idx).Child("image"), "must be required")) + } + if container.Name == trainingoperator.PyTorchJobDefaultContainerName { + defaultContainerPresent = true + } + } + // Make sure there has at least one container named "pytorch" + if !defaultContainerPresent { + allErrs = append(allErrs, field.Required(containersPath, fmt.Sprintf("must have at least one container with name %s", trainingoperator.PyTorchJobDefaultContainerName))) + } + if rType == trainingoperator.PyTorchJobReplicaTypeMaster { + if rSpec.Replicas == nil || int(*rSpec.Replicas) != 1 { + allErrs = append(allErrs, field.Forbidden(rolePath.Child("replicas"), "must be 1")) + } + } + } + return allErrs +} diff --git a/pkg/webhooks/pytorch/pytorchjob_webhook_test.go b/pkg/webhooks/pytorch/pytorchjob_webhook_test.go new file mode 100644 index 0000000000..ae7664ac95 --- /dev/null +++ b/pkg/webhooks/pytorch/pytorchjob_webhook_test.go @@ -0,0 +1,269 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pytorch + +import ( + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/utils/pointer" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + + trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" +) + +func TestValidateV1PyTorchJob(t *testing.T) { + validPyTorchReplicaSpecs := map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.PyTorchJobReplicaTypeMaster: { + Replicas: pointer.Int32(1), + RestartPolicy: trainingoperator.RestartPolicyOnFailure, + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "pytorch", + Image: "docker.io/kubeflowkatib/pytorch-mnist:v1beta1-45c5727", + ImagePullPolicy: corev1.PullAlways, + Command: []string{ + "python3", + "/opt/pytorch-mnist/mnist.py", + "--epochs=1", + }, + }}, + }, + }, + }, + trainingoperator.PyTorchJobReplicaTypeWorker: { + Replicas: pointer.Int32(1), + RestartPolicy: trainingoperator.RestartPolicyOnFailure, + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "pytorch", + Image: "docker.io/kubeflowkatib/pytorch-mnist:v1beta1-45c5727", + ImagePullPolicy: corev1.PullAlways, + Command: []string{ + "python3", + "/opt/pytorch-mnist/mnist.py", + "--epochs=1", + }, + }}, + }, + }, + }, + } + + testCases := map[string]struct { + pytorchJob *trainingoperator.PyTorchJob + wantErr field.ErrorList + wantWarnings admission.Warnings + }{ + "valid PyTorchJob": { + pytorchJob: &trainingoperator.PyTorchJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.PyTorchJobSpec{ + PyTorchReplicaSpecs: validPyTorchReplicaSpecs, + }, + }, + }, + "pytorchJob name does not meet DNS1035": { + pytorchJob: &trainingoperator.PyTorchJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "0-test", + }, + Spec: trainingoperator.PyTorchJobSpec{ + PyTorchReplicaSpecs: validPyTorchReplicaSpecs, + }, + }, + wantErr: field.ErrorList{ + field.Invalid(field.NewPath("metadata").Child("name"), "", ""), + }, + }, + "no containers": { + pytorchJob: &trainingoperator.PyTorchJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.PyTorchJobSpec{ + PyTorchReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.PyTorchJobReplicaTypeWorker: { + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{}, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(pytorchReplicaSpecPath. + Key(string(trainingoperator.PyTorchJobReplicaTypeWorker)). + Child("template"). + Child("spec"). + Child("containers"), ""), + field.Required(pytorchReplicaSpecPath. + Key(string(trainingoperator.PyTorchJobReplicaTypeWorker)). + Child("template"). + Child("spec"). + Child("containers"), ""), + }, + }, + "image is empty": { + pytorchJob: &trainingoperator.PyTorchJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.PyTorchJobSpec{ + PyTorchReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.PyTorchJobReplicaTypeWorker: { + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "pytorch", + Image: "", + }, + }, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(pytorchReplicaSpecPath. + Key(string(trainingoperator.PyTorchJobReplicaTypeWorker)). + Child("template"). + Child("spec"). + Child("containers"). + Index(0). + Child("image"), ""), + }, + }, + "pytorchJob default container name doesn't present": { + pytorchJob: &trainingoperator.PyTorchJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.PyTorchJobSpec{ + PyTorchReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.PyTorchJobReplicaTypeWorker: { + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "", + Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0", + }, + }, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(pytorchReplicaSpecPath. + Key(string(trainingoperator.PyTorchJobReplicaTypeWorker)). + Child("template"). + Child("spec"). + Child("containers"), ""), + }, + }, + "the number of replicas in masterReplica is other than 1": { + pytorchJob: &trainingoperator.PyTorchJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.PyTorchJobSpec{ + PyTorchReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.PyTorchJobReplicaTypeMaster: { + Replicas: pointer.Int32(2), + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "pytorch", + Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0", + }, + }, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Forbidden(pytorchReplicaSpecPath.Key(string(trainingoperator.PyTorchJobReplicaTypeMaster)).Child("replicas"), ""), + }, + }, + "Spec.NprocPerNode and Spec.ElasticPolicy.NProcPerNode are set": { + pytorchJob: &trainingoperator.PyTorchJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.PyTorchJobSpec{ + NprocPerNode: pointer.String("1"), + ElasticPolicy: &trainingoperator.ElasticPolicy{ + NProcPerNode: pointer.Int32(1), + }, + PyTorchReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.PyTorchJobReplicaTypeMaster: { + Replicas: pointer.Int32(1), + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "pytorch", + Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0", + }, + }, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Forbidden(specPath.Child("elasticPolicy").Child("nProcPerNode"), ""), + }, + wantWarnings: admission.Warnings{ + fmt.Sprintf("%s is deprecated, use %s instead", + specPath.Child("elasticPolicy").Child("nProcPerNode"), specPath.Child("nprocPerNode")), + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + gotWarnings, gotError := validatePyTorchJob(tc.pytorchJob) + if diff := cmp.Diff(tc.wantWarnings, gotWarnings, cmpopts.SortSlices(func(a, b string) bool { return a < b })); len(diff) != 0 { + t.Errorf("Unexpected warnings (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantErr, gotError, cmpopts.IgnoreFields(field.Error{}, "Detail", "BadValue")); len(diff) != 0 { + t.Errorf("Unexpected errors (-want,+got):\n%s", diff) + } + }) + } +} diff --git a/pkg/webhooks/webhooks.go b/pkg/webhooks/webhooks.go new file mode 100644 index 0000000000..5e97a3d3f3 --- /dev/null +++ b/pkg/webhooks/webhooks.go @@ -0,0 +1,41 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package webhooks + +import ( + "sigs.k8s.io/controller-runtime/pkg/manager" + + trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + "github.com/kubeflow/training-operator/pkg/webhooks/pytorch" +) + +type WebhookSetupFunc func(manager manager.Manager) error + +var ( + SupportedSchemeWebhook = map[string]WebhookSetupFunc{ + trainingoperator.PyTorchJobKind: pytorch.SetupWebhook, + trainingoperator.TFJobKind: scaffold, + trainingoperator.MXJobKind: scaffold, + trainingoperator.XGBoostJobKind: scaffold, + trainingoperator.MPIJobKind: scaffold, + trainingoperator.PaddleJobKind: scaffold, + } +) + +func scaffold(manager.Manager) error { + return nil +}