Skip to content

Commit

Permalink
feat: Support placeholders for processing step (#155)
Browse files Browse the repository at this point in the history
* Support placeholders for processor parameters in processing step

* Merge sagemaker generated parameters with placeholder compatible parameters received in args

Co-authored-by: Shiv Lakshminarayan <shivlaks@amazon.com>
Co-authored-by: Adam Wong <55506708+wong-a@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 20, 2021
1 parent 6b62bf7 commit 01e18c3
Show file tree
Hide file tree
Showing 6 changed files with 321 additions and 21 deletions.
1 change: 0 additions & 1 deletion src/stepfunctions/steps/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class Field(Enum):
HeartbeatSeconds = 'heartbeat_seconds'
HeartbeatSecondsPath = 'heartbeat_seconds_path'


# Retry and catch fields
ErrorEquals = 'error_equals'
IntervalSeconds = 'interval_seconds'
Expand Down
45 changes: 27 additions & 18 deletions src/stepfunctions/steps/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from stepfunctions.inputs import Placeholder
from stepfunctions.steps.states import Task
from stepfunctions.steps.fields import Field
from stepfunctions.steps.utils import tags_dict_to_kv_list
from stepfunctions.steps.utils import merge_dicts, tags_dict_to_kv_list
from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn

from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config
Expand All @@ -30,6 +30,7 @@

SAGEMAKER_SERVICE_NAME = "sagemaker"


class SageMakerApi(Enum):
CreateTrainingJob = "createTrainingJob"
CreateTransformJob = "createTransformJob"
Expand Down Expand Up @@ -479,7 +480,9 @@ class ProcessingStep(Task):
Creates a Task State to execute a SageMaker Processing Job.
"""

def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True, tags=None, **kwargs):
def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None,
container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True,
tags=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
Expand All @@ -491,15 +494,18 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
the processing job. These can be specified as either path strings or
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
experiment_config (dict, optional): Specify the experiment config for the processing. (Default: None)
container_arguments ([str]): The arguments for a container used to run a processing job.
container_entrypoint ([str]): The entrypoint for a container used to run a processing job.
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker
experiment_config (dict or Placeholder, optional): Specify the experiment config for the processing. (Default: None)
container_arguments ([str] or Placeholder): The arguments for a container used to run a processing job.
container_entrypoint ([str] or Placeholder): The entrypoint for a container used to run a processing job.
kms_key_id (str or Placeholder): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker
uses to encrypt the processing job output. KmsKeyId can be an ID of a KMS key,
ARN of a KMS key, alias of a KMS key, or alias of a KMS key.
The KmsKeyId is applied to all outputs.
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True)
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateProcessingJob<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateProcessingJob.html>`_.
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
"""
if wait_for_completion:
"""
Expand All @@ -518,22 +524,25 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
SageMakerApi.CreateProcessingJob)

if isinstance(job_name, str):
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name)
processing_parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name)
else:
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id)
processing_parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id)

if isinstance(job_name, Placeholder):
parameters['ProcessingJobName'] = job_name
processing_parameters['ProcessingJobName'] = job_name

if experiment_config is not None:
parameters['ExperimentConfig'] = experiment_config
processing_parameters['ExperimentConfig'] = experiment_config

if tags:
parameters['Tags'] = tags_dict_to_kv_list(tags)

if 'S3Operations' in parameters:
del parameters['S3Operations']

kwargs[Field.Parameters.value] = parameters
processing_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)

if 'S3Operations' in processing_parameters:
del processing_parameters['S3Operations']

if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
# Update processing_parameters with input parameters
merge_dicts(processing_parameters, kwargs[Field.Parameters.value])

kwargs[Field.Parameters.value] = processing_parameters
super(ProcessingStep, self).__init__(state_id, **kwargs)
26 changes: 26 additions & 0 deletions src/stepfunctions/steps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import boto3
import logging
from stepfunctions.inputs import Placeholder

logger = logging.getLogger('stepfunctions')

Expand Down Expand Up @@ -45,3 +46,28 @@ def get_aws_partition():
return cur_partition

return cur_partition


def merge_dicts(target, source):
"""
Merges source dictionary into the target dictionary.
Values in the target dict are updated with the values of the source dict.
Args:
target (dict): Base dictionary into which source is merged
source (dict): Dictionary used to update target. If the same key is present in both dictionaries, source's value
will overwrite target's value for the corresponding key
"""
if isinstance(target, dict) and isinstance(source, dict):
for key, value in source.items():
if key in target:
if isinstance(target[key], dict) and isinstance(source[key], dict):
merge_dicts(target[key], source[key])
elif target[key] == value:
pass
else:
logger.info(
f"Property: <{key}> with value: <{target[key]}>"
f" will be overwritten with provided value: <{value}>")
target[key] = source[key]
else:
target[key] = source[key]
96 changes: 96 additions & 0 deletions tests/integ/test_sagemaker_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from sagemaker.tuner import HyperparameterTuner
from sagemaker.processing import ProcessingInput, ProcessingOutput

from stepfunctions.inputs import ExecutionInput
from stepfunctions.steps import Chain
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, TuningStep, ProcessingStep
from stepfunctions.workflow import Workflow
Expand Down Expand Up @@ -352,3 +353,98 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien
# Cleanup
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
# End of Cleanup


def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn,
sagemaker_role_arn):
region = boto3.session.Session().region_name
input_data = f"s3://sagemaker-sample-data-{region}/processing/census/census-income.csv"

input_s3 = sagemaker_session.upload_data(
path=os.path.join(DATA_DIR, 'sklearn_processing'),
bucket=sagemaker_session.default_bucket(),
key_prefix='integ-test-data/sklearn_processing/code'
)

output_s3 = f"s3://{sagemaker_session.default_bucket()}/integ-test-data/sklearn_processing"

inputs = [
ProcessingInput(source=input_data, destination='/opt/ml/processing/input', input_name='input-1'),
ProcessingInput(source=input_s3 + '/preprocessor.py', destination='/opt/ml/processing/input/code',
input_name='code'),
]

outputs = [
ProcessingOutput(source='/opt/ml/processing/train', destination=output_s3 + '/train_data',
output_name='train_data'),
ProcessingOutput(source='/opt/ml/processing/test', destination=output_s3 + '/test_data',
output_name='test_data'),
]

# Build workflow definition
execution_input = ExecutionInput(schema={
'image_uri': str,
'instance_count': int,
'entrypoint': str,
'role': str,
'volume_size_in_gb': int,
'max_runtime_in_seconds': int,
'container_arguments': [str],
})

parameters = {
'AppSpecification': {
'ContainerEntrypoint': execution_input['entrypoint'],
'ImageUri': execution_input['image_uri']
},
'ProcessingResources': {
'ClusterConfig': {
'InstanceCount': execution_input['instance_count'],
'VolumeSizeInGB': execution_input['volume_size_in_gb']
}
},
'RoleArn': execution_input['role'],
'StoppingCondition': {
'MaxRuntimeInSeconds': execution_input['max_runtime_in_seconds']
}
}

job_name = generate_job_name()
processing_step = ProcessingStep('create_processing_job_step',
processor=sklearn_processor_fixture,
job_name=job_name,
inputs=inputs,
outputs=outputs,
container_arguments=execution_input['container_arguments'],
container_entrypoint=execution_input['entrypoint'],
parameters=parameters
)
workflow_graph = Chain([processing_step])

with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
workflow = create_workflow_and_check_definition(
workflow_graph=workflow_graph,
workflow_name=unique_name_from_base("integ-test-processing-step-workflow"),
sfn_client=sfn_client,
sfn_role_arn=sfn_role_arn
)

execution_input = {
'image_uri': '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3',
'instance_count': 1,
'entrypoint': ['python3', '/opt/ml/processing/input/code/preprocessor.py'],
'role': sagemaker_role_arn,
'volume_size_in_gb': 30,
'max_runtime_in_seconds': 500,
'container_arguments': ['--train-test-split-ratio', '0.2']
}

# Execute workflow
execution = workflow.execute(inputs=execution_input)
execution_output = execution.get_output(wait=True)

# Check workflow output
assert execution_output.get("ProcessingJobStatus") == "Completed"

# Cleanup
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
137 changes: 136 additions & 1 deletion tests/unit/test_sagemaker_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@

from unittest.mock import MagicMock, patch
from stepfunctions.inputs import ExecutionInput, StepInput
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, ProcessingStep
from stepfunctions.steps.fields import Field
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep,\
ProcessingStep
from stepfunctions.steps.sagemaker import tuning_config

from tests.unit.utils import mock_boto_api_call
Expand Down Expand Up @@ -962,3 +964,136 @@ def test_processing_step_creation(sklearn_processor):
'Resource': 'arn:aws:states:::sagemaker:createProcessingJob.sync',
'End': True
}


def test_processing_step_creation_with_placeholders(sklearn_processor):
execution_input = ExecutionInput(schema={
'image_uri': str,
'instance_count': int,
'entrypoint': str,
'output_kms_key': str,
'role': str,
'env': str,
'volume_size_in_gb': int,
'volume_kms_key': str,
'max_runtime_in_seconds': int,
'tags': [{str: str}],
'container_arguments': [str]
})

step_input = StepInput(schema={
'instance_type': str
})

parameters = {
'AppSpecification': {
'ContainerEntrypoint': execution_input['entrypoint'],
'ImageUri': execution_input['image_uri']
},
'Environment': execution_input['env'],
'ProcessingOutputConfig': {
'KmsKeyId': execution_input['output_kms_key']
},
'ProcessingResources': {
'ClusterConfig': {
'InstanceCount': execution_input['instance_count'],
'InstanceType': step_input['instance_type'],
'VolumeKmsKeyId': execution_input['volume_kms_key'],
'VolumeSizeInGB': execution_input['volume_size_in_gb']
}
},
'RoleArn': execution_input['role'],
'StoppingCondition': {
'MaxRuntimeInSeconds': execution_input['max_runtime_in_seconds']
},
'Tags': execution_input['tags']
}

inputs = [ProcessingInput(source='dataset.csv', destination='/opt/ml/processing/input')]
outputs = [
ProcessingOutput(source='/opt/ml/processing/output/train'),
ProcessingOutput(source='/opt/ml/processing/output/validation'),
ProcessingOutput(source='/opt/ml/processing/output/test')
]
step = ProcessingStep(
'Feature Transformation',
sklearn_processor,
'MyProcessingJob',
container_entrypoint=execution_input['entrypoint'],
container_arguments=execution_input['container_arguments'],
kms_key_id=execution_input['output_kms_key'],
inputs=inputs,
outputs=outputs,
parameters=parameters
)
assert step.to_dict() == {
'Type': 'Task',
'Parameters': {
'AppSpecification': {
'ContainerArguments.$': "$$.Execution.Input['container_arguments']",
'ContainerEntrypoint.$': "$$.Execution.Input['entrypoint']",
'ImageUri.$': "$$.Execution.Input['image_uri']"
},
'Environment.$': "$$.Execution.Input['env']",
'ProcessingInputs': [
{
'InputName': None,
'AppManaged': False,
'S3Input': {
'LocalPath': '/opt/ml/processing/input',
'S3CompressionType': 'None',
'S3DataDistributionType': 'FullyReplicated',
'S3DataType': 'S3Prefix',
'S3InputMode': 'File',
'S3Uri': 'dataset.csv'
}
}
],
'ProcessingOutputConfig': {
'KmsKeyId.$': "$$.Execution.Input['output_kms_key']",
'Outputs': [
{
'OutputName': None,
'AppManaged': False,
'S3Output': {
'LocalPath': '/opt/ml/processing/output/train',
'S3UploadMode': 'EndOfJob',
'S3Uri': None
}
},
{
'OutputName': None,
'AppManaged': False,
'S3Output': {
'LocalPath': '/opt/ml/processing/output/validation',
'S3UploadMode': 'EndOfJob',
'S3Uri': None
}
},
{
'OutputName': None,
'AppManaged': False,
'S3Output': {
'LocalPath': '/opt/ml/processing/output/test',
'S3UploadMode': 'EndOfJob',
'S3Uri': None
}
}
]
},
'ProcessingResources': {
'ClusterConfig': {
'InstanceCount.$': "$$.Execution.Input['instance_count']",
'InstanceType.$': "$['instance_type']",
'VolumeKmsKeyId.$': "$$.Execution.Input['volume_kms_key']",
'VolumeSizeInGB.$': "$$.Execution.Input['volume_size_in_gb']"
}
},
'ProcessingJobName': 'MyProcessingJob',
'RoleArn.$': "$$.Execution.Input['role']",
'Tags.$': "$$.Execution.Input['tags']",
'StoppingCondition': {'MaxRuntimeInSeconds.$': "$$.Execution.Input['max_runtime_in_seconds']"},
},
'Resource': 'arn:aws:states:::sagemaker:createProcessingJob.sync',
'End': True
}
Loading

0 comments on commit 01e18c3

Please sign in to comment.