Skip to content

Commit

Permalink
fix: Retrier and Catcher passed to constructor for Task, Parallel and…
Browse files Browse the repository at this point in the history
… Map states are not added to the state's Retriers and Catchers (#169)
  • Loading branch information
ca-nguyen authored Oct 7, 2021
1 parent 58bed5d commit 7f223e8
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 64 deletions.
48 changes: 37 additions & 11 deletions src/stepfunctions/steps/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,27 +254,29 @@ def accept(self, visitor):

def add_retry(self, retry):
"""
Add a Retry block to the tail end of the list of retriers for the state.
Add a retrier or a list of retriers to the tail end of the list of retriers for the state.
See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details.
Args:
retry (Retry): Retry block to add.
retry (Retry or list(Retry)): A retrier or list of retriers to add.
"""
if Field.Retry in self.allowed_fields():
self.retries.append(retry)
self.retries.extend(retry) if isinstance(retry, list) else self.retries.append(retry)
else:
raise ValueError("{state_type} state does not support retry field. ".format(state_type=type(self).__name__))
raise ValueError(f"{type(self).__name__} state does not support retry field. ")

def add_catch(self, catch):
"""
Add a Catch block to the tail end of the list of catchers for the state.
Add a catcher or a list of catchers to the tail end of the list of catchers for the state.
See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-fallback-states>`_ for more details.
Args:
catch (Catch): Catch block to add.
catch (Catch or list(Catch): catcher or list of catchers to add.
"""
if Field.Catch in self.allowed_fields():
self.catches.append(catch)
self.catches.extend(catch) if isinstance(catch, list) else self.catches.append(catch)
else:
raise ValueError("{state_type} state does not support catch field. ".format(state_type=type(self).__name__))
raise ValueError(f"{type(self).__name__} state does not support catch field. ")

def to_dict(self):
result = super(State, self).to_dict()
Expand Down Expand Up @@ -487,10 +489,12 @@ class Parallel(State):
A Parallel state causes the interpreter to execute each branch as concurrently as possible, and wait until each branch terminates (reaches a terminal state) before processing the next state in the Chain.
"""

def __init__(self, state_id, **kwargs):
def __init__(self, state_id, retry=None, catch=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details.
catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-fallback-states>`_ for more details.
comment (str, optional): Human-readable comment or description. (default: None)
input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$')
parameters (dict, optional): The value of this field becomes the effective input for the state.
Expand All @@ -500,6 +504,12 @@ def __init__(self, state_id, **kwargs):
super(Parallel, self).__init__(state_id, 'Parallel', **kwargs)
self.branches = []

if retry:
self.add_retry(retry)

if catch:
self.add_catch(catch)

def allowed_fields(self):
return [
Field.Comment,
Expand Down Expand Up @@ -536,11 +546,13 @@ class Map(State):
A Map state can accept an input with a list of items, execute a state or chain for each item in the list, and return a list, with all corresponding results of each execution, as its output.
"""

def __init__(self, state_id, **kwargs):
def __init__(self, state_id, retry=None, catch=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
iterator (State or Chain): State or chain to execute for each of the items in `items_path`.
retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details.
catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-fallback-states>`_ for more details.
items_path (str, optional): Path in the input for items to iterate over. (default: '$')
max_concurrency (int, optional): Maximum number of iterations to have running at any given point in time. (default: 0)
comment (str, optional): Human-readable comment or description. (default: None)
Expand All @@ -551,6 +563,12 @@ def __init__(self, state_id, **kwargs):
"""
super(Map, self).__init__(state_id, 'Map', **kwargs)

if retry:
self.add_retry(retry)

if catch:
self.add_catch(catch)

def attach_iterator(self, iterator):
"""
Attach `State` or `Chain` as iterator to the Map state, that will execute for each of the items in `items_path`. If an iterator was attached previously with the Map state, it will be replaced.
Expand Down Expand Up @@ -586,10 +604,12 @@ class Task(State):
Task State causes the interpreter to execute the work identified by the state’s `resource` field.
"""

def __init__(self, state_id, **kwargs):
def __init__(self, state_id, retry=None, catch=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details.
catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-fallback-states>`_ for more details.
resource (str): A URI that uniquely identifies the specific task to execute. The States language does not constrain the URI scheme nor any other part of the URI.
timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60)
timeout_seconds_path (str, optional): Path specifying the state's timeout value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer.
Expand All @@ -608,6 +628,12 @@ def __init__(self, state_id, **kwargs):
if self.heartbeat_seconds is not None and self.heartbeat_seconds_path is not None:
raise ValueError("Only one of 'heartbeat_seconds' or 'heartbeat_seconds_path' can be provided.")

if retry:
self.add_retry(retry)

if catch:
self.add_catch(catch)

def allowed_fields(self):
return [
Field.Comment,
Expand Down
131 changes: 79 additions & 52 deletions tests/integ/test_state_machine_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,18 +422,38 @@ def test_task_state_machine_creation(sfn_client, sfn_role_arn, training_job_para

def test_catch_state_machine_creation(sfn_client, sfn_role_arn, training_job_parameters):
catch_state_name = "TaskWithCatchState"
custom_error = "CustomError"
task_failed_error = "States.TaskFailed"
all_fail_error = "States.ALL"
custom_error_state_name = "Custom Error End"
task_failed_state_name = "Task Failed End"
all_error_state_name = "Catch All End"
timeout_error = "States.Timeout"
task_failed_state_name = "Catch Task Failed End"
timeout_state_name = "Catch Timeout End"
catch_state_result = "Catch Result"
task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync"

# change the parameters to cause task state to fail
# Provide invalid TrainingImage to cause States.TaskFailed error
training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image"

task = steps.Task(
catch_state_name,
parameters=training_job_parameters,
resource=task_resource,
catch=steps.Catch(
error_equals=[timeout_error],
next_step=steps.Pass(timeout_state_name, result=catch_state_result)
)
)
task.add_catch(
steps.Catch(
error_equals=[task_failed_error],
next_step=steps.Pass(task_failed_state_name, result=catch_state_result)
)
)

workflow = Workflow(
unique_name_from_base('Test_Catch_Workflow'),
definition=task,
role=sfn_role_arn
)

asl_state_machine_definition = {
"StartAt": catch_state_name,
"States": {
Expand All @@ -445,80 +465,61 @@ def test_catch_state_machine_creation(sfn_client, sfn_role_arn, training_job_par
"Catch": [
{
"ErrorEquals": [
all_fail_error
timeout_error
],
"Next": all_error_state_name
"Next": timeout_state_name
},
{
"ErrorEquals": [
task_failed_error
],
"Next": task_failed_state_name
}
]
},
all_error_state_name: {
task_failed_state_name: {
"Type": "Pass",
"Result": catch_state_result,
"End": True
}
},
timeout_state_name: {
"Type": "Pass",
"Result": catch_state_result,
"End": True
},
}
}
task = steps.Task(
catch_state_name,
parameters=training_job_parameters,
resource=task_resource
)
task.add_catch(
steps.Catch(
error_equals=[all_fail_error],
next_step=steps.Pass(all_error_state_name, result=catch_state_result)
)
)

workflow = Workflow(
unique_name_from_base('Test_Catch_Workflow'),
definition=task,
role=sfn_role_arn
)

workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, catch_state_result)


def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_parameters):
retry_state_name = "RetryStateName"
all_fail_error = "Starts.ALL"
task_failed_error = "States.TaskFailed"
timeout_error = "States.Timeout"
interval_seconds = 1
max_attempts = 2
backoff_rate = 2
task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync"

# change the parameters to cause task state to fail
# Provide invalid TrainingImage to cause States.TaskFailed error
training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image"

asl_state_machine_definition = {
"StartAt": retry_state_name,
"States": {
retry_state_name: {
"Resource": task_resource,
"Parameters": training_job_parameters,
"Type": "Task",
"End": True,
"Retry": [
{
"ErrorEquals": [all_fail_error],
"IntervalSeconds": interval_seconds,
"MaxAttempts": max_attempts,
"BackoffRate": backoff_rate
}
]
}
}
}

task = steps.Task(
retry_state_name,
parameters=training_job_parameters,
resource=task_resource
resource=task_resource,
retry=steps.Retry(
error_equals=[timeout_error],
interval_seconds=interval_seconds,
max_attempts=max_attempts,
backoff_rate=backoff_rate
)
)

task.add_retry(
steps.Retry(
error_equals=[all_fail_error],
error_equals=[task_failed_error],
interval_seconds=interval_seconds,
max_attempts=max_attempts,
backoff_rate=backoff_rate
Expand All @@ -531,4 +532,30 @@ def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_par
role=sfn_role_arn
)

workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, None)
asl_state_machine_definition = {
"StartAt": retry_state_name,
"States": {
retry_state_name: {
"Resource": task_resource,
"Parameters": training_job_parameters,
"Type": "Task",
"End": True,
"Retry": [
{
"ErrorEquals": [timeout_error],
"IntervalSeconds": interval_seconds,
"MaxAttempts": max_attempts,
"BackoffRate": backoff_rate
},
{
"ErrorEquals": [task_failed_error],
"IntervalSeconds": interval_seconds,
"MaxAttempts": max_attempts,
"BackoffRate": backoff_rate
}
]
}
}
}

workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, None)
Loading

0 comments on commit 7f223e8

Please sign in to comment.