Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support placeholders for processing step #155

Merged
merged 23 commits into from
Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
927b24f
documentation: Add setup instructions to run/debug tests locally
ca-nguyen Jul 16, 2021
003b5e8
Merge branch 'main' into update-contributing
shivlaks Aug 9, 2021
a7700a6
Added sub section for debug setup and linked to run tests instructions
ca-nguyen Aug 10, 2021
6b6443a
Update table
ca-nguyen Aug 12, 2021
7f6ef30
Support placeholders for processor parameters in processingstep
ca-nguyen Aug 12, 2021
00830f3
Added doc
ca-nguyen Aug 12, 2021
c708da7
Removed contibuting changes(included in another pr)
ca-nguyen Aug 12, 2021
2ea9e1f
Merge sagemaker generated parameters with placeholder compatible para…
ca-nguyen Aug 17, 2021
17543ed
documentation: Add setup instructions to run/debug tests locally
ca-nguyen Jul 16, 2021
36e2ee8
Added sub section for debug setup and linked to run tests instructions
ca-nguyen Aug 10, 2021
ea40f7c
Update table
ca-nguyen Aug 12, 2021
e499108
Support placeholders for processor parameters in processingstep
ca-nguyen Aug 12, 2021
4c63229
Added doc
ca-nguyen Aug 12, 2021
34bb281
Removed contibuting changes(included in another pr)
ca-nguyen Aug 12, 2021
a098c61
Merge sagemaker generated parameters with placeholder compatible para…
ca-nguyen Aug 17, 2021
06eb069
Merge branch 'support-placeholders-for-processing-step' of https://gi…
ca-nguyen Aug 17, 2021
da99c92
Using == instead of is()
ca-nguyen Aug 17, 2021
37b2422
Removed unused InvalidPathToPlaceholderParameter exception
ca-nguyen Aug 17, 2021
c433576
Merge branch 'main' into support-placeholders-for-processing-step
ca-nguyen Aug 17, 2021
fd640ab
Added doc and renamed args
ca-nguyen Aug 18, 2021
1dfa0e3
Update src/stepfunctions/steps/sagemaker.py parameters description
ca-nguyen Aug 19, 2021
6143783
Removed dict name args to opt for more generic log message when overw…
ca-nguyen Aug 19, 2021
ebc5e22
Using fstring in test
ca-nguyen Aug 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/stepfunctions/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ class MissingRequiredParameter(Exception):


class DuplicateStatesInChain(Exception):
pass
pass

1 change: 0 additions & 1 deletion src/stepfunctions/steps/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class Field(Enum):
HeartbeatSeconds = 'heartbeat_seconds'
HeartbeatSecondsPath = 'heartbeat_seconds_path'


# Retry and catch fields
ErrorEquals = 'error_equals'
IntervalSeconds = 'interval_seconds'
Expand Down
47 changes: 29 additions & 18 deletions src/stepfunctions/steps/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from stepfunctions.inputs import Placeholder
from stepfunctions.steps.states import Task
from stepfunctions.steps.fields import Field
from stepfunctions.steps.utils import tags_dict_to_kv_list
from stepfunctions.steps.utils import merge_dicts, tags_dict_to_kv_list
from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn

from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config
Expand All @@ -30,6 +30,7 @@

SAGEMAKER_SERVICE_NAME = "sagemaker"


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice to see us embracing pep8 in files we touch 🙌

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙌🙌🙌

class SageMakerApi(Enum):
CreateTrainingJob = "createTrainingJob"
CreateTransformJob = "createTransformJob"
Expand Down Expand Up @@ -479,7 +480,9 @@ class ProcessingStep(Task):
Creates a Task State to execute a SageMaker Processing Job.
"""

def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True, tags=None, **kwargs):
def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None,
container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True,
tags=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
Expand All @@ -491,15 +494,19 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
the processing job. These can be specified as either path strings or
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
experiment_config (dict, optional): Specify the experiment config for the processing. (Default: None)
container_arguments ([str]): The arguments for a container used to run a processing job.
container_entrypoint ([str]): The entrypoint for a container used to run a processing job.
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker
experiment_config (dict or Placeholder, optional): Specify the experiment config for the processing. (Default: None)
container_arguments ([str] or Placeholder): The arguments for a container used to run a processing job.
container_entrypoint ([str] or Placeholder): The entrypoint for a container used to run a processing job.
kms_key_id (str or Placeholder): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker
uses to encrypt the processing job output. KmsKeyId can be an ID of a KMS key,
ARN of a KMS key, alias of a KMS key, or alias of a KMS key.
The KmsKeyId is applied to all outputs.
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True)
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
parameters(dict, optional): The value of this field becomes the request for the
`CreateProcessingJob<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateProcessingJob.html>`_ created by the processing step.
All parameters fields are compatible with `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
Any value defined in the parameters argument will overwrite the ones defined in the other arguments, including properties that were previously defined in the processor.
"""
if wait_for_completion:
"""
Expand All @@ -518,22 +525,26 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
SageMakerApi.CreateProcessingJob)

if isinstance(job_name, str):
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name)
processing_parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name)
else:
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id)
processing_parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id)

if isinstance(job_name, Placeholder):
parameters['ProcessingJobName'] = job_name
processing_parameters['ProcessingJobName'] = job_name

if experiment_config is not None:
parameters['ExperimentConfig'] = experiment_config
processing_parameters['ExperimentConfig'] = experiment_config

if tags:
parameters['Tags'] = tags_dict_to_kv_list(tags)

if 'S3Operations' in parameters:
del parameters['S3Operations']

kwargs[Field.Parameters.value] = parameters
processing_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)

if 'S3Operations' in processing_parameters:
del processing_parameters['S3Operations']

if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
# Update processing_parameters with input parameters
merge_dicts(processing_parameters, kwargs[Field.Parameters.value], "Processing Parameters",
"Input Parameters")

kwargs[Field.Parameters.value] = processing_parameters
super(ProcessingStep, self).__init__(state_id, **kwargs)
28 changes: 28 additions & 0 deletions src/stepfunctions/steps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import boto3
import logging
from stepfunctions.inputs import Placeholder

logger = logging.getLogger('stepfunctions')

Expand Down Expand Up @@ -45,3 +46,30 @@ def get_aws_partition():
return cur_partition

return cur_partition


def merge_dicts(target, source, target_name, source_name):
"""
Merges source dictionary into the target dictionary.
Values in the target dict are updated with the values of the source dict.
Args:
target (dict): Base dictionary into which source is merged
source (dict): Dictionary used to update target. If the same key is present in both dictionaries, source's value
will overwrite target's value for the corresponding key
target_name (str): Name of target dictionary used for logging purposes
source_name (str): Name of source dictionary used for logging purposes
"""
if isinstance(target, dict) and isinstance(source, dict):
for key, value in source.items():
if key in target:
if isinstance(target[key], dict) and isinstance(source[key], dict):
merge_dicts(target[key], source[key], target_name, source_name)
elif target[key] == value:
pass
else:
logger.info(
f"{target_name} property: <{key}> with value: <{target[key]}>"
f" will be overwritten with value provided in {source_name} : <{value}>")
target[key] = source[key]
else:
target[key] = source[key]
98 changes: 98 additions & 0 deletions tests/integ/test_sagemaker_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from sagemaker.tuner import HyperparameterTuner
from sagemaker.processing import ProcessingInput, ProcessingOutput

from stepfunctions.inputs import ExecutionInput
from stepfunctions.steps import Chain
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, TuningStep, ProcessingStep
from stepfunctions.workflow import Workflow
Expand Down Expand Up @@ -352,3 +353,100 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien
# Cleanup
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
# End of Cleanup


def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn,
sagemaker_role_arn):
region = boto3.session.Session().region_name
input_data = 's3://sagemaker-sample-data-{}/processing/census/census-income.csv'.format(region)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why not use f strings here too instead of format?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that using fstring is more readable and efficient. format was used for all other tests so i kept it for consistency.
Will change it for this added test and perhaps we can make the change for the rest of the file in a separate PR


input_s3 = sagemaker_session.upload_data(
path=os.path.join(DATA_DIR, 'sklearn_processing'),
bucket=sagemaker_session.default_bucket(),
key_prefix='integ-test-data/sklearn_processing/code'
)

output_s3 = 's3://' + sagemaker_session.default_bucket() + '/integ-test-data/sklearn_processing'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why not use f strings here instead of concatenation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed - using fstringwould be more readable and efficient.

Same comment: format was used for all other tests so i kept it for consistency.
Will change it for this added test and perhaps we can make the change for the rest of the file in a separate PR


inputs = [
ProcessingInput(source=input_data, destination='/opt/ml/processing/input', input_name='input-1'),
ProcessingInput(source=input_s3 + '/preprocessor.py', destination='/opt/ml/processing/input/code',
input_name='code'),
]

outputs = [
ProcessingOutput(source='/opt/ml/processing/train', destination=output_s3 + '/train_data',
output_name='train_data'),
ProcessingOutput(source='/opt/ml/processing/test', destination=output_s3 + '/test_data',
output_name='test_data'),
]

# Build workflow definition
execution_input = ExecutionInput(schema={
'image_uri': str,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're only using these values for test purposes, using the direct string values for better code readability

'instance_count': int,
'entrypoint': str,
'role': str,
'volume_size_in_gb': int,
'max_runtime_in_seconds': int,
'container_arguments': [str],
})

parameters = {
'AppSpecification': {
'ContainerEntrypoint': execution_input['entrypoint'],
'ImageUri': execution_input['image_uri']
},
'ProcessingResources': {
'ClusterConfig': {
'InstanceCount': execution_input['instance_count'],
'VolumeSizeInGB': execution_input['volume_size_in_gb']
}
},
'RoleArn': execution_input['role'],
'StoppingCondition': {
'MaxRuntimeInSeconds': execution_input['max_runtime_in_seconds']
}
}

job_name = generate_job_name()
processing_step = ProcessingStep('create_processing_job_step',
processor=sklearn_processor_fixture,
job_name=job_name,
inputs=inputs,
outputs=outputs,
container_arguments=execution_input['container_arguments'],
container_entrypoint=execution_input['entrypoint'],
parameters=parameters
)
workflow_graph = Chain([processing_step])

with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
# Create workflow and check definition
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: unnecessary comment as the method name expresses this in snake case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed- will be removed with the next commit

workflow = create_workflow_and_check_definition(
workflow_graph=workflow_graph,
workflow_name=unique_name_from_base("integ-test-processing-step-workflow"),
sfn_client=sfn_client,
sfn_role_arn=sfn_role_arn
)

execution_input = {
'image_uri': '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3',
'instance_count': 1,
'entrypoint': ['python3', '/opt/ml/processing/input/code/preprocessor.py'],
'role': sagemaker_role_arn,
'volume_size_in_gb': 30,
'max_runtime_in_seconds': 500,
'container_arguments': ['--train-test-split-ratio', '0.2']
}

# Execute workflow
execution = workflow.execute(inputs=execution_input)
execution_output = execution.get_output(wait=True)

# Check workflow output
assert execution_output.get("ProcessingJobStatus") == "Completed"

# Cleanup
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
# End of Cleanup
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: i think the code is self explanatory. we can drop this comment 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right! i'll remove the comments :)
They are included in all the other tests - will do a cleanup for the other tests in another PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you forget to remove this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - will remove it in the next commit!

Loading