From 4571b221c11c34284af74f5cc0e471dfa2e2ec5d Mon Sep 17 00:00:00 2001 From: lucaswzhang Date: Mon, 20 Dec 2021 17:32:34 +0800 Subject: [PATCH 1/3] add unit test for tf add amend function add test exit code add scale up and down cases add norm test fix test exit code case fix import order fix exit code issue --- pkg/common/util/v1/testutil/pod.go | 116 ++-- pkg/common/util/v1/testutil/service.go | 66 +- pkg/common/util/v1/testutil/tfjob.go | 2 +- pkg/common/util/v1/testutil/util.go | 26 +- pkg/controller.v1/tensorflow/job_test.go | 525 +++++++++++++++ pkg/controller.v1/tensorflow/pod_test.go | 540 ++++++++++++++++ pkg/controller.v1/tensorflow/status_test.go | 609 ++++++++++++++++++ pkg/controller.v1/tensorflow/suite_test.go | 51 +- .../tensorflow/tfjob_controller_test.go | 328 ++++++++++ pkg/controller.v1/tensorflow/util_test.go | 74 +++ 10 files changed, 2245 insertions(+), 92 deletions(-) create mode 100644 pkg/controller.v1/tensorflow/job_test.go create mode 100644 pkg/controller.v1/tensorflow/pod_test.go create mode 100644 pkg/controller.v1/tensorflow/status_test.go create mode 100644 pkg/controller.v1/tensorflow/tfjob_controller_test.go create mode 100644 pkg/controller.v1/tensorflow/util_test.go diff --git a/pkg/common/util/v1/testutil/pod.go b/pkg/common/util/v1/testutil/pod.go index adce63fa32..0ab1e73848 100644 --- a/pkg/common/util/v1/testutil/pod.go +++ b/pkg/common/util/v1/testutil/pod.go @@ -15,81 +15,99 @@ package testutil import ( + "context" "fmt" - "testing" + "time" - v1 "k8s.io/api/core/v1" + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/client-go/tools/cache" - - tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" ) const ( - // labels for pods and servers. - tfReplicaTypeLabel = "replica-type" - tfReplicaIndexLabel = "replica-index" + DummyContainerName = "dummy" + DummyContainerImage = "dummy/dummy:latest" ) -var ( - controllerKind = tfv1.GroupVersion.WithKind(TFJobKind) -) +func NewBasePod(name string, job metav1.Object, refs []metav1.OwnerReference) *corev1.Pod { -func NewBasePod(name string, tfJob *tfv1.TFJob) *v1.Pod { - return &v1.Pod{ + return &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ Name: name, - Labels: GenLabels(tfJob.Name), - Namespace: tfJob.Namespace, - OwnerReferences: []metav1.OwnerReference{*metav1.NewControllerRef(tfJob, controllerKind)}, + Labels: map[string]string{}, + Namespace: job.GetNamespace(), + OwnerReferences: refs, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: DummyContainerName, + Image: DummyContainerImage, + }, + }, }, } } -func NewPod(tfJob *tfv1.TFJob, typ string, index int) *v1.Pod { - pod := NewBasePod(fmt.Sprintf("%s-%d", typ, index), tfJob) - pod.Labels[tfReplicaTypeLabel] = typ - pod.Labels[tfReplicaIndexLabel] = fmt.Sprintf("%d", index) +func NewPod(job metav1.Object, typ string, index int, refs []metav1.OwnerReference) *corev1.Pod { + pod := NewBasePod(fmt.Sprintf("%s-%s-%d", job.GetName(), typ, index), job, refs) + pod.Labels[commonv1.ReplicaTypeLabelDeprecated] = typ + pod.Labels[commonv1.ReplicaTypeLabel] = typ + pod.Labels[commonv1.ReplicaIndexLabelDeprecated] = fmt.Sprintf("%d", index) + pod.Labels[commonv1.ReplicaIndexLabel] = fmt.Sprintf("%d", index) return pod } -// create count pods with the given phase for the given tfJob -func NewPodList(count int32, status v1.PodPhase, tfJob *tfv1.TFJob, typ string, start int32) []*v1.Pod { - pods := []*v1.Pod{} +// NewPodList create count pods with the given phase for the given tfJob +func NewPodList(count int32, status corev1.PodPhase, job metav1.Object, typ string, start int32, refs []metav1.OwnerReference) []*corev1.Pod { + pods := []*corev1.Pod{} for i := int32(0); i < count; i++ { - newPod := NewPod(tfJob, typ, int(start+i)) - newPod.Status = v1.PodStatus{Phase: status} + newPod := NewPod(job, typ, int(start+i), refs) + newPod.Status = corev1.PodStatus{Phase: status} pods = append(pods, newPod) } return pods } -func SetPodsStatuses(podIndexer cache.Indexer, tfJob *tfv1.TFJob, typ string, pendingPods, activePods, succeededPods, failedPods int32, restartCounts []int32, t *testing.T) { +func SetPodsStatusesV2(client client.Client, job metav1.Object, typ string, + pendingPods, activePods, succeededPods, failedPods int32, restartCounts []int32, + refs []metav1.OwnerReference, basicLabels map[string]string) { + timeout := 10 * time.Second + interval := 1000 * time.Millisecond var index int32 - for _, pod := range NewPodList(pendingPods, v1.PodPending, tfJob, typ, index) { - if err := podIndexer.Add(pod); err != nil { - t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err) - } - } - index += pendingPods - for i, pod := range NewPodList(activePods, v1.PodRunning, tfJob, typ, index) { - if restartCounts != nil { - pod.Status.ContainerStatuses = []v1.ContainerStatus{{RestartCount: restartCounts[i]}} - } - if err := podIndexer.Add(pod); err != nil { - t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err) - } + taskMap := map[corev1.PodPhase]int32{ + corev1.PodFailed: failedPods, + corev1.PodPending: pendingPods, + corev1.PodRunning: activePods, + corev1.PodSucceeded: succeededPods, } - index += activePods - for _, pod := range NewPodList(succeededPods, v1.PodSucceeded, tfJob, typ, index) { - if err := podIndexer.Add(pod); err != nil { - t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err) - } - } - index += succeededPods - for _, pod := range NewPodList(failedPods, v1.PodFailed, tfJob, typ, index) { - if err := podIndexer.Add(pod); err != nil { - t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err) + ctx := context.Background() + + for podPhase, desiredCount := range taskMap { + for i, pod := range NewPodList(desiredCount, podPhase, job, typ, index, refs) { + for k, v := range basicLabels { + pod.Labels[k] = v + } + _ = client.Create(ctx, pod) + launcherKey := types.NamespacedName{ + Namespace: metav1.NamespaceDefault, + Name: pod.GetName(), + } + Eventually(func() error { + po := &corev1.Pod{} + if err := client.Get(ctx, launcherKey, po); err != nil { + return err + } + po.Status.Phase = podPhase + if podPhase == corev1.PodRunning && restartCounts != nil { + po.Status.ContainerStatuses = []corev1.ContainerStatus{{RestartCount: restartCounts[i]}} + } + return client.Status().Update(ctx, po) + }, timeout, interval).Should(BeNil()) } + index += desiredCount } } diff --git a/pkg/common/util/v1/testutil/service.go b/pkg/common/util/v1/testutil/service.go index 2bf6448f5f..e2451f45e6 100644 --- a/pkg/common/util/v1/testutil/service.go +++ b/pkg/common/util/v1/testutil/service.go @@ -15,48 +15,72 @@ package testutil import ( + "context" "fmt" - "testing" - v1 "k8s.io/api/core/v1" + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/client-go/tools/cache" + "sigs.k8s.io/controller-runtime/pkg/client" +) - tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1" +const ( + DummyPortName = "dummy" + DummyPort int32 = 1221 ) -func NewBaseService(name string, tfJob *tfv1.TFJob, t *testing.T) *v1.Service { - return &v1.Service{ +func NewBaseService(name string, job metav1.Object, refs []metav1.OwnerReference) *corev1.Service { + return &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ Name: name, - Labels: GenLabels(tfJob.Name), - Namespace: tfJob.Namespace, - OwnerReferences: []metav1.OwnerReference{*metav1.NewControllerRef(tfJob, controllerKind)}, + Labels: map[string]string{}, + Namespace: job.GetNamespace(), + OwnerReferences: refs, + }, + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{ + { + Name: DummyPortName, + Port: DummyPort, + }, + }, }, } } -func NewService(tfJob *tfv1.TFJob, typ string, index int, t *testing.T) *v1.Service { - service := NewBaseService(fmt.Sprintf("%s-%d", typ, index), tfJob, t) - service.Labels[tfReplicaTypeLabel] = typ - service.Labels[tfReplicaIndexLabel] = fmt.Sprintf("%d", index) - return service +func NewService(job metav1.Object, typ string, index int, refs []metav1.OwnerReference) *corev1.Service { + svc := NewBaseService(fmt.Sprintf("%s-%s-%d", job.GetName(), typ, index), job, refs) + svc.Labels[commonv1.ReplicaTypeLabelDeprecated] = typ + svc.Labels[commonv1.ReplicaTypeLabel] = typ + svc.Labels[commonv1.ReplicaIndexLabelDeprecated] = fmt.Sprintf("%d", index) + svc.Labels[commonv1.ReplicaIndexLabel] = fmt.Sprintf("%d", index) + return svc } // NewServiceList creates count pods with the given phase for the given tfJob -func NewServiceList(count int32, tfJob *tfv1.TFJob, typ string, t *testing.T) []*v1.Service { - services := []*v1.Service{} +func NewServiceList(count int32, job metav1.Object, typ string, refs []metav1.OwnerReference) []*corev1.Service { + services := []*corev1.Service{} for i := int32(0); i < count; i++ { - newService := NewService(tfJob, typ, int(i), t) + newService := NewService(job, typ, int(i), refs) services = append(services, newService) } return services } -func SetServices(serviceIndexer cache.Indexer, tfJob *tfv1.TFJob, typ string, activeWorkerServices int32, t *testing.T) { - for _, service := range NewServiceList(activeWorkerServices, tfJob, typ, t) { - if err := serviceIndexer.Add(service); err != nil { - t.Errorf("unexpected error when adding service %v", err) +func SetServicesV2(client client.Client, job metav1.Object, typ string, activeWorkerServices int32, + refs []metav1.OwnerReference, basicLabels map[string]string) { + ctx := context.Background() + for _, svc := range NewServiceList(activeWorkerServices, job, typ, refs) { + for k, v := range basicLabels { + svc.Labels[k] = v + } + err := client.Create(ctx, svc) + if errors.IsAlreadyExists(err) { + return + } else { + Expect(err).To(BeNil()) } } } diff --git a/pkg/common/util/v1/testutil/tfjob.go b/pkg/common/util/v1/testutil/tfjob.go index b3c8b666e3..e5376f3771 100644 --- a/pkg/common/util/v1/testutil/tfjob.go +++ b/pkg/common/util/v1/testutil/tfjob.go @@ -17,10 +17,10 @@ package testutil import ( "time" + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1" ) diff --git a/pkg/common/util/v1/testutil/util.go b/pkg/common/util/v1/testutil/util.go index 5337ad04f2..e8333afbea 100644 --- a/pkg/common/util/v1/testutil/util.go +++ b/pkg/common/util/v1/testutil/util.go @@ -15,16 +15,16 @@ package testutil import ( - "strings" "testing" - common "github.com/kubeflow/common/pkg/apis/common/v1" - tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1" + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime" "k8s.io/client-go/tools/cache" + + tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1" ) const ( @@ -43,21 +43,13 @@ var ( ControllerName = "training-operator" ) -func GenLabels(jobName string) map[string]string { - return map[string]string{ - LabelGroupName: GroupName, - JobNameLabel: strings.Replace(jobName, "/", "-", -1), - DeprecatedLabelTFJobName: strings.Replace(jobName, "/", "-", -1), - } -} - -func GenOwnerReference(tfjob *tfv1.TFJob) *metav1.OwnerReference { +func GenOwnerReference(job metav1.Object, apiVersion string, kind string) *metav1.OwnerReference { boolPtr := func(b bool) *bool { return &b } controllerRef := &metav1.OwnerReference{ - APIVersion: tfv1.GroupVersion.Version, - Kind: TFJobKind, - Name: tfjob.Name, - UID: tfjob.UID, + APIVersion: apiVersion, + Kind: kind, + Name: job.GetName(), + UID: job.GetUID(), BlockOwnerDeletion: boolPtr(true), Controller: boolPtr(true), } @@ -85,7 +77,7 @@ func GetKey(tfJob *tfv1.TFJob, t *testing.T) string { return key } -func CheckCondition(tfJob *tfv1.TFJob, condition common.JobConditionType, reason string) bool { +func CheckCondition(tfJob *tfv1.TFJob, condition commonv1.JobConditionType, reason string) bool { for _, v := range tfJob.Status.Conditions { if v.Type == condition && v.Status == v1.ConditionTrue && v.Reason == reason { return true diff --git a/pkg/controller.v1/tensorflow/job_test.go b/pkg/controller.v1/tensorflow/job_test.go new file mode 100644 index 0000000000..b22d77d5b2 --- /dev/null +++ b/pkg/controller.v1/tensorflow/job_test.go @@ -0,0 +1,525 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tensorflow + +import ( + "context" + "fmt" + "time" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + "github.com/kubeflow/common/pkg/controller.v1/common" + commonutil "github.com/kubeflow/common/pkg/util" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/uuid" + "sigs.k8s.io/controller-runtime/pkg/client" + + tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1" + "github.com/kubeflow/training-operator/pkg/common/util/v1/testutil" +) + +var _ = Describe("TFJob controller", func() { + // Define utility constants for object names and testing timeouts/durations and intervals. + const ( + timeout = 10 * time.Second + interval = 1000 * time.Millisecond + ) + + Context("Test Add TFJob", func() { + It("should get the exact TFJob", func() { + By("submitting an TFJob") + + testJobName := "test-case-12" + testNamespace := metav1.NamespaceDefault + + decoyJobName := "decoy-case-34" + + ctx := context.Background() + + tfJob := testutil.NewTFJob(1, 0) + tfJob.SetName(testJobName) + tfJob.SetNamespace(testNamespace) + + decoyJob := testutil.NewTFJob(2, 3) + decoyJob.SetName(decoyJobName) + decoyJob.SetNamespace(testNamespace) + + Expect(testK8sClient.Create(ctx, tfJob)).Should(Succeed()) + Expect(testK8sClient.Create(ctx, decoyJob)).Should(Succeed()) + + key := types.NamespacedName{ + Namespace: testNamespace, + Name: testJobName, + } + Eventually(func() error { + job := &tfv1.TFJob{} + return reconciler.Get(ctx, key, job) + }, timeout, interval).Should(BeNil()) + + Expect(testK8sClient.Delete(ctx, tfJob)).Should(Succeed()) + Expect(testK8sClient.Delete(ctx, decoyJob)).Should(Succeed()) + }) + }) + + Context("Test Copy Labels and Annotation", func() { + It("should copy labels and annotation from the spec to generated Pods", func() { + ctx := context.Background() + testAnnotationKey := "annotation1" + testAnnotationVal := "1" + testLabelKey := "label1" + testLabelVal := "1" + + testJobName := "test-copy-labels-anno" + tfjob := testutil.NewTFJob(1, 0) + tfjob.SetName(testJobName) + annotations := map[string]string{ + testAnnotationKey: testAnnotationVal, + } + labels := map[string]string{ + testLabelKey: testLabelVal, + } + tfjob.Spec.TFReplicaSpecs[tfv1.TFReplicaTypeWorker].Template.Labels = labels + tfjob.Spec.TFReplicaSpecs[tfv1.TFReplicaTypeWorker].Template.Annotations = annotations + + By("submitting an TFJob with specific labels and annotations") + Expect(testK8sClient.Create(ctx, tfjob)).Should(Succeed()) + + Eventually(func() error { + pod := &corev1.Pod{} + key := types.NamespacedName{ + Namespace: metav1.NamespaceDefault, + Name: common.GenGeneralName(tfjob.Name, "worker", "0"), + } + err := testK8sClient.Get(ctx, key, pod) + if err != nil { + return err + } + + if pod.Annotations == nil { + return fmt.Errorf("annotation of %s/%s is nil", pod.GetNamespace(), pod.GetName()) + } + if val, exist := pod.Annotations[testAnnotationKey]; exist { + if val != testAnnotationVal { + return fmt.Errorf("annotation of %s not match with %s", testAnnotationKey, testAnnotationVal) + } + } else { + return fmt.Errorf("annotation %s not found", testAnnotationKey) + } + + if pod.Labels == nil { + return fmt.Errorf("label of %s/%s is nil", pod.GetNamespace(), pod.GetName()) + } + if val, exist := pod.Labels[testLabelKey]; exist { + if val != testLabelVal { + return fmt.Errorf("annotation of %s not match with %s", testLabelKey, testLabelVal) + } + } else { + return fmt.Errorf("label %s not found", testLabelKey) + } + + return nil + }, timeout, interval).Should(BeNil()) + }) + }) + + Context("Test Delete Pods and Services", func() { + It("it should clean associated Pods and Services according to clean policy", func() { + type testCase struct { + description string + tfJob *tfv1.TFJob + + pendingWorkerPods int32 + activeWorkerPods int32 + succeededWorkerPods int32 + failedWorkerPods int32 + + pendingPSPods int32 + activePSPods int32 + succeededPSPods int32 + failedPSPods int32 + + activeWorkerServices int32 + activePSServices int32 + + expectedPodRemaining int + } + + testCases := []testCase{ + { + description: "4 workers and 2 ps is running, policy is all", + tfJob: testutil.NewTFJobWithCleanPolicy(0, 4, 2, commonv1.CleanPodPolicyAll), + + pendingWorkerPods: 0, + activeWorkerPods: 4, + succeededWorkerPods: 0, + failedWorkerPods: 0, + + pendingPSPods: 0, + activePSPods: 2, + succeededPSPods: 0, + failedPSPods: 0, + + activeWorkerServices: 4, + activePSServices: 2, + + expectedPodRemaining: 0, + }, + { + description: "4 workers and 2 ps is running, policy is running", + tfJob: testutil.NewTFJobWithCleanPolicy(0, 4, 2, commonv1.CleanPodPolicyRunning), + + pendingWorkerPods: 0, + activeWorkerPods: 4, + succeededWorkerPods: 0, + failedWorkerPods: 0, + + pendingPSPods: 0, + activePSPods: 2, + succeededPSPods: 0, + failedPSPods: 0, + + activeWorkerServices: 4, + activePSServices: 2, + + expectedPodRemaining: 0, + }, + { + description: "4 workers and 2 ps is succeeded, policy is running", + tfJob: testutil.NewTFJobWithCleanPolicy(0, 4, 2, commonv1.CleanPodPolicyRunning), + + pendingWorkerPods: 0, + activeWorkerPods: 0, + succeededWorkerPods: 4, + failedWorkerPods: 0, + + pendingPSPods: 0, + activePSPods: 0, + succeededPSPods: 2, + failedPSPods: 0, + + activeWorkerServices: 4, + activePSServices: 2, + + expectedPodRemaining: 6, + }, + { + description: "4 workers and 2 ps is succeeded, policy is None", + tfJob: testutil.NewTFJobWithCleanPolicy(0, 4, 2, commonv1.CleanPodPolicyNone), + + pendingWorkerPods: 0, + activeWorkerPods: 0, + succeededWorkerPods: 4, + failedWorkerPods: 0, + + pendingPSPods: 0, + activePSPods: 0, + succeededPSPods: 2, + failedPSPods: 0, + + activeWorkerServices: 4, + activePSServices: 2, + + expectedPodRemaining: 6, + }, + } + + jobNameTemplate := "test-del-pod-svc-%d" + for idx, tc := range testCases { + By(fmt.Sprintf("preparing cases %s", tc.description)) + ctx := context.Background() + tc.tfJob.SetName(fmt.Sprintf(jobNameTemplate, idx)) + tc.tfJob.SetUID(uuid.NewUUID()) + Expect(commonutil.UpdateJobConditions(&tc.tfJob.Status, commonv1.JobSucceeded, tfJobSucceededReason, "")).Should(Succeed()) + + refs := []metav1.OwnerReference{ + *reconciler.GenOwnerReference(tc.tfJob), + } + + basicLabels := reconciler.GenLabels(tc.tfJob.GetName()) + selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ + MatchLabels: basicLabels, + }) + Expect(err).Should(BeNil()) + listOpt := client.MatchingLabelsSelector{ + Selector: selector, + } + + By("creating Services and Pods with designed phases") + testutil.SetPodsStatusesV2(testK8sClient, tc.tfJob, testutil.LabelWorker, + tc.pendingWorkerPods, tc.activeWorkerPods, tc.succeededWorkerPods, tc.failedWorkerPods, + nil, refs, basicLabels) + testutil.SetPodsStatusesV2(testK8sClient, tc.tfJob, testutil.LabelPS, + tc.pendingPSPods, tc.activePSPods, tc.succeededPSPods, tc.failedPSPods, + nil, refs, basicLabels) + + testutil.SetServicesV2(testK8sClient, tc.tfJob, testutil.LabelWorker, tc.activeWorkerServices, refs, basicLabels) + testutil.SetServicesV2(testK8sClient, tc.tfJob, testutil.LabelPS, tc.activePSServices, refs, basicLabels) + + podList := &corev1.PodList{} + Expect(testK8sClient.List(ctx, podList, listOpt)).Should(Succeed()) + Expect(len(podList.Items)).To(Equal( + int(tc.pendingPSPods + tc.activePSPods + tc.failedPSPods + tc.succeededPSPods + + tc.pendingWorkerPods + tc.activeWorkerPods + tc.failedWorkerPods + tc.succeededWorkerPods))) + + By("calling ReconcileJob") + _ = reconciler.ReconcileJobs(tc.tfJob, tc.tfJob.Spec.TFReplicaSpecs, tc.tfJob.Status, &tc.tfJob.Spec.RunPolicy) + + podList = &corev1.PodList{} + Expect(testK8sClient.List(ctx, podList, listOpt, client.InNamespace(tc.tfJob.GetNamespace()))).Should(Succeed()) + podRemainingCount := len(podList.Items) + Expect(podRemainingCount).To(Equal(tc.expectedPodRemaining)) + + svcList := &corev1.ServiceList{} + Expect(testK8sClient.List(ctx, svcList, listOpt)).Should(Succeed()) + svcRemainingCount := len(svcList.Items) + Expect(svcRemainingCount).To(Equal(tc.expectedPodRemaining)) + } + }) + }) + + Context("Test Active Deadline Seconds", func() { + It("clean desired Pods and Services according to TFJob config", func() { + type testCase struct { + description string + tfJob *tfv1.TFJob + + pendingWorkerPods int32 + activeWorkerPods int32 + succeededWorkerPods int32 + failedWorkerPods int32 + + pendingPSPods int32 + activePSPods int32 + succeededPSPods int32 + failedPSPods int32 + + activeWorkerServices int32 + activePSServices int32 + + expectedPodRemaining int + } + + ads2 := int64(2) + adsTest2 := &ads2 + testCases := []testCase{ + { + description: "4 workers and 2 ps is running, ActiveDeadlineSeconds unset", + tfJob: testutil.NewTFJobWithActiveDeadlineSeconds(0, 4, 2, nil), + + pendingWorkerPods: 0, + activeWorkerPods: 4, + succeededWorkerPods: 0, + failedWorkerPods: 0, + + pendingPSPods: 0, + activePSPods: 2, + succeededPSPods: 0, + failedPSPods: 0, + + activeWorkerServices: 4, + activePSServices: 2, + + expectedPodRemaining: 6, + }, + { + description: "4 workers and 2 ps is running, ActiveDeadlineSeconds is 2", + tfJob: testutil.NewTFJobWithActiveDeadlineSeconds(0, 4, 2, adsTest2), + + pendingWorkerPods: 0, + activeWorkerPods: 4, + succeededWorkerPods: 0, + failedWorkerPods: 0, + + pendingPSPods: 0, + activePSPods: 2, + succeededPSPods: 0, + failedPSPods: 0, + + activeWorkerServices: 4, + activePSServices: 2, + + expectedPodRemaining: 0, + }, + } + jobNameTemplate := "test-ads-%d" + for idx, tc := range testCases { + By(fmt.Sprintf("preparing cases %s", tc.description)) + ctx := context.Background() + tc.tfJob.SetName(fmt.Sprintf(jobNameTemplate, idx)) + tc.tfJob.SetUID(uuid.NewUUID()) + + refs := []metav1.OwnerReference{ + *reconciler.GenOwnerReference(tc.tfJob), + } + + basicLabels := reconciler.GenLabels(tc.tfJob.GetName()) + selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ + MatchLabels: basicLabels, + }) + Expect(err).Should(BeNil()) + listOpt := client.MatchingLabelsSelector{ + Selector: selector, + } + + By("creating Services and Pods with designed phases") + testutil.SetPodsStatusesV2(testK8sClient, tc.tfJob, testutil.LabelWorker, + tc.pendingWorkerPods, tc.activeWorkerPods, tc.succeededWorkerPods, tc.failedWorkerPods, + nil, refs, basicLabels) + testutil.SetPodsStatusesV2(testK8sClient, tc.tfJob, testutil.LabelPS, + tc.pendingPSPods, tc.activePSPods, tc.succeededPSPods, tc.failedPSPods, + nil, refs, basicLabels) + + testutil.SetServicesV2(testK8sClient, tc.tfJob, testutil.LabelWorker, tc.activeWorkerServices, refs, basicLabels) + testutil.SetServicesV2(testK8sClient, tc.tfJob, testutil.LabelPS, tc.activePSServices, refs, basicLabels) + + podList := &corev1.PodList{} + Expect(testK8sClient.List(ctx, podList, listOpt)).Should(Succeed()) + Expect(len(podList.Items)).To(Equal( + int(tc.pendingPSPods + tc.activePSPods + tc.failedPSPods + tc.succeededPSPods + + tc.pendingWorkerPods + tc.activeWorkerPods + tc.failedWorkerPods + tc.succeededWorkerPods))) + + By("waiting enough time") + now := metav1.Now() + tc.tfJob.Status.StartTime = &now + ads := tc.tfJob.Spec.RunPolicy.ActiveDeadlineSeconds + if ads != nil { + dur := time.Second * time.Duration(*ads) + time.Sleep(dur) + } + + By("calling ReconcileJob") + _ = reconciler.ReconcileJobs(tc.tfJob, tc.tfJob.Spec.TFReplicaSpecs, tc.tfJob.Status, &tc.tfJob.Spec.RunPolicy) + + podList = &corev1.PodList{} + Expect(testK8sClient.List(ctx, podList, listOpt, client.InNamespace(tc.tfJob.GetNamespace()))).Should(Succeed()) + podRemainingCount := len(podList.Items) + Expect(podRemainingCount).To(Equal(tc.expectedPodRemaining)) + + svcList := &corev1.ServiceList{} + Expect(testK8sClient.List(ctx, svcList, listOpt)).Should(Succeed()) + svcRemainingCount := len(svcList.Items) + Expect(svcRemainingCount).To(Equal(tc.expectedPodRemaining)) + } + }) + }) + + Context("Test Backoff For On Failure(", func() { + It("clean desired Pods and Services according to TFJob config", func() { + type testCase struct { + description string + tfJob *tfv1.TFJob + + pendingWorkerPods int32 + activeWorkerPods int32 + succeededWorkerPods int32 + failedWorkerPods int32 + + restartCounts []int32 + + pendingPSPods int32 + activePSPods int32 + succeededPSPods int32 + failedPSPods int32 + + activeWorkerServices int32 + activePSServices int32 + + expectedPodRemaining int + } + + backoffLimit4 := int32(4) + backoffLimitTest4 := &backoffLimit4 + testCases := []testCase{ + { + description: "4 workers each having 1 restartCount and 2 ps is running, backoffLimit 4 ", + tfJob: testutil.NewTFJobWithBackoffLimit(0, 4, 2, backoffLimitTest4), + + pendingWorkerPods: 0, + activeWorkerPods: 4, + succeededWorkerPods: 0, + failedWorkerPods: 0, + + restartCounts: []int32{1, 1, 1, 1}, + + pendingPSPods: 0, + activePSPods: 2, + succeededPSPods: 0, + failedPSPods: 0, + + activeWorkerServices: 4, + activePSServices: 2, + + expectedPodRemaining: 0, + }, + } + + jobNameTemplate := "test-bof-%d" + for idx, tc := range testCases { + By(fmt.Sprintf("preparing cases %s", tc.description)) + ctx := context.Background() + tc.tfJob.SetName(fmt.Sprintf(jobNameTemplate, idx)) + tc.tfJob.SetUID(uuid.NewUUID()) + + refs := []metav1.OwnerReference{ + *reconciler.GenOwnerReference(tc.tfJob), + } + + basicLabels := reconciler.GenLabels(tc.tfJob.GetName()) + selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ + MatchLabels: basicLabels, + }) + Expect(err).Should(BeNil()) + listOpt := client.MatchingLabelsSelector{ + Selector: selector, + } + + By("creating Services and Pods with designed phases") + testutil.SetPodsStatusesV2(testK8sClient, tc.tfJob, testutil.LabelWorker, + tc.pendingWorkerPods, tc.activeWorkerPods, tc.succeededWorkerPods, tc.failedWorkerPods, + tc.restartCounts, refs, basicLabels) + testutil.SetPodsStatusesV2(testK8sClient, tc.tfJob, testutil.LabelPS, + tc.pendingPSPods, tc.activePSPods, tc.succeededPSPods, tc.failedPSPods, + tc.restartCounts, refs, basicLabels) + + testutil.SetServicesV2(testK8sClient, tc.tfJob, testutil.LabelWorker, tc.activeWorkerServices, refs, basicLabels) + testutil.SetServicesV2(testK8sClient, tc.tfJob, testutil.LabelPS, tc.activePSServices, refs, basicLabels) + + podList := &corev1.PodList{} + Expect(testK8sClient.List(ctx, podList, listOpt)).Should(Succeed()) + Expect(len(podList.Items)).To(Equal( + int(tc.pendingPSPods + tc.activePSPods + tc.failedPSPods + tc.succeededPSPods + + tc.pendingWorkerPods + tc.activeWorkerPods + tc.failedWorkerPods + tc.succeededWorkerPods))) + + By("calling ReconcileJob") + _ = reconciler.ReconcileJobs(tc.tfJob, tc.tfJob.Spec.TFReplicaSpecs, tc.tfJob.Status, &tc.tfJob.Spec.RunPolicy) + + podList = &corev1.PodList{} + Expect(testK8sClient.List(ctx, podList, listOpt, client.InNamespace(tc.tfJob.GetNamespace()))).Should(Succeed()) + podRemainingCount := len(podList.Items) + Expect(podRemainingCount).To(Equal(tc.expectedPodRemaining)) + + svcList := &corev1.ServiceList{} + Expect(testK8sClient.List(ctx, svcList, listOpt)).Should(Succeed()) + svcRemainingCount := len(svcList.Items) + Expect(svcRemainingCount).To(Equal(tc.expectedPodRemaining)) + } + }) + }) + +}) diff --git a/pkg/controller.v1/tensorflow/pod_test.go b/pkg/controller.v1/tensorflow/pod_test.go new file mode 100644 index 0000000000..a853ac03f2 --- /dev/null +++ b/pkg/controller.v1/tensorflow/pod_test.go @@ -0,0 +1,540 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tensorflow + +import ( + "context" + "fmt" + "os" + "time" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + "github.com/kubeflow/common/pkg/core" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/uuid" + "sigs.k8s.io/controller-runtime/pkg/client" + + tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1" + "github.com/kubeflow/training-operator/pkg/common/util/v1/testutil" +) + +var _ = Describe("TFJob controller", func() { + const ( + timeout = 10 * time.Second + interval = 1000 * time.Millisecond + ) + + Context("Test ClusterSpec", func() { + It("should generate desired cluster spec", func() { + type tc struct { + tfJob *tfv1.TFJob + rt string + index string + customClusterDomain string + expectedClusterSpec string + } + testCase := []tc{ + { + tfJob: testutil.NewTFJobWithNamespace(1, 0, "ns0"), + rt: "worker", + index: "0", + customClusterDomain: "", + expectedClusterSpec: "", + }, + { + tfJob: testutil.NewTFJobWithNamespace(1, 0, "ns1"), + rt: "worker", + index: "0", + customClusterDomain: "tf.training.com", + expectedClusterSpec: "", + }, + { + tfJob: testutil.NewTFJobWithNamespace(1, 1, "ns2"), + rt: "worker", + index: "0", + customClusterDomain: "tf.training.org", + expectedClusterSpec: `{"cluster":{"ps":["` + testutil.TestTFJobName + + `-ps-0.ns2.svc.tf.training.org:2222"],"worker":["` + testutil.TestTFJobName + + `-worker-0.ns2.svc.tf.training.org:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`, + }, + { + tfJob: testutil.NewTFJobWithEvaluatorAndNamespace(1, 1, 1, "ns3"), + rt: "worker", + index: "0", + customClusterDomain: "tf.training.io", + expectedClusterSpec: `{"cluster":{"evaluator":["` + testutil.TestTFJobName + + `-evaluator-0.ns3.svc.tf.training.io:2222"],"ps":["` + testutil.TestTFJobName + + `-ps-0.ns3.svc.tf.training.io:2222"],"worker":["` + testutil.TestTFJobName + + `-worker-0.ns3.svc.tf.training.io:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`, + }, + { + tfJob: testutil.NewTFJobWithEvaluatorAndNamespace(1, 1, 1, "ns3"), + rt: "worker", + index: "0", + customClusterDomain: "", + expectedClusterSpec: `{"cluster":{"evaluator":["` + testutil.TestTFJobName + + `-evaluator-0.ns3.svc:2222"],"ps":["` + testutil.TestTFJobName + + `-ps-0.ns3.svc:2222"],"worker":["` + testutil.TestTFJobName + + `-worker-0.ns3.svc:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`, + }, + } + + for _, c := range testCase { + c.tfJob.SetName("test-tfjob") + c.tfJob.SetUID(uuid.NewUUID()) + _ = os.Setenv(EnvCustomClusterDomain, c.customClusterDomain) + + podTemplate := c.tfJob.Spec.TFReplicaSpecs[tfv1.TFReplicaTypeWorker].Template.DeepCopy() + + podTemplate.Name = core.GenGeneralName(c.tfJob.GetName(), c.rt, c.index) + + if podTemplate.Labels == nil { + podTemplate.Labels = map[string]string{} + } + + jobName := c.tfJob.GetName() + labels := reconciler.GenLabels(jobName) + labels[commonv1.ReplicaTypeLabelDeprecated] = c.rt + labels[commonv1.ReplicaTypeLabel] = c.rt + labels[commonv1.ReplicaIndexLabelDeprecated] = c.index + labels[commonv1.ReplicaIndexLabel] = c.index + + Expect(reconciler.SetClusterSpec(c.tfJob, podTemplate, c.rt, c.index)).Should(Succeed()) + + if c.expectedClusterSpec == "" { + Expect(len(podTemplate.Spec.Containers[0].Env)).Should(Equal(0)) + } else { + actual := podTemplate.Spec.Containers[0].Env[0].Value + reconciler.Log.Info("printing cluster spec", "expected", c.expectedClusterSpec, "actual pod", podTemplate) + Expect(actual).Should(Equal(c.expectedClusterSpec)) + } + } + }) + }) + + Context("Test IsDistributed", func() { + It("should returns correctly", func() { + type tc struct { + tfJob *tfv1.TFJob + expected bool + } + testCase := []tc{ + { + tfJob: testutil.NewTFJob(1, 0), + expected: false, + }, + { + tfJob: testutil.NewTFJob(1, 1), + expected: true, + }, + { + tfJob: testutil.NewTFJob(0, 1), + expected: false, + }, + { + tfJob: testutil.NewTFJobWithChief(1, 0), + expected: true, + }, + } + for _, c := range testCase { + Expect(isDistributed(c.tfJob)).To(Equal(c.expected)) + } + }) + }) + + Context("Test Restart Policy", func() { + It("should assign proper restart policy to pod", func() { + type tc struct { + tfJob *tfv1.TFJob + expectedRestartPolicy corev1.RestartPolicy + expectedType commonv1.ReplicaType + } + testCase := []tc{ + func() tc { + tfJob := testutil.NewTFJob(1, 0) + specRestartPolicy := commonv1.RestartPolicyExitCode + tfJob.Spec.TFReplicaSpecs[tfv1.TFReplicaTypeWorker].RestartPolicy = specRestartPolicy + return tc{ + tfJob: tfJob, + expectedRestartPolicy: corev1.RestartPolicyNever, + expectedType: tfv1.TFReplicaTypeWorker, + } + }(), + func() tc { + tfJob := testutil.NewTFJob(1, 0) + specRestartPolicy := commonv1.RestartPolicyNever + tfJob.Spec.TFReplicaSpecs[tfv1.TFReplicaTypeWorker].RestartPolicy = specRestartPolicy + return tc{ + tfJob: tfJob, + expectedRestartPolicy: corev1.RestartPolicyNever, + expectedType: tfv1.TFReplicaTypeWorker, + } + }(), + func() tc { + tfJob := testutil.NewTFJob(1, 0) + specRestartPolicy := commonv1.RestartPolicyAlways + tfJob.Spec.TFReplicaSpecs[tfv1.TFReplicaTypeWorker].RestartPolicy = specRestartPolicy + return tc{ + tfJob: tfJob, + expectedRestartPolicy: corev1.RestartPolicyAlways, + expectedType: tfv1.TFReplicaTypeWorker, + } + }(), + func() tc { + tfJob := testutil.NewTFJob(1, 0) + specRestartPolicy := commonv1.RestartPolicyOnFailure + tfJob.Spec.TFReplicaSpecs[tfv1.TFReplicaTypeWorker].RestartPolicy = specRestartPolicy + return tc{ + tfJob: tfJob, + expectedRestartPolicy: corev1.RestartPolicyOnFailure, + expectedType: tfv1.TFReplicaTypeWorker, + } + }(), + } + for _, c := range testCase { + spec := c.tfJob.Spec.TFReplicaSpecs[c.expectedType] + podTemplate := spec.Template + setRestartPolicy(&podTemplate, spec) + Expect(podTemplate.Spec.RestartPolicy).To(Equal(c.expectedRestartPolicy)) + } + }) + }) + + Context("Test Exit Code", func() { + It("should delete designated Pod", func() { + By("Creating TFJob \"test-exit-code\" with 1 worker only") + ctx := context.Background() + + tfJob := testutil.NewTFJob(1, 0) + tfJob.SetName("test-exit-code") + tfJob.SetUID(uuid.NewUUID()) + tfJob.Spec.TFReplicaSpecs[tfv1.TFReplicaTypeWorker].RestartPolicy = commonv1.RestartPolicyExitCode + + refs := []metav1.OwnerReference{ + *reconciler.GenOwnerReference(tfJob), + } + By("creating worker Pod") + pod := testutil.NewPod(tfJob, testutil.LabelWorker, 0, refs) + basicLabels := reconciler.GenLabels(tfJob.GetName()) + for k, v := range basicLabels { + pod.Labels[k] = v + } + Expect(testK8sClient.Create(ctx, pod)).Should(Succeed()) + + po := &corev1.Pod{} + key := types.NamespacedName{Namespace: metav1.NamespaceDefault, Name: pod.GetName()} + Expect(testK8sClient.Get(ctx, key, po)).Should(Succeed()) + po.Status.Phase = corev1.PodFailed + po.Spec.Containers = append(po.Spec.Containers, corev1.Container{}) + po.Status.ContainerStatuses = append(po.Status.ContainerStatuses, corev1.ContainerStatus{ + Name: tfv1.DefaultContainerName, + State: corev1.ContainerState{ + Terminated: &corev1.ContainerStateTerminated{ + ExitCode: 130, + }, + }, + }) + Expect(testK8sClient.Status().Update(ctx, po)) + + _ = reconciler.ReconcileJobs(tfJob, tfJob.Spec.TFReplicaSpecs, tfJob.Status, &tfJob.Spec.RunPolicy) + + Eventually(func() bool { + noPod := &corev1.Pod{} + err := testK8sClient.Get(ctx, key, noPod) + if err == nil { + reconciler.Log.Info("still got pod", "jobName", tfJob.GetName(), "pod", noPod) + return noPod.GetDeletionTimestamp() != nil + } + return errors.IsNotFound(err) + }, timeout, interval).Should(BeTrue()) + }) + }) + + Describe("Test Scale Down", func() { + It("should delete redundant Pods", func() { + ctx := context.Background() + + tfJob := testutil.NewTFJob(2, 0) + //tfJob.SelfLink = "/api/v1/namespaces/default/tfjob/test-tfjob" + tfJob.SetName("test-scale-down") + tfJob.SetUID(uuid.NewUUID()) + tfJob.Spec.EnableDynamicWorker = true + + refs := []metav1.OwnerReference{*reconciler.GenOwnerReference(tfJob)} + + pods := []*corev1.Pod{ + testutil.NewPod(tfJob, testutil.LabelWorker, 0, refs), + testutil.NewPod(tfJob, testutil.LabelWorker, 1, refs), + testutil.NewPod(tfJob, testutil.LabelWorker, 2, refs), + } + + for i := range pods { + pod := pods[i] + for k, v := range reconciler.GenLabels(tfJob.GetName()) { + pod.Labels[k] = v + } + Expect(testK8sClient.Create(ctx, pod)).Should(Succeed()) + } + + // Ensure the created Pods are all in cache + Eventually(func() error { + podList := &corev1.PodList{} + selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ + MatchLabels: reconciler.GenLabels(tfJob.GetName()), + }) + if err != nil { + return err + } + listOpt := client.MatchingLabelsSelector{ + Selector: selector, + } + err = testK8sClient.List(ctx, podList, listOpt) + if err != nil { + return err + } + if len(podList.Items) != 3 { + return fmt.Errorf("expecting %d Pods while got %d", 3, len(podList.Items)) + } + return nil + }, timeout, interval).Should(BeNil()) + + _ = reconciler.ReconcileJobs(tfJob, tfJob.Spec.TFReplicaSpecs, tfJob.Status, &tfJob.Spec.RunPolicy) + + noKey := types.NamespacedName{ + Namespace: metav1.NamespaceDefault, + Name: pods[2].GetName(), + } + Eventually(func() bool { + noPod := &corev1.Pod{} + err := testK8sClient.Get(ctx, noKey, noPod) + if err == nil { + return false + } + return errors.IsNotFound(err) + }, timeout, interval).Should(BeTrue()) + }) + }) + + Describe("Test Scale Up", func() { + It("should create missing Pods", func() { + ctx := context.Background() + + tfJob := testutil.NewTFJob(3, 0) + tfJob.SetName("test-scale-up") + tfJob.SetUID(uuid.NewUUID()) + tfJob.Spec.EnableDynamicWorker = true + + refs := []metav1.OwnerReference{*reconciler.GenOwnerReference(tfJob)} + + pods := []*corev1.Pod{ + testutil.NewPod(tfJob, testutil.LabelWorker, 0, refs), + } + + for i := range pods { + pod := pods[i] + for k, v := range reconciler.GenLabels(tfJob.GetName()) { + pod.Labels[k] = v + } + Expect(testK8sClient.Create(ctx, pod)).Should(Succeed()) + } + + // Ensure the created Pods are all in cache + Eventually(func() error { + podList := &corev1.PodList{} + selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ + MatchLabels: reconciler.GenLabels(tfJob.GetName()), + }) + if err != nil { + return err + } + listOpt := client.MatchingLabelsSelector{ + Selector: selector, + } + err = testK8sClient.List(ctx, podList, listOpt) + if err != nil { + return err + } + if len(podList.Items) != 1 { + return fmt.Errorf("before reconciling, expecting %d Pods while got %d", 1, len(podList.Items)) + } + return nil + }, timeout, interval).Should(BeNil()) + + _ = reconciler.ReconcileJobs(tfJob, tfJob.Spec.TFReplicaSpecs, tfJob.Status, &tfJob.Spec.RunPolicy) + + // Check if there are two more Pods created + Eventually(func() error { + podList := &corev1.PodList{} + selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ + MatchLabels: reconciler.GenLabels(tfJob.GetName()), + }) + if err != nil { + return err + } + listOpt := client.MatchingLabelsSelector{ + Selector: selector, + } + err = testK8sClient.List(ctx, podList, listOpt) + if err != nil { + return err + } + if len(podList.Items) != 3 { + return fmt.Errorf("after reconciling, expecting %d Pods while got %d", 3, len(podList.Items)) + } + return nil + }, timeout, interval).Should(BeNil()) + }) + }) + + Describe("TestIsWorker0Completed", func() { + It("should match expected result", func() { + newInt32 := func(in int32) *int32 { + return &in + } + tests := []struct { + // worker failed, succeeded, running num + workers [3]int32 + tfJob *tfv1.TFJob + replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec + expected bool + expectedErr bool + }{ + { + workers: [3]int32{0, 0, 1}, + tfJob: testutil.NewTFJobV2(1, 1, 0, 0, 0), + expected: false, + expectedErr: false, + replicas: map[commonv1.ReplicaType]*commonv1.ReplicaSpec{ + tfv1.TFReplicaTypeWorker: { + Replicas: newInt32(1), + Template: testutil.NewTFReplicaSpecTemplate(), + }, + tfv1.TFReplicaTypePS: { + Replicas: newInt32(1), + Template: testutil.NewTFReplicaSpecTemplate(), + }, + }, + }, + { + workers: [3]int32{0, 1, 0}, + tfJob: testutil.NewTFJobV2(1, 0, 0, 0, 0), + expected: true, + expectedErr: false, + replicas: map[commonv1.ReplicaType]*commonv1.ReplicaSpec{ + tfv1.TFReplicaTypeWorker: { + Replicas: newInt32(1), + Template: testutil.NewTFReplicaSpecTemplate(), + }, + }, + }, + { + workers: [3]int32{0, 0, 0}, + tfJob: testutil.NewTFJobV2(0, 0, 1, 0, 0), + expected: true, + expectedErr: false, + replicas: map[commonv1.ReplicaType]*commonv1.ReplicaSpec{ + tfv1.TFReplicaTypeMaster: { + Replicas: newInt32(1), + Template: testutil.NewTFReplicaSpecTemplate(), + }, + }, + }, + { + workers: [3]int32{0, 0, 0}, + tfJob: testutil.NewTFJobV2(0, 0, 0, 1, 0), + expected: true, + expectedErr: false, + replicas: map[commonv1.ReplicaType]*commonv1.ReplicaSpec{ + tfv1.TFReplicaTypeChief: { + Replicas: newInt32(1), + Template: testutil.NewTFReplicaSpecTemplate(), + }, + }, + }, + { + workers: [3]int32{1, 1, 0}, + tfJob: testutil.NewTFJobV2(2, 0, 0, 0, 0), + expected: true, + expectedErr: false, + replicas: map[commonv1.ReplicaType]*commonv1.ReplicaSpec{ + tfv1.TFReplicaTypeWorker: { + Replicas: newInt32(2), + Template: testutil.NewTFReplicaSpecTemplate(), + }, + }, + }, + { + workers: [3]int32{1, 0, 1}, + tfJob: testutil.NewTFJobV2(2, 0, 0, 0, 0), + expected: false, + expectedErr: false, + replicas: map[commonv1.ReplicaType]*commonv1.ReplicaSpec{ + tfv1.TFReplicaTypeWorker: { + Replicas: newInt32(2), + Template: testutil.NewTFReplicaSpecTemplate(), + }, + }, + }, + } + + jobNameTemplate := "test-worker0-complete-%d" + for i, tt := range tests { + tt.tfJob.SetName(fmt.Sprintf(jobNameTemplate, i)) + tt.tfJob.SetUID(uuid.NewUUID()) + // only related to worker status + initializeReplicaStatuses(&tt.tfJob.Status, tfv1.TFReplicaTypeWorker) + // set status and add pod to indexer + setStatusForTest(tt.tfJob, tfv1.TFReplicaTypeWorker, tt.workers[0], tt.workers[1], tt.workers[2], false, true, testK8sClient) + + // Adding this section to make sure all pods are created and cached + Eventually(func() error { + podList := &corev1.PodList{} + selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ + MatchLabels: reconciler.GenLabels(tt.tfJob.GetName()), + }) + if err != nil { + return err + } + listOpt := client.MatchingLabelsSelector{ + Selector: selector, + } + err = testK8sClient.List(context.Background(), podList, listOpt) + if err != nil { + return nil + } + totalExpectedPodCount := tt.workers[0] + tt.workers[1] + tt.workers[2] + if len(podList.Items) != int(totalExpectedPodCount) { + return fmt.Errorf("pod number (%d) for %s not match for expected pod number %d", + len(podList.Items), tt.tfJob.GetName(), totalExpectedPodCount) + } + return nil + }, timeout, interval).Should(BeNil()) + + got, err := reconciler.IsWorker0Completed(tt.tfJob, tt.replicas) + + if err != nil { + Expect(err).To(Equal(tt.expectedErr)) + } else { + Expect(got).To(Equal(tt.expected)) + } + } + }) + }) +}) diff --git a/pkg/controller.v1/tensorflow/status_test.go b/pkg/controller.v1/tensorflow/status_test.go new file mode 100644 index 0000000000..00963d90b5 --- /dev/null +++ b/pkg/controller.v1/tensorflow/status_test.go @@ -0,0 +1,609 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tensorflow + +import ( + "context" + "fmt" + "time" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + "github.com/kubeflow/common/pkg/util" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/uuid" + "sigs.k8s.io/controller-runtime/pkg/client" + + tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1" + "github.com/kubeflow/training-operator/pkg/common/util/v1/testutil" +) + +var _ = Describe("TFJob controller", func() { + // Define utility constants for object names and testing timeouts/durations and intervals. + const ( + timeout = 10 * time.Second + interval = 1000 * time.Millisecond + ) + + Context("Test Failed", func() { + It("should update TFJob with failed status", func() { + By("creating a TFJob with replicaStatues initialized") + tfJob := testutil.NewTFJob(3, 0) + initializeReplicaStatuses(&tfJob.Status, tfv1.TFReplicaTypeWorker) + + By("prepare pod") + refs := []metav1.OwnerReference{ + *reconciler.GenOwnerReference(tfJob), + } + pod := testutil.NewBasePod("pod", tfJob, refs) + pod.Status.Phase = v1.PodFailed + + By("update job replica statuses") + updateJobReplicaStatuses(&tfJob.Status, tfv1.TFReplicaTypeWorker, pod) + Expect(tfJob.Status.ReplicaStatuses[tfv1.TFReplicaTypeWorker].Failed).Should(Equal(int32(1))) + + By("update job status") + Expect(reconciler.UpdateJobStatus(tfJob, tfJob.Spec.TFReplicaSpecs, &tfJob.Status)).To(Succeed()) + + By("finding failed job status") + found := false + for _, condition := range tfJob.Status.Conditions { + if condition.Type == commonv1.JobFailed { + found = true + } + } + Expect(found).To(BeTrue()) + }) + }) + + Context("Test Status", func() { + It("should update TFJob with desired status", func() { + type testCase struct { + description string + tfJob *tfv1.TFJob + + expectedFailedPS int32 + expectedSucceededPS int32 + expectedActivePS int32 + + expectedFailedWorker int32 + expectedSucceededWorker int32 + expectedActiveWorker int32 + + expectedFailedChief int32 + expectedSucceededChief int32 + expectedActiveChief int32 + + restart bool + worker0Completed bool + + expectedType commonv1.JobConditionType + } + + testCases := []testCase{ + { + description: "Chief worker is succeeded", + tfJob: testutil.NewTFJobWithChief(1, 0), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 0, + expectedSucceededWorker: 1, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 1, + expectedActiveChief: 0, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobSucceeded, + }, + { + description: "Chief worker is running", + tfJob: testutil.NewTFJobWithChief(1, 0), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 0, + expectedSucceededWorker: 0, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 1, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobRunning, + }, + { + description: "Chief worker is failed", + tfJob: testutil.NewTFJobWithChief(1, 0), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 0, + expectedSucceededWorker: 0, + expectedActiveWorker: 0, + expectedFailedChief: 1, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobFailed, + }, + { + description: "(No chief worker) Worker is failed", + tfJob: testutil.NewTFJob(1, 0), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 1, + expectedSucceededWorker: 0, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobFailed, + }, + { + description: "(No chief worker) Worker is succeeded", + tfJob: testutil.NewTFJob(1, 0), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 0, + expectedSucceededWorker: 1, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobSucceeded, + }, + { + description: "(No chief worker) Worker is running", + tfJob: testutil.NewTFJob(1, 0), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 0, + expectedSucceededWorker: 0, + expectedActiveWorker: 1, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobRunning, + }, + { + description: "(No chief worker) 2 workers are succeeded, 2 workers are active", + tfJob: testutil.NewTFJob(4, 2), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 2, + expectedFailedWorker: 0, + expectedSucceededWorker: 2, + expectedActiveWorker: 2, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobRunning, + }, + { + description: "(No chief worker) 2 workers are running, 2 workers are failed", + tfJob: testutil.NewTFJob(4, 2), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 2, + expectedFailedWorker: 2, + expectedSucceededWorker: 0, + expectedActiveWorker: 2, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobFailed, + }, + { + description: "(No chief worker) 2 workers are succeeded, 2 workers are failed", + tfJob: testutil.NewTFJob(4, 2), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 2, + expectedFailedWorker: 2, + expectedSucceededWorker: 2, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobFailed, + }, + { + description: "(No chief worker) worker-0 are succeeded, 3 workers are active", + tfJob: testutil.NewTFJob(4, 2), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 2, + expectedFailedWorker: 0, + expectedSucceededWorker: 1, + expectedActiveWorker: 3, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: true, + expectedType: commonv1.JobSucceeded, + }, + { + description: "(No chief worker, successPolicy: AllWorkers) worker-0 are succeeded, 3 workers are active", + tfJob: testutil.NewTFJobWithSuccessPolicy(4, 0, tfv1.SuccessPolicyAllWorkers), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 0, + expectedSucceededWorker: 1, + expectedActiveWorker: 3, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: true, + expectedType: commonv1.JobRunning, + }, + { + description: "(No chief worker, successPolicy: AllWorkers) 4 workers are succeeded", + tfJob: testutil.NewTFJobWithSuccessPolicy(4, 0, tfv1.SuccessPolicyAllWorkers), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 0, + expectedSucceededWorker: 4, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: true, + expectedType: commonv1.JobSucceeded, + }, + { + description: "(No chief worker, successPolicy: AllWorkers) worker-0 is succeeded, 2 workers are running, 1 worker is failed", + tfJob: testutil.NewTFJobWithSuccessPolicy(4, 0, tfv1.SuccessPolicyAllWorkers), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 1, + expectedSucceededWorker: 1, + expectedActiveWorker: 2, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: true, + expectedType: commonv1.JobFailed, + }, + { + description: "Chief is running, workers are failed", + tfJob: testutil.NewTFJobWithChief(4, 2), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 2, + expectedFailedWorker: 4, + expectedSucceededWorker: 0, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 1, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobRunning, + }, + { + description: "Chief is running, workers are succeeded", + tfJob: testutil.NewTFJobWithChief(4, 2), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 2, + expectedFailedWorker: 0, + expectedSucceededWorker: 4, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 1, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobRunning, + }, + { + description: "Chief is running, a PS is failed", + tfJob: testutil.NewTFJobWithChief(4, 2), + expectedFailedPS: 1, + expectedSucceededPS: 0, + expectedActivePS: 1, + expectedFailedWorker: 0, + expectedSucceededWorker: 4, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 1, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobFailed, + }, + { + description: "Chief is failed, workers are succeeded", + tfJob: testutil.NewTFJobWithChief(4, 2), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 2, + expectedFailedWorker: 0, + expectedSucceededWorker: 4, + expectedActiveWorker: 0, + expectedFailedChief: 1, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobFailed, + }, + { + description: "Chief is succeeded, workers are failed", + tfJob: testutil.NewTFJobWithChief(4, 2), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 2, + expectedFailedWorker: 4, + expectedSucceededWorker: 0, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 1, + expectedActiveChief: 0, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobSucceeded, + }, + { + description: "Chief is failed and restarting", + tfJob: testutil.NewTFJobWithChief(4, 2), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 2, + expectedFailedWorker: 4, + expectedSucceededWorker: 0, + expectedActiveWorker: 0, + expectedFailedChief: 1, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: true, + worker0Completed: false, + expectedType: commonv1.JobRestarting, + }, + } + + jobNameTemplate := "test-status-%d" + for i, c := range testCases { + reconciler.Log.Info("testing case", "description", c.description) + c.tfJob.SetName(fmt.Sprintf(jobNameTemplate, i)) + c.tfJob.SetUID(uuid.NewUUID()) + + initializeReplicaStatuses(&c.tfJob.Status, tfv1.TFReplicaTypeWorker) + initializeReplicaStatuses(&c.tfJob.Status, tfv1.TFReplicaTypeChief) + initializeReplicaStatuses(&c.tfJob.Status, tfv1.TFReplicaTypePS) + + setStatusForTest(c.tfJob, tfv1.TFReplicaTypePS, c.expectedFailedPS, c.expectedSucceededPS, c.expectedActivePS, c.restart, c.worker0Completed, testK8sClient) + setStatusForTest(c.tfJob, tfv1.TFReplicaTypeWorker, c.expectedFailedWorker, c.expectedSucceededWorker, c.expectedActiveWorker, c.restart, c.worker0Completed, testK8sClient) + setStatusForTest(c.tfJob, tfv1.TFReplicaTypeChief, c.expectedFailedChief, c.expectedSucceededChief, c.expectedActiveChief, c.restart, c.worker0Completed, testK8sClient) + + // Adding this section to make sure all pods are created and cached + Eventually(func() error { + podList := &corev1.PodList{} + basicLabels := reconciler.GenLabels(c.tfJob.GetName()) + selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ + MatchLabels: basicLabels, + }) + if err != nil { + return err + } + listOpt := client.MatchingLabelsSelector{ + Selector: selector, + } + err = testK8sClient.List(context.Background(), podList, listOpt) + if err != nil { + return nil + } + totalExpectedPodCount := c.expectedFailedPS + c.expectedSucceededPS + c.expectedActivePS + + c.expectedFailedWorker + c.expectedSucceededWorker + c.expectedActiveWorker + + c.expectedFailedChief + c.expectedSucceededChief + c.expectedActiveChief + if len(podList.Items) != int(totalExpectedPodCount) { + return fmt.Errorf("pod number (%d) for %s not match for expected pod number %d", + len(podList.Items), c.tfJob.GetName(), totalExpectedPodCount) + } + return nil + }, timeout, interval).Should(BeNil()) + + _ = reconciler.ReconcileJobs(c.tfJob, c.tfJob.Spec.TFReplicaSpecs, c.tfJob.Status, &c.tfJob.Spec.RunPolicy) + + Expect(filterOutConditionTest(c.tfJob.Status)).Should(Succeed()) + + reconciler.Log.Info("checking status", "tfJob.Status", c.tfJob.Status) + found := false + for _, condition := range c.tfJob.Status.Conditions { + if condition.Type == c.expectedType { + found = true + } + } + Expect(found).To(BeTrue()) + reconciler.Log.Info("passed!", + "job name", c.tfJob.GetName(), "job description", c.description) + } + }) + }) +}) + +func setStatusForTest(tfJob *tfv1.TFJob, rtype commonv1.ReplicaType, failed, succeeded, active int32, restart bool, worker0Completed bool, client client.Client) { + if restart == true { + tfJob.Spec.TFReplicaSpecs[rtype].RestartPolicy = commonv1.RestartPolicyExitCode + } + + basicLabels := reconciler.GenLabels(tfJob.GetName()) + + const ( + timeout = 10 * time.Second + interval = 1000 * time.Millisecond + ) + + ctx := context.Background() + + var typ string + switch rtype { + case tfv1.TFReplicaTypeWorker: + typ = testutil.LabelWorker + case tfv1.TFReplicaTypePS: + typ = testutil.LabelPS + case tfv1.TFReplicaTypeChief: + typ = testutil.LabelChief + default: + fmt.Println("wrong type") + } + refs := []metav1.OwnerReference{ + *reconciler.GenOwnerReference(tfJob), + } + + var i int32 + index := 0 + for i = 0; i < succeeded; i++ { + pod := testutil.NewPod(tfJob, typ, index, refs) + for k, v := range basicLabels { + pod.Labels[k] = v + } + po := &corev1.Pod{} + _ = client.Create(ctx, pod) + key := genKeyFromJob(pod) + Eventually(func() error { + if err := client.Get(ctx, key, po); err != nil { + return err + } + + po.Status.Phase = corev1.PodSucceeded + if worker0Completed == true && rtype == tfv1.TFReplicaTypeWorker && index == 0 { + po.Status.ContainerStatuses = []corev1.ContainerStatus{ + { + Name: tfv1.DefaultContainerName, + State: corev1.ContainerState{ + Terminated: &corev1.ContainerStateTerminated{ + ExitCode: int32(0), // exit with 0 + }, + }, + }, + } + } + + return client.Status().Update(ctx, po) + }, timeout, interval).Should(BeNil()) + + updateJobReplicaStatuses(&tfJob.Status, rtype, po) + + index++ + } + for i = 0; i < failed; i++ { + pod := testutil.NewPod(tfJob, typ, index, refs) + for k, v := range basicLabels { + pod.Labels[k] = v + } + po := &corev1.Pod{} + _ = client.Create(ctx, pod) + key := genKeyFromJob(pod) + Eventually(func() error { + + if err := client.Get(ctx, key, po); err != nil { + return err + } + + po.Status.Phase = corev1.PodFailed + if restart == true { + if po.Status.ContainerStatuses == nil { + po.Status.ContainerStatuses = []corev1.ContainerStatus{ + { + Name: tfv1.DefaultContainerName, + State: corev1.ContainerState{ + Terminated: &corev1.ContainerStateTerminated{ + ExitCode: int32(130), // 130 is a retryable code + }, + }, + }, + } + } + } + + return client.Status().Update(ctx, po) + }, timeout, interval).Should(BeNil()) + + updateJobReplicaStatuses(&tfJob.Status, rtype, po) + index++ + } + for i = 0; i < active; i++ { + pod := testutil.NewPod(tfJob, typ, index, refs) + for k, v := range basicLabels { + pod.Labels[k] = v + } + po := &corev1.Pod{} + Expect(client.Create(ctx, pod)).Should(Succeed()) + key := genKeyFromJob(pod) + Eventually(func() error { + if err := client.Get(ctx, key, po); err != nil { + return err + } + + po.Status.Phase = corev1.PodRunning + + return client.Status().Update(ctx, po) + }, timeout, interval).Should(BeNil()) + + updateJobReplicaStatuses(&tfJob.Status, rtype, po) + index++ + } +} + +func genKeyFromJob(job client.Object) types.NamespacedName { + ns := metav1.NamespaceDefault + if job.GetNamespace() != "" { + ns = job.GetNamespace() + } + return types.NamespacedName{ + Namespace: ns, + Name: job.GetName(), + } +} + +func filterOutConditionTest(status commonv1.JobStatus) error { + flag := util.IsFailed(status) || util.IsSucceeded(status) + for _, condition := range status.Conditions { + if flag && condition.Type == commonv1.JobRunning && condition.Status == corev1.ConditionTrue { + return fmt.Errorf("error condition status when succeeded or failed") + } + } + return nil +} diff --git a/pkg/controller.v1/tensorflow/suite_test.go b/pkg/controller.v1/tensorflow/suite_test.go index 640c6284c5..ac2b0a9961 100644 --- a/pkg/controller.v1/tensorflow/suite_test.go +++ b/pkg/controller.v1/tensorflow/suite_test.go @@ -15,12 +15,17 @@ package tensorflow import ( + "context" + "fmt" "path/filepath" "testing" + "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" "k8s.io/client-go/kubernetes/scheme" + ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/envtest" "sigs.k8s.io/controller-runtime/pkg/envtest/printer" @@ -34,8 +39,13 @@ import ( // These tests use Ginkgo (BDD-style Go testing framework). Refer to // http://onsi.github.io/ginkgo/ to learn more about Ginkgo. -var k8sClient client.Client -var testEnv *envtest.Environment +var ( + testK8sClient client.Client + testEnv *envtest.Environment + testCtx context.Context + testCancel context.CancelFunc + reconciler *TFJobReconciler +) func TestAPIs(t *testing.T) { RegisterFailHandler(Fail) @@ -46,8 +56,14 @@ func TestAPIs(t *testing.T) { } var _ = BeforeSuite(func() { + const ( + timeout = 10 * time.Second + interval = 1000 * time.Millisecond + ) logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true))) + testCtx, testCancel = context.WithCancel(context.TODO()) + By("bootstrapping test environment") testEnv = &envtest.Environment{ CRDDirectoryPaths: []string{filepath.Join("..", "..", "..", "manifests", "base", "crds")}, @@ -63,14 +79,41 @@ var _ = BeforeSuite(func() { //+kubebuilder:scaffold:scheme - k8sClient, err = client.New(cfg, client.Options{Scheme: scheme.Scheme}) + testK8sClient, err = client.New(cfg, client.Options{Scheme: scheme.Scheme}) + Expect(err).NotTo(HaveOccurred()) + Expect(testK8sClient).NotTo(BeNil()) + + mgr, err := ctrl.NewManager(cfg, ctrl.Options{ + MetricsBindAddress: "0", + }) Expect(err).NotTo(HaveOccurred()) - Expect(k8sClient).NotTo(BeNil()) + reconciler = NewReconciler(mgr, false) + Expect(reconciler.SetupWithManager(mgr)).NotTo(HaveOccurred()) + + go func() { + defer GinkgoRecover() + err = mgr.Start(testCtx) + Expect(err).ToNot(HaveOccurred(), "failed to run manager") + }() + + // This step is introduced to make sure cache starts before running any tests + Eventually(func() error { + nsList := &corev1.NamespaceList{} + if err := testK8sClient.List(context.Background(), nsList); err != nil { + return err + } else if len(nsList.Items) < 1 { + return fmt.Errorf("cannot get at lease one namespace, got %d", len(nsList.Items)) + } + return nil + }, timeout, interval).Should(BeNil()) }, 60) var _ = AfterSuite(func() { By("tearing down the test environment") + testCancel() + // Give 5 seconds to stop all tests + time.Sleep(5 * time.Second) err := testEnv.Stop() Expect(err).NotTo(HaveOccurred()) }) diff --git a/pkg/controller.v1/tensorflow/tfjob_controller_test.go b/pkg/controller.v1/tensorflow/tfjob_controller_test.go new file mode 100644 index 0000000000..db634e4eb4 --- /dev/null +++ b/pkg/controller.v1/tensorflow/tfjob_controller_test.go @@ -0,0 +1,328 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tensorflow + +import ( + "context" + "fmt" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/uuid" + "sigs.k8s.io/controller-runtime/pkg/client" + + tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1" + "github.com/kubeflow/training-operator/pkg/common/util/v1/testutil" +) + +var _ = Describe("TFJob controller", func() { + // Define utility constants for object names and testing timeouts/durations and intervals. + //const ( + // timeout = 10 * time.Second + // interval = 1000 * time.Millisecond + //) + + Context("Test Normal Path", func() { + It("should create desired Pods and Services", func() { + var ( + tfJobRunning = commonv1.JobRunning + tfJobSucceeded = commonv1.JobSucceeded + ) + + testCases := map[string]struct { + worker int + ps int + + // pod setup + // ControllerError error + // jobKeyForget bool + + pendingWorkerPods int32 + activeWorkerPods int32 + succeededWorkerPods int32 + failedWorkerPods int32 + + pendingPSPods int32 + activePSPods int32 + succeededPSPods int32 + failedPSPods int32 + + activeWorkerServices int32 + activePSServices int32 + + // expectations + expectedPodCreations int32 + expectedPodDeletions int32 + expectedServiceCreations int32 + + expectedActiveWorkerPods int32 + expectedSucceededWorkerPods int32 + expectedFailedWorkerPods int32 + + expectedActivePSPods int32 + expectedSucceededPSPods int32 + expectedFailedPSPods int32 + + expectedCondition *commonv1.JobConditionType + expectedConditionReason string + + // There are some cases that should not check start time since the field should be set in the previous sync loop. + needCheckStartTime bool + }{ + "Local TFJob is created": { + 1, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, + 1, 0, 1, + 0, 0, 0, + 0, 0, 0, + // We can not check if it is created since the condition is set in addTFJob. + nil, "", + false, + }, + "Distributed TFJob (4 workers, 2 PS) is created": { + 4, 2, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, + 6, 0, 6, + 0, 0, 0, + 0, 0, 0, + nil, "", + false, + }, + "Distributed TFJob (4 workers, 2 PS) is created and all replicas are pending": { + 4, 2, + 4, 0, 0, 0, + 2, 0, 0, 0, + 4, 2, + 0, 0, 0, + 0, 0, 0, + 0, 0, 0, + nil, "", + false, + }, + "Distributed TFJob (4 workers, 2 PS) is created and all replicas are running": { + 4, 2, + 0, 4, 0, 0, + 0, 2, 0, 0, + 4, 2, + 0, 0, 0, + 4, 0, 0, + 2, 0, 0, + &tfJobRunning, tfJobRunningReason, + true, + }, + "Distributed TFJob (4 workers, 2 PS) is created, 2 workers, 1 PS are pending": { + 4, 2, + 2, 0, 0, 0, + 1, 0, 0, 0, + 2, 1, + 3, 0, 3, + 0, 0, 0, + 0, 0, 0, + nil, "", + false, + }, + "Distributed TFJob (4 workers, 2 PS) is created, 2 workers, 1 PS are pending, 1 worker is running": { + 4, 2, + 2, 1, 0, 0, + 1, 0, 0, 0, + 3, 1, + 2, 0, 2, + 1, 0, 0, + 0, 0, 0, + &tfJobRunning, tfJobRunningReason, + false, + }, + "Distributed TFJob (4 workers, 2 PS) is created, 2 workers, 1 PS are pending, 1 worker is succeeded": { + 4, 2, + 2, 0, 1, 0, + 1, 0, 0, 0, + 3, 1, + 2, 0, 2, + 0, 1, 0, + 0, 0, 0, + nil, "", + false, + }, + "Distributed TFJob (4 workers, 2 PS) is succeeded": { + 4, 2, + 0, 0, 4, 0, + 0, 0, 2, 0, + 4, 2, + 0, 0, 0, + 0, 4, 0, + 0, 2, 0, + &tfJobSucceeded, tfJobSucceededReason, + false, + }, + } + + jobNameTemplate := "test-case-norm-%d" + caseIdx := 0 + for name, tc := range testCases { + By(name) + ctx := context.Background() + jobName := fmt.Sprintf(jobNameTemplate, caseIdx) + caseIdx++ + + tfJob := testutil.NewTFJob(tc.worker, tc.ps) + tfJob.SetName(jobName) + tfJob.SetUID(uuid.NewUUID()) + + refs := []metav1.OwnerReference{*reconciler.GenOwnerReference(tfJob)} + basicLabels := reconciler.GenLabels(tfJob.GetName()) + + testutil.SetPodsStatusesV2(testK8sClient, tfJob, testutil.LabelWorker, tc.pendingWorkerPods, tc.activeWorkerPods, tc.succeededWorkerPods, tc.failedWorkerPods, nil, refs, basicLabels) + testutil.SetPodsStatusesV2(testK8sClient, tfJob, testutil.LabelPS, tc.pendingPSPods, tc.activePSPods, tc.succeededPSPods, tc.failedPSPods, nil, refs, basicLabels) + + testutil.SetServicesV2(testK8sClient, tfJob, testutil.LabelWorker, tc.activeWorkerServices, refs, basicLabels) + testutil.SetServicesV2(testK8sClient, tfJob, testutil.LabelPS, tc.activePSServices, refs, basicLabels) + + totalPodNumber := int(tc.pendingWorkerPods + tc.activeWorkerPods + tc.succeededWorkerPods + tc.failedWorkerPods + tc.pendingPSPods + tc.activePSPods + tc.succeededPSPods + tc.failedPSPods) + totalServiceNumber := int(tc.activeWorkerServices + tc.activePSServices) + + selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{MatchLabels: reconciler.GenLabels(tfJob.GetName())}) + Expect(err).Should(BeNil()) + listOpt := client.MatchingLabelsSelector{Selector: selector} + Eventually(func() error { + podList := &corev1.PodList{} + svcList := &corev1.ServiceList{} + + err = testK8sClient.List(ctx, podList, listOpt) + if err != nil { + return err + } + if len(podList.Items) != totalPodNumber { + return fmt.Errorf("expected %d Pods, got %d", totalPodNumber, len(podList.Items)) + } + + err = testK8sClient.List(ctx, svcList, listOpt) + if err != nil { + return err + } + if len(svcList.Items) != totalServiceNumber { + return fmt.Errorf("expected %d Services, got %d", totalServiceNumber, len(svcList.Items)) + } + return nil + }).Should(BeNil()) + + _ = reconciler.ReconcileJobs(tfJob, tfJob.Spec.TFReplicaSpecs, tfJob.Status, &tfJob.Spec.RunPolicy) + + // Check the number of Pods and Services + //var pods []*corev1.Pod = nil + //var svcs []*corev1.Service = nil + Eventually(func() error { + podList := &corev1.PodList{} + svcList := &corev1.ServiceList{} + + err = testK8sClient.List(ctx, podList, listOpt) + if err != nil { + return err + } + podCreatedNumber := 0 + if len(podList.Items) > totalPodNumber { + podCreatedNumber = len(podList.Items) - totalPodNumber + } + podDeletedNumber := 0 + if len(podList.Items) < totalPodNumber { + podDeletedNumber = totalPodNumber - len(podList.Items) + } + if podCreatedNumber != int(tc.expectedPodCreations) { + return fmt.Errorf("%s: unexpected number of pod creates. Expected %d, saw %d\n", name, tc.expectedPodCreations, podCreatedNumber) + } + if podDeletedNumber != int(tc.expectedPodDeletions) { + return fmt.Errorf("%s: unexpected number of service creates. Expected %d, saw %d\n", name, tc.expectedServiceCreations, podDeletedNumber) + } + // check controller references for all pods + for _, p := range podList.Items { + for _, ref := range p.GetOwnerReferences() { + if ref.APIVersion != tfv1.SchemeGroupVersion.String() { + return fmt.Errorf("controllerRef.APIVersion = %q, want %q", ref.APIVersion, tfv1.SchemeGroupVersion.String()) + } + if ref.Kind != tfv1.Kind { + return fmt.Errorf("controllerRef.Kind = %q, want %q", ref.Kind, tfv1.Kind) + } + if ref.Name != tfJob.GetName() { + return fmt.Errorf("controllerRef.Name = %q, want %q", ref.Name, tfJob.GetName()) + } + if ref.UID != tfJob.GetUID() { + return fmt.Errorf("controllerRef.UID = %q, want %q", ref.UID, tfJob.GetUID()) + } + } + } + + err = testK8sClient.List(ctx, svcList, listOpt) + if err != nil { + return err + } + serviceCreatedNumber := 0 + if len(svcList.Items) > totalServiceNumber { + serviceCreatedNumber = len(svcList.Items) - totalServiceNumber + } + if serviceCreatedNumber != int(tc.expectedServiceCreations) { + return fmt.Errorf("%s: unexpected number of pod deletes. Expected %d, saw %d\n", name, tc.expectedPodDeletions, serviceCreatedNumber) + } + // check controller reference for all services + for _, s := range svcList.Items { + for _, ref := range s.GetOwnerReferences() { + if ref.APIVersion != tfv1.SchemeGroupVersion.String() { + return fmt.Errorf("controllerRef.APIVersion = %q, want %q", ref.APIVersion, tfv1.SchemeGroupVersion.String()) + } + if ref.Kind != tfv1.Kind { + return fmt.Errorf("controllerRef.Kind = %q, want %q", ref.Kind, tfv1.Kind) + } + if ref.Name != tfJob.GetName() { + return fmt.Errorf("controllerRef.Name = %q, want %q", ref.Name, tfJob.GetName()) + } + if ref.UID != tfJob.GetUID() { + return fmt.Errorf("controllerRef.UID = %q, want %q", ref.UID, tfJob.GetUID()) + } + } + } + return nil + }).Should(BeNil()) + + // Validate Worker status + if tfJob.Status.ReplicaStatuses[tfv1.TFReplicaTypeWorker] != nil { + Expect(tfJob.Status.ReplicaStatuses[tfv1.TFReplicaTypeWorker].Active).To(Equal(tc.expectedActiveWorkerPods)) + Expect(tfJob.Status.ReplicaStatuses[tfv1.TFReplicaTypeWorker].Succeeded).To(Equal(tc.expectedSucceededWorkerPods)) + Expect(tfJob.Status.ReplicaStatuses[tfv1.TFReplicaTypeWorker].Failed).To(Equal(tc.expectedFailedWorkerPods)) + } + // Validate PS status + if tfJob.Status.ReplicaStatuses[tfv1.TFReplicaTypePS] != nil { + Expect(tfJob.Status.ReplicaStatuses[tfv1.TFReplicaTypePS].Active).To(Equal(tc.expectedActivePSPods)) + Expect(tfJob.Status.ReplicaStatuses[tfv1.TFReplicaTypePS].Succeeded).To(Equal(tc.expectedSucceededPSPods)) + Expect(tfJob.Status.ReplicaStatuses[tfv1.TFReplicaTypePS].Failed).To(Equal(tc.expectedFailedPSPods)) + } + + // Validate StartTime + if tc.needCheckStartTime { + Expect(tfJob.Status.StartTime).NotTo(BeNil()) + } + + // Validate Conditions + if tc.expectedCondition != nil { + Expect(testutil.CheckCondition(tfJob, *tc.expectedCondition, tc.expectedConditionReason)).Should(BeTrue()) + } + } + }) + }) +}) diff --git a/pkg/controller.v1/tensorflow/util_test.go b/pkg/controller.v1/tensorflow/util_test.go new file mode 100644 index 0000000000..435440fab0 --- /dev/null +++ b/pkg/controller.v1/tensorflow/util_test.go @@ -0,0 +1,74 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tensorflow + +import ( + "testing" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/uuid" + + tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1" +) + +func TestGenOwnerReference(t *testing.T) { + testName := "test-tfjob" + testUID := uuid.NewUUID() + tfJob := &tfv1.TFJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: testName, + UID: testUID, + }, + } + + ref := reconciler.GenOwnerReference(tfJob) + if ref.UID != testUID { + t.Errorf("Expected UID %s, got %s", testUID, ref.UID) + } + if ref.Name != testName { + t.Errorf("Expected Name %s, got %s", testName, ref.Name) + } + if ref.APIVersion != tfv1.SchemeGroupVersion.String() { + t.Errorf("Expected APIVersion %s, got %s", tfv1.SchemeGroupVersion.String(), ref.APIVersion) + } +} + +func TestGenLabels(t *testing.T) { + testJobName := "test/key" + expctedVal := "test-key" + + labels := reconciler.GenLabels(testJobName) + jobNameLabel := commonv1.JobNameLabel + JobNameLabelDeprecated := commonv1.JobNameLabelDeprecated + + if labels[jobNameLabel] != expctedVal { + t.Errorf("Expected %s %s, got %s", jobNameLabel, expctedVal, jobNameLabel) + } + + if labels[JobNameLabelDeprecated] != expctedVal { + t.Errorf("Expected %s %s, got %s", JobNameLabelDeprecated, expctedVal, JobNameLabelDeprecated) + } + + if labels[commonv1.GroupNameLabelDeprecated] != tfv1.GroupVersion.Group { + t.Errorf("Expected %s %s, got %s", commonv1.GroupNameLabelDeprecated, tfv1.GroupVersion.Group, + labels[commonv1.GroupNameLabelDeprecated]) + } + + if labels[commonv1.OperatorNameLabel] != controllerName { + t.Errorf("Expected %s %s, got %s", commonv1.OperatorNameLabel, controllerName, + labels[commonv1.OperatorNameLabel]) + } +} From f7acfc6738d45d3969749163c47862c5cb4ac9f7 Mon Sep 17 00:00:00 2001 From: Wang Zhang Date: Thu, 23 Dec 2021 17:00:02 +0800 Subject: [PATCH 2/3] fix comment fix test exit code issue fix pod spec --- pkg/common/util/v1/testutil/service.go | 4 +-- pkg/controller.v1/tensorflow/pod_test.go | 27 ++++++++++++++----- .../tensorflow/tfjob_controller_test.go | 6 ----- 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/pkg/common/util/v1/testutil/service.go b/pkg/common/util/v1/testutil/service.go index e2451f45e6..327cfd5c9a 100644 --- a/pkg/common/util/v1/testutil/service.go +++ b/pkg/common/util/v1/testutil/service.go @@ -27,8 +27,8 @@ import ( ) const ( - DummyPortName = "dummy" - DummyPort int32 = 1221 + DummyPortName string = "dummy" + DummyPort int32 = 1221 ) func NewBaseService(name string, job metav1.Object, refs []metav1.OwnerReference) *corev1.Service { diff --git a/pkg/controller.v1/tensorflow/pod_test.go b/pkg/controller.v1/tensorflow/pod_test.go index a853ac03f2..f625fe640f 100644 --- a/pkg/controller.v1/tensorflow/pod_test.go +++ b/pkg/controller.v1/tensorflow/pod_test.go @@ -236,14 +236,17 @@ var _ = Describe("TFJob controller", func() { for k, v := range basicLabels { pod.Labels[k] = v } + pod.Spec.Containers = append(pod.Spec.Containers, corev1.Container{ + Name: tfv1.DefaultContainerName, + Image: testutil.DummyContainerImage, + }) Expect(testK8sClient.Create(ctx, pod)).Should(Succeed()) - po := &corev1.Pod{} + created := &corev1.Pod{} key := types.NamespacedName{Namespace: metav1.NamespaceDefault, Name: pod.GetName()} - Expect(testK8sClient.Get(ctx, key, po)).Should(Succeed()) - po.Status.Phase = corev1.PodFailed - po.Spec.Containers = append(po.Spec.Containers, corev1.Container{}) - po.Status.ContainerStatuses = append(po.Status.ContainerStatuses, corev1.ContainerStatus{ + Expect(testK8sClient.Get(ctx, key, created)).Should(Succeed()) + created.Status.Phase = corev1.PodFailed + created.Status.ContainerStatuses = append(created.Status.ContainerStatuses, corev1.ContainerStatus{ Name: tfv1.DefaultContainerName, State: corev1.ContainerState{ Terminated: &corev1.ContainerStateTerminated{ @@ -251,7 +254,19 @@ var _ = Describe("TFJob controller", func() { }, }, }) - Expect(testK8sClient.Status().Update(ctx, po)) + Expect(testK8sClient.Status().Update(ctx, created)) + + // Make sure the version of pod created is updated with desired status + Eventually(func() error { + updated := &corev1.Pod{} + if err := testK8sClient.Get(ctx, key, updated); err != nil { + return err + } + if updated.Status.Phase != corev1.PodFailed { + return fmt.Errorf("pod status is not Failed") + } + return nil + }, timeout, interval).Should(BeNil()) _ = reconciler.ReconcileJobs(tfJob, tfJob.Spec.TFReplicaSpecs, tfJob.Status, &tfJob.Spec.RunPolicy) diff --git a/pkg/controller.v1/tensorflow/tfjob_controller_test.go b/pkg/controller.v1/tensorflow/tfjob_controller_test.go index db634e4eb4..2c8e8014aa 100644 --- a/pkg/controller.v1/tensorflow/tfjob_controller_test.go +++ b/pkg/controller.v1/tensorflow/tfjob_controller_test.go @@ -31,12 +31,6 @@ import ( ) var _ = Describe("TFJob controller", func() { - // Define utility constants for object names and testing timeouts/durations and intervals. - //const ( - // timeout = 10 * time.Second - // interval = 1000 * time.Millisecond - //) - Context("Test Normal Path", func() { It("should create desired Pods and Services", func() { var ( From d78d8836dee1833eebedf1d2b755b96690089cbe Mon Sep 17 00:00:00 2001 From: Wang Zhang Date: Wed, 5 Jan 2022 23:04:52 +0800 Subject: [PATCH 3/3] keep SetPodsStatuses --- pkg/common/util/v1/testutil/pod.go | 2 +- pkg/controller.v1/tensorflow/job_test.go | 12 ++++++------ .../tensorflow/tfjob_controller_test.go | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pkg/common/util/v1/testutil/pod.go b/pkg/common/util/v1/testutil/pod.go index 0ab1e73848..732436663e 100644 --- a/pkg/common/util/v1/testutil/pod.go +++ b/pkg/common/util/v1/testutil/pod.go @@ -72,7 +72,7 @@ func NewPodList(count int32, status corev1.PodPhase, job metav1.Object, typ stri return pods } -func SetPodsStatusesV2(client client.Client, job metav1.Object, typ string, +func SetPodsStatuses(client client.Client, job metav1.Object, typ string, pendingPods, activePods, succeededPods, failedPods int32, restartCounts []int32, refs []metav1.OwnerReference, basicLabels map[string]string) { timeout := 10 * time.Second diff --git a/pkg/controller.v1/tensorflow/job_test.go b/pkg/controller.v1/tensorflow/job_test.go index b22d77d5b2..d407fb45fc 100644 --- a/pkg/controller.v1/tensorflow/job_test.go +++ b/pkg/controller.v1/tensorflow/job_test.go @@ -261,10 +261,10 @@ var _ = Describe("TFJob controller", func() { } By("creating Services and Pods with designed phases") - testutil.SetPodsStatusesV2(testK8sClient, tc.tfJob, testutil.LabelWorker, + testutil.SetPodsStatuses(testK8sClient, tc.tfJob, testutil.LabelWorker, tc.pendingWorkerPods, tc.activeWorkerPods, tc.succeededWorkerPods, tc.failedWorkerPods, nil, refs, basicLabels) - testutil.SetPodsStatusesV2(testK8sClient, tc.tfJob, testutil.LabelPS, + testutil.SetPodsStatuses(testK8sClient, tc.tfJob, testutil.LabelPS, tc.pendingPSPods, tc.activePSPods, tc.succeededPSPods, tc.failedPSPods, nil, refs, basicLabels) @@ -378,10 +378,10 @@ var _ = Describe("TFJob controller", func() { } By("creating Services and Pods with designed phases") - testutil.SetPodsStatusesV2(testK8sClient, tc.tfJob, testutil.LabelWorker, + testutil.SetPodsStatuses(testK8sClient, tc.tfJob, testutil.LabelWorker, tc.pendingWorkerPods, tc.activeWorkerPods, tc.succeededWorkerPods, tc.failedWorkerPods, nil, refs, basicLabels) - testutil.SetPodsStatusesV2(testK8sClient, tc.tfJob, testutil.LabelPS, + testutil.SetPodsStatuses(testK8sClient, tc.tfJob, testutil.LabelPS, tc.pendingPSPods, tc.activePSPods, tc.succeededPSPods, tc.failedPSPods, nil, refs, basicLabels) @@ -490,10 +490,10 @@ var _ = Describe("TFJob controller", func() { } By("creating Services and Pods with designed phases") - testutil.SetPodsStatusesV2(testK8sClient, tc.tfJob, testutil.LabelWorker, + testutil.SetPodsStatuses(testK8sClient, tc.tfJob, testutil.LabelWorker, tc.pendingWorkerPods, tc.activeWorkerPods, tc.succeededWorkerPods, tc.failedWorkerPods, tc.restartCounts, refs, basicLabels) - testutil.SetPodsStatusesV2(testK8sClient, tc.tfJob, testutil.LabelPS, + testutil.SetPodsStatuses(testK8sClient, tc.tfJob, testutil.LabelPS, tc.pendingPSPods, tc.activePSPods, tc.succeededPSPods, tc.failedPSPods, tc.restartCounts, refs, basicLabels) diff --git a/pkg/controller.v1/tensorflow/tfjob_controller_test.go b/pkg/controller.v1/tensorflow/tfjob_controller_test.go index 2c8e8014aa..76d9cd78fa 100644 --- a/pkg/controller.v1/tensorflow/tfjob_controller_test.go +++ b/pkg/controller.v1/tensorflow/tfjob_controller_test.go @@ -184,8 +184,8 @@ var _ = Describe("TFJob controller", func() { refs := []metav1.OwnerReference{*reconciler.GenOwnerReference(tfJob)} basicLabels := reconciler.GenLabels(tfJob.GetName()) - testutil.SetPodsStatusesV2(testK8sClient, tfJob, testutil.LabelWorker, tc.pendingWorkerPods, tc.activeWorkerPods, tc.succeededWorkerPods, tc.failedWorkerPods, nil, refs, basicLabels) - testutil.SetPodsStatusesV2(testK8sClient, tfJob, testutil.LabelPS, tc.pendingPSPods, tc.activePSPods, tc.succeededPSPods, tc.failedPSPods, nil, refs, basicLabels) + testutil.SetPodsStatuses(testK8sClient, tfJob, testutil.LabelWorker, tc.pendingWorkerPods, tc.activeWorkerPods, tc.succeededWorkerPods, tc.failedWorkerPods, nil, refs, basicLabels) + testutil.SetPodsStatuses(testK8sClient, tfJob, testutil.LabelPS, tc.pendingPSPods, tc.activePSPods, tc.succeededPSPods, tc.failedPSPods, nil, refs, basicLabels) testutil.SetServicesV2(testK8sClient, tfJob, testutil.LabelWorker, tc.activeWorkerServices, refs, basicLabels) testutil.SetServicesV2(testK8sClient, tfJob, testutil.LabelPS, tc.activePSServices, refs, basicLabels)