diff --git a/pkg/controllers/raycluster_webhook.go b/pkg/controllers/raycluster_webhook.go index d1500559a..ac4fbb745 100644 --- a/pkg/controllers/raycluster_webhook.go +++ b/pkg/controllers/raycluster_webhook.go @@ -100,24 +100,26 @@ func (w *rayClusterWebhook) Default(ctx context.Context, obj runtime.Object) err } // WorkerGroupSpec - if len(rayCluster.Spec.WorkerGroupSpecs) != 0 { + for i := range rayCluster.Spec.WorkerGroupSpecs { + workerSpec := &rayCluster.Spec.WorkerGroupSpecs[i] + // Append the list of environment variables for the worker container for _, envVar := range envVarList() { - rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Env = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Env, envVar, withEnvVarName(envVar.Name)) + workerSpec.Template.Spec.Containers[0].Env = upsert(workerSpec.Template.Spec.Containers[0].Env, envVar, withEnvVarName(envVar.Name)) } - // Append the CA volumes - for _, caVol := range caVolumes(rayCluster) { - rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes, caVol, withVolumeName(caVol.Name)) - } + // Append the CA volumes + for _, caVol := range caVolumes(rayCluster) { + rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes, caVol, withVolumeName(caVol.Name)) + } - // Append the certificate volume mounts - for _, mount := range certVolumeMounts() { - rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].VolumeMounts = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].VolumeMounts, mount, byVolumeMountName) - } + // Append the certificate volume mounts + for _, mount := range certVolumeMounts() { + rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].VolumeMounts = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].VolumeMounts, mount, byVolumeMountName) + } - // Append the create-cert Init Container - rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers, rayWorkerInitContainer(w.Config), withContainerName(initContainerName)) + // Append the create-cert Init Container + rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers, rayWorkerInitContainer(w.Config), withContainerName(initContainerName)) } } @@ -387,14 +389,13 @@ func validateHeadInitContainer(rayCluster *rayv1.RayCluster, config *config.Kube func validateWorkerInitContainer(rayCluster *rayv1.RayCluster, config *config.KubeRayConfiguration) field.ErrorList { var allErrors field.ErrorList - if len(rayCluster.Spec.WorkerGroupSpecs) == 0 { - return allErrors - } - - if err := contains(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers, rayWorkerInitContainer(config), byContainerName, - field.NewPath("spec", "workerGroupSpecs", "0", "template", "spec", "initContainers"), - "create-cert Init Container is immutable"); err != nil { - allErrors = append(allErrors, err) + for i := range rayCluster.Spec.WorkerGroupSpecs { + workerSpec := &rayCluster.Spec.WorkerGroupSpecs[i] + if err := contains(workerSpec.Template.Spec.InitContainers, rayWorkerInitContainer(config), byContainerName, + field.NewPath("spec", "workerGroupSpecs", strconv.Itoa(i), "template", "spec", "initContainers"), + "create-cert Init Container is immutable"); err != nil { + allErrors = append(allErrors, err) + } } return allErrors @@ -409,9 +410,10 @@ func validateCaVolumes(rayCluster *rayv1.RayCluster) field.ErrorList { "ca-vol and server-cert Secret volumes are immutable"); err != nil { allErrors = append(allErrors, err) } - if len(rayCluster.Spec.WorkerGroupSpecs) != 0 { - if err := contains(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes, caVol, byVolumeName, - field.NewPath("spec", "workerGroupSpecs", "0", "template", "spec", "volumes"), + for i := range rayCluster.Spec.WorkerGroupSpecs { + workerSpec := &rayCluster.Spec.WorkerGroupSpecs[i] + if err := contains(workerSpec.Template.Spec.Volumes, caVol, byVolumeName, + field.NewPath("spec", "workerGroupSpecs", strconv.Itoa(i), "template", "spec", "volumes"), "ca-vol and server-cert Secret volumes are immutable"); err != nil { allErrors = append(allErrors, err) } @@ -438,15 +440,14 @@ func validateHeadEnvVars(rayCluster *rayv1.RayCluster) field.ErrorList { func validateWorkerEnvVars(rayCluster *rayv1.RayCluster) field.ErrorList { var allErrors field.ErrorList - if len(rayCluster.Spec.WorkerGroupSpecs) == 0 { - return allErrors - } - - for _, envVar := range envVarList() { - if err := contains(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Env, envVar, byEnvVarName, - field.NewPath("spec", "workerGroupSpecs", "0", "template", "spec", "containers", strconv.Itoa(0), "env"), - "RAY_TLS related environment variables are immutable"); err != nil { - allErrors = append(allErrors, err) + for i := range rayCluster.Spec.WorkerGroupSpecs { + workerSpec := &rayCluster.Spec.WorkerGroupSpecs[i] + for _, envVar := range envVarList() { + if err := contains(workerSpec.Template.Spec.Containers[0].Env, envVar, byEnvVarName, + field.NewPath("spec", "workerGroupSpecs", strconv.Itoa(i), "template", "spec", "containers", strconv.Itoa(0), "env"), + "RAY_TLS related environment variables are immutable"); err != nil { + allErrors = append(allErrors, err) + } } } diff --git a/pkg/controllers/raycluster_webhook_test.go b/pkg/controllers/raycluster_webhook_test.go new file mode 100644 index 000000000..b889101b4 --- /dev/null +++ b/pkg/controllers/raycluster_webhook_test.go @@ -0,0 +1,393 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package controllers + +import ( + "testing" + + "github.com/onsi/gomega" + . "github.com/project-codeflare/codeflare-common/support" + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + + "github.com/project-codeflare/codeflare-operator/pkg/config" +) + +var ( + namespace = "test-namespace" + + rcWebhook = &rayClusterWebhook{ + Config: &config.KubeRayConfiguration{}, + } +) + +func TestRayClusterWebhookDefault(t *testing.T) { + test := NewTest(t) + + validRayCluster := &rayv1.RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-valid-raycluster", + Namespace: namespace, + }, + Spec: rayv1.RayClusterSpec{ + HeadGroupSpec: rayv1.HeadGroupSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{}, + }, + }, + RayStartParams: map[string]string{}, + }, + WorkerGroupSpecs: []rayv1.WorkerGroupSpec{ + { + GroupName: "worker-group-1", + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "worker-container-1", + }, + }, + }, + }, + RayStartParams: map[string]string{}, + }, + { + GroupName: "worker-group-2", + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "worker-container-2", + }, + }, + }, + }, + RayStartParams: map[string]string{}, + }, + }, + }, + } + + // Create the RayClusters + test.Client().Ray().RayV1().RayClusters(namespace).Create(test.Ctx(), validRayCluster, metav1.CreateOptions{}) + + // Call to default function is made + err := rcWebhook.Default(test.Ctx(), runtime.Object(validRayCluster)) + t.Run("Expected no errors on call to Default function", func(t *testing.T) { + g := gomega.NewWithT(t) + g.Expect(err).To(gomega.BeNil()) + }) + + t.Run("Expected required OAuth proxy container for the head group", func(t *testing.T) { + g := gomega.NewWithT(t) + g.Expect(func() bool { + for _, container := range validRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers { + if container.Name == oauthProxyContainerName { + return true + } + } + return false + }()).To(gomega.BeTrue(), "Expected the OAuth proxy container to be present in the list of head group containers") + }) + + t.Run("Expected required OAuth proxy TLS secret volume for the head group", func(t *testing.T) { + g := gomega.NewWithT(t) + g.Expect(func() bool { + for _, volume := range validRayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes { + if volume.Name == oauthProxyVolumeName { + return true + } + } + return false + }()).To(gomega.BeTrue(), "Expected the OAuth proxy TLS secret volume to be present in the list of head group volumes") + }) + + t.Run("Expected required service account name for the head group", func(t *testing.T) { + g := gomega.NewWithT(t) + expectedServiceAccountName := validRayCluster.Name + "-oauth-proxy" + actualServiceAccountName := validRayCluster.Spec.HeadGroupSpec.Template.Spec.ServiceAccountName + g.Expect(actualServiceAccountName).To(gomega.Equal(expectedServiceAccountName), "Expected the service account name to be set correctly") + }) + + t.Run("Expected required environment variables for the head group", func(t *testing.T) { + g := gomega.NewWithT(t) + expectedEnvVars := map[string]bool{ + "MY_POD_IP": false, + "RAY_USE_TLS": false, + "RAY_TLS_SERVER_CERT": false, + "RAY_TLS_SERVER_KEY": false, + "RAY_TLS_CA_CERT": false, + } + for _, envVar := range validRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[0].Env { + if _, found := expectedEnvVars[envVar.Name]; found { + expectedEnvVars[envVar.Name] = true + } + } + for _, found := range expectedEnvVars { + g.Expect(found).To(gomega.BeTrue(), "Expected required environment variables to be present in the head group") + } + }) + + t.Run("Expected required create-cert init container for the head group", func(t *testing.T) { + g := gomega.NewWithT(t) + g.Expect(func() bool { + for _, initContainer := range validRayCluster.Spec.HeadGroupSpec.Template.Spec.InitContainers { + if initContainer.Name == initContainerName { + return true + } + } + return false + }()).To(gomega.BeTrue(), "Expected the create-cert init container to be present in the list of head group init containers") + }) + + t.Run("Expected required CA volumes for the head group", func(t *testing.T) { + g := gomega.NewWithT(t) + expectedCaVolumes := map[string]bool{ + "ca-vol": false, + "server-cert": false, + } + for _, caVolume := range validRayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes { + if _, found := expectedCaVolumes[caVolume.Name]; found { + expectedCaVolumes[caVolume.Name] = true + } + } + for _, found := range expectedCaVolumes { + g.Expect(found).To(gomega.BeTrue(), "Expected required CA volumes to be present in the head group") + } + }) + + t.Run("Expected required certificate volume mounts for the head group", func(t *testing.T) { + g := gomega.NewWithT(t) + expectedCertVolumeMounts := map[string]bool{ + "ca-vol": false, + "server-cert": false, + } + for _, certVolumeMount := range validRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[0].VolumeMounts { + if _, found := expectedCertVolumeMounts[certVolumeMount.Name]; found { + expectedCertVolumeMounts[certVolumeMount.Name] = true + } + } + for _, found := range expectedCertVolumeMounts { + g.Expect(found).To(gomega.BeTrue(), "Expected required certificate volume mounts to be present in the head group") + } + }) + + t.Run("Expected required environment variables for each worker group", func(t *testing.T) { + g := gomega.NewWithT(t) + + checkEnvVars := func(envs []corev1.EnvVar) { + expectedEnvVars := map[string]bool{ + "MY_POD_IP": false, + "RAY_USE_TLS": false, + "RAY_TLS_SERVER_CERT": false, + "RAY_TLS_SERVER_KEY": false, + "RAY_TLS_CA_CERT": false, + } + for _, envVar := range envs { + if _, found := expectedEnvVars[envVar.Name]; found { + expectedEnvVars[envVar.Name] = true + } + } + for _, found := range expectedEnvVars { + g.Expect(found).To(gomega.BeTrue(), "Expected required environment variables to be present in each worker group") + } + } + + // Check each worker group independently + for _, workerSpec := range validRayCluster.Spec.WorkerGroupSpecs { + workerEnvVars := workerSpec.Template.Spec.Containers[0].Env + checkEnvVars(workerEnvVars) + } + }) + + t.Run("Expected required CA Volumes for each worker group", func(t *testing.T) { + g := gomega.NewWithT(t) + + checkCaVolumes := func(caVolumes []corev1.Volume) { + expectedCaVolumes := map[string]bool{ + "ca-vol": false, + "server-cert": false, + } + for _, caVolume := range caVolumes { + if _, found := expectedCaVolumes[caVolume.Name]; found { + expectedCaVolumes[caVolume.Name] = true + } + } + for _, found := range expectedCaVolumes { + g.Expect(found).To(gomega.BeTrue(), "Expected required CA volumes to be present in each worker group") + } + } + + // Check each worker group independently + for _, workerSpec := range validRayCluster.Spec.WorkerGroupSpecs { + workerCaVolumes := workerSpec.Template.Spec.Volumes + checkCaVolumes(workerCaVolumes) + } + }) + + t.Run("Expected required certificate volume mounts for each worker group", func(t *testing.T) { + g := gomega.NewWithT(t) + + checkCertVolumeMounts := func(certVolumeMounts []corev1.VolumeMount) { + expectedCertVolumeMounts := map[string]bool{ + "ca-vol": false, + "server-cert": false, + } + for _, certVolumeMount := range certVolumeMounts { + if _, found := expectedCertVolumeMounts[certVolumeMount.Name]; found { + expectedCertVolumeMounts[certVolumeMount.Name] = true + } + } + for _, found := range expectedCertVolumeMounts { + g.Expect(found).To(gomega.BeTrue(), "Expected required certificate volume mounts to be present in the worker group") + } + } + + // Check each worker group independently + for _, workerSpec := range validRayCluster.Spec.WorkerGroupSpecs { + workerCertVolumeMounts := workerSpec.Template.Spec.Containers[0].VolumeMounts + checkCertVolumeMounts(workerCertVolumeMounts) + } + }) + + t.Run("Expected required init container for each worker group", func(t *testing.T) { + g := gomega.NewWithT(t) + + checkInitContainers := func(initContainers []corev1.Container) { + expectedInitContainers := map[string]bool{ + initContainerName: false, + } + for _, initContainer := range initContainers { + if _, found := expectedInitContainers[initContainer.Name]; found { + expectedInitContainers[initContainer.Name] = true + } + } + for _, found := range expectedInitContainers { + g.Expect(found).To(gomega.BeTrue(), "Expected required init container to be present in each worker group") + } + } + + // Check each worker group independently + for _, workerSpec := range validRayCluster.Spec.WorkerGroupSpecs { + workerInitContainers := workerSpec.Template.Spec.InitContainers + checkInitContainers(workerInitContainers) + } + }) +} + +func TestValidateCreate(t *testing.T) { + test := NewTest(t) + + validRayCluster := &rayv1.RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-valid-raycluster", + Namespace: namespace, + }, + Spec: rayv1.RayClusterSpec{ + HeadGroupSpec: rayv1.HeadGroupSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: oauthProxyContainerName, + Image: "registry.redhat.io/openshift4/ose-oauth-proxy@sha256:1ea6a01bf3e63cdcf125c6064cbd4a4a270deaf0f157b3eabb78f60556840366", + Ports: []corev1.ContainerPort{ + {ContainerPort: 8443, Name: "oauth-proxy"}, + }, + Env: []corev1.EnvVar{ + { + Name: "COOKIE_SECRET", + ValueFrom: &corev1.EnvVarSource{ + SecretKeyRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: "test-valid-raycluster" + "-oauth-config", + }, + Key: "cookie_secret", + }, + }, + }, + }, + Args: []string{ + "--https-address=:8443", + "--provider=openshift", + "--openshift-service-account=" + "test-valid-raycluster" + "-oauth-proxy", + "--upstream=http://localhost:8265", + "--tls-cert=/etc/tls/private/tls.crt", + "--tls-key=/etc/tls/private/tls.key", + "--cookie-secret=$(COOKIE_SECRET)", + "--openshift-delegate-urls={\"/\":{\"resource\":\"pods\",\"namespace\":\"" + namespace + "\",\"verb\":\"get\"}}", + }, + VolumeMounts: []corev1.VolumeMount{ + { + Name: oauthProxyVolumeName, + MountPath: "/etc/tls/private", + ReadOnly: true, + }, + }, + }, + }, + Volumes: []corev1.Volume{ + { + Name: oauthProxyVolumeName, + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: "test-valid-raycluster" + "-proxy-tls-secret", + }, + }, + }, + }, + ServiceAccountName: "test-valid-raycluster" + "-oauth-proxy", + }, + }, + RayStartParams: map[string]string{}, + }, + }, + } + + // Create the RayClusters + test.Client().Ray().RayV1().RayClusters(namespace).Create(test.Ctx(), validRayCluster, metav1.CreateOptions{}) + + // Call to ValidateCreate function is made + warnings, err := rcWebhook.ValidateCreate(test.Ctx(), runtime.Object(validRayCluster)) + t.Run("Expected no warnings or errors on call to ValidateCreate function", func(t *testing.T) { + g := gomega.NewWithT(t) + g.Expect(warnings).To(gomega.BeNil()) + g.Expect(err).To(gomega.BeNil()) + }) + + // No need for below tests as the above one covers it all? + // t.Run("Expected enableIngress to be either nil or false", func(t *testing.T) { + // g := gomega.NewWithT(t) + // g.Expect(validRayCluster.Spec.HeadGroupSpec.EnableIngress).Should(gomega.Satisfy(func(enableIngress *bool) bool { + // return enableIngress == nil || !*enableIngress + // }), "Expected EnableIngress to be either nil or false") + // }) + + // t.Run("Expected OAuth proxy container for head group", func(t *testing.T) { + // g := gomega.NewWithT(t) + // g.Expect(validRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers).Should(gomega.ContainElement(gomega.WithTransform( + // func(container corev1.Container) string { + // return container.Name + // }, + // gomega.Equal(oauthProxyContainerName), + // )), "Expected required OAuth proxy container in the head group") + // }) +}