diff --git a/common/persistence/nosql/nosql_task_store_test.go b/common/persistence/nosql/nosql_task_store_test.go index ffa453d7b81..fe6182f27df 100644 --- a/common/persistence/nosql/nosql_task_store_test.go +++ b/common/persistence/nosql/nosql_task_store_test.go @@ -23,11 +23,22 @@ package nosql import ( + ctx "context" "testing" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/uber/cadence/common/log" + "github.com/uber/cadence/common/persistence" + "github.com/uber/cadence/common/persistence/nosql/nosqlplugin" + "github.com/uber/cadence/common/types" +) + +const ( + TestDomainID = "test-domain-id" + TestDomainName = "test-domain" + TestTaskListName = "test-tasklist" ) func TestNewNoSQLStore(t *testing.T) { @@ -38,3 +49,58 @@ func TestNewNoSQLStore(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, store) } + +func setupNoSQLStoreMocks(t *testing.T) (*nosqlTaskStore, *nosqlplugin.MockDB) { + ctrl := gomock.NewController(t) + dbMock := nosqlplugin.NewMockDB(ctrl) + + nosqlSt := nosqlStore{ + logger: log.NewNoop(), + db: dbMock, + } + + shardedNosqlStoreMock := NewMockshardedNosqlStore(ctrl) + shardedNosqlStoreMock.EXPECT(). + GetStoreShardByTaskList( + TestDomainID, + TestTaskListName, + int(types.TaskListTypeDecision)). + Return(&nosqlSt, nil). + AnyTimes() + + store := &nosqlTaskStore{ + shardedNosqlStore: shardedNosqlStoreMock, + } + + return store, dbMock +} + +func TestGetTaskListSize(t *testing.T) { + store, db := setupNoSQLStoreMocks(t) + + db.EXPECT().GetTasksCount( + gomock.Any(), + &nosqlplugin.TasksFilter{ + TaskListFilter: nosqlplugin.TaskListFilter{ + DomainID: TestDomainID, + TaskListName: TestTaskListName, + TaskListType: int(types.TaskListTypeDecision), + }, + MinTaskID: 456, + }, + ).Return(int64(123), nil) + + size, err := store.GetTaskListSize(ctx.Background(), &persistence.GetTaskListSizeRequest{ + DomainID: TestDomainID, + DomainName: TestDomainName, + TaskListName: TestTaskListName, + TaskListType: int(types.TaskListTypeDecision), + AckLevel: 456, + }) + + assert.NoError(t, err) + assert.Equal(t, + &persistence.GetTaskListSizeResponse{Size: 123}, + size, + ) +}