diff --git a/internal/internal_workflow_testsuite.go b/internal/internal_workflow_testsuite.go index c42134cf1..18d26b817 100644 --- a/internal/internal_workflow_testsuite.go +++ b/internal/internal_workflow_testsuite.go @@ -374,6 +374,7 @@ func (env *testWorkflowEnvironmentImpl) newTestWorkflowEnvironmentForChild(param childEnv.testWorkflowEnvironmentShared = env.testWorkflowEnvironmentShared childEnv.workerOptions = env.workerOptions childEnv.dataConverter = params.DataConverter + childEnv.failureConverter = env.failureConverter childEnv.registry = env.registry childEnv.detachedChildWaitDisabled = env.detachedChildWaitDisabled @@ -1403,8 +1404,14 @@ func (env *testWorkflowEnvironmentImpl) executeActivityWithRetryForTest( // check if a retry is needed if request, ok := result.(*workflowservice.RespondActivityTaskFailedRequest); ok && parameters.RetryPolicy != nil { + failure := request.GetFailure() + + if failure.GetApplicationFailureInfo().GetNonRetryable() { + break + } + p := fromProtoRetryPolicy(parameters.RetryPolicy) - backoff := getRetryBackoffWithNowTime(p, task.GetAttempt(), env.failureConverter.FailureToError(request.GetFailure()), env.Now(), expireTime) + backoff := getRetryBackoffWithNowTime(p, task.GetAttempt(), env.failureConverter.FailureToError(failure), env.Now(), expireTime) if backoff > 0 { // need a retry waitCh := make(chan struct{}) @@ -1987,6 +1994,7 @@ func (env *testWorkflowEnvironmentImpl) newTestActivityTaskHandler(taskQueue str MetricsHandler: env.metricsHandler, Logger: env.logger, UserContext: env.workerOptions.BackgroundActivityContext, + FailureConverter: env.failureConverter, DataConverter: dataConverter, WorkerStopChannel: env.workerStopChannel, ContextPropagators: env.contextPropagators, diff --git a/internal/workflow_testsuite_test.go b/internal/workflow_testsuite_test.go index 7d1dc4e11..465ab41fd 100644 --- a/internal/workflow_testsuite_test.go +++ b/internal/workflow_testsuite_test.go @@ -28,10 +28,13 @@ import ( "context" "errors" "strings" + "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" + failurepb "go.temporal.io/api/failure/v1" + "go.temporal.io/sdk/converter" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -533,3 +536,64 @@ func TestMockCallWrapperNotBefore(t *testing.T) { require.ErrorAs(t, env.GetWorkflowError(), &expectedErr) require.ErrorContains(t, expectedErr, "Must not be called before") } + +func TestCustomFailureConverter(t *testing.T) { + t.Parallel() + + var suite WorkflowTestSuite + env := suite.NewTestWorkflowEnvironment() + env.SetFailureConverter(testFailureConverter{ + fallback: defaultFailureConverter, + }) + + var calls atomic.Int32 + activity := func(context.Context) error { + _ = calls.Add(1) + return testCustomError{} + } + env.RegisterActivity(activity) + + env.ExecuteWorkflow(func(ctx Context) error { + ctx = WithActivityOptions(ctx, ActivityOptions{ + StartToCloseTimeout: time.Hour, + }) + return ExecuteActivity(ctx, activity).Get(ctx, nil) + }) + require.True(t, env.IsWorkflowCompleted()) + + // Failure converter should've reconstructed the custom error type. + require.True(t, errors.As(env.GetWorkflowError(), &testCustomError{})) + + // Activity should've only been called once because the failure converter + // set the NonRetryable flag. + require.Equal(t, 1, int(calls.Load())) +} + +type testCustomError struct{} + +func (testCustomError) Error() string { return "this is a custom error type" } + +type testFailureConverter struct { + fallback converter.FailureConverter +} + +func (c testFailureConverter) ErrorToFailure(err error) *failurepb.Failure { + if errors.As(err, &testCustomError{}) { + return &failurepb.Failure{ + FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + Type: "CUSTOM ERROR", + NonRetryable: true, + }, + }, + } + } + return c.fallback.ErrorToFailure(err) +} + +func (c testFailureConverter) FailureToError(failure *failurepb.Failure) error { + if failure.GetApplicationFailureInfo().GetType() == "CUSTOM ERROR" { + return testCustomError{} + } + return c.fallback.FailureToError(failure) +}