Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support placeholders for processing step #155

Merged
merged 23 commits into from
Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
927b24f
documentation: Add setup instructions to run/debug tests locally
ca-nguyen Jul 16, 2021
003b5e8
Merge branch 'main' into update-contributing
shivlaks Aug 9, 2021
a7700a6
Added sub section for debug setup and linked to run tests instructions
ca-nguyen Aug 10, 2021
6b6443a
Update table
ca-nguyen Aug 12, 2021
7f6ef30
Support placeholders for processor parameters in processingstep
ca-nguyen Aug 12, 2021
00830f3
Added doc
ca-nguyen Aug 12, 2021
c708da7
Removed contibuting changes(included in another pr)
ca-nguyen Aug 12, 2021
2ea9e1f
Merge sagemaker generated parameters with placeholder compatible para…
ca-nguyen Aug 17, 2021
17543ed
documentation: Add setup instructions to run/debug tests locally
ca-nguyen Jul 16, 2021
36e2ee8
Added sub section for debug setup and linked to run tests instructions
ca-nguyen Aug 10, 2021
ea40f7c
Update table
ca-nguyen Aug 12, 2021
e499108
Support placeholders for processor parameters in processingstep
ca-nguyen Aug 12, 2021
4c63229
Added doc
ca-nguyen Aug 12, 2021
34bb281
Removed contibuting changes(included in another pr)
ca-nguyen Aug 12, 2021
a098c61
Merge sagemaker generated parameters with placeholder compatible para…
ca-nguyen Aug 17, 2021
06eb069
Merge branch 'support-placeholders-for-processing-step' of https://gi…
ca-nguyen Aug 17, 2021
da99c92
Using == instead of is()
ca-nguyen Aug 17, 2021
37b2422
Removed unused InvalidPathToPlaceholderParameter exception
ca-nguyen Aug 17, 2021
c433576
Merge branch 'main' into support-placeholders-for-processing-step
ca-nguyen Aug 17, 2021
fd640ab
Added doc and renamed args
ca-nguyen Aug 18, 2021
1dfa0e3
Update src/stepfunctions/steps/sagemaker.py parameters description
ca-nguyen Aug 19, 2021
6143783
Removed dict name args to opt for more generic log message when overw…
ca-nguyen Aug 19, 2021
ebc5e22
Using fstring in test
ca-nguyen Aug 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 0 additions & 30 deletions src/stepfunctions/steps/constants.py

This file was deleted.

14 changes: 0 additions & 14 deletions src/stepfunctions/steps/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,3 @@ class Field(Enum):
MaxAttempts = 'max_attempts'
BackoffRate = 'backoff_rate'
NextStep = 'next_step'

# Sagemaker step fields
# Processing Step: Processor
Role = 'role'
ImageUri = 'image_uri'
InstanceCount = 'instance_count'
InstanceType = 'instance_type'
Entrypoint = 'entrypoint'
VolumeSizeInGB = 'volume_size_in_gb'
VolumeKMSKey = 'volume_kms_key'
OutputKMSKey = 'output_kms_key'
MaxRuntimeInSeconds = 'max_runtime_in_seconds'
Env = 'env'
Tags = 'tags'
148 changes: 25 additions & 123 deletions src/stepfunctions/steps/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,24 @@
from __future__ import absolute_import

import logging
import operator

from enum import Enum
from functools import reduce

from stepfunctions.exceptions import InvalidPathToPlaceholderParameter
from stepfunctions.inputs import Placeholder
from stepfunctions.steps.constants import placeholder_paths
from stepfunctions.steps.states import Task
from stepfunctions.steps.fields import Field
from stepfunctions.steps.utils import tags_dict_to_kv_list
from stepfunctions.steps.utils import merge_dicts, tags_dict_to_kv_list
from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn

from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config
from sagemaker.model import Model, FrameworkModel
from sagemaker.model_monitor import DataCaptureConfig
from sagemaker.processing import ProcessingJob

logger = logging.getLogger('stepfunctions.sagemaker')

SAGEMAKER_SERVICE_NAME = "sagemaker"


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice to see us embracing pep8 in files we touch 🙌

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙌🙌🙌

class SageMakerApi(Enum):
CreateTrainingJob = "createTrainingJob"
CreateTransformJob = "createTransformJob"
Expand All @@ -46,104 +42,6 @@ class SageMakerApi(Enum):
CreateProcessingJob = "createProcessingJob"


class SageMakerTask(Task):

"""
Task State causes the interpreter to execute the work identified by the state’s `resource` field.
"""

def __init__(self, state_id, step_type, tags, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
resource (str): A URI that uniquely identifies the specific task to execute. The States language does not constrain the URI scheme nor any other part of the URI.
timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60)
timeout_seconds_path (str, optional): Path specifying the state's timeout value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer.
heartbeat_seconds (int, optional): Positive integer specifying heartbeat timeout for the state in seconds. This value should be lower than the one specified for `timeout_seconds`. If more time than the specified heartbeat elapses between heartbeats from the task, then the interpreter fails the state with a `States.Timeout` Error Name.
heartbeat_seconds_path (str, optional): Path specifying the state's heartbeat value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer.
comment (str, optional): Human-readable comment or description. (default: None)
input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$')
parameters (dict, optional): The value of this field becomes the effective input for the state.
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
self._replace_sagemaker_placeholders(step_type, kwargs)
if tags:
self.set_tags_config(tags, kwargs[Field.Parameters.value], step_type)

super(SageMakerTask, self).__init__(state_id, **kwargs)


def allowed_fields(self):
sagemaker_fields = [
# ProcessingStep: Processor
Field.Role,
Field.ImageUri,
Field.InstanceCount,
Field.InstanceType,
Field.Entrypoint,
Field.VolumeSizeInGB,
Field.VolumeKMSKey,
Field.OutputKMSKey,
Field.MaxRuntimeInSeconds,
Field.Env,
Field.Tags,
]

return super(SageMakerTask, self).allowed_fields() + sagemaker_fields


def _replace_sagemaker_placeholders(self, step_type, args):
# Fetch path from type
sagemaker_parameters = args[Field.Parameters.value]
paths = placeholder_paths.get(step_type)
treated_args = []

for arg_name, value in args.items():
if arg_name in [Field.Parameters.value]:
continue
if arg_name in paths.keys():
path = paths.get(arg_name)
if self._set_placeholder(sagemaker_parameters, path, value, arg_name):
treated_args.append(arg_name)

SageMakerTask.remove_treated_args(treated_args, args)

@staticmethod
def get_value_from_path(parameters, path):
value_from_path = reduce(operator.getitem, path, parameters)
return value_from_path
# return reduce(operator.getitem, path, parameters)

@staticmethod
def _set_placeholder(parameters, path, value, arg_name):
is_set = False
try:
SageMakerTask.get_value_from_path(parameters, path[:-1])[path[-1]] = value
is_set = True
except KeyError as e:
message = f"Invalid path {path} for {arg_name}: {e}"
raise InvalidPathToPlaceholderParameter(message)
return is_set

@staticmethod
def remove_treated_args(treated_args, args):
for treated_arg in treated_args:
try:
del args[treated_arg]
except KeyError as e:
pass

def set_tags_config(self, tags, parameters, step_type):
if isinstance(tags, Placeholder):
# Replace with placeholder
path = placeholder_paths.get(step_type).get(Field.Tags.value)
if path:
self._set_placeholder(parameters, path, tags, Field.Tags.value)
else:
parameters['Tags'] = tags_dict_to_kv_list(tags)


class TrainingStep(Task):

"""
Expand Down Expand Up @@ -576,7 +474,7 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
super(TuningStep, self).__init__(state_id, **kwargs)


class ProcessingStep(SageMakerTask):
class ProcessingStep(Task):

"""
Creates a Task State to execute a SageMaker Processing Job.
Expand All @@ -588,7 +486,7 @@ class ProcessingStep(SageMakerTask):

def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None,
container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True,
tags=None, max_runtime_in_seconds=None, **kwargs):
tags=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
Expand All @@ -600,16 +498,16 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
the processing job. These can be specified as either path strings or
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
experiment_config (dict, optional): Specify the experiment config for the processing. (Default: None)
container_arguments ([str]): The arguments for a container used to run a processing job.
container_entrypoint ([str]): The entrypoint for a container used to run a processing job.
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker
experiment_config (dict or Placeholder, optional): Specify the experiment config for the processing. (Default: None)
container_arguments ([str] or Placeholder): The arguments for a container used to run a processing job.
container_entrypoint ([str] or Placeholder): The entrypoint for a container used to run a processing job.
kms_key_id (str or Placeholder): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker
uses to encrypt the processing job output. KmsKeyId can be an ID of a KMS key,
ARN of a KMS key, alias of a KMS key, or alias of a KMS key.
The KmsKeyId is applied to all outputs.
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True)
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
max_runtime_in_seconds (int or Placeholder): Specifies the maximum runtime in seconds for the processing job
parameters(dict, optional): The value of this field becomes the effective input for the state.
"""
if wait_for_completion:
"""
Expand All @@ -628,22 +526,26 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
SageMakerApi.CreateProcessingJob)

if isinstance(job_name, str):
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name)
processing_parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name)
else:
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id)
processing_parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id)

if isinstance(job_name, Placeholder):
parameters['ProcessingJobName'] = job_name
processing_parameters['ProcessingJobName'] = job_name

if experiment_config is not None:
parameters['ExperimentConfig'] = experiment_config

if 'S3Operations' in parameters:
del parameters['S3Operations']
processing_parameters['ExperimentConfig'] = experiment_config

if max_runtime_in_seconds:
parameters['StoppingCondition'] = ProcessingJob.prepare_stopping_condition(max_runtime_in_seconds)

kwargs[Field.Parameters.value] = parameters
if tags:
processing_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)

if 'S3Operations' in processing_parameters:
del processing_parameters['S3Operations']

if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
# Update processing_parameters with input parameters
merge_dicts(processing_parameters, kwargs[Field.Parameters.value], "Processing Parameters",
"Input Parameters")

super(ProcessingStep, self).__init__(state_id, __class__.__name__, tags, **kwargs)
kwargs[Field.Parameters.value] = processing_parameters
super(ProcessingStep, self).__init__(state_id, **kwargs)
22 changes: 22 additions & 0 deletions src/stepfunctions/steps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import boto3
import logging
from stepfunctions.inputs import Placeholder

logger = logging.getLogger('stepfunctions')

Expand Down Expand Up @@ -45,3 +46,24 @@ def get_aws_partition():
return cur_partition

return cur_partition


def merge_dicts(first, second, first_name, second_name):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could also be used to merge the hyperparameters in TrainingStep - will make the changes in another PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: First and second don't describe the side effects and which dict gets merged into what. Borrowing from JavaScript's Object.assign:

Suggested change
def merge_dicts(first, second, first_name, second_name):
def merge_dicts(target, source):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 - I also like to push for doc strings where behaviour is not entirely intuitive. i.e. what happens if there are clashes, are overwrites allowed, etc.

"""
Merges first and second dictionaries into the first one.
Values in the first dict are updated with the values of the second one.
"""
if all(isinstance(d, dict) for d in [first, second]):
for key, value in second.items():
if key in first:
if isinstance(first[key], dict) and isinstance(second[key], dict):
merge_dicts(first[key], second[key], first_name, second_name)
elif first[key] is value:
pass
else:
logger.info(
f"{first_name} property: <{key}> with value: <{first[key]}>"
f" will be overwritten with value provided in {second_name} : <{value}>")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: Do we think this is useful? If not, can just use Python's built-in dict.update

Copy link
Contributor Author

@ca-nguyen ca-nguyen Aug 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The built in update() does not take into account nested dictionary values - for ex:

d1 = {'a': {'aa': 1, 'bb': 2, 'c': 3}}
d2 = {'a': {'bb': 1}}

d1.update(d2)
print(d1)

Will have following output: {'a': {'bb': 1}}

Since we would expect to get {'a': {'aa': 1, 'bb': 1, 'c': 3}}, we can't use the update() function in our case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially added them to facilitate troubleshooting, but I'm open to remove the logs if we deem them not useful enough or too noisy

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the expected behaviour is well documented it seems unnecessary. Since the method only exists for logging, if we get rid of it there's less code to maintain. What do you think, @shivlaks?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The built in update() does not take into account nested dictionary values

Missed this comment. Since we need a deep merge, dict.update is not going to work here

first[key] = second[key]
else:
first[key] = second[key]
54 changes: 34 additions & 20 deletions tests/integ/test_sagemaker_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

from stepfunctions.inputs import ExecutionInput
from stepfunctions.steps import Chain
from stepfunctions.steps.fields import Field
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, TuningStep, ProcessingStep
from stepfunctions.workflow import Workflow

Expand Down Expand Up @@ -384,27 +383,41 @@ def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_

# Build workflow definition
execution_input = ExecutionInput(schema={
Field.ImageUri.value: str,
Field.InstanceCount.value: int,
Field.Entrypoint.value: str,
Field.Role.value: str,
Field.VolumeSizeInGB.value: int,
Field.MaxRuntimeInSeconds.value: int
'image_uri': str,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're only using these values for test purposes, using the direct string values for better code readability

'instance_count': int,
'entrypoint': str,
'role': str,
'volume_size_in_gb': int,
'max_runtime_in_seconds': int,
'container_arguments': [str],
})

parameters = {
'AppSpecification': {
'ContainerEntrypoint': execution_input['entrypoint'],
'ImageUri': execution_input['image_uri']
},
'ProcessingResources': {
'ClusterConfig': {
'InstanceCount': execution_input['instance_count'],
'VolumeSizeInGB': execution_input['volume_size_in_gb']
}
},
'RoleArn': execution_input['role'],
'StoppingCondition': {
'MaxRuntimeInSeconds': execution_input['max_runtime_in_seconds']
}
}

job_name = generate_job_name()
processing_step = ProcessingStep('create_processing_job_step',
processor=sklearn_processor_fixture,
job_name=job_name,
inputs=inputs,
outputs=outputs,
container_arguments=['--train-test-split-ratio', '0.2'],
container_entrypoint=execution_input[Field.Entrypoint.value],
image_uri=execution_input[Field.ImageUri.value],
instance_count=execution_input[Field.InstanceCount.value],
role=execution_input[Field.Role.value],
volume_size_in_gb=execution_input[Field.VolumeSizeInGB.value],
max_runtime_in_seconds=execution_input[Field.MaxRuntimeInSeconds.value]
container_arguments=execution_input['container_arguments'],
container_entrypoint=execution_input['entrypoint'],
parameters=parameters
)
workflow_graph = Chain([processing_step])

Expand All @@ -418,12 +431,13 @@ def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_
)

execution_input = {
Field.ImageUri.value: '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3',
Field.InstanceCount.value: 1,
Field.Entrypoint.value: ['python3', '/opt/ml/processing/input/code/preprocessor.py'],
Field.Role.value: sagemaker_role_arn,
Field.VolumeSizeInGB.value: 30,
Field.MaxRuntimeInSeconds.value: 500
'image_uri': '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3',
'instance_count': 1,
'entrypoint': ['python3', '/opt/ml/processing/input/code/preprocessor.py'],
'role': sagemaker_role_arn,
'volume_size_in_gb': 30,
'max_runtime_in_seconds': 500,
'container_arguments': ['--train-test-split-ratio', '0.2']
}

# Execute workflow
Expand Down
Loading