diff --git a/service/history/replication/dlq_handler.go b/service/history/replication/dlq_handler.go index 453de452be4..3627cd192d8 100644 --- a/service/history/replication/dlq_handler.go +++ b/service/history/replication/dlq_handler.go @@ -29,6 +29,7 @@ import ( "github.com/uber/cadence/common" "github.com/uber/cadence/common/backoff" + "github.com/uber/cadence/common/clock" "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/log/tag" "github.com/uber/cadence/common/metrics" @@ -82,6 +83,7 @@ type ( metricsClient metrics.Client done chan struct{} status int32 + timeSource clock.TimeSource mu sync.Mutex latestCounts map[string]int64 @@ -106,6 +108,7 @@ func NewDLQHandler( logger: shard.GetLogger(), metricsClient: shard.GetMetricsClient(), done: make(chan struct{}), + timeSource: clock.NewRealTimeSource(), } } @@ -176,12 +179,12 @@ func (r *dlqHandlerImpl) emitDLQSizeMetricsLoop() { ) } - timer := time.NewTimer(getInterval()) + timer := r.timeSource.NewTimer(getInterval()) defer timer.Stop() for { select { - case <-timer.C: + case <-timer.Chan(): r.fetchAndEmitMessageCount(context.Background()) timer.Reset(getInterval()) case <-r.done: @@ -322,7 +325,7 @@ func (r *dlqHandlerImpl) MergeMessages( } } - // If hydrated replication task does not exists in remote cluster - continue merging + // If hydrated replication task does not exist in remote cluster - continue merging // Record lastMessageID with raw task id, so that they can be purged after. if lastMessageID < raw.TaskID { lastMessageID = raw.TaskID diff --git a/service/history/replication/dlq_handler_test.go b/service/history/replication/dlq_handler_test.go index 3244d2879c1..643c1b1eddf 100644 --- a/service/history/replication/dlq_handler_test.go +++ b/service/history/replication/dlq_handler_test.go @@ -22,16 +22,21 @@ package replication import ( "context" + "errors" "testing" + "time" "github.com/golang/mock/gomock" "github.com/pborman/uuid" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "go.uber.org/goleak" "github.com/uber/cadence/client" "github.com/uber/cadence/client/admin" + "github.com/uber/cadence/common" + "github.com/uber/cadence/common/clock" "github.com/uber/cadence/common/mocks" "github.com/uber/cadence/common/persistence" "github.com/uber/cadence/common/types" @@ -110,11 +115,150 @@ func (s *dlqHandlerSuite) TearDownTest() { s.mockShard.Finish(s.T()) } +func (s *dlqHandlerSuite) TestNewDLQHandler_panic() { + s.Panics(func() { NewDLQHandler(s.mockShard, nil) }, "Failed to initialize replication DLQ handler due to nil task executors") +} + +func (s *dlqHandlerSuite) TestStartStop() { + tests := []struct { + name string + status int32 + }{ + { + name: "started", + status: common.DaemonStatusInitialized, + }, + { + name: "not started", + status: common.DaemonStatusStopped, + }, + } + + for _, tc := range tests { + s.T().Run(tc.name, func(t *testing.T) { + defer goleak.VerifyNone(t) + + s.messageHandler.status = tc.status + + s.messageHandler.Start() + + s.messageHandler.Stop() + }) + } +} + +func (s *dlqHandlerSuite) TestGetMessageCount() { + size := int64(1) + tests := []struct { + name string + latestCounts map[string]int64 + forceFetch bool + err error + }{ + { + name: "success", + latestCounts: map[string]int64{s.sourceCluster: size}, + }, + { + name: "success with fetchAndEmitMessageCount call", + forceFetch: true, + }, + { + name: "error", + forceFetch: true, + err: errors.New("fetchAndEmitMessageCount error"), + }, + } + + for _, tc := range tests { + s.T().Run(tc.name, func(t *testing.T) { + s.messageHandler.latestCounts = tc.latestCounts + + if tc.forceFetch || tc.latestCounts == nil { + s.executionManager.On("GetReplicationDLQSize", mock.Anything, mock.Anything).Return(&persistence.GetReplicationDLQSizeResponse{Size: size}, tc.err).Times(1) + } + + counts, err := s.messageHandler.GetMessageCount(context.Background(), tc.forceFetch) + + if tc.err != nil { + s.Error(err) + s.Equal(tc.err, err) + } else if tc.latestCounts != nil { + s.NoError(err) + s.Equal(size, counts[s.sourceCluster]) + } else { + s.NoError(err) + } + }) + } +} + +func (s *dlqHandlerSuite) TestFetchAndEmitMessageCount() { + tests := []struct { + name string + err error + }{ + { + name: "success", + err: nil, + }, + { + name: "error", + err: errors.New("error"), + }, + } + + for _, tc := range tests { + s.T().Run(tc.name, func(t *testing.T) { + size := int64(3) + rets := &persistence.GetReplicationDLQSizeResponse{Size: size} + s.messageHandler.latestCounts = make(map[string]int64) + + s.executionManager.On("GetReplicationDLQSize", context.Background(), mock.Anything).Return(rets, tc.err).Times(1) + + err := s.messageHandler.fetchAndEmitMessageCount(context.Background()) + + if tc.err != nil { + s.Error(err) + s.Equal(tc.err, err) + } else { + s.NoError(err) + s.Equal(len(s.messageHandler.latestCounts), len(s.taskExecutors)) + s.Equal(size, s.messageHandler.latestCounts[s.sourceCluster]) + } + }) + } +} + +func (s *dlqHandlerSuite) TestEmitDLQSizeMetricsLoop_FetchesAndEmitsMetricsPeriodically() { + defer goleak.VerifyNone(s.T()) + + emissionNumber := 2 + + s.messageHandler.status = common.DaemonStatusStarted + s.executionManager.On("GetReplicationDLQSize", mock.Anything, mock.Anything).Return(&persistence.GetReplicationDLQSizeResponse{Size: 1}, nil).Times(emissionNumber) + mockTimeSource := clock.NewMockedTimeSource() + s.messageHandler.timeSource = mockTimeSource + + go s.messageHandler.emitDLQSizeMetricsLoop() + + for i := 0; i < emissionNumber; i++ { + mockTimeSource.BlockUntil(1) + + // Advance time to trigger the next emission + mockTimeSource.Advance(dlqMetricsEmitTimerInterval + time.Duration(int64(float64(dlqMetricsEmitTimerInterval)*(1+dlqMetricsEmitTimerCoefficient)))) + } + + s.messageHandler.Stop() + + s.Equal(common.DaemonStatusStopped, s.messageHandler.status) +} + func (s *dlqHandlerSuite) TestReadMessages_OK() { ctx := context.Background() lastMessageID := int64(1) pageSize := 1 - pageToken := []byte{} + var pageToken []byte resp := &persistence.GetReplicationTasksFromDLQResponse{ Tasks: []*persistence.ReplicationTaskInfo{ @@ -150,26 +294,187 @@ func (s *dlqHandlerSuite) TestReadMessages_OK() { s.Nil(tasks) } -func (s *dlqHandlerSuite) TestPurgeMessages_OK() { - sourceCluster := "test" - lastMessageID := int64(1) +func (s *dlqHandlerSuite) TestReadMessagesWithAckLevel_OK() { + replicationTasksResponse := &persistence.GetReplicationTasksFromDLQResponse{ + Tasks: []*persistence.ReplicationTaskInfo{ + { + DomainID: "domainID", + TaskID: 123, + WorkflowID: "workflowID", + RunID: "runID", + TaskType: 5, + Version: 1, + FirstEventID: 1, + NextEventID: 2, + ScheduledID: 3, + }, + }, + NextPageToken: []byte("token"), + } - s.executionManager.On("RangeDeleteReplicationTaskFromDLQ", mock.Anything, - &persistence.RangeDeleteReplicationTaskFromDLQRequest{ - SourceClusterName: sourceCluster, - ExclusiveBeginTaskID: -1, - InclusiveEndTaskID: lastMessageID, - }).Return(&persistence.RangeDeleteReplicationTaskFromDLQResponse{TasksCompleted: persistence.UnknownNumRowsAffected}, nil).Times(1) + DLQReplicationMessagesResponse := &types.GetDLQReplicationMessagesResponse{ + ReplicationTasks: []*types.ReplicationTask{ + { + SourceTaskID: 123, + }, + }, + } + + ctx := context.Background() + lastMessageID := int64(123) + pageSize := 12 + pageToken := []byte("token") + + req := &persistence.GetReplicationTasksFromDLQRequest{ + SourceClusterName: s.sourceCluster, + GetReplicationTasksRequest: persistence.GetReplicationTasksRequest{ + ReadLevel: defaultBeginningMessageID, + MaxReadLevel: lastMessageID, + BatchSize: pageSize, + NextPageToken: pageToken, + }, + } + + s.executionManager.On("GetReplicationTasksFromDLQ", ctx, req).Return(replicationTasksResponse, nil).Times(1) + + s.adminClient.EXPECT(). + GetDLQReplicationMessages(ctx, gomock.Any()). + Return(DLQReplicationMessagesResponse, nil).Times(1) + + replicationTasks, taskInfo, nextPageToken, err := s.messageHandler.readMessagesWithAckLevel(ctx, s.sourceCluster, lastMessageID, pageSize, pageToken) - err := s.messageHandler.PurgeMessages(context.Background(), sourceCluster, lastMessageID) s.NoError(err) + s.Equal(replicationTasks, DLQReplicationMessagesResponse.ReplicationTasks) + s.Len(taskInfo, len(replicationTasksResponse.Tasks)) + // testing content of taskInfo because it's assembled in the method using tasks from replicationTasksFromDLQ + for i, task := range taskInfo { + s.Equal(task.GetDomainID(), replicationTasksResponse.Tasks[i].GetDomainID()) + s.Equal(task.GetWorkflowID(), replicationTasksResponse.Tasks[i].GetWorkflowID()) + s.Equal(task.GetRunID(), replicationTasksResponse.Tasks[i].GetRunID()) + s.Equal(task.GetTaskID(), replicationTasksResponse.Tasks[i].GetTaskID()) + s.Equal(task.GetTaskType(), int16(replicationTasksResponse.Tasks[i].GetTaskType())) + s.Equal(task.GetVersion(), replicationTasksResponse.Tasks[i].GetVersion()) + s.Equal(task.FirstEventID, replicationTasksResponse.Tasks[i].FirstEventID) + s.Equal(task.NextEventID, replicationTasksResponse.Tasks[i].NextEventID) + s.Equal(task.ScheduledID, replicationTasksResponse.Tasks[i].ScheduledID) + } + s.Equal(nextPageToken, replicationTasksResponse.NextPageToken) +} + +func (s *dlqHandlerSuite) TestReadMessagesWithAckLevel_GetReplicationTasksFromDLQFailed() { + errorMessage := "GetReplicationTasksFromDLQFailed" + ctx := context.Background() + lastMessageID := int64(123) + pageSize := 12 + pageToken := []byte("token") + + req := &persistence.GetReplicationTasksFromDLQRequest{ + SourceClusterName: s.sourceCluster, + GetReplicationTasksRequest: persistence.GetReplicationTasksRequest{ + ReadLevel: defaultBeginningMessageID, + MaxReadLevel: lastMessageID, + BatchSize: pageSize, + NextPageToken: pageToken, + }, + } + + s.executionManager.On("GetReplicationTasksFromDLQ", ctx, req).Return(nil, errors.New(errorMessage)).Times(1) + + _, _, _, err := s.messageHandler.readMessagesWithAckLevel(ctx, s.sourceCluster, lastMessageID, pageSize, pageToken) + + s.Error(err) + s.Equal(err, errors.New(errorMessage)) +} + +func (s *dlqHandlerSuite) TestReadMessagesWithAckLevel_InvalidCluster() { + s.executionManager.On("GetReplicationTasksFromDLQ", mock.Anything, mock.Anything).Return(nil, nil).Times(1) + + s.mockShard.Resource.ClientBean = client.NewMockBean(s.controller) + s.mockShard.Resource.ClientBean.EXPECT().GetRemoteAdminClient("invalidCluster").Return(nil).Times(1) + + _, _, _, err := s.messageHandler.readMessagesWithAckLevel(context.Background(), "invalidCluster", 123, 12, []byte("token")) + + s.Error(err) + s.Equal(errInvalidCluster, err) +} + +func (s *dlqHandlerSuite) TestReadMessagesWithAckLevel_GetDLQReplicationMessagesFailed() { + errorMessage := "GetDLQReplicationMessagesFailed" + ctx := context.Background() + lastMessageID := int64(123) + pageSize := 12 + pageToken := []byte("token") + + req := &persistence.GetReplicationTasksFromDLQRequest{ + SourceClusterName: s.sourceCluster, + GetReplicationTasksRequest: persistence.GetReplicationTasksRequest{ + ReadLevel: defaultBeginningMessageID, + MaxReadLevel: lastMessageID, + BatchSize: pageSize, + NextPageToken: pageToken, + }, + } + + replicationTasksResponse := &persistence.GetReplicationTasksFromDLQResponse{ + Tasks: []*persistence.ReplicationTaskInfo{ + { + DomainID: "domainID", + }, + }, + } + + s.executionManager.On("GetReplicationTasksFromDLQ", ctx, req).Return(replicationTasksResponse, nil).Times(1) + + s.adminClient.EXPECT(). + GetDLQReplicationMessages(ctx, gomock.Any()). + Return(nil, errors.New(errorMessage)).Times(1) + + _, _, _, err := s.messageHandler.readMessagesWithAckLevel(ctx, s.sourceCluster, lastMessageID, pageSize, pageToken) + + s.Error(err) + s.Equal(err, errors.New(errorMessage)) +} + +func (s *dlqHandlerSuite) TestPurgeMessages() { + tests := []struct { + name string + err error + }{ + { + name: "success", + }, + { + name: "error", + err: errors.New("error"), + }, + } + + for _, tc := range tests { + s.T().Run(tc.name, func(t *testing.T) { + lastMessageID := int64(1) + s.executionManager.On("RangeDeleteReplicationTaskFromDLQ", mock.Anything, + &persistence.RangeDeleteReplicationTaskFromDLQRequest{ + SourceClusterName: s.sourceCluster, + ExclusiveBeginTaskID: -1, + InclusiveEndTaskID: lastMessageID, + }).Return(&persistence.RangeDeleteReplicationTaskFromDLQResponse{TasksCompleted: persistence.UnknownNumRowsAffected}, tc.err).Times(1) + + err := s.messageHandler.PurgeMessages(context.Background(), s.sourceCluster, lastMessageID) + + if tc.err != nil { + s.Error(err) + } else { + s.NoError(err) + } + }) + } } func (s *dlqHandlerSuite) TestMergeMessages_OK() { ctx := context.Background() lastMessageID := int64(2) pageSize := 1 - pageToken := []byte{} + var pageToken []byte resp := &persistence.GetReplicationTasksFromDLQResponse{ Tasks: []*persistence.ReplicationTaskInfo{ @@ -222,6 +527,57 @@ func (s *dlqHandlerSuite) TestMergeMessages_OK() { s.Equal(1, len(s.taskExecutor.executedTasks)) } +func (s *dlqHandlerSuite) TestMergeMessages_InvalidCluster() { + _, err := s.messageHandler.MergeMessages(context.Background(), "invalid", 1, 1, nil) + s.Error(err) + s.Equal(errInvalidCluster, err) +} + +func (s *dlqHandlerSuite) TestMergeMessages_GetReplicationTasksFromDLQFailed() { + errorMessage := "GetReplicationTasksFromDLQFailed" + s.executionManager.On("GetReplicationTasksFromDLQ", mock.Anything, mock.Anything).Return(nil, errors.New(errorMessage)).Times(1) + _, err := s.messageHandler.MergeMessages(context.Background(), s.sourceCluster, 1, 1, nil) + s.Error(err) + s.Equal(err, errors.New(errorMessage)) +} + +func (s *dlqHandlerSuite) TestMergeMessages_RangeDeleteReplicationTaskFromDLQFailed() { + errorMessage := "RangeDeleteReplicationTaskFromDLQFailed" + s.executionManager.On("GetReplicationTasksFromDLQ", mock.Anything, mock.Anything).Return(&persistence.GetReplicationTasksFromDLQResponse{}, nil).Times(1) + s.executionManager.On("RangeDeleteReplicationTaskFromDLQ", mock.Anything, mock.Anything).Return(nil, errors.New(errorMessage)).Times(1) + _, err := s.messageHandler.MergeMessages(context.Background(), s.sourceCluster, 1, 1, nil) + s.Error(err) + s.Equal(err, errors.New(errorMessage)) +} + +func (s *dlqHandlerSuite) TestMergeMessages_executeFailed() { + errorMessage := "executeFailed" + s.taskExecutors[s.sourceCluster] = &fakeTaskExecutor{err: errors.New(errorMessage)} + + ctx := context.Background() + lastMessageID := int64(2) + pageSize := 1 + var pageToken []byte + + s.executionManager.On("GetReplicationTasksFromDLQ", mock.Anything, &persistence.GetReplicationTasksFromDLQRequest{ + SourceClusterName: s.sourceCluster, + GetReplicationTasksRequest: persistence.GetReplicationTasksRequest{ + ReadLevel: -1, + MaxReadLevel: lastMessageID, + BatchSize: pageSize, + NextPageToken: pageToken, + }, + }).Return(&persistence.GetReplicationTasksFromDLQResponse{Tasks: []*persistence.ReplicationTaskInfo{{TaskID: 1}}}, nil).Times(1) + + s.adminClient.EXPECT().GetDLQReplicationMessages(ctx, gomock.Any()). + Return(&types.GetDLQReplicationMessagesResponse{ReplicationTasks: []*types.ReplicationTask{{SourceTaskID: 1}}}, nil) + + _, err := s.messageHandler.MergeMessages(ctx, s.sourceCluster, lastMessageID, pageSize, pageToken) + + s.Error(err) + s.Equal(err, errors.New(errorMessage)) +} + type fakeTaskExecutor struct { scope int err error @@ -229,7 +585,7 @@ type fakeTaskExecutor struct { executedTasks []*types.ReplicationTask } -func (e *fakeTaskExecutor) execute(replicationTask *types.ReplicationTask, forceApply bool) (int, error) { +func (e *fakeTaskExecutor) execute(replicationTask *types.ReplicationTask, _ bool) (int, error) { e.executedTasks = append(e.executedTasks, replicationTask) return e.scope, e.err }