diff --git a/internal/internal_workflow_testsuite.go b/internal/internal_workflow_testsuite.go index b5189f102..e5a552c6e 100644 --- a/internal/internal_workflow_testsuite.go +++ b/internal/internal_workflow_testsuite.go @@ -157,6 +157,14 @@ type ( taskQueues map[string]struct{} } + updateResult struct { + success interface{} + err error + update_id string + callbacks []updateCallbacksWrapper + completed bool + } + // testWorkflowEnvironmentShared is the shared data between parent workflow and child workflow test environments testWorkflowEnvironmentShared struct { locker sync.Mutex @@ -229,6 +237,7 @@ type ( signalHandler func(name string, input *commonpb.Payloads, header *commonpb.Header) error queryHandler func(string, *commonpb.Payloads, *commonpb.Header) (*commonpb.Payloads, error) updateHandler func(name string, id string, input *commonpb.Payloads, header *commonpb.Header, resp UpdateCallbacks) + updateMap map[string]*updateResult startedHandler func(r WorkflowExecution, e error) isWorkflowCompleted bool @@ -258,6 +267,13 @@ type ( *sessionEnvironmentImpl testWorkflowEnvironment *testWorkflowEnvironmentImpl } + + // UpdateCallbacksWrapper is a wrapper to UpdateCallbacks. It allows us to dedup duplicate update IDs in the test environment. + updateCallbacksWrapper struct { + uc UpdateCallbacks + env *testWorkflowEnvironmentImpl + updateID string + } ) func newTestWorkflowEnvironmentImpl(s *WorkflowTestSuite, parentRegistry *registry) *testWorkflowEnvironmentImpl { @@ -2917,10 +2933,32 @@ func (env *testWorkflowEnvironmentImpl) updateWorkflow(name string, id string, u if err != nil { panic(err) } - env.postCallback(func() { - // Do not send any headers on test invocations - env.updateHandler(name, id, data, nil, uc) - }, true) + + if env.updateMap == nil { + env.updateMap = make(map[string]*updateResult) + } + + var ucWrapper = updateCallbacksWrapper{uc: uc, env: env, updateID: id} + + // check for duplicate update ID + if result, ok := env.updateMap[id]; ok { + if result.completed { + env.postCallback(func() { + ucWrapper.uc.Accept() + ucWrapper.uc.Complete(result.success, result.err) + }, false) + } else { + result.callbacks = append(result.callbacks, ucWrapper) + } + env.updateMap[id] = result + } else { + env.updateMap[id] = &updateResult{nil, nil, id, []updateCallbacksWrapper{}, false} + env.postCallback(func() { + // Do not send any headers on test invocations + env.updateHandler(name, id, data, nil, ucWrapper) + }, true) + } + } func (env *testWorkflowEnvironmentImpl) updateWorkflowByID(workflowID, name, id string, uc UpdateCallbacks, args ...interface{}) error { @@ -2932,9 +2970,31 @@ func (env *testWorkflowEnvironmentImpl) updateWorkflowByID(workflowID, name, id if err != nil { panic(err) } - workflowHandle.env.postCallback(func() { - workflowHandle.env.updateHandler(name, id, data, nil, uc) - }, true) + + if env.updateMap == nil { + env.updateMap = make(map[string]*updateResult) + } + + var ucWrapper = updateCallbacksWrapper{uc: uc, env: env, updateID: id} + + // Check for duplicate update ID + if result, ok := env.updateMap[id]; ok { + if result.completed { + env.postCallback(func() { + ucWrapper.uc.Accept() + ucWrapper.uc.Complete(result.success, result.err) + }, false) + } else { + result.callbacks = append(result.callbacks, ucWrapper) + } + env.updateMap[id] = result + } else { + env.updateMap[id] = &updateResult{nil, nil, id, []updateCallbacksWrapper{}, false} + workflowHandle.env.postCallback(func() { + workflowHandle.env.updateHandler(name, id, data, nil, ucWrapper) + }, true) + } + return nil } @@ -3075,6 +3135,34 @@ func mockFnGetVersion(string, Version, Version) Version { // make sure interface is implemented var _ WorkflowEnvironment = (*testWorkflowEnvironmentImpl)(nil) +func (uc updateCallbacksWrapper) Accept() { + uc.uc.Accept() +} + +func (uc updateCallbacksWrapper) Reject(err error) { + uc.uc.Reject(err) +} + +func (uc updateCallbacksWrapper) Complete(success interface{}, err error) { + // cache update result so we can dedup duplicate update IDs + if uc.env == nil { + panic("env is needed in updateCallback to cache update results for deduping purposes") + } + if result, ok := uc.env.updateMap[uc.updateID]; ok { + if !result.completed { + result.success = success + result.err = err + uc.uc.Complete(success, err) + result.completed = true + result.post_callbacks(uc.env) + } else { + uc.uc.Complete(result.success, result.err) + } + } else { + panic("updateMap[updateID] should already be created from updateWorkflow()") + } +} + func (h *testNexusOperationHandle) newStartTask() *workflowservice.PollNexusTaskQueueResponse { return &workflowservice.PollNexusTaskQueueResponse{ TaskToken: []byte{}, @@ -3425,3 +3513,13 @@ func newTestNexusOperation(opRef testNexusOperationReference) *testNexusOperatio testNexusOperationReference: opRef, } } + +func (res *updateResult) post_callbacks(env *testWorkflowEnvironmentImpl) { + for _, uc := range res.callbacks { + env.postCallback(func() { + uc.Accept() + uc.Complete(res.success, res.err) + }, false) + } + res.callbacks = []updateCallbacksWrapper{} +} diff --git a/internal/workflow_testsuite_test.go b/internal/workflow_testsuite_test.go index 9abe099c2..9d0684266 100644 --- a/internal/workflow_testsuite_test.go +++ b/internal/workflow_testsuite_test.go @@ -473,6 +473,99 @@ func TestWorkflowUpdateOrderAcceptReject(t *testing.T) { require.Equal(t, "unknown update bad update. KnownUpdates=[update]", updateRejectionErr.Error()) } +func TestWorkflowDuplicateIDDedup(t *testing.T) { + duplicateIDDedup(t, true, false, 1) +} + +func TestWorkflowDuplicateIDDedupInterweave(t *testing.T) { + // The second update should be scheduled before the first update is complete. + // This causes the second update to be completed only after the first update + // is complete and its result is cached for the second update to dedup. + duplicateIDDedup(t, false, false, 1) +} + +func TestWorkflowDuplicateIDDedupWithSleep(t *testing.T) { + duplicateIDDedup(t, false, true, 1) +} + +func TestWorkflowDuplicateIDDedupMore(t *testing.T) { + duplicateIDDedup(t, true, false, 50) +} + +func TestWorkflowDuplicateIDDedupDelayAndSleep(t *testing.T) { + duplicateIDDedup(t, true, true, 50) +} + +func duplicateIDDedup(t *testing.T, delay_second bool, with_sleep bool, additional int) { + var suite WorkflowTestSuite + var second_delay time.Duration + if delay_second { + second_delay = 1 * time.Second + } else { + second_delay = 0 * time.Second + } + additional_update_count := 0 + // Test dev server dedups UpdateWorkflow with same ID + env := suite.NewTestWorkflowEnvironment() + env.RegisterDelayedCallback(func() { + env.UpdateWorkflow("update", "id", &TestUpdateCallback{ + OnReject: func(err error) { + require.Fail(t, fmt.Sprintf("update should not be rejected, err: %v", err)) + }, + OnAccept: func() {}, + OnComplete: func(result interface{}, err error) { + intResult, ok := result.(int) + if !ok { + require.Fail(t, fmt.Sprintf("result should be int: %v\nerr: %v", result, err)) + } else { + require.Equal(t, 0, intResult) + } + }, + }, 0) + }, 0) + + for i := 0; i < additional; i++ { + env.RegisterDelayedCallback(func() { + env.UpdateWorkflow("update", "id", &TestUpdateCallback{ + OnReject: func(err error) { + require.Fail(t, fmt.Sprintf("update should not be rejected, err: %v", err)) + }, + OnAccept: func() {}, + OnComplete: func(result interface{}, err error) { + intResult, ok := result.(int) + if !ok { + require.Fail(t, fmt.Sprintf("result should be int: %v\nerr: %v", result, err)) + } else { + // if dedup, this be okay, even if we pass in 1 as arg, since it's deduping, + // the result should match the first update's result, 0 + require.Equal(t, 0, intResult) + } + additional_update_count += 1 + }, + }, 1) + + }, second_delay) + } + + env.ExecuteWorkflow(func(ctx Context) error { + err := SetUpdateHandler(ctx, "update", func(ctx Context, i int) (int, error) { + if with_sleep { + err := Sleep(ctx, time.Second) + if err != nil { + return 0, err + } + } + return i, nil + }, UpdateHandlerOptions{}) + if err != nil { + return err + } + return Sleep(ctx, time.Hour) + }) + require.NoError(t, env.GetWorkflowError()) + require.Equal(t, additional, additional_update_count) +} + func TestAllHandlersFinished(t *testing.T) { var suite WorkflowTestSuite env := suite.NewTestWorkflowEnvironment()