diff --git a/Makefile b/Makefile index c374267257..1727392003 100644 --- a/Makefile +++ b/Makefile @@ -37,7 +37,10 @@ help: ## Display this help. ##@ Development manifests: controller-gen ## Generate WebhookConfiguration, ClusterRole and CustomResourceDefinition objects. - $(CONTROLLER_GEN) $(CRD_OPTIONS) rbac:roleName=training-operator webhook paths="./pkg/..." output:crd:artifacts:config=manifests/base/crds output:rbac:artifacts:config=manifests/base/rbac + $(CONTROLLER_GEN) $(CRD_OPTIONS) rbac:roleName=training-operator webhook paths="./pkg/..." \ + output:crd:artifacts:config=manifests/base/crds \ + output:rbac:artifacts:config=manifests/base/rbac \ + output:webhook:artifacts:config=manifests/base/webhook generate: controller-gen ## Generate apidoc, sdk and code containing DeepCopy, DeepCopyInto, and DeepCopyObject method implementations. $(CONTROLLER_GEN) object:headerFile="hack/boilerplate/boilerplate.go.txt" paths="./pkg/apis/..." diff --git a/cmd/training-operator.v1/main.go b/cmd/training-operator.v1/main.go index 9acd3f52ba..13aa566daa 100644 --- a/cmd/training-operator.v1/main.go +++ b/cmd/training-operator.v1/main.go @@ -19,6 +19,7 @@ package main import ( "errors" "flag" + "net/http" "os" "strings" @@ -40,9 +41,11 @@ import ( volcanoclient "volcano.sh/apis/pkg/client/clientset/versioned" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + "github.com/kubeflow/training-operator/pkg/cert" "github.com/kubeflow/training-operator/pkg/config" controllerv1 "github.com/kubeflow/training-operator/pkg/controller.v1" "github.com/kubeflow/training-operator/pkg/controller.v1/common" + "github.com/kubeflow/training-operator/pkg/webhooks" //+kubebuilder:scaffold:imports ) @@ -72,8 +75,11 @@ func main() { var enabledSchemes controllerv1.EnabledSchemes var gangSchedulerName string var namespace string - var webhookServerPort int var controllerThreads int + var webhookServerPort int + var webhookServiceName string + var webhookSecretName string + flag.StringVar(&metricsAddr, "metrics-bind-address", ":8080", "The address the metric endpoint binds to.") flag.StringVar(&probeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.") flag.BoolVar(&enableLeaderElection, "leader-elect", false, @@ -86,7 +92,6 @@ func main() { " Note: If you set another scheduler name, the training-operator assumes it's the scheduler-plugins.") flag.StringVar(&namespace, "namespace", os.Getenv(EnvKubeflowNamespace), "The namespace to monitor kubeflow jobs. If unset, it monitors all namespaces cluster-wide."+ "If set, it only monitors kubeflow jobs in the given namespace.") - flag.IntVar(&webhookServerPort, "webhook-server-port", 9443, "Endpoint port for the webhook server.") flag.IntVar(&controllerThreads, "controller-threads", 1, "Number of worker threads used by the controller.") // PyTorch related flags @@ -101,6 +106,11 @@ func main() { flag.StringVar(&config.Config.MPIKubectlDeliveryImage, "mpi-kubectl-delivery-image", config.MPIKubectlDeliveryImageDefault, "The image for mpi launcher init container") + // Cert generation flags + flag.IntVar(&webhookServerPort, "webhook-server-port", 9443, "Endpoint port for the webhook server.") + flag.StringVar(&webhookServiceName, "webhook-service-name", "training-operator", "Name of the Service used as part of the DNSName") + flag.StringVar(&webhookSecretName, "webhook-secret-name", "training-operator-webhook-cert", "Name of the Secret to store CA and server certs") + opts := zap.Options{ Development: true, StacktraceLevel: zapcore.DPanicLevel, @@ -124,9 +134,9 @@ func main() { Metrics: metricsserver.Options{ BindAddress: metricsAddr, }, - WebhookServer: &webhook.DefaultServer{Options: webhook.Options{ + WebhookServer: webhook.NewServer(webhook.Options{ Port: webhookServerPort, - }}, + }), HealthProbeBindAddress: probeAddr, LeaderElection: enableLeaderElection, LeaderElectionID: leaderElectionID, @@ -137,20 +147,23 @@ func main() { os.Exit(1) } - // Set up controllers using goroutines to start the manager quickly. - go setupControllers(mgr, enabledSchemes, gangSchedulerName, controllerThreads) - - //+kubebuilder:scaffold:builder - - if err := mgr.AddHealthzCheck("healthz", healthz.Ping); err != nil { - setupLog.Error(err, "unable to set up health check") - os.Exit(1) + certsReady := make(chan struct{}) + defer close(certsReady) + certGenerationConfig := cert.Config{ + WebhookSecretName: webhookSecretName, + WebhookServiceName: webhookServiceName, } - if err := mgr.AddReadyzCheck("readyz", healthz.Ping); err != nil { - setupLog.Error(err, "unable to set up ready check") + if err = cert.ManageCerts(mgr, certGenerationConfig, certsReady); err != nil { + setupLog.Error(err, "Unable to set up cert rotation") os.Exit(1) } + setupProbeEndpoints(mgr, certsReady) + // Set up controllers using goroutines to start the manager quickly. + go setupControllers(mgr, enabledSchemes, gangSchedulerName, controllerThreads, certsReady) + + //+kubebuilder:scaffold:builder + setupLog.Info("starting manager") if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil { setupLog.Error(err, "problem running manager") @@ -158,9 +171,12 @@ func main() { } } -func setupControllers(mgr ctrl.Manager, enabledSchemes controllerv1.EnabledSchemes, gangSchedulerName string, controllerThreads int) { - setupLog.Info("registering controllers...") +func setupControllers(mgr ctrl.Manager, enabledSchemes controllerv1.EnabledSchemes, gangSchedulerName string, controllerThreads int, certsReady <-chan struct{}) { + setupLog.Info("Waiting for certificate generation to complete") + <-certsReady + setupLog.Info("Certs ready") + setupLog.Info("registering controllers...") // Prepare GangSchedulingSetupFunc gangSchedulingSetupFunc := common.GenNonGangSchedulerSetupFunc() if strings.EqualFold(gangSchedulerName, string(common.GangSchedulerVolcano)) { @@ -182,15 +198,52 @@ func setupControllers(mgr ctrl.Manager, enabledSchemes controllerv1.EnabledSchem } errMsg := "failed to set up controllers" for _, s := range enabledSchemes { - setupFunc, supported := controllerv1.SupportedSchemeReconciler[s] - if !supported { + setupReconcilerFunc, supportedReconciler := controllerv1.SupportedSchemeReconciler[s] + if !supportedReconciler { setupLog.Error(errors.New(errMsg), "scheme is not supported", "scheme", s) os.Exit(1) } - if err := setupFunc(mgr, gangSchedulingSetupFunc, controllerThreads); err != nil { + if err := setupReconcilerFunc(mgr, gangSchedulingSetupFunc, controllerThreads); err != nil { setupLog.Error(errors.New(errMsg), "unable to create controller", "scheme", s) os.Exit(1) } + setupWebhookFunc, supportedWebhook := webhooks.SupportedSchemeWebhook[s] + if !supportedWebhook { + setupLog.Error(errors.New(errMsg), "scheme is not supported", "scheme", s) + os.Exit(1) + } + if err := setupWebhookFunc(mgr); err != nil { + setupLog.Error(errors.New(errMsg), "unable to start webhook server", "scheme", s) + os.Exit(1) + } + } +} + +func setupProbeEndpoints(mgr ctrl.Manager, certsReady <-chan struct{}) { + defer setupLog.Info("Probe endpoints are configured on healthz and readyz") + + if err := mgr.AddHealthzCheck("healthz", healthz.Ping); err != nil { + setupLog.Error(err, "unable to set up health check") + os.Exit(1) + } + + // Wait for the webhook server to be listening before advertising the + // training-operator replica as ready. This allows users to wait with sending the first + // requests, requiring webhooks, until the training-operator deployment is available, so + // that the early requests are not rejected during the traininig-operator's startup. + // We wrap the call to GetWebhookServer in a closure to delay calling + // the function, otherwise a not fully-initialized webhook server (without + // ready certs) fails the start of the manager. + if err := mgr.AddReadyzCheck("readyz", func(req *http.Request) error { + select { + case <-certsReady: + return mgr.GetWebhookServer().StartedChecker()(req) + default: + return errors.New("certificates are not ready") + } + }); err != nil { + setupLog.Error(err, "unable to set up ready check") + os.Exit(1) } } diff --git a/go.mod b/go.mod index 43f4b15432..54b52e7dda 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/google/go-cmp v0.6.0 github.com/onsi/ginkgo/v2 v2.14.0 github.com/onsi/gomega v1.30.0 + github.com/open-policy-agent/cert-controller v0.10.1 github.com/prometheus/client_golang v1.18.0 github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.8.4 @@ -58,6 +59,7 @@ require ( github.com/prometheus/common v0.45.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e // indirect golang.org/x/mod v0.14.0 // indirect diff --git a/go.sum b/go.sum index 3316faed5f..535095a4e3 100644 --- a/go.sum +++ b/go.sum @@ -84,6 +84,10 @@ github.com/onsi/ginkgo/v2 v2.14.0 h1:vSmGj2Z5YPb9JwCWT6z6ihcUvDhuXLc3sJiqd3jMKAY github.com/onsi/ginkgo/v2 v2.14.0/go.mod h1:JkUdW7JkN0V6rFvsHcJ478egV3XH9NxpD27Hal/PhZw= github.com/onsi/gomega v1.30.0 h1:hvMK7xYz4D3HapigLTeGdId/NcfQx1VHMJc60ew99+8= github.com/onsi/gomega v1.30.0/go.mod h1:9sxs+SwGrKI0+PWe4Fxa9tFQQBG5xSsSbMXOI8PPpoQ= +github.com/open-policy-agent/cert-controller v0.10.1 h1:RXSYoyn8FdCenWecRP//UV5nbVfmstNpj4kHQFkvPK4= +github.com/open-policy-agent/cert-controller v0.10.1/go.mod h1:4uRbBLY5DsPOog+a9pqk3JLxuuhrWsbUedQW65HcLTI= +github.com/open-policy-agent/frameworks/constraint v0.0.0-20230822235116-f0b62fe1e4c4 h1:5dum5SLEz+95JDLkMls7Z7IDPjvSq3UhJSFe4f5einQ= +github.com/open-policy-agent/frameworks/constraint v0.0.0-20230822235116-f0b62fe1e4c4/go.mod h1:54/KzLMvA5ndBVpm7B1OjLeV0cUtTLTz2bZ2OtydLpU= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -115,6 +119,8 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= @@ -205,6 +211,8 @@ k8s.io/gengo v0.0.0-20230829151522-9cce18d56c01/go.mod h1:FiNAH4ZV3gBg2Kwh89tzAE k8s.io/klog/v2 v2.2.0/go.mod h1:Od+F08eJP+W3HUb4pSrPpgp9DGU4GzlpG/TmITuYh/Y= k8s.io/klog/v2 v2.110.1 h1:U/Af64HJf7FcwMcXyKm2RPM22WZzyR7OSpYj5tg3cL0= k8s.io/klog/v2 v2.110.1/go.mod h1:YGtd1984u+GgbuZ7e08/yBuAfKLSO0+uR1Fhi6ExXjo= +k8s.io/kube-aggregator v0.28.1 h1:rvG4llYnQKHjj6YjjoBPEJxfD1uH0DJwkrJTNKGAaCs= +k8s.io/kube-aggregator v0.28.1/go.mod h1:JaLizMe+AECSpO2OmrWVsvnG0V3dX1RpW+Wq/QHbu18= k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00 h1:aVUu9fTY98ivBPKR9Y5w/AuzbMm96cd3YHRTU83I780= k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00/go.mod h1:AsvuZPBlUDVuCdzJ87iajxtXuR9oktsTctW/R9wwouA= k8s.io/utils v0.0.0-20230726121419-3b25d923346b h1:sgn3ZU783SCgtaSJjpcVVlRqd6GSnlTLKgpAAttJvpI= diff --git a/manifests/base/deployment.yaml b/manifests/base/deployment.yaml index de4dc8740e..b38295a4a0 100644 --- a/manifests/base/deployment.yaml +++ b/manifests/base/deployment.yaml @@ -23,6 +23,9 @@ spec: name: training-operator ports: - containerPort: 8080 + - containerPort: 9443 + name: webhook-server + protocol: TCP env: - name: MY_POD_NAMESPACE valueFrom: @@ -34,6 +37,10 @@ spec: fieldPath: metadata.name securityContext: allowPrivilegeEscalation: false + volumeMounts: + - mountPath: /tmp/k8s-webhook-server/serving-certs + name: cert + readOnly: true livenessProbe: httpGet: path: /healthz @@ -50,3 +57,8 @@ spec: timeoutSeconds: 3 serviceAccountName: training-operator terminationGracePeriodSeconds: 10 + volumes: + - name: cert + secret: + defaultMode: 420 + secretName: training-operator-webhook-cert diff --git a/manifests/base/kustomization.yaml b/manifests/base/kustomization.yaml index 1308bb6da2..b140be1441 100644 --- a/manifests/base/kustomization.yaml +++ b/manifests/base/kustomization.yaml @@ -5,5 +5,6 @@ resources: - ./rbac/cluster-role-binding.yaml - ./rbac/role.yaml - ./rbac/service-account.yaml + - ./webhook - service.yaml - deployment.yaml diff --git a/manifests/base/rbac/role.yaml b/manifests/base/rbac/role.yaml index 4c77d2fae6..60cdbf2aa5 100644 --- a/manifests/base/rbac/role.yaml +++ b/manifests/base/rbac/role.yaml @@ -43,6 +43,15 @@ rules: - pods/exec verbs: - create +- apiGroups: + - "" + resources: + - secrets + verbs: + - get + - list + - update + - watch - apiGroups: - "" resources: @@ -62,6 +71,15 @@ rules: - get - list - watch +- apiGroups: + - admissionregistration.k8s.io + resources: + - validatingwebhookconfigurations + verbs: + - get + - list + - update + - watch - apiGroups: - autoscaling resources: diff --git a/manifests/base/service.yaml b/manifests/base/service.yaml index d26aa20b06..4f2300aedf 100644 --- a/manifests/base/service.yaml +++ b/manifests/base/service.yaml @@ -1,4 +1,3 @@ ---- apiVersion: v1 kind: Service metadata: @@ -11,9 +10,13 @@ metadata: name: training-operator spec: ports: - - name: monitoring-port - port: 8080 - targetPort: 8080 + - name: monitoring-port + port: 8080 + targetPort: 8080 + - name: webhook-server + port: 443 + protocol: TCP + targetPort: 9443 selector: control-plane: kubeflow-training-operator type: ClusterIP diff --git a/manifests/base/webhook/kustomization.yaml b/manifests/base/webhook/kustomization.yaml new file mode 100644 index 0000000000..1523fc1f35 --- /dev/null +++ b/manifests/base/webhook/kustomization.yaml @@ -0,0 +1,15 @@ +apiVersion: kustomize.config.k8s.io/v1beta1 +kind: Kustomization +resources: + - manifests.yaml +commonLabels: + control-plane: kubeflow-training-operator +patches: + - path: patch.yaml + target: + group: admissionregistration.k8s.io + version: v1 + kind: ValidatingWebhookConfiguration + +configurations: + - kustomizeconfig.yaml diff --git a/manifests/base/webhook/kustomizeconfig.yaml b/manifests/base/webhook/kustomizeconfig.yaml new file mode 100644 index 0000000000..8b55ef316b --- /dev/null +++ b/manifests/base/webhook/kustomizeconfig.yaml @@ -0,0 +1,10 @@ +# the following config is for teaching kustomize where to look at when substituting vars. +# It requires kustomize v2.1.0 or newer to work properly. +namespace: + - kind: ValidatingWebhookConfiguration + group: admissionregistration.k8s.io + path: webhooks/clientConfig/service/namespace + create: true + +varReference: + - path: metadata/annotations diff --git a/manifests/base/webhook/manifests.yaml b/manifests/base/webhook/manifests.yaml new file mode 100644 index 0000000000..ea3fad7c4c --- /dev/null +++ b/manifests/base/webhook/manifests.yaml @@ -0,0 +1,26 @@ +--- +apiVersion: admissionregistration.k8s.io/v1 +kind: ValidatingWebhookConfiguration +metadata: + name: validating-webhook-configuration +webhooks: +- admissionReviewVersions: + - v1 + clientConfig: + service: + name: webhook-service + namespace: system + path: /validate-kubeflow-org-v1-pytorchjob + failurePolicy: Fail + name: validator.pytorchjob.training-operator.kubeflow.org + rules: + - apiGroups: + - kubeflow.org + apiVersions: + - v1 + operations: + - CREATE + - UPDATE + resources: + - pytorchjobs + sideEffects: None diff --git a/manifests/base/webhook/patch.yaml b/manifests/base/webhook/patch.yaml new file mode 100644 index 0000000000..060dfc8a52 --- /dev/null +++ b/manifests/base/webhook/patch.yaml @@ -0,0 +1,6 @@ +- op: replace + path: /webhooks/0/clientConfig/service/name + value: training-operator +- op: replace + path: /metadata/name + value: validator.training-operator.kubeflow.org diff --git a/manifests/overlays/kubeflow/kustomization.yaml b/manifests/overlays/kubeflow/kustomization.yaml index 3f6f99f296..b7d06ab17f 100644 --- a/manifests/overlays/kubeflow/kustomization.yaml +++ b/manifests/overlays/kubeflow/kustomization.yaml @@ -7,3 +7,9 @@ resources: images: - name: kubeflow/training-operator newTag: "v1-855e096" +# TODO (tenzen-y): Once we support cert-manager, we need to remove this secret generation. +# REF: https://github.com/kubeflow/training-operator/issues/2049 +secretGenerator: + - name: training-operator-webhook-cert + options: + disableNameSuffixHash: true diff --git a/manifests/overlays/standalone/kustomization.yaml b/manifests/overlays/standalone/kustomization.yaml index 975dd75b37..65ad5d3843 100644 --- a/manifests/overlays/standalone/kustomization.yaml +++ b/manifests/overlays/standalone/kustomization.yaml @@ -7,3 +7,7 @@ resources: images: - name: kubeflow/training-operator newTag: "v1-855e096" +secretGenerator: + - name: training-operator-webhook-cert + options: + disableNameSuffixHash: true diff --git a/pkg/apis/kubeflow.org/v1/pytorch_validation.go b/pkg/apis/kubeflow.org/v1/pytorch_validation.go deleted file mode 100644 index 752b8196df..0000000000 --- a/pkg/apis/kubeflow.org/v1/pytorch_validation.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2018 The Kubeflow Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package v1 - -import ( - "fmt" - - apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" -) - -func ValidateV1PyTorchJob(pytorchJob *PyTorchJob) error { - if errors := apimachineryvalidation.NameIsDNS1035Label(pytorchJob.ObjectMeta.Name, false); errors != nil { - return fmt.Errorf("PyTorchJob name is invalid: %v", errors) - } - if err := validatePyTorchReplicaSpecs(pytorchJob.Spec.PyTorchReplicaSpecs); err != nil { - return err - } - if err := validateNprocPerNode(pytorchJob); err != nil { - return err - } - return nil -} - -func validateNprocPerNode(pytorchJob *PyTorchJob) error { - if pytorchJob.Spec.NprocPerNode != nil && pytorchJob.Spec.ElasticPolicy != nil && pytorchJob.Spec.ElasticPolicy.NProcPerNode != nil { - return fmt.Errorf(".spec.elasticPolicy.nProcPerNode is deprecated, use .spec.nprocPerNode instead") - } - return nil -} - -func validatePyTorchReplicaSpecs(specs map[ReplicaType]*ReplicaSpec) error { - if specs == nil { - return fmt.Errorf("PyTorchJobSpec is not valid") - } - for rType, value := range specs { - if value == nil || len(value.Template.Spec.Containers) == 0 { - return fmt.Errorf("PyTorchJobSpec is not valid: containers definition expected in %v", rType) - } - // Make sure the replica type is valid. - validReplicaTypes := []ReplicaType{PyTorchJobReplicaTypeMaster, PyTorchJobReplicaTypeWorker} - - isValidReplicaType := false - for _, t := range validReplicaTypes { - if t == rType { - isValidReplicaType = true - break - } - } - - if !isValidReplicaType { - return fmt.Errorf("PyTorchReplicaType is %v but must be one of %v", rType, validReplicaTypes) - } - - //Make sure the image is defined in the container - defaultContainerPresent := false - for _, container := range value.Template.Spec.Containers { - if container.Image == "" { - msg := fmt.Sprintf("PyTorchJobSpec is not valid: Image is undefined in the container of %v", rType) - return fmt.Errorf(msg) - } - if container.Name == PyTorchJobDefaultContainerName { - defaultContainerPresent = true - } - } - //Make sure there has at least one container named "pytorch" - if !defaultContainerPresent { - msg := fmt.Sprintf("PyTorchJobSpec is not valid: There is no container named %s in %v", PyTorchJobDefaultContainerName, rType) - return fmt.Errorf(msg) - } - if rType == PyTorchJobReplicaTypeMaster { - if value.Replicas != nil && int(*value.Replicas) != 1 { - return fmt.Errorf("PyTorchJobSpec is not valid: There must be only 1 master replica") - } - } - - } - - return nil - -} diff --git a/pkg/apis/kubeflow.org/v1/pytorch_validation_test.go b/pkg/apis/kubeflow.org/v1/pytorch_validation_test.go deleted file mode 100644 index 1be8f9922e..0000000000 --- a/pkg/apis/kubeflow.org/v1/pytorch_validation_test.go +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright 2018 The Kubeflow Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package v1 - -import ( - "testing" - - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/utils/ptr" -) - -func TestValidateV1PyTorchJob(t *testing.T) { - validPyTorchReplicaSpecs := map[ReplicaType]*ReplicaSpec{ - PyTorchJobReplicaTypeMaster: { - Replicas: ptr.To[int32](1), - RestartPolicy: RestartPolicyOnFailure, - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{{ - Name: "pytorch", - Image: "docker.io/kubeflowkatib/pytorch-mnist:v1beta1-45c5727", - ImagePullPolicy: corev1.PullAlways, - Command: []string{ - "python3", - "/opt/pytorch-mnist/mnist.py", - "--epochs=1", - }, - }}, - }, - }, - }, - PyTorchJobReplicaTypeWorker: { - Replicas: ptr.To[int32](1), - RestartPolicy: RestartPolicyOnFailure, - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{{ - Name: "pytorch", - Image: "docker.io/kubeflowkatib/pytorch-mnist:v1beta1-45c5727", - ImagePullPolicy: corev1.PullAlways, - Command: []string{ - "python3", - "/opt/pytorch-mnist/mnist.py", - "--epochs=1", - }, - }}, - }, - }, - }, - } - - testCases := map[string]struct { - pytorchJob *PyTorchJob - wantErr bool - }{ - "valid PyTorchJob": { - pytorchJob: &PyTorchJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: PyTorchJobSpec{ - PyTorchReplicaSpecs: validPyTorchReplicaSpecs, - }, - }, - wantErr: false, - }, - "pytorchJob name does not meet DNS1035": { - pytorchJob: &PyTorchJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "0-test", - }, - Spec: PyTorchJobSpec{ - PyTorchReplicaSpecs: validPyTorchReplicaSpecs, - }, - }, - wantErr: true, - }, - "no containers": { - pytorchJob: &PyTorchJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: PyTorchJobSpec{ - PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - PyTorchJobReplicaTypeWorker: { - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{}, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - "image is empty": { - pytorchJob: &PyTorchJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: PyTorchJobSpec{ - PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - PyTorchJobReplicaTypeWorker: { - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "pytorch", - Image: "", - }, - }, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - "pytorchJob default container name doesn't present": { - pytorchJob: &PyTorchJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: PyTorchJobSpec{ - PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - PyTorchJobReplicaTypeWorker: { - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "", - Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0", - }, - }, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - "the number of replicas in masterReplica is other than 1": { - pytorchJob: &PyTorchJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: PyTorchJobSpec{ - PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - PyTorchJobReplicaTypeMaster: { - Replicas: ptr.To[int32](2), - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "pytorch", - Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0", - }, - }, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - "Spec.NprocPerNode and Spec.ElasticPolicy.NProcPerNode are set": { - pytorchJob: &PyTorchJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: PyTorchJobSpec{ - NprocPerNode: ptr.To("1"), - ElasticPolicy: &ElasticPolicy{ - NProcPerNode: ptr.To[int32](1), - }, - PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - PyTorchJobReplicaTypeMaster: { - Replicas: ptr.To[int32](2), - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "pytorch", - Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0", - }, - }, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - } - - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - got := ValidateV1PyTorchJob(tc.pytorchJob) - if (got != nil) != tc.wantErr { - t.Fatalf("ValidateV1PyTorchJob() error = %v, wantErr %v", got, tc.wantErr) - } - }) - } -} diff --git a/pkg/cert/cert.go b/pkg/cert/cert.go new file mode 100644 index 0000000000..4d1593fab5 --- /dev/null +++ b/pkg/cert/cert.go @@ -0,0 +1,76 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cert + +import ( + "fmt" + cert "github.com/open-policy-agent/cert-controller/pkg/rotator" + "k8s.io/apimachinery/pkg/types" + "os" + ctrl "sigs.k8s.io/controller-runtime" +) + +const ( + certDir = "/tmp/k8s-webhook-server/serving-certs" + vwcName = "validator.training-operator.kubeflow.org" + caName = "training-operator-ca" + caOrganization = "training-operator" + defaultOperatorNamespace = "kubeflow" +) + +type Config struct { + WebhookServiceName string + WebhookSecretName string +} + +// +kubebuilder:rbac:groups="",resources=secrets,verbs=get;list;watch;update +// +kubebuilder:rbac:groups="admissionregistration.k8s.io",resources=validatingwebhookconfigurations,verbs=get;list;watch;update + +// ManageCerts creates all certs for webhooks. +func ManageCerts(mgr ctrl.Manager, cfg Config, setupFinished chan struct{}) error { + var ( + ns = getOperatorNamespace() + // DNSName is ..svc + dnsName = fmt.Sprintf("%s.%s.svc", cfg.WebhookServiceName, ns) + ) + + return cert.AddRotator(mgr, &cert.CertRotator{ + SecretKey: types.NamespacedName{ + Namespace: ns, + Name: cfg.WebhookSecretName, + }, + CertDir: certDir, + CAName: caName, + CAOrganization: caOrganization, + DNSName: dnsName, + IsReady: setupFinished, + Webhooks: []cert.WebhookInfo{{ + Type: cert.Validating, + Name: vwcName, + }}, + // When training-operator is running in the leader election mode, + // we expect webhook server will run in primary and secondary instance + RequireLeaderElection: false, + }) +} + +func getOperatorNamespace() string { + if ns := os.Getenv("MY_POD_NAMESPACE"); ns != "" { + return ns + } + return defaultOperatorNamespace +} diff --git a/pkg/controller.v1/pytorch/pytorchjob_controller.go b/pkg/controller.v1/pytorch/pytorchjob_controller.go index 151530c138..7025c24396 100644 --- a/pkg/controller.v1/pytorch/pytorchjob_controller.go +++ b/pkg/controller.v1/pytorch/pytorchjob_controller.go @@ -132,13 +132,6 @@ func (r *PyTorchJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) return ctrl.Result{}, client.IgnoreNotFound(err) } - if err = kubeflowv1.ValidateV1PyTorchJob(pytorchjob); err != nil { - logger.Error(err, "PyTorchJob failed validation") - r.Recorder.Eventf(pytorchjob, corev1.EventTypeWarning, commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobFailedValidationReason), - "PyTorchJob failed validation because %s", err) - return ctrl.Result{}, err - } - // Check if reconciliation is needed jobKey, err := common.KeyFunc(pytorchjob) if err != nil { diff --git a/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go b/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go index 0f8cf9744e..35810c9d1c 100644 --- a/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go +++ b/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go @@ -16,15 +16,19 @@ package pytorch import ( "context" + "crypto/tls" + "fmt" + "net" "path/filepath" "testing" + "time" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/kubeflow/training-operator/pkg/config" "github.com/kubeflow/training-operator/pkg/controller.v1/common" + pytorchwebhook "github.com/kubeflow/training-operator/pkg/webhooks/pytorch" . "github.com/onsi/ginkgo/v2" - "github.com/onsi/gomega" . "github.com/onsi/gomega" "k8s.io/client-go/kubernetes/scheme" ctrl "sigs.k8s.io/controller-runtime" @@ -33,6 +37,7 @@ import ( logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" + "sigs.k8s.io/controller-runtime/pkg/webhook" "volcano.sh/apis/pkg/apis/scheduling/v1beta1" //+kubebuilder:scaffold:imports ) @@ -62,6 +67,9 @@ var _ = BeforeSuite(func() { testEnv = &envtest.Environment{ CRDDirectoryPaths: []string{filepath.Join("..", "..", "..", "manifests", "base", "crds")}, ErrorIfCRDPathMissing: true, + WebhookInstallOptions: envtest.WebhookInstallOptions{ + Paths: []string{filepath.Join("..", "..", "..", "manifests", "base", "webhook", "manifests.yaml")}, + }, } cfg, err := testEnv.Start() @@ -87,19 +95,34 @@ var _ = BeforeSuite(func() { Metrics: metricsserver.Options{ BindAddress: "0", }, + WebhookServer: webhook.NewServer( + webhook.Options{ + Host: testEnv.WebhookInstallOptions.LocalServingHost, + Port: testEnv.WebhookInstallOptions.LocalServingPort, + CertDir: testEnv.WebhookInstallOptions.LocalServingCertDir, + }), }) - Expect(err).NotTo(gomega.HaveOccurred()) + Expect(err).NotTo(HaveOccurred()) gangSchedulingSetupFunc := common.GenNonGangSchedulerSetupFunc() r := NewReconciler(mgr, gangSchedulingSetupFunc) - Expect(r.SetupWithManager(mgr, 1)).NotTo(gomega.HaveOccurred()) + Expect(r.SetupWithManager(mgr, 1)).NotTo(HaveOccurred()) + Expect(pytorchwebhook.SetupWebhook(mgr)).NotTo(HaveOccurred()) go func() { defer GinkgoRecover() err = mgr.Start(testCtx) Expect(err).ToNot(HaveOccurred(), "failed to run manager") }() + + dialer := &net.Dialer{Timeout: time.Second} + addrPort := fmt.Sprintf("%s:%d", testEnv.WebhookInstallOptions.LocalServingHost, testEnv.WebhookInstallOptions.LocalServingPort) + Eventually(func(g Gomega) { + conn, err := tls.DialWithDialer(dialer, "tcp", addrPort, &tls.Config{InsecureSkipVerify: true}) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(conn.Close()).NotTo(HaveOccurred()) + }).Should(Succeed()) }) var _ = AfterSuite(func() { diff --git a/pkg/webhooks/pytorch/pytorchjob_webhook.go b/pkg/webhooks/pytorch/pytorchjob_webhook.go new file mode 100644 index 0000000000..1dd17a3376 --- /dev/null +++ b/pkg/webhooks/pytorch/pytorchjob_webhook.go @@ -0,0 +1,145 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pytorch + +import ( + "context" + "fmt" + "slices" + "strings" + + apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/klog/v2" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/webhook" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + + trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" +) + +var ( + specPath = field.NewPath("spec") + pytorchReplicaSpecPath = specPath.Child("pytorchReplicaSpecs") +) + +type Webhook struct{} + +func SetupWebhook(mgr ctrl.Manager) error { + return ctrl.NewWebhookManagedBy(mgr). + For(&trainingoperator.PyTorchJob{}). + WithValidator(&Webhook{}). + Complete() +} + +// +kubebuilder:webhook:path=/validate-kubeflow-org-v1-pytorchjob,mutating=false,failurePolicy=fail,sideEffects=None,groups=kubeflow.org,resources=pytorchjobs,verbs=create;update,versions=v1,name=validator.pytorchjob.training-operator.kubeflow.org,admissionReviewVersions=v1 + +var _ webhook.CustomValidator = &Webhook{} + +func (w *Webhook) ValidateCreate(ctx context.Context, obj runtime.Object) (admission.Warnings, error) { + job := obj.(*trainingoperator.PyTorchJob) + log := ctrl.LoggerFrom(ctx).WithName("pytorchjob-webhook") + log.V(5).Info("Validating create", "pytorchJob", klog.KObj(job)) + warnings, errs := validatePyTorchJob(job) + return warnings, errs.ToAggregate() +} + +func (w *Webhook) ValidateUpdate(ctx context.Context, _ runtime.Object, newObj runtime.Object) (admission.Warnings, error) { + job := newObj.(*trainingoperator.PyTorchJob) + log := ctrl.LoggerFrom(ctx).WithName("pytorchjob-webhook") + log.V(5).Info("Validating update", "pytorchJob", klog.KObj(job)) + warnings, errs := validatePyTorchJob(job) + return warnings, errs.ToAggregate() +} + +func (w *Webhook) ValidateDelete(context.Context, runtime.Object) (admission.Warnings, error) { + return nil, nil +} + +func validatePyTorchJob(job *trainingoperator.PyTorchJob) (admission.Warnings, field.ErrorList) { + var allErrs field.ErrorList + var warnings admission.Warnings + + if errors := apimachineryvalidation.NameIsDNS1035Label(job.ObjectMeta.Name, false); len(errors) != 0 { + allErrs = append(allErrs, field.Invalid(field.NewPath("metadata").Child("name"), job.Name, fmt.Sprintf("should match: %v", strings.Join(errors, ",")))) + } + ws, err := validateSpec(job.Spec) + warnings = append(warnings, ws...) + allErrs = append(allErrs, err...) + return warnings, allErrs +} + +func validateSpec(spec trainingoperator.PyTorchJobSpec) (admission.Warnings, field.ErrorList) { + var allErrs field.ErrorList + var warnings admission.Warnings + + if spec.NprocPerNode != nil && spec.ElasticPolicy != nil && spec.ElasticPolicy.NProcPerNode != nil { + elasticNProcPerNodePath := specPath.Child("elasticPolicy").Child("nProcPerNode") + nprocPerNodePath := specPath.Child("nprocPerNode") + allErrs = append(allErrs, field.Forbidden(elasticNProcPerNodePath, fmt.Sprintf("must not be used with %s", nprocPerNodePath))) + warnings = append(warnings, fmt.Sprintf("%s is deprecated, use %s instead", elasticNProcPerNodePath.String(), nprocPerNodePath.String())) + } + allErrs = append(allErrs, validatePyTorchReplicaSpecs(spec.PyTorchReplicaSpecs)...) + return warnings, allErrs +} + +func validatePyTorchReplicaSpecs(rSpecs map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec) field.ErrorList { + var allErrs field.ErrorList + + if rSpecs == nil { + allErrs = append(allErrs, field.Required(pytorchReplicaSpecPath, "must be required")) + } + for rType, rSpec := range rSpecs { + rolePath := pytorchReplicaSpecPath.Key(string(rType)) + containersPath := rolePath.Child("template").Child("spec").Child("containers") + + // Make sure the replica type is valid. + validRoleTypes := []trainingoperator.ReplicaType{ + trainingoperator.PyTorchJobReplicaTypeMaster, + trainingoperator.PyTorchJobReplicaTypeWorker, + } + if !slices.Contains(validRoleTypes, rType) { + allErrs = append(allErrs, field.NotSupported(rolePath, rType, validRoleTypes)) + } + + if rSpec == nil || len(rSpec.Template.Spec.Containers) == 0 { + allErrs = append(allErrs, field.Required(containersPath, "must be specified")) + } + + // Make sure the image is defined in the container + defaultContainerPresent := false + for idx, container := range rSpec.Template.Spec.Containers { + if container.Image == "" { + allErrs = append(allErrs, field.Required(containersPath.Index(idx).Child("image"), "must be required")) + } + if container.Name == trainingoperator.PyTorchJobDefaultContainerName { + defaultContainerPresent = true + } + } + // Make sure there has at least one container named "pytorch" + if !defaultContainerPresent { + allErrs = append(allErrs, field.Required(containersPath, fmt.Sprintf("must have at least one container with name %s", trainingoperator.PyTorchJobDefaultContainerName))) + } + if rType == trainingoperator.PyTorchJobReplicaTypeMaster { + if rSpec.Replicas == nil || int(*rSpec.Replicas) != 1 { + allErrs = append(allErrs, field.Forbidden(rolePath.Child("replicas"), "must be 1")) + } + } + } + return allErrs +} diff --git a/pkg/webhooks/pytorch/pytorchjob_webhook_test.go b/pkg/webhooks/pytorch/pytorchjob_webhook_test.go new file mode 100644 index 0000000000..8f2e492293 --- /dev/null +++ b/pkg/webhooks/pytorch/pytorchjob_webhook_test.go @@ -0,0 +1,269 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pytorch + +import ( + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + + trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" +) + +func TestValidateV1PyTorchJob(t *testing.T) { + validPyTorchReplicaSpecs := map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.PyTorchJobReplicaTypeMaster: { + Replicas: ptr.To[int32](1), + RestartPolicy: trainingoperator.RestartPolicyOnFailure, + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "pytorch", + Image: "docker.io/kubeflowkatib/pytorch-mnist:v1beta1-45c5727", + ImagePullPolicy: corev1.PullAlways, + Command: []string{ + "python3", + "/opt/pytorch-mnist/mnist.py", + "--epochs=1", + }, + }}, + }, + }, + }, + trainingoperator.PyTorchJobReplicaTypeWorker: { + Replicas: ptr.To[int32](1), + RestartPolicy: trainingoperator.RestartPolicyOnFailure, + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "pytorch", + Image: "docker.io/kubeflowkatib/pytorch-mnist:v1beta1-45c5727", + ImagePullPolicy: corev1.PullAlways, + Command: []string{ + "python3", + "/opt/pytorch-mnist/mnist.py", + "--epochs=1", + }, + }}, + }, + }, + }, + } + + testCases := map[string]struct { + pytorchJob *trainingoperator.PyTorchJob + wantErr field.ErrorList + wantWarnings admission.Warnings + }{ + "valid PyTorchJob": { + pytorchJob: &trainingoperator.PyTorchJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.PyTorchJobSpec{ + PyTorchReplicaSpecs: validPyTorchReplicaSpecs, + }, + }, + }, + "pytorchJob name does not meet DNS1035": { + pytorchJob: &trainingoperator.PyTorchJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "0-test", + }, + Spec: trainingoperator.PyTorchJobSpec{ + PyTorchReplicaSpecs: validPyTorchReplicaSpecs, + }, + }, + wantErr: field.ErrorList{ + field.Invalid(field.NewPath("metadata").Child("name"), "", ""), + }, + }, + "no containers": { + pytorchJob: &trainingoperator.PyTorchJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.PyTorchJobSpec{ + PyTorchReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.PyTorchJobReplicaTypeWorker: { + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{}, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(pytorchReplicaSpecPath. + Key(string(trainingoperator.PyTorchJobReplicaTypeWorker)). + Child("template"). + Child("spec"). + Child("containers"), ""), + field.Required(pytorchReplicaSpecPath. + Key(string(trainingoperator.PyTorchJobReplicaTypeWorker)). + Child("template"). + Child("spec"). + Child("containers"), ""), + }, + }, + "image is empty": { + pytorchJob: &trainingoperator.PyTorchJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.PyTorchJobSpec{ + PyTorchReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.PyTorchJobReplicaTypeWorker: { + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "pytorch", + Image: "", + }, + }, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(pytorchReplicaSpecPath. + Key(string(trainingoperator.PyTorchJobReplicaTypeWorker)). + Child("template"). + Child("spec"). + Child("containers"). + Index(0). + Child("image"), ""), + }, + }, + "pytorchJob default container name doesn't present": { + pytorchJob: &trainingoperator.PyTorchJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.PyTorchJobSpec{ + PyTorchReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.PyTorchJobReplicaTypeWorker: { + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "", + Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0", + }, + }, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(pytorchReplicaSpecPath. + Key(string(trainingoperator.PyTorchJobReplicaTypeWorker)). + Child("template"). + Child("spec"). + Child("containers"), ""), + }, + }, + "the number of replicas in masterReplica is other than 1": { + pytorchJob: &trainingoperator.PyTorchJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.PyTorchJobSpec{ + PyTorchReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.PyTorchJobReplicaTypeMaster: { + Replicas: ptr.To[int32](2), + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "pytorch", + Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0", + }, + }, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Forbidden(pytorchReplicaSpecPath.Key(string(trainingoperator.PyTorchJobReplicaTypeMaster)).Child("replicas"), ""), + }, + }, + "Spec.NprocPerNode and Spec.ElasticPolicy.NProcPerNode are set": { + pytorchJob: &trainingoperator.PyTorchJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.PyTorchJobSpec{ + NprocPerNode: ptr.To("1"), + ElasticPolicy: &trainingoperator.ElasticPolicy{ + NProcPerNode: ptr.To[int32](1), + }, + PyTorchReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.PyTorchJobReplicaTypeMaster: { + Replicas: ptr.To[int32](1), + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "pytorch", + Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0", + }, + }, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Forbidden(specPath.Child("elasticPolicy").Child("nProcPerNode"), ""), + }, + wantWarnings: admission.Warnings{ + fmt.Sprintf("%s is deprecated, use %s instead", + specPath.Child("elasticPolicy").Child("nProcPerNode"), specPath.Child("nprocPerNode")), + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + gotWarnings, gotError := validatePyTorchJob(tc.pytorchJob) + if diff := cmp.Diff(tc.wantWarnings, gotWarnings, cmpopts.SortSlices(func(a, b string) bool { return a < b })); len(diff) != 0 { + t.Errorf("Unexpected warnings (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantErr, gotError, cmpopts.IgnoreFields(field.Error{}, "Detail", "BadValue")); len(diff) != 0 { + t.Errorf("Unexpected errors (-want,+got):\n%s", diff) + } + }) + } +} diff --git a/pkg/webhooks/webhooks.go b/pkg/webhooks/webhooks.go new file mode 100644 index 0000000000..5e97a3d3f3 --- /dev/null +++ b/pkg/webhooks/webhooks.go @@ -0,0 +1,41 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package webhooks + +import ( + "sigs.k8s.io/controller-runtime/pkg/manager" + + trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + "github.com/kubeflow/training-operator/pkg/webhooks/pytorch" +) + +type WebhookSetupFunc func(manager manager.Manager) error + +var ( + SupportedSchemeWebhook = map[string]WebhookSetupFunc{ + trainingoperator.PyTorchJobKind: pytorch.SetupWebhook, + trainingoperator.TFJobKind: scaffold, + trainingoperator.MXJobKind: scaffold, + trainingoperator.XGBoostJobKind: scaffold, + trainingoperator.MPIJobKind: scaffold, + trainingoperator.PaddleJobKind: scaffold, + } +) + +func scaffold(manager.Manager) error { + return nil +}