Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Retrier and Catcher passed to constructor for Task, Parallel and Map states are not added to the state's Retriers and Catchers #169

Merged
merged 7 commits into from
Oct 7, 2021
48 changes: 37 additions & 11 deletions src/stepfunctions/steps/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,27 +254,29 @@ def accept(self, visitor):

def add_retry(self, retry):
"""
Add a Retry block to the tail end of the list of retriers for the state.
Add a retrier or a list of retriers to the tail end of the list of retriers for the state.
See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details.

Args:
retry (Retry): Retry block to add.
retry (Retry or list(Retry)): A retrier or list of retriers to add.
"""
if Field.Retry in self.allowed_fields():
self.retries.append(retry)
self.retries.extend(retry) if isinstance(retry, list) else self.retries.append(retry)
else:
raise ValueError("{state_type} state does not support retry field. ".format(state_type=type(self).__name__))
raise ValueError(f"{type(self).__name__} state does not support retry field. ")

def add_catch(self, catch):
"""
Add a Catch block to the tail end of the list of catchers for the state.
Add a catcher or a list of catchers to the tail end of the list of catchers for the state.
See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-fallback-states>`_ for more details.

Args:
catch (Catch): Catch block to add.
catch (Catch or list(Catch): catcher or list of catchers to add.
"""
if Field.Catch in self.allowed_fields():
self.catches.append(catch)
self.catches.extend(catch) if isinstance(catch, list) else self.catches.append(catch)
else:
raise ValueError("{state_type} state does not support catch field. ".format(state_type=type(self).__name__))
raise ValueError(f"{type(self).__name__} state does not support catch field. ")

def to_dict(self):
result = super(State, self).to_dict()
Expand Down Expand Up @@ -487,10 +489,12 @@ class Parallel(State):
A Parallel state causes the interpreter to execute each branch as concurrently as possible, and wait until each branch terminates (reaches a terminal state) before processing the next state in the Chain.
"""

def __init__(self, state_id, **kwargs):
def __init__(self, state_id, retry=None, catch=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details.
catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-fallback-states>`_ for more details.
comment (str, optional): Human-readable comment or description. (default: None)
input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$')
parameters (dict, optional): The value of this field becomes the effective input for the state.
Expand All @@ -500,6 +504,12 @@ def __init__(self, state_id, **kwargs):
super(Parallel, self).__init__(state_id, 'Parallel', **kwargs)
self.branches = []

if retry:
self.add_retry(retry)

if catch:
self.add_catch(catch)

def allowed_fields(self):
return [
Field.Comment,
Expand Down Expand Up @@ -536,11 +546,13 @@ class Map(State):
A Map state can accept an input with a list of items, execute a state or chain for each item in the list, and return a list, with all corresponding results of each execution, as its output.
"""

def __init__(self, state_id, **kwargs):
def __init__(self, state_id, retry=None, catch=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
iterator (State or Chain): State or chain to execute for each of the items in `items_path`.
retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details.
catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-fallback-states>`_ for more details.
items_path (str, optional): Path in the input for items to iterate over. (default: '$')
max_concurrency (int, optional): Maximum number of iterations to have running at any given point in time. (default: 0)
comment (str, optional): Human-readable comment or description. (default: None)
Expand All @@ -551,6 +563,12 @@ def __init__(self, state_id, **kwargs):
"""
super(Map, self).__init__(state_id, 'Map', **kwargs)

if retry:
self.add_retry(retry)

if catch:
self.add_catch(catch)

def attach_iterator(self, iterator):
"""
Attach `State` or `Chain` as iterator to the Map state, that will execute for each of the items in `items_path`. If an iterator was attached previously with the Map state, it will be replaced.
Expand Down Expand Up @@ -586,10 +604,12 @@ class Task(State):
Task State causes the interpreter to execute the work identified by the state’s `resource` field.
"""

def __init__(self, state_id, **kwargs):
def __init__(self, state_id, retry=None, catch=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details.
catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-fallback-states>`_ for more details.
resource (str): A URI that uniquely identifies the specific task to execute. The States language does not constrain the URI scheme nor any other part of the URI.
timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60)
timeout_seconds_path (str, optional): Path specifying the state's timeout value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer.
Expand All @@ -608,6 +628,12 @@ def __init__(self, state_id, **kwargs):
if self.heartbeat_seconds is not None and self.heartbeat_seconds_path is not None:
raise ValueError("Only one of 'heartbeat_seconds' or 'heartbeat_seconds_path' can be provided.")

if retry:
self.add_retry(retry)

if catch:
self.add_catch(catch)

def allowed_fields(self):
return [
Field.Comment,
Expand Down
131 changes: 79 additions & 52 deletions tests/integ/test_state_machine_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,18 +422,38 @@ def test_task_state_machine_creation(sfn_client, sfn_role_arn, training_job_para

def test_catch_state_machine_creation(sfn_client, sfn_role_arn, training_job_parameters):
catch_state_name = "TaskWithCatchState"
custom_error = "CustomError"
task_failed_error = "States.TaskFailed"
all_fail_error = "States.ALL"
custom_error_state_name = "Custom Error End"
task_failed_state_name = "Task Failed End"
all_error_state_name = "Catch All End"
timeout_error = "States.Timeout"
task_failed_state_name = "Catch Task Failed End"
timeout_state_name = "Catch Timeout End"
catch_state_result = "Catch Result"
task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync"

# change the parameters to cause task state to fail
# Provide invalid TrainingImage to cause States.TaskFailed error
training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image"

task = steps.Task(
catch_state_name,
parameters=training_job_parameters,
resource=task_resource,
catch=steps.Catch(
error_equals=[timeout_error],
next_step=steps.Pass(timeout_state_name, result=catch_state_result)
)
)
task.add_catch(
steps.Catch(
error_equals=[task_failed_error],
next_step=steps.Pass(task_failed_state_name, result=catch_state_result)
)
)

workflow = Workflow(
unique_name_from_base('Test_Catch_Workflow'),
definition=task,
role=sfn_role_arn
)

asl_state_machine_definition = {
"StartAt": catch_state_name,
"States": {
Expand All @@ -445,80 +465,61 @@ def test_catch_state_machine_creation(sfn_client, sfn_role_arn, training_job_par
"Catch": [
{
"ErrorEquals": [
all_fail_error
timeout_error
],
"Next": all_error_state_name
"Next": timeout_state_name
},
{
"ErrorEquals": [
task_failed_error
],
"Next": task_failed_state_name
}
]
},
all_error_state_name: {
task_failed_state_name: {
"Type": "Pass",
"Result": catch_state_result,
"End": True
}
},
timeout_state_name: {
"Type": "Pass",
"Result": catch_state_result,
"End": True
},
}
}
task = steps.Task(
catch_state_name,
parameters=training_job_parameters,
resource=task_resource
)
task.add_catch(
steps.Catch(
error_equals=[all_fail_error],
next_step=steps.Pass(all_error_state_name, result=catch_state_result)
)
)

workflow = Workflow(
unique_name_from_base('Test_Catch_Workflow'),
definition=task,
role=sfn_role_arn
)

workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, catch_state_result)


def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_parameters):
retry_state_name = "RetryStateName"
all_fail_error = "Starts.ALL"
task_failed_error = "States.TaskFailed"
timeout_error = "States.Timeout"
interval_seconds = 1
max_attempts = 2
backoff_rate = 2
task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync"

# change the parameters to cause task state to fail
# Provide invalid TrainingImage to cause States.TaskFailed error
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you, it's a lot less mysterious what we're trying to do now

training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image"

asl_state_machine_definition = {
"StartAt": retry_state_name,
"States": {
retry_state_name: {
"Resource": task_resource,
"Parameters": training_job_parameters,
"Type": "Task",
"End": True,
"Retry": [
{
"ErrorEquals": [all_fail_error],
"IntervalSeconds": interval_seconds,
"MaxAttempts": max_attempts,
"BackoffRate": backoff_rate
}
]
}
}
}

task = steps.Task(
retry_state_name,
parameters=training_job_parameters,
resource=task_resource
resource=task_resource,
retry=steps.Retry(
error_equals=[timeout_error],
interval_seconds=interval_seconds,
max_attempts=max_attempts,
backoff_rate=backoff_rate
)
)

task.add_retry(
steps.Retry(
error_equals=[all_fail_error],
error_equals=[task_failed_error],
interval_seconds=interval_seconds,
max_attempts=max_attempts,
backoff_rate=backoff_rate
Expand All @@ -531,4 +532,30 @@ def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_par
role=sfn_role_arn
)

workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, None)
asl_state_machine_definition = {
"StartAt": retry_state_name,
"States": {
retry_state_name: {
"Resource": task_resource,
"Parameters": training_job_parameters,
"Type": "Task",
"End": True,
"Retry": [
{
"ErrorEquals": [timeout_error],
"IntervalSeconds": interval_seconds,
"MaxAttempts": max_attempts,
"BackoffRate": backoff_rate
},
{
"ErrorEquals": [task_failed_error],
"IntervalSeconds": interval_seconds,
"MaxAttempts": max_attempts,
"BackoffRate": backoff_rate
}
]
}
}
}

workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, None)
Loading