Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add unit tests for tensorflow controller #1511

Merged
merged 3 commits into from
Jan 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 67 additions & 49 deletions pkg/common/util/v1/testutil/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,81 +15,99 @@
package testutil

import (
"context"
"fmt"
"testing"
"time"

v1 "k8s.io/api/core/v1"
commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
. "github.com/onsi/gomega"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason to import gomega this way?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be consistent with the tutorial of kubebuilder: https://book.kubebuilder.io/cronjob-tutorial/writing-tests.html

corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/tools/cache"

tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"
)

const (
// labels for pods and servers.
tfReplicaTypeLabel = "replica-type"
Jeffwan marked this conversation as resolved.
Show resolved Hide resolved
tfReplicaIndexLabel = "replica-index"
DummyContainerName = "dummy"
DummyContainerImage = "dummy/dummy:latest"
)

var (
controllerKind = tfv1.GroupVersion.WithKind(TFJobKind)
)
func NewBasePod(name string, job metav1.Object, refs []metav1.OwnerReference) *corev1.Pod {

func NewBasePod(name string, tfJob *tfv1.TFJob) *v1.Pod {
return &v1.Pod{
return &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: name,
Labels: GenLabels(tfJob.Name),
Namespace: tfJob.Namespace,
OwnerReferences: []metav1.OwnerReference{*metav1.NewControllerRef(tfJob, controllerKind)},
Labels: map[string]string{},
Namespace: job.GetNamespace(),
OwnerReferences: refs,
},
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: DummyContainerName,
Image: DummyContainerImage,
},
},
},
}
}

func NewPod(tfJob *tfv1.TFJob, typ string, index int) *v1.Pod {
pod := NewBasePod(fmt.Sprintf("%s-%d", typ, index), tfJob)
pod.Labels[tfReplicaTypeLabel] = typ
pod.Labels[tfReplicaIndexLabel] = fmt.Sprintf("%d", index)
func NewPod(job metav1.Object, typ string, index int, refs []metav1.OwnerReference) *corev1.Pod {
pod := NewBasePod(fmt.Sprintf("%s-%s-%d", job.GetName(), typ, index), job, refs)
pod.Labels[commonv1.ReplicaTypeLabelDeprecated] = typ
pod.Labels[commonv1.ReplicaTypeLabel] = typ
pod.Labels[commonv1.ReplicaIndexLabelDeprecated] = fmt.Sprintf("%d", index)
pod.Labels[commonv1.ReplicaIndexLabel] = fmt.Sprintf("%d", index)
return pod
}

// create count pods with the given phase for the given tfJob
func NewPodList(count int32, status v1.PodPhase, tfJob *tfv1.TFJob, typ string, start int32) []*v1.Pod {
pods := []*v1.Pod{}
// NewPodList create count pods with the given phase for the given tfJob
func NewPodList(count int32, status corev1.PodPhase, job metav1.Object, typ string, start int32, refs []metav1.OwnerReference) []*corev1.Pod {
pods := []*corev1.Pod{}
for i := int32(0); i < count; i++ {
newPod := NewPod(tfJob, typ, int(start+i))
newPod.Status = v1.PodStatus{Phase: status}
newPod := NewPod(job, typ, int(start+i), refs)
newPod.Status = corev1.PodStatus{Phase: status}
pods = append(pods, newPod)
}
return pods
}

func SetPodsStatuses(podIndexer cache.Indexer, tfJob *tfv1.TFJob, typ string, pendingPods, activePods, succeededPods, failedPods int32, restartCounts []int32, t *testing.T) {
func SetPodsStatuses(client client.Client, job metav1.Object, typ string,
pendingPods, activePods, succeededPods, failedPods int32, restartCounts []int32,
refs []metav1.OwnerReference, basicLabels map[string]string) {
timeout := 10 * time.Second
interval := 1000 * time.Millisecond
var index int32
for _, pod := range NewPodList(pendingPods, v1.PodPending, tfJob, typ, index) {
if err := podIndexer.Add(pod); err != nil {
t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err)
}
}
index += pendingPods
for i, pod := range NewPodList(activePods, v1.PodRunning, tfJob, typ, index) {
if restartCounts != nil {
pod.Status.ContainerStatuses = []v1.ContainerStatus{{RestartCount: restartCounts[i]}}
}
if err := podIndexer.Add(pod); err != nil {
t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err)
}
taskMap := map[corev1.PodPhase]int32{
corev1.PodFailed: failedPods,
corev1.PodPending: pendingPods,
corev1.PodRunning: activePods,
corev1.PodSucceeded: succeededPods,
}
index += activePods
for _, pod := range NewPodList(succeededPods, v1.PodSucceeded, tfJob, typ, index) {
if err := podIndexer.Add(pod); err != nil {
t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err)
}
}
index += succeededPods
for _, pod := range NewPodList(failedPods, v1.PodFailed, tfJob, typ, index) {
if err := podIndexer.Add(pod); err != nil {
t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err)
ctx := context.Background()

for podPhase, desiredCount := range taskMap {
for i, pod := range NewPodList(desiredCount, podPhase, job, typ, index, refs) {
for k, v := range basicLabels {
pod.Labels[k] = v
}
_ = client.Create(ctx, pod)
launcherKey := types.NamespacedName{
Namespace: metav1.NamespaceDefault,
Name: pod.GetName(),
}
Eventually(func() error {
po := &corev1.Pod{}
if err := client.Get(ctx, launcherKey, po); err != nil {
return err
}
po.Status.Phase = podPhase
if podPhase == corev1.PodRunning && restartCounts != nil {
po.Status.ContainerStatuses = []corev1.ContainerStatus{{RestartCount: restartCounts[i]}}
}
return client.Status().Update(ctx, po)
}, timeout, interval).Should(BeNil())
}
index += desiredCount
}
}
66 changes: 45 additions & 21 deletions pkg/common/util/v1/testutil/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,48 +15,72 @@
package testutil

import (
"context"
"fmt"
"testing"

v1 "k8s.io/api/core/v1"
commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
. "github.com/onsi/gomega"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/tools/cache"
"sigs.k8s.io/controller-runtime/pkg/client"
)

tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1"
const (
DummyPortName string = "dummy"
DummyPort int32 = 1221
)

func NewBaseService(name string, tfJob *tfv1.TFJob, t *testing.T) *v1.Service {
return &v1.Service{
func NewBaseService(name string, job metav1.Object, refs []metav1.OwnerReference) *corev1.Service {
return &corev1.Service{
ObjectMeta: metav1.ObjectMeta{
Name: name,
Labels: GenLabels(tfJob.Name),
Namespace: tfJob.Namespace,
OwnerReferences: []metav1.OwnerReference{*metav1.NewControllerRef(tfJob, controllerKind)},
Labels: map[string]string{},
Namespace: job.GetNamespace(),
OwnerReferences: refs,
},
Spec: corev1.ServiceSpec{
Ports: []corev1.ServicePort{
{
Name: DummyPortName,
Port: DummyPort,
},
},
},
}
}

func NewService(tfJob *tfv1.TFJob, typ string, index int, t *testing.T) *v1.Service {
service := NewBaseService(fmt.Sprintf("%s-%d", typ, index), tfJob, t)
service.Labels[tfReplicaTypeLabel] = typ
service.Labels[tfReplicaIndexLabel] = fmt.Sprintf("%d", index)
return service
func NewService(job metav1.Object, typ string, index int, refs []metav1.OwnerReference) *corev1.Service {
svc := NewBaseService(fmt.Sprintf("%s-%s-%d", job.GetName(), typ, index), job, refs)
svc.Labels[commonv1.ReplicaTypeLabelDeprecated] = typ
svc.Labels[commonv1.ReplicaTypeLabel] = typ
svc.Labels[commonv1.ReplicaIndexLabelDeprecated] = fmt.Sprintf("%d", index)
svc.Labels[commonv1.ReplicaIndexLabel] = fmt.Sprintf("%d", index)
return svc
}

// NewServiceList creates count pods with the given phase for the given tfJob
func NewServiceList(count int32, tfJob *tfv1.TFJob, typ string, t *testing.T) []*v1.Service {
services := []*v1.Service{}
func NewServiceList(count int32, job metav1.Object, typ string, refs []metav1.OwnerReference) []*corev1.Service {
services := []*corev1.Service{}
for i := int32(0); i < count; i++ {
newService := NewService(tfJob, typ, int(i), t)
newService := NewService(job, typ, int(i), refs)
services = append(services, newService)
}
return services
}

func SetServices(serviceIndexer cache.Indexer, tfJob *tfv1.TFJob, typ string, activeWorkerServices int32, t *testing.T) {
for _, service := range NewServiceList(activeWorkerServices, tfJob, typ, t) {
if err := serviceIndexer.Add(service); err != nil {
t.Errorf("unexpected error when adding service %v", err)
func SetServicesV2(client client.Client, job metav1.Object, typ string, activeWorkerServices int32,
refs []metav1.OwnerReference, basicLabels map[string]string) {
ctx := context.Background()
for _, svc := range NewServiceList(activeWorkerServices, job, typ, refs) {
for k, v := range basicLabels {
svc.Labels[k] = v
}
err := client.Create(ctx, svc)
if errors.IsAlreadyExists(err) {
return
} else {
Expect(err).To(BeNil())
}
}
}
2 changes: 1 addition & 1 deletion pkg/common/util/v1/testutil/tfjob.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ package testutil
import (
"time"

commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1"
)

Expand Down
26 changes: 9 additions & 17 deletions pkg/common/util/v1/testutil/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
package testutil

import (
"strings"
"testing"

common "github.com/kubeflow/common/pkg/apis/common/v1"
tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1"
commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/tools/cache"

tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1"
)

const (
Expand All @@ -43,21 +43,13 @@ var (
ControllerName = "training-operator"
)

func GenLabels(jobName string) map[string]string {
return map[string]string{
LabelGroupName: GroupName,
JobNameLabel: strings.Replace(jobName, "/", "-", -1),
DeprecatedLabelTFJobName: strings.Replace(jobName, "/", "-", -1),
}
}

func GenOwnerReference(tfjob *tfv1.TFJob) *metav1.OwnerReference {
func GenOwnerReference(job metav1.Object, apiVersion string, kind string) *metav1.OwnerReference {
boolPtr := func(b bool) *bool { return &b }
controllerRef := &metav1.OwnerReference{
APIVersion: tfv1.GroupVersion.Version,
Kind: TFJobKind,
Name: tfjob.Name,
UID: tfjob.UID,
APIVersion: apiVersion,
Kind: kind,
Name: job.GetName(),
UID: job.GetUID(),
BlockOwnerDeletion: boolPtr(true),
Controller: boolPtr(true),
}
Expand Down Expand Up @@ -85,7 +77,7 @@ func GetKey(tfJob *tfv1.TFJob, t *testing.T) string {
return key
}

func CheckCondition(tfJob *tfv1.TFJob, condition common.JobConditionType, reason string) bool {
func CheckCondition(tfJob *tfv1.TFJob, condition commonv1.JobConditionType, reason string) bool {
for _, v := range tfJob.Status.Conditions {
if v.Type == condition && v.Status == v1.ConditionTrue && v.Reason == reason {
return true
Expand Down
Loading