From c98a6d27e61cd5b3bba2e1423804860fe264937a Mon Sep 17 00:00:00 2001 From: Zijian Date: Tue, 4 Jun 2024 23:25:06 +0000 Subject: [PATCH] Add unit tests for matching engine --- service/matching/handler/engine.go | 38 +- service/matching/handler/engine_test.go | 619 ++++++++++++++++++ service/matching/tasklist/interfaces.go | 60 ++ service/matching/tasklist/interfaces_mock.go | 241 +++++++ .../matching/tasklist/task_list_manager.go | 29 - 5 files changed, 937 insertions(+), 50 deletions(-) create mode 100644 service/matching/handler/engine_test.go create mode 100644 service/matching/tasklist/interfaces.go create mode 100644 service/matching/tasklist/interfaces_mock.go diff --git a/service/matching/handler/engine.go b/service/matching/handler/engine.go index a09c8650708..fd5fd51b170 100644 --- a/service/matching/handler/engine.go +++ b/service/matching/handler/engine.go @@ -92,6 +92,8 @@ type ( membershipResolver membership.Resolver partitioner partition.Partitioner timeSource clock.TimeSource + + waitForQueryResultFn func(hCtx *handlerContext, isStrongConsistencyQuery bool, queryResultCh <-chan *queryResult) (*types.QueryWorkflowResponse, error) } // HistoryInfo consists of two integer regarding the history size and history count @@ -128,7 +130,7 @@ func NewEngine(taskManager persistence.TaskManager, partitioner partition.Partitioner, timeSource clock.TimeSource, ) Engine { - return &matchingEngineImpl{ + e := &matchingEngineImpl{ taskManager: taskManager, clusterMetadata: clusterMetadata, historyService: historyService, @@ -145,6 +147,8 @@ func NewEngine(taskManager persistence.TaskManager, partitioner: partitioner, timeSource: timeSource, } + e.waitForQueryResultFn = e.waitForQueryResult + return e } func (e *matchingEngineImpl) Start() { @@ -249,15 +253,14 @@ func (e *matchingEngineImpl) getTaskListByDomainLocked(domainID string) *types.G decisionTaskListMap := make(map[string]*types.DescribeTaskListResponse) activityTaskListMap := make(map[string]*types.DescribeTaskListResponse) for tl, tlm := range e.taskLists { - if tlm.GetTaskListKind() == types.TaskListKindNormal && tl.GetDomainID() == domainID { + if tl.GetDomainID() == domainID && tlm.GetTaskListKind() == types.TaskListKindNormal { if types.TaskListType(tl.GetType()) == types.TaskListTypeDecision { decisionTaskListMap[tl.GetRoot()] = tlm.DescribeTaskList(false) + } else { + activityTaskListMap[tl.GetRoot()] = tlm.DescribeTaskList(false) } - // TODO: review this logic - activityTaskListMap[tl.GetRoot()] = tlm.DescribeTaskList(false) } } - return &types.GetTaskListsByDomainResponse{ DecisionTaskListMap: decisionTaskListMap, ActivityTaskListMap: activityTaskListMap, @@ -712,23 +715,24 @@ func (e *matchingEngineImpl) QueryWorkflow( queryResultCh := make(chan *queryResult, 1) e.lockableQueryTaskMap.put(taskID, queryResultCh) defer e.lockableQueryTaskMap.delete(taskID) + return e.waitForQueryResultFn(hCtx, queryRequest.GetQueryRequest().GetQueryConsistencyLevel() == types.QueryConsistencyLevelStrong, queryResultCh) +} +func (e *matchingEngineImpl) waitForQueryResult(hCtx *handlerContext, isStrongConsistencyQuery bool, queryResultCh <-chan *queryResult) (*types.QueryWorkflowResponse, error) { select { case result := <-queryResultCh: if result.internalError != nil { return nil, result.internalError } - workerResponse := result.workerResponse // if query was intended as consistent query check to see if worker supports consistent query - if queryRequest.GetQueryRequest().GetQueryConsistencyLevel() == types.QueryConsistencyLevelStrong { + if isStrongConsistencyQuery { if err := e.versionChecker.SupportsConsistentQuery( workerResponse.GetCompletedRequest().GetWorkerVersionInfo().GetImpl(), workerResponse.GetCompletedRequest().GetWorkerVersionInfo().GetFeatureVersion()); err != nil { return nil, err } } - switch workerResponse.GetCompletedRequest().GetCompletedType() { case types.QueryTaskCompletedTypeCompleted: return &types.QueryWorkflowResponse{QueryResult: workerResponse.GetCompletedRequest().GetQueryResult()}, nil @@ -878,30 +882,22 @@ func (e *matchingEngineImpl) getAllPartitions( request *types.MatchingListTaskListPartitionsRequest, taskListType int, ) ([]string, error) { - var partitionKeys []string domainID, err := e.domainCache.GetDomainID(request.GetDomain()) if err != nil { - return partitionKeys, err + return nil, err } taskList := request.GetTaskList() taskListID, err := tasklist.NewIdentifier(domainID, taskList.GetName(), taskListType) if err != nil { - return partitionKeys, err - } - rootPartition := taskListID.GetRoot() - - partitionKeys = append(partitionKeys, rootPartition) - - nWritePartitions := e.config.NumTasklistWritePartitions - n := nWritePartitions(request.GetDomain(), rootPartition, taskListType) - if n <= 0 { - return partitionKeys, nil + return nil, err } + rootPartition := taskListID.GetRoot() + partitionKeys := []string{rootPartition} + n := e.config.NumTasklistWritePartitions(request.GetDomain(), rootPartition, taskListType) for i := 1; i < n; i++ { partitionKeys = append(partitionKeys, fmt.Sprintf("%v%v/%v", common.ReservedTaskListPrefix, rootPartition, i)) } - return partitionKeys, nil } diff --git a/service/matching/handler/engine_test.go b/service/matching/handler/engine_test.go new file mode 100644 index 00000000000..87c84088362 --- /dev/null +++ b/service/matching/handler/engine_test.go @@ -0,0 +1,619 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package handler + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/uber/cadence/common/cache" + "github.com/uber/cadence/common/client" + "github.com/uber/cadence/common/clock" + "github.com/uber/cadence/common/dynamicconfig" + "github.com/uber/cadence/common/membership" + "github.com/uber/cadence/common/metrics" + "github.com/uber/cadence/common/types" + "github.com/uber/cadence/service/matching/config" + "github.com/uber/cadence/service/matching/tasklist" +) + +func TestGetTaskListsByDomain(t *testing.T) { + testCases := []struct { + name string + mockSetup func(*cache.MockDomainCache, map[tasklist.Identifier]*tasklist.MockManager, map[tasklist.Identifier]*tasklist.MockManager) + wantErr bool + want *types.GetTaskListsByDomainResponse + }{ + { + name: "domain cache error", + mockSetup: func(mockDomainCache *cache.MockDomainCache, mockTaskListManagers map[tasklist.Identifier]*tasklist.MockManager, mockStickyManagers map[tasklist.Identifier]*tasklist.MockManager) { + mockDomainCache.EXPECT().GetDomainID("test-domain").Return("", errors.New("cache failure")) + }, + wantErr: true, + }, + { + name: "success", + mockSetup: func(mockDomainCache *cache.MockDomainCache, mockTaskListManagers map[tasklist.Identifier]*tasklist.MockManager, mockStickyManagers map[tasklist.Identifier]*tasklist.MockManager) { + mockDomainCache.EXPECT().GetDomainID("test-domain").Return("test-domain-id", nil) + for id, mockManager := range mockTaskListManagers { + if id.GetDomainID() == "test-domain-id" { + mockManager.EXPECT().GetTaskListKind().Return(types.TaskListKindNormal) + mockManager.EXPECT().DescribeTaskList(false).Return(&types.DescribeTaskListResponse{ + Pollers: []*types.PollerInfo{ + { + Identity: fmt.Sprintf("test-poller-%s", id.GetRoot()), + }, + }, + }) + } + } + for id, mockManager := range mockStickyManagers { + if id.GetDomainID() == "test-domain-id" { + mockManager.EXPECT().GetTaskListKind().Return(types.TaskListKindSticky) + } + } + }, + wantErr: false, + want: &types.GetTaskListsByDomainResponse{ + DecisionTaskListMap: map[string]*types.DescribeTaskListResponse{ + "decision0": { + Pollers: []*types.PollerInfo{ + { + Identity: "test-poller-decision0", + }, + }, + }, + }, + ActivityTaskListMap: map[string]*types.DescribeTaskListResponse{ + "activity0": { + Pollers: []*types.PollerInfo{ + { + Identity: "test-poller-activity0", + }, + }, + }, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockDomainCache := cache.NewMockDomainCache(mockCtrl) + decisionTasklistID, err := tasklist.NewIdentifier("test-domain-id", "decision0", 0) + require.NoError(t, err) + activityTasklistID, err := tasklist.NewIdentifier("test-domain-id", "activity0", 1) + require.NoError(t, err) + otherDomainTasklistID, err := tasklist.NewIdentifier("other-domain-id", "other0", 0) + require.NoError(t, err) + mockDecisionTaskListManager := tasklist.NewMockManager(mockCtrl) + mockActivityTaskListManager := tasklist.NewMockManager(mockCtrl) + mockOtherDomainTaskListManager := tasklist.NewMockManager(mockCtrl) + mockTaskListManagers := map[tasklist.Identifier]*tasklist.MockManager{ + *decisionTasklistID: mockDecisionTaskListManager, + *activityTasklistID: mockActivityTaskListManager, + *otherDomainTasklistID: mockOtherDomainTaskListManager, + } + stickyTasklistID, err := tasklist.NewIdentifier("test-domain-id", "sticky0", 0) + require.NoError(t, err) + mockStickyManager := tasklist.NewMockManager(mockCtrl) + mockStickyManagers := map[tasklist.Identifier]*tasklist.MockManager{ + *stickyTasklistID: mockStickyManager, + } + tc.mockSetup(mockDomainCache, mockTaskListManagers, mockStickyManagers) + + engine := &matchingEngineImpl{ + domainCache: mockDomainCache, + taskLists: map[tasklist.Identifier]tasklist.Manager{ + *decisionTasklistID: mockDecisionTaskListManager, + *activityTasklistID: mockActivityTaskListManager, + *otherDomainTasklistID: mockOtherDomainTaskListManager, + *stickyTasklistID: mockStickyManager, + }, + } + resp, err := engine.GetTaskListsByDomain(nil, &types.GetTaskListsByDomainRequest{Domain: "test-domain"}) + + if tc.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tc.want, resp) + } + }) + } +} + +func TestListTaskListPartitions(t *testing.T) { + testCases := []struct { + name string + req *types.MatchingListTaskListPartitionsRequest + mockSetup func(*cache.MockDomainCache, *membership.MockResolver) + wantErr bool + want *types.ListTaskListPartitionsResponse + }{ + { + name: "domain cache error", + req: &types.MatchingListTaskListPartitionsRequest{ + Domain: "test-domain", + TaskList: &types.TaskList{ + Name: "test-tasklist", + }, + }, + mockSetup: func(mockDomainCache *cache.MockDomainCache, mockResolver *membership.MockResolver) { + mockDomainCache.EXPECT().GetDomainID("test-domain").Return("", errors.New("cache failure")) + }, + wantErr: true, + }, + { + name: "invalid tasklist name", + req: &types.MatchingListTaskListPartitionsRequest{ + Domain: "test-domain", + TaskList: &types.TaskList{ + Name: "/__cadence_sys/invalid-tasklist-name", + }, + }, + mockSetup: func(mockDomainCache *cache.MockDomainCache, mockResolver *membership.MockResolver) { + mockDomainCache.EXPECT().GetDomainID("test-domain").Return("test-domain-id", nil) + }, + wantErr: true, + }, + { + name: "success", + req: &types.MatchingListTaskListPartitionsRequest{ + Domain: "test-domain", + TaskList: &types.TaskList{ + Name: "test-tasklist", + }, + }, + mockSetup: func(mockDomainCache *cache.MockDomainCache, mockResolver *membership.MockResolver) { + // activity tasklist + mockDomainCache.EXPECT().GetDomainID("test-domain").Return("test-domain-id", nil) + mockResolver.EXPECT().Lookup(gomock.Any(), "test-tasklist").Return(membership.NewHostInfo("addr2"), nil) + mockResolver.EXPECT().Lookup(gomock.Any(), "/__cadence_sys/test-tasklist/1").Return(membership.HostInfo{}, errors.New("some error")) + mockResolver.EXPECT().Lookup(gomock.Any(), "/__cadence_sys/test-tasklist/2").Return(membership.NewHostInfo("addr3"), nil) + // decision tasklist + mockDomainCache.EXPECT().GetDomainID("test-domain").Return("test-domain-id", nil) + mockResolver.EXPECT().Lookup(gomock.Any(), "test-tasklist").Return(membership.NewHostInfo("addr0"), nil) + mockResolver.EXPECT().Lookup(gomock.Any(), "/__cadence_sys/test-tasklist/1").Return(membership.HostInfo{}, errors.New("some error")) + mockResolver.EXPECT().Lookup(gomock.Any(), "/__cadence_sys/test-tasklist/2").Return(membership.NewHostInfo("addr1"), nil) + }, + wantErr: false, + want: &types.ListTaskListPartitionsResponse{ + DecisionTaskListPartitions: []*types.TaskListPartitionMetadata{ + { + Key: "test-tasklist", + OwnerHostName: "addr0", + }, + { + Key: "/__cadence_sys/test-tasklist/1", + OwnerHostName: "", + }, + { + Key: "/__cadence_sys/test-tasklist/2", + OwnerHostName: "addr1", + }, + }, + ActivityTaskListPartitions: []*types.TaskListPartitionMetadata{ + { + Key: "test-tasklist", + OwnerHostName: "addr2", + }, + { + Key: "/__cadence_sys/test-tasklist/1", + OwnerHostName: "", + }, + { + Key: "/__cadence_sys/test-tasklist/2", + OwnerHostName: "addr3", + }, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockDomainCache := cache.NewMockDomainCache(mockCtrl) + mockResolver := membership.NewMockResolver(mockCtrl) + tc.mockSetup(mockDomainCache, mockResolver) + + engine := &matchingEngineImpl{ + domainCache: mockDomainCache, + membershipResolver: mockResolver, + config: &config.Config{ + NumTasklistWritePartitions: dynamicconfig.GetIntPropertyFilteredByTaskListInfo(3), + }, + } + resp, err := engine.ListTaskListPartitions(nil, tc.req) + if tc.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tc.want, resp) + } + }) + } +} + +func TestCancelOutstandingPoll(t *testing.T) { + testCases := []struct { + name string + req *types.CancelOutstandingPollRequest + mockSetup func(*tasklist.MockManager) + wantErr bool + }{ + { + name: "invalid tasklist name", + req: &types.CancelOutstandingPollRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "/__cadence_sys/invalid-tasklist-name", + }, + PollerID: "test-poller-id", + }, + mockSetup: func(mockManager *tasklist.MockManager) { + }, + wantErr: true, + }, + { + name: "success", + req: &types.CancelOutstandingPollRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "test-tasklist", + }, + PollerID: "test-poller-id", + }, + mockSetup: func(mockManager *tasklist.MockManager) { + mockManager.EXPECT().CancelPoller("test-poller-id") + }, + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockManager := tasklist.NewMockManager(mockCtrl) + tc.mockSetup(mockManager) + tasklistID, err := tasklist.NewIdentifier("test-domain-id", "test-tasklist", 0) + require.NoError(t, err) + engine := &matchingEngineImpl{ + taskLists: map[tasklist.Identifier]tasklist.Manager{ + *tasklistID: mockManager, + }, + } + err = engine.CancelOutstandingPoll(nil, tc.req) + if tc.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestRespondQueryTaskCompleted(t *testing.T) { + testCases := []struct { + name string + req *types.MatchingRespondQueryTaskCompletedRequest + queryTaskMap map[string]chan *queryResult + wantErr bool + }{ + { + name: "success", + req: &types.MatchingRespondQueryTaskCompletedRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "test-tasklist", + }, + TaskID: "id-0", + }, + queryTaskMap: map[string]chan *queryResult{ + "id-0": make(chan *queryResult, 1), + }, + wantErr: false, + }, + { + name: "query task not found", + req: &types.MatchingRespondQueryTaskCompletedRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "test-tasklist", + }, + TaskID: "id-0", + }, + queryTaskMap: map[string]chan *queryResult{}, + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + engine := &matchingEngineImpl{ + lockableQueryTaskMap: lockableQueryTaskMap{ + queryTaskMap: tc.queryTaskMap, + }, + } + err := engine.RespondQueryTaskCompleted(&handlerContext{scope: metrics.NewNoopMetricsClient().Scope(0)}, tc.req) + if tc.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestQueryWorkflow(t *testing.T) { + testCases := []struct { + name string + req *types.MatchingQueryWorkflowRequest + hCtx *handlerContext + mockSetup func(*tasklist.MockManager) + waitForQueryResultFn func(hCtx *handlerContext, isStrongConsistencyQuery bool, queryResultCh <-chan *queryResult) (*types.QueryWorkflowResponse, error) + wantErr bool + want *types.QueryWorkflowResponse + }{ + { + name: "invalid tasklist name", + req: &types.MatchingQueryWorkflowRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "/__cadence_sys/invalid-tasklist-name", + }, + }, + mockSetup: func(mockManager *tasklist.MockManager) {}, + wantErr: true, + }, + { + name: "sticky worker unavailable", + req: &types.MatchingQueryWorkflowRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "test-tasklist", + Kind: types.TaskListKindSticky.Ptr(), + }, + }, + mockSetup: func(mockManager *tasklist.MockManager) { + mockManager.EXPECT().HasPollerAfter(gomock.Any()).Return(false) + }, + wantErr: true, + }, + { + name: "failed to dispatch query task", + req: &types.MatchingQueryWorkflowRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "test-tasklist", + }, + }, + hCtx: &handlerContext{ + Context: context.Background(), + }, + mockSetup: func(mockManager *tasklist.MockManager) { + mockManager.EXPECT().DispatchQueryTask(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("some error")) + }, + wantErr: true, + }, + { + name: "success", + req: &types.MatchingQueryWorkflowRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "test-tasklist", + }, + }, + hCtx: &handlerContext{ + Context: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx + }(), + }, + mockSetup: func(mockManager *tasklist.MockManager) { + mockManager.EXPECT().DispatchQueryTask(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + }, + waitForQueryResultFn: func(hCtx *handlerContext, isStrongConsistencyQuery bool, queryResultCh <-chan *queryResult) (*types.QueryWorkflowResponse, error) { + return &types.QueryWorkflowResponse{ + QueryResult: []byte("some result"), + }, nil + }, + wantErr: false, + want: &types.QueryWorkflowResponse{ + QueryResult: []byte("some result"), + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockManager := tasklist.NewMockManager(mockCtrl) + tc.mockSetup(mockManager) + tasklistID, err := tasklist.NewIdentifier("test-domain-id", "test-tasklist", 0) + require.NoError(t, err) + engine := &matchingEngineImpl{ + taskLists: map[tasklist.Identifier]tasklist.Manager{ + *tasklistID: mockManager, + }, + timeSource: clock.NewRealTimeSource(), + lockableQueryTaskMap: lockableQueryTaskMap{queryTaskMap: make(map[string]chan *queryResult)}, + waitForQueryResultFn: tc.waitForQueryResultFn, + } + resp, err := engine.QueryWorkflow(tc.hCtx, tc.req) + if tc.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tc.want, resp) + } + }) + } +} + +func TestWaitForQueryResult(t *testing.T) { + testCases := []struct { + name string + result *queryResult + mockSetup func(*client.VersionCheckerMock) + wantErr bool + assertErr func(*testing.T, error) + want *types.QueryWorkflowResponse + }{ + { + name: "internal error", + result: &queryResult{ + internalError: errors.New("some error"), + }, + mockSetup: func(mockVersionChecker *client.VersionCheckerMock) {}, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, "some error", err.Error()) + }, + wantErr: true, + }, + { + name: "strong consistency query not supported", + result: &queryResult{ + workerResponse: &types.MatchingRespondQueryTaskCompletedRequest{ + CompletedRequest: &types.RespondQueryTaskCompletedRequest{ + WorkerVersionInfo: &types.WorkerVersionInfo{ + Impl: "uber-go", + FeatureVersion: "1.0.0", + }, + }, + }, + }, + mockSetup: func(mockVersionChecker *client.VersionCheckerMock) { + mockVersionChecker.EXPECT().SupportsConsistentQuery("uber-go", "1.0.0").Return(errors.New("version error")) + }, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, "version error", err.Error()) + }, + wantErr: true, + }, + { + name: "success - query task completed", + result: &queryResult{ + workerResponse: &types.MatchingRespondQueryTaskCompletedRequest{ + CompletedRequest: &types.RespondQueryTaskCompletedRequest{ + WorkerVersionInfo: &types.WorkerVersionInfo{ + Impl: "uber-go", + FeatureVersion: "1.0.0", + }, + CompletedType: types.QueryTaskCompletedTypeCompleted.Ptr(), + QueryResult: []byte("some result"), + }, + }, + }, + mockSetup: func(mockVersionChecker *client.VersionCheckerMock) { + mockVersionChecker.EXPECT().SupportsConsistentQuery("uber-go", "1.0.0").Return(nil) + }, + wantErr: false, + want: &types.QueryWorkflowResponse{ + QueryResult: []byte("some result"), + }, + }, + { + name: "query task failed", + result: &queryResult{ + workerResponse: &types.MatchingRespondQueryTaskCompletedRequest{ + CompletedRequest: &types.RespondQueryTaskCompletedRequest{ + WorkerVersionInfo: &types.WorkerVersionInfo{ + Impl: "uber-go", + FeatureVersion: "1.0.0", + }, + CompletedType: types.QueryTaskCompletedTypeFailed.Ptr(), + ErrorMessage: "query failed", + }, + }, + }, + mockSetup: func(mockVersionChecker *client.VersionCheckerMock) { + mockVersionChecker.EXPECT().SupportsConsistentQuery("uber-go", "1.0.0").Return(nil) + }, + assertErr: func(t *testing.T, err error) { + var e *types.QueryFailedError + assert.ErrorAs(t, err, &e) + assert.Equal(t, "query failed", e.Message) + }, + wantErr: true, + }, + { + name: "unknown query result", + result: &queryResult{ + workerResponse: &types.MatchingRespondQueryTaskCompletedRequest{ + CompletedRequest: &types.RespondQueryTaskCompletedRequest{ + WorkerVersionInfo: &types.WorkerVersionInfo{ + Impl: "uber-go", + FeatureVersion: "1.0.0", + }, + CompletedType: types.QueryTaskCompletedType(100).Ptr(), + }, + }, + }, + mockSetup: func(mockVersionChecker *client.VersionCheckerMock) { + mockVersionChecker.EXPECT().SupportsConsistentQuery("uber-go", "1.0.0").Return(nil) + }, + assertErr: func(t *testing.T, err error) { + var e *types.InternalServiceError + assert.ErrorAs(t, err, &e) + assert.Equal(t, "unknown query completed type", e.Message) + }, + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockVersionChecker := client.NewMockVersionChecker(mockCtrl) + tc.mockSetup(mockVersionChecker) + engine := &matchingEngineImpl{ + versionChecker: mockVersionChecker, + } + hCtx := &handlerContext{ + Context: context.Background(), + } + ch := make(chan *queryResult, 1) + ch <- tc.result + resp, err := engine.waitForQueryResult(hCtx, true, ch) + if tc.wantErr { + require.Error(t, err) + if tc.assertErr != nil { + tc.assertErr(t, err) + } + } else { + require.NoError(t, err) + assert.Equal(t, tc.want, resp) + } + }) + } +} diff --git a/service/matching/tasklist/interfaces.go b/service/matching/tasklist/interfaces.go new file mode 100644 index 00000000000..ab6e9b6dd64 --- /dev/null +++ b/service/matching/tasklist/interfaces.go @@ -0,0 +1,60 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +//go:generate mockgen -package $GOPACKAGE -source $GOFILE -destination interfaces_mock.go github.com/uber/cadence/service/matching/tasklist Manager +package tasklist + +import ( + "context" + "time" + + "github.com/uber/cadence/common/types" +) + +type ( + Manager interface { + Start() error + Stop() + // AddTask adds a task to the task list. This method will first attempt a synchronous + // match with a poller. When that fails, task will be written to database and later + // asynchronously matched with a poller + AddTask(ctx context.Context, params AddTaskParams) (syncMatch bool, err error) + // GetTask blocks waiting for a task Returns error when context deadline is exceeded + // maxDispatchPerSecond is the max rate at which tasks are allowed to be dispatched + // from this task list to pollers + GetTask(ctx context.Context, maxDispatchPerSecond *float64) (*InternalTask, error) + // DispatchTask dispatches a task to a poller. When there are no pollers to pick + // up the task, this method will return error. Task will not be persisted to db + DispatchTask(ctx context.Context, task *InternalTask) error + // DispatchQueryTask will dispatch query to local or remote poller. If forwarded then result or error is returned, + // if dispatched to local poller then nil and nil is returned. + DispatchQueryTask(ctx context.Context, taskID string, request *types.MatchingQueryWorkflowRequest) (*types.QueryWorkflowResponse, error) + CancelPoller(pollerID string) + GetAllPollerInfo() []*types.PollerInfo + HasPollerAfter(accessTime time.Time) bool + // DescribeTaskList returns information about the target tasklist + DescribeTaskList(includeTaskListStatus bool) *types.DescribeTaskListResponse + String() string + GetTaskListKind() types.TaskListKind + TaskListID() *Identifier + } +) diff --git a/service/matching/tasklist/interfaces_mock.go b/service/matching/tasklist/interfaces_mock.go new file mode 100644 index 00000000000..6642208bc0c --- /dev/null +++ b/service/matching/tasklist/interfaces_mock.go @@ -0,0 +1,241 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Code generated by MockGen. DO NOT EDIT. +// Source: interfaces.go + +// Package tasklist is a generated GoMock package. +package tasklist + +import ( + context "context" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + + types "github.com/uber/cadence/common/types" +) + +// MockManager is a mock of Manager interface. +type MockManager struct { + ctrl *gomock.Controller + recorder *MockManagerMockRecorder +} + +// MockManagerMockRecorder is the mock recorder for MockManager. +type MockManagerMockRecorder struct { + mock *MockManager +} + +// NewMockManager creates a new mock instance. +func NewMockManager(ctrl *gomock.Controller) *MockManager { + mock := &MockManager{ctrl: ctrl} + mock.recorder = &MockManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockManager) EXPECT() *MockManagerMockRecorder { + return m.recorder +} + +// AddTask mocks base method. +func (m *MockManager) AddTask(ctx context.Context, params AddTaskParams) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddTask", ctx, params) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddTask indicates an expected call of AddTask. +func (mr *MockManagerMockRecorder) AddTask(ctx, params interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddTask", reflect.TypeOf((*MockManager)(nil).AddTask), ctx, params) +} + +// CancelPoller mocks base method. +func (m *MockManager) CancelPoller(pollerID string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "CancelPoller", pollerID) +} + +// CancelPoller indicates an expected call of CancelPoller. +func (mr *MockManagerMockRecorder) CancelPoller(pollerID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelPoller", reflect.TypeOf((*MockManager)(nil).CancelPoller), pollerID) +} + +// DescribeTaskList mocks base method. +func (m *MockManager) DescribeTaskList(includeTaskListStatus bool) *types.DescribeTaskListResponse { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DescribeTaskList", includeTaskListStatus) + ret0, _ := ret[0].(*types.DescribeTaskListResponse) + return ret0 +} + +// DescribeTaskList indicates an expected call of DescribeTaskList. +func (mr *MockManagerMockRecorder) DescribeTaskList(includeTaskListStatus interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeTaskList", reflect.TypeOf((*MockManager)(nil).DescribeTaskList), includeTaskListStatus) +} + +// DispatchQueryTask mocks base method. +func (m *MockManager) DispatchQueryTask(ctx context.Context, taskID string, request *types.MatchingQueryWorkflowRequest) (*types.QueryWorkflowResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DispatchQueryTask", ctx, taskID, request) + ret0, _ := ret[0].(*types.QueryWorkflowResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DispatchQueryTask indicates an expected call of DispatchQueryTask. +func (mr *MockManagerMockRecorder) DispatchQueryTask(ctx, taskID, request interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DispatchQueryTask", reflect.TypeOf((*MockManager)(nil).DispatchQueryTask), ctx, taskID, request) +} + +// DispatchTask mocks base method. +func (m *MockManager) DispatchTask(ctx context.Context, task *InternalTask) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DispatchTask", ctx, task) + ret0, _ := ret[0].(error) + return ret0 +} + +// DispatchTask indicates an expected call of DispatchTask. +func (mr *MockManagerMockRecorder) DispatchTask(ctx, task interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DispatchTask", reflect.TypeOf((*MockManager)(nil).DispatchTask), ctx, task) +} + +// GetAllPollerInfo mocks base method. +func (m *MockManager) GetAllPollerInfo() []*types.PollerInfo { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllPollerInfo") + ret0, _ := ret[0].([]*types.PollerInfo) + return ret0 +} + +// GetAllPollerInfo indicates an expected call of GetAllPollerInfo. +func (mr *MockManagerMockRecorder) GetAllPollerInfo() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllPollerInfo", reflect.TypeOf((*MockManager)(nil).GetAllPollerInfo)) +} + +// GetTask mocks base method. +func (m *MockManager) GetTask(ctx context.Context, maxDispatchPerSecond *float64) (*InternalTask, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTask", ctx, maxDispatchPerSecond) + ret0, _ := ret[0].(*InternalTask) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTask indicates an expected call of GetTask. +func (mr *MockManagerMockRecorder) GetTask(ctx, maxDispatchPerSecond interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTask", reflect.TypeOf((*MockManager)(nil).GetTask), ctx, maxDispatchPerSecond) +} + +// GetTaskListKind mocks base method. +func (m *MockManager) GetTaskListKind() types.TaskListKind { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTaskListKind") + ret0, _ := ret[0].(types.TaskListKind) + return ret0 +} + +// GetTaskListKind indicates an expected call of GetTaskListKind. +func (mr *MockManagerMockRecorder) GetTaskListKind() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskListKind", reflect.TypeOf((*MockManager)(nil).GetTaskListKind)) +} + +// HasPollerAfter mocks base method. +func (m *MockManager) HasPollerAfter(accessTime time.Time) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasPollerAfter", accessTime) + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasPollerAfter indicates an expected call of HasPollerAfter. +func (mr *MockManagerMockRecorder) HasPollerAfter(accessTime interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasPollerAfter", reflect.TypeOf((*MockManager)(nil).HasPollerAfter), accessTime) +} + +// Start mocks base method. +func (m *MockManager) Start() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Start") + ret0, _ := ret[0].(error) + return ret0 +} + +// Start indicates an expected call of Start. +func (mr *MockManagerMockRecorder) Start() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockManager)(nil).Start)) +} + +// Stop mocks base method. +func (m *MockManager) Stop() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Stop") +} + +// Stop indicates an expected call of Stop. +func (mr *MockManagerMockRecorder) Stop() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockManager)(nil).Stop)) +} + +// String mocks base method. +func (m *MockManager) String() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "String") + ret0, _ := ret[0].(string) + return ret0 +} + +// String indicates an expected call of String. +func (mr *MockManagerMockRecorder) String() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "String", reflect.TypeOf((*MockManager)(nil).String)) +} + +// TaskListID mocks base method. +func (m *MockManager) TaskListID() *Identifier { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TaskListID") + ret0, _ := ret[0].(*Identifier) + return ret0 +} + +// TaskListID indicates an expected call of TaskListID. +func (mr *MockManagerMockRecorder) TaskListID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TaskListID", reflect.TypeOf((*MockManager)(nil).TaskListID)) +} diff --git a/service/matching/tasklist/task_list_manager.go b/service/matching/tasklist/task_list_manager.go index f44c81d3f4d..61fbe58e7cd 100644 --- a/service/matching/tasklist/task_list_manager.go +++ b/service/matching/tasklist/task_list_manager.go @@ -78,33 +78,6 @@ type ( ActivityTaskDispatchInfo *types.ActivityTaskDispatchInfo } - Manager interface { - Start() error - Stop() - // AddTask adds a task to the task list. This method will first attempt a synchronous - // match with a poller. When that fails, task will be written to database and later - // asynchronously matched with a poller - AddTask(ctx context.Context, params AddTaskParams) (syncMatch bool, err error) - // GetTask blocks waiting for a task Returns error when context deadline is exceeded - // maxDispatchPerSecond is the max rate at which tasks are allowed to be dispatched - // from this task list to pollers - GetTask(ctx context.Context, maxDispatchPerSecond *float64) (*InternalTask, error) - // DispatchTask dispatches a task to a poller. When there are no pollers to pick - // up the task, this method will return error. Task will not be persisted to db - DispatchTask(ctx context.Context, task *InternalTask) error - // DispatchQueryTask will dispatch query to local or remote poller. If forwarded then result or error is returned, - // if dispatched to local poller then nil and nil is returned. - DispatchQueryTask(ctx context.Context, taskID string, request *types.MatchingQueryWorkflowRequest) (*types.QueryWorkflowResponse, error) - CancelPoller(pollerID string) - GetAllPollerInfo() []*types.PollerInfo - HasPollerAfter(accessTime time.Time) bool - // DescribeTaskList returns information about the target tasklist - DescribeTaskList(includeTaskListStatus bool) *types.DescribeTaskListResponse - String() string - GetTaskListKind() types.TaskListKind - TaskListID() *Identifier - } - outstandingPollerInfo struct { isolationGroup string cancel context.CancelFunc @@ -152,8 +125,6 @@ const ( maxSyncMatchWaitTime = 200 * time.Millisecond ) -var _ Manager = (*taskListManagerImpl)(nil) - var errRemoteSyncMatchFailed = &types.RemoteSyncMatchedError{Message: "remote sync match failed"} func NewManager(