diff --git a/service/frontend/api/handler_test.go b/service/frontend/api/handler_test.go index e57d6f4f1db..e7e302522f7 100644 --- a/service/frontend/api/handler_test.go +++ b/service/frontend/api/handler_test.go @@ -2310,6 +2310,172 @@ func (s *workflowHandlerSuite) TestRespondActivityTaskCanceled() { } +func (s *workflowHandlerSuite) TestRespondActivityTaskCanceledByID() { + config := s.newConfig(dc.NewInMemoryClient()) + config.EnableClientVersionCheck = dc.GetBoolPropertyFn(true) + wh := NewWorkflowHandler(s.mockResource, config, s.mockVersionChecker, nil) + wh.tokenSerializer = s.mockTokenSerializer + + validInput := &types.RespondActivityTaskCanceledByIDRequest{ + Domain: s.testDomain, + WorkflowID: testWorkflowID, + RunID: testRunID, + ActivityID: "activityID", + Identity: "identity", + Details: make([]byte, 1000), + } + + testInput := map[string]struct { + request *types.RespondActivityTaskCanceledByIDRequest + expectError bool + mockFn func() + }{ + "shutting down": { + request: validInput, + mockFn: func() { + wh.shuttingDown = int32(1) + }, + expectError: true, + }, + "nil request": { + request: nil, + mockFn: func() {}, + expectError: true, + }, + "empty domain name": { + request: &types.RespondActivityTaskCanceledByIDRequest{ + Domain: "", + }, + mockFn: func() {}, + expectError: true, + }, + "cannot get domain ID": { + request: validInput, + mockFn: func() { + s.mockDomainCache.EXPECT().GetDomainID(s.testDomain).Return("", errors.New("error getting domain ID")) + }, + expectError: true, + }, + "empty domain ID": { + request: &types.RespondActivityTaskCanceledByIDRequest{ + Domain: s.testDomain, + }, + mockFn: func() { + s.mockDomainCache.EXPECT().GetDomainID(s.testDomain).Return("", nil) + }, + expectError: true, + }, + "empty workflow ID": { + request: &types.RespondActivityTaskCanceledByIDRequest{ + Domain: s.testDomain, + WorkflowID: "", + }, + mockFn: func() { + s.mockDomainCache.EXPECT().GetDomainID(s.testDomain).Return(s.testDomainID, nil) + }, + expectError: true, + }, + "empty activity ID": { + request: &types.RespondActivityTaskCanceledByIDRequest{ + Domain: s.testDomain, + WorkflowID: testWorkflowID, + ActivityID: "", + }, + mockFn: func() { + s.mockDomainCache.EXPECT().GetDomainID(s.testDomain).Return(s.testDomainID, nil) + }, + expectError: true, + }, + "exceeds id length limit": { + request: validInput, + mockFn: func() { + s.mockDomainCache.EXPECT().GetDomainID(s.testDomain).Return(s.testDomainID, nil) + wh.config.MaxIDLengthWarnLimit = dc.GetIntPropertyFn(1) + wh.config.IdentityMaxLength = dc.GetIntPropertyFilteredByDomain(1) + }, + expectError: true, + }, + "serialization failure": { + request: validInput, + mockFn: func() { + s.mockDomainCache.EXPECT().GetDomainID(s.testDomain).Return(s.testDomainID, nil) + s.mockTokenSerializer.EXPECT().Serialize(gomock.Any()).Return(nil, errors.New("failed to deserialize token")) + }, + expectError: true, + }, + "exceeds blob size limit": { + request: validInput, + mockFn: func() { + s.mockDomainCache.EXPECT().GetDomainID(s.testDomain).Return(s.testDomainID, nil) + s.mockTokenSerializer.EXPECT().Serialize(gomock.Any()).Return(make([]byte, 100), nil) + wh.config.BlobSizeLimitWarn = dc.GetIntPropertyFilteredByDomain(10) + wh.config.BlobSizeLimitError = dc.GetIntPropertyFilteredByDomain(10) + s.mockHistoryClient.EXPECT().RespondActivityTaskFailed(gomock.Any(), gomock.Any()).Return(errors.New("error")) + }, + expectError: true, + }, + "history client returns error": { + request: validInput, + mockFn: func() { + s.mockDomainCache.EXPECT().GetDomainID(s.testDomain).Return(s.testDomainID, nil) + s.mockTokenSerializer.EXPECT().Serialize(gomock.Any()).Return(make([]byte, 5), nil) + wh.config.BlobSizeLimitWarn = dc.GetIntPropertyFilteredByDomain(1000) + wh.config.BlobSizeLimitError = dc.GetIntPropertyFilteredByDomain(1000) + s.mockHistoryClient.EXPECT().RespondActivityTaskCanceled(gomock.Any(), gomock.Any()).Return(errors.New("error")) + }, + expectError: true, + }, + "success": { + request: validInput, + mockFn: func() { + s.mockDomainCache.EXPECT().GetDomainID(s.testDomain).Return(s.testDomainID, nil) + s.mockTokenSerializer.EXPECT().Serialize(gomock.Any()).Return(make([]byte, 5), nil) + wh.config.BlobSizeLimitWarn = dc.GetIntPropertyFilteredByDomain(1000) + wh.config.BlobSizeLimitError = dc.GetIntPropertyFilteredByDomain(1000) + s.mockHistoryClient.EXPECT().RespondActivityTaskCanceled(gomock.Any(), gomock.Any()).Return(nil) + }, + expectError: false, + }, + } + + for name, input := range testInput { + s.Run(name, func() { + input.mockFn() + err := wh.RespondActivityTaskCanceledByID(context.Background(), input.request) + if input.expectError { + s.Error(err) + } else { + s.NoError(err) + } + wh.shuttingDown = int32(0) + wh.config.MaxIDLengthWarnLimit = dc.GetIntPropertyFn(1000) + wh.config.IdentityMaxLength = dc.GetIntPropertyFilteredByDomain(1000) + }) + } + + // test version checker + s.Run("version checker", func() { + mockCtrl := gomock.NewController(s.T()) + mockResource := resource.NewTest(s.T(), mockCtrl, metrics.Frontend) + mockVersionChecker := client.NewMockVersionChecker(mockCtrl) + + cfg := frontendcfg.NewConfig( + dc.NewCollection( + dc.NewInMemoryClient(), + mockResource.GetLogger(), + ), + numHistoryShards, + false, + "hostname", + ) + cfg.EnableClientVersionCheck = dc.GetBoolPropertyFn(true) + wh := NewWorkflowHandler(mockResource, cfg, mockVersionChecker, nil) + mockVersionChecker.EXPECT().ClientSupported(gomock.Any(), gomock.Any()).Return(errors.New("error")).Times(1) + err := wh.RespondActivityTaskCanceledByID(context.Background(), validInput) + s.Error(err) + }) +} + func updateRequest( historyArchivalURI *string, historyArchivalStatus *types.ArchivalStatus,