diff --git a/internal/test/mock/k8s/client.go b/internal/test/mock/k8s/client.go new file mode 100644 index 0000000000..04abcea2a5 --- /dev/null +++ b/internal/test/mock/k8s/client.go @@ -0,0 +1,59 @@ +package k8s + +import ( + "context" + + "github.com/stretchr/testify/mock" + "github.com/vdaas/vald/internal/k8s/client" + + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/selection" + "k8s.io/apimachinery/pkg/watch" + crclient "sigs.k8s.io/controller-runtime/pkg/client" +) + +type ValdK8sClientMock struct { + mock.Mock +} + +var _ client.Client = (*ValdK8sClientMock)(nil) + +func (m *ValdK8sClientMock) Get(ctx context.Context, name string, namespace string, obj client.Object, opts ...crclient.GetOption) error { + args := m.Called(ctx, name, namespace, obj, opts) + return args.Error(0) +} + +func (m *ValdK8sClientMock) List(ctx context.Context, list crclient.ObjectList, opts ...client.ListOption) error { + args := m.Called(ctx, list, opts) + return args.Error(0) +} + +func (m *ValdK8sClientMock) Create(ctx context.Context, obj client.Object, opts ...client.CreateOption) error { + args := m.Called(ctx, obj, opts) + return args.Error(0) +} + +func (m *ValdK8sClientMock) Delete(ctx context.Context, obj client.Object, opts ...crclient.DeleteOption) error { + args := m.Called(ctx, obj, opts) + return args.Error(0) +} + +func (m *ValdK8sClientMock) Update(ctx context.Context, obj client.Object, opts ...crclient.UpdateOption) error { + args := m.Called(ctx, obj, opts) + return args.Error(0) +} + +func (m *ValdK8sClientMock) Patch(ctx context.Context, obj client.Object, patch crclient.Patch, opts ...crclient.PatchOption) error { + args := m.Called(ctx, obj, patch, opts) + return args.Error(0) +} + +func (m *ValdK8sClientMock) Watch(ctx context.Context, obj crclient.ObjectList, opts ...client.ListOption) (watch.Interface, error) { + args := m.Called(ctx, obj, opts) + return args.Get(0).(watch.Interface), args.Error(1) +} + +func (m *ValdK8sClientMock) LabelSelector(key string, op selection.Operator, vals []string) (labels.Selector, error) { + args := m.Called(key, op, vals) + return args.Get(0).(labels.Selector), args.Error(1) +} diff --git a/pkg/index/job/readreplica/rotate/service/rotator_test.go b/pkg/index/job/readreplica/rotate/service/rotator_test.go index 8ed40cb13a..f47e479a57 100644 --- a/pkg/index/job/readreplica/rotate/service/rotator_test.go +++ b/pkg/index/job/readreplica/rotate/service/rotator_test.go @@ -16,9 +16,15 @@ package service import ( "testing" + tmock "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/vdaas/vald/internal/errors" "github.com/vdaas/vald/internal/k8s/client" + "github.com/vdaas/vald/internal/test/mock/k8s" + + appsv1 "k8s.io/api/apps/v1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" ) func Test_getNewBaseName(t *testing.T) { @@ -82,6 +88,7 @@ func Test_getNewBaseName(t *testing.T) { } func Test_parseReplicaID(t *testing.T) { + labelKey := "foo" type args struct { replicaID string c client.Client @@ -90,11 +97,12 @@ func Test_parseReplicaID(t *testing.T) { ids []string err error } - tests := []struct { + type test struct { name string args args want want - }{ + } + tests := []test{ { name: "single replicaID", args: args{ @@ -128,12 +136,51 @@ func Test_parseReplicaID(t *testing.T) { err: errors.ErrReadReplicaIDEmpty, }, }, + func() test { + wantId1 := "bar" + wantId2 := "baz" + mock := &k8s.ValdK8sClientMock{} + mock.On("LabelSelector", tmock.Anything, tmock.Anything, tmock.Anything).Return(labels.NewSelector(), nil) + mock.On("List", tmock.Anything, tmock.Anything, tmock.Anything).Run(func(args tmock.Arguments) { + if depList, ok := args.Get(1).(*appsv1.DeploymentList); ok { + depList.Items = []appsv1.Deployment{ + { + ObjectMeta: v1.ObjectMeta{ + Labels: map[string]string{ + labelKey: wantId1, + }, + }, + }, + { + ObjectMeta: v1.ObjectMeta{ + Labels: map[string]string{ + labelKey: wantId2, + }, + }, + }, + } + } + }).Return(nil) + return test{ + name: "returns all ids when rotate-all option is set", + args: args{ + replicaID: rotateAllID, + c: mock, + }, + want: want{ + ids: []string{wantId1, wantId2}, + err: nil, + }, + } + }(), } for _, test := range tests { tt := test t.Run(tt.name, func(t *testing.T) { t.Parallel() - r := &rotator{} + r := &rotator{ + readReplicaLabelKey: labelKey, + } ids, err := r.parseReplicaID(tt.args.replicaID, tt.args.c) require.Equal(t, tt.want.ids, ids) require.Equal(t, tt.want.err, err)