Skip to content

Commit

Permalink
Add unit tests for raycluster_webhook.go
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristianZaccaria committed May 10, 2024
1 parent 4106ddb commit b80e2bd
Show file tree
Hide file tree
Showing 2 changed files with 430 additions and 32 deletions.
65 changes: 33 additions & 32 deletions pkg/controllers/raycluster_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,24 +100,26 @@ func (w *rayClusterWebhook) Default(ctx context.Context, obj runtime.Object) err
}

// WorkerGroupSpec
if len(rayCluster.Spec.WorkerGroupSpecs) != 0 {
for i := range rayCluster.Spec.WorkerGroupSpecs {
workerSpec := &rayCluster.Spec.WorkerGroupSpecs[i]

// Append the list of environment variables for the worker container
for _, envVar := range envVarList() {
rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Env = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Env, envVar, withEnvVarName(envVar.Name))
workerSpec.Template.Spec.Containers[0].Env = upsert(workerSpec.Template.Spec.Containers[0].Env, envVar, withEnvVarName(envVar.Name))
}

// Append the CA volumes
for _, caVol := range caVolumes(rayCluster) {
rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes, caVol, withVolumeName(caVol.Name))
}
// Append the CA volumes
for _, caVol := range caVolumes(rayCluster) {
rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes, caVol, withVolumeName(caVol.Name))
}

// Append the certificate volume mounts
for _, mount := range certVolumeMounts() {
rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].VolumeMounts = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].VolumeMounts, mount, byVolumeMountName)
}
// Append the certificate volume mounts
for _, mount := range certVolumeMounts() {
rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].VolumeMounts = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].VolumeMounts, mount, byVolumeMountName)
}

// Append the create-cert Init Container
rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers, rayWorkerInitContainer(w.Config), withContainerName(initContainerName))
// Append the create-cert Init Container
rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers, rayWorkerInitContainer(w.Config), withContainerName(initContainerName))
}
}

Expand Down Expand Up @@ -387,14 +389,13 @@ func validateHeadInitContainer(rayCluster *rayv1.RayCluster, config *config.Kube
func validateWorkerInitContainer(rayCluster *rayv1.RayCluster, config *config.KubeRayConfiguration) field.ErrorList {
var allErrors field.ErrorList

if len(rayCluster.Spec.WorkerGroupSpecs) == 0 {
return allErrors
}

if err := contains(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers, rayWorkerInitContainer(config), byContainerName,
field.NewPath("spec", "workerGroupSpecs", "0", "template", "spec", "initContainers"),
"create-cert Init Container is immutable"); err != nil {
allErrors = append(allErrors, err)
for i := range rayCluster.Spec.WorkerGroupSpecs {
workerSpec := &rayCluster.Spec.WorkerGroupSpecs[i]
if err := contains(workerSpec.Template.Spec.InitContainers, rayWorkerInitContainer(config), byContainerName,
field.NewPath("spec", "workerGroupSpecs", strconv.Itoa(i), "template", "spec", "initContainers"),
"create-cert Init Container is immutable"); err != nil {
allErrors = append(allErrors, err)
}
}

return allErrors
Expand All @@ -409,9 +410,10 @@ func validateCaVolumes(rayCluster *rayv1.RayCluster) field.ErrorList {
"ca-vol and server-cert Secret volumes are immutable"); err != nil {
allErrors = append(allErrors, err)
}
if len(rayCluster.Spec.WorkerGroupSpecs) != 0 {
if err := contains(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes, caVol, byVolumeName,
field.NewPath("spec", "workerGroupSpecs", "0", "template", "spec", "volumes"),
for i := range rayCluster.Spec.WorkerGroupSpecs {
workerSpec := &rayCluster.Spec.WorkerGroupSpecs[i]
if err := contains(workerSpec.Template.Spec.Volumes, caVol, byVolumeName,
field.NewPath("spec", "workerGroupSpecs", strconv.Itoa(i), "template", "spec", "volumes"),
"ca-vol and server-cert Secret volumes are immutable"); err != nil {
allErrors = append(allErrors, err)
}
Expand All @@ -438,15 +440,14 @@ func validateHeadEnvVars(rayCluster *rayv1.RayCluster) field.ErrorList {
func validateWorkerEnvVars(rayCluster *rayv1.RayCluster) field.ErrorList {
var allErrors field.ErrorList

if len(rayCluster.Spec.WorkerGroupSpecs) == 0 {
return allErrors
}

for _, envVar := range envVarList() {
if err := contains(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Env, envVar, byEnvVarName,
field.NewPath("spec", "workerGroupSpecs", "0", "template", "spec", "containers", strconv.Itoa(0), "env"),
"RAY_TLS related environment variables are immutable"); err != nil {
allErrors = append(allErrors, err)
for i := range rayCluster.Spec.WorkerGroupSpecs {
workerSpec := &rayCluster.Spec.WorkerGroupSpecs[i]
for _, envVar := range envVarList() {
if err := contains(workerSpec.Template.Spec.Containers[0].Env, envVar, byEnvVarName,
field.NewPath("spec", "workerGroupSpecs", strconv.Itoa(i), "template", "spec", "containers", strconv.Itoa(0), "env"),
"RAY_TLS related environment variables are immutable"); err != nil {
allErrors = append(allErrors, err)
}
}
}

Expand Down
Loading

0 comments on commit b80e2bd

Please sign in to comment.