diff --git a/src/stepfunctions/steps/fields.py b/src/stepfunctions/steps/fields.py index 24c3949..8eb102d 100644 --- a/src/stepfunctions/steps/fields.py +++ b/src/stepfunctions/steps/fields.py @@ -59,7 +59,6 @@ class Field(Enum): HeartbeatSeconds = 'heartbeat_seconds' HeartbeatSecondsPath = 'heartbeat_seconds_path' - # Retry and catch fields ErrorEquals = 'error_equals' IntervalSeconds = 'interval_seconds' diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index 30e3d7c..9530478 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -19,7 +19,7 @@ from stepfunctions.inputs import Placeholder from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field -from stepfunctions.steps.utils import tags_dict_to_kv_list +from stepfunctions.steps.utils import merge_dicts, tags_dict_to_kv_list from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config @@ -30,6 +30,7 @@ SAGEMAKER_SERVICE_NAME = "sagemaker" + class SageMakerApi(Enum): CreateTrainingJob = "createTrainingJob" CreateTransformJob = "createTransformJob" @@ -479,7 +480,9 @@ class ProcessingStep(Task): Creates a Task State to execute a SageMaker Processing Job. """ - def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True, tags=None, **kwargs): + def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, + container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True, + tags=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. @@ -491,15 +494,18 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for the processing job. These can be specified as either path strings or :class:`~sagemaker.processing.ProcessingOutput` objects (default: None). - experiment_config (dict, optional): Specify the experiment config for the processing. (Default: None) - container_arguments ([str]): The arguments for a container used to run a processing job. - container_entrypoint ([str]): The entrypoint for a container used to run a processing job. - kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker + experiment_config (dict or Placeholder, optional): Specify the experiment config for the processing. (Default: None) + container_arguments ([str] or Placeholder): The arguments for a container used to run a processing job. + container_entrypoint ([str] or Placeholder): The entrypoint for a container used to run a processing job. + kms_key_id (str or Placeholder): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker uses to encrypt the processing job output. KmsKeyId can be an ID of a KMS key, ARN of a KMS key, alias of a KMS key, or alias of a KMS key. The KmsKeyId is applied to all outputs. wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True) - tags (list[dict], optional): `List to tags `_ to associate with the resource. + tags (list[dict] or Placeholder, optional): `List to tags `_ to associate with the resource. + parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateProcessingJob`_. + You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders`_. + """ if wait_for_completion: """ @@ -518,22 +524,25 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp SageMakerApi.CreateProcessingJob) if isinstance(job_name, str): - parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name) + processing_parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name) else: - parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id) + processing_parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id) if isinstance(job_name, Placeholder): - parameters['ProcessingJobName'] = job_name + processing_parameters['ProcessingJobName'] = job_name if experiment_config is not None: - parameters['ExperimentConfig'] = experiment_config - + processing_parameters['ExperimentConfig'] = experiment_config + if tags: - parameters['Tags'] = tags_dict_to_kv_list(tags) - - if 'S3Operations' in parameters: - del parameters['S3Operations'] - - kwargs[Field.Parameters.value] = parameters + processing_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags) + + if 'S3Operations' in processing_parameters: + del processing_parameters['S3Operations'] + + if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict): + # Update processing_parameters with input parameters + merge_dicts(processing_parameters, kwargs[Field.Parameters.value]) + kwargs[Field.Parameters.value] = processing_parameters super(ProcessingStep, self).__init__(state_id, **kwargs) diff --git a/src/stepfunctions/steps/utils.py b/src/stepfunctions/steps/utils.py index 6f44481..5c6861f 100644 --- a/src/stepfunctions/steps/utils.py +++ b/src/stepfunctions/steps/utils.py @@ -14,6 +14,7 @@ import boto3 import logging +from stepfunctions.inputs import Placeholder logger = logging.getLogger('stepfunctions') @@ -45,3 +46,28 @@ def get_aws_partition(): return cur_partition return cur_partition + + +def merge_dicts(target, source): + """ + Merges source dictionary into the target dictionary. + Values in the target dict are updated with the values of the source dict. + Args: + target (dict): Base dictionary into which source is merged + source (dict): Dictionary used to update target. If the same key is present in both dictionaries, source's value + will overwrite target's value for the corresponding key + """ + if isinstance(target, dict) and isinstance(source, dict): + for key, value in source.items(): + if key in target: + if isinstance(target[key], dict) and isinstance(source[key], dict): + merge_dicts(target[key], source[key]) + elif target[key] == value: + pass + else: + logger.info( + f"Property: <{key}> with value: <{target[key]}>" + f" will be overwritten with provided value: <{value}>") + target[key] = source[key] + else: + target[key] = source[key] diff --git a/tests/integ/test_sagemaker_steps.py b/tests/integ/test_sagemaker_steps.py index 63c060a..f840302 100644 --- a/tests/integ/test_sagemaker_steps.py +++ b/tests/integ/test_sagemaker_steps.py @@ -29,6 +29,7 @@ from sagemaker.tuner import HyperparameterTuner from sagemaker.processing import ProcessingInput, ProcessingOutput +from stepfunctions.inputs import ExecutionInput from stepfunctions.steps import Chain from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, TuningStep, ProcessingStep from stepfunctions.workflow import Workflow @@ -352,3 +353,98 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien # Cleanup state_machine_delete_wait(sfn_client, workflow.state_machine_arn) # End of Cleanup + + +def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn, + sagemaker_role_arn): + region = boto3.session.Session().region_name + input_data = f"s3://sagemaker-sample-data-{region}/processing/census/census-income.csv" + + input_s3 = sagemaker_session.upload_data( + path=os.path.join(DATA_DIR, 'sklearn_processing'), + bucket=sagemaker_session.default_bucket(), + key_prefix='integ-test-data/sklearn_processing/code' + ) + + output_s3 = f"s3://{sagemaker_session.default_bucket()}/integ-test-data/sklearn_processing" + + inputs = [ + ProcessingInput(source=input_data, destination='/opt/ml/processing/input', input_name='input-1'), + ProcessingInput(source=input_s3 + '/preprocessor.py', destination='/opt/ml/processing/input/code', + input_name='code'), + ] + + outputs = [ + ProcessingOutput(source='/opt/ml/processing/train', destination=output_s3 + '/train_data', + output_name='train_data'), + ProcessingOutput(source='/opt/ml/processing/test', destination=output_s3 + '/test_data', + output_name='test_data'), + ] + + # Build workflow definition + execution_input = ExecutionInput(schema={ + 'image_uri': str, + 'instance_count': int, + 'entrypoint': str, + 'role': str, + 'volume_size_in_gb': int, + 'max_runtime_in_seconds': int, + 'container_arguments': [str], + }) + + parameters = { + 'AppSpecification': { + 'ContainerEntrypoint': execution_input['entrypoint'], + 'ImageUri': execution_input['image_uri'] + }, + 'ProcessingResources': { + 'ClusterConfig': { + 'InstanceCount': execution_input['instance_count'], + 'VolumeSizeInGB': execution_input['volume_size_in_gb'] + } + }, + 'RoleArn': execution_input['role'], + 'StoppingCondition': { + 'MaxRuntimeInSeconds': execution_input['max_runtime_in_seconds'] + } + } + + job_name = generate_job_name() + processing_step = ProcessingStep('create_processing_job_step', + processor=sklearn_processor_fixture, + job_name=job_name, + inputs=inputs, + outputs=outputs, + container_arguments=execution_input['container_arguments'], + container_entrypoint=execution_input['entrypoint'], + parameters=parameters + ) + workflow_graph = Chain([processing_step]) + + with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): + workflow = create_workflow_and_check_definition( + workflow_graph=workflow_graph, + workflow_name=unique_name_from_base("integ-test-processing-step-workflow"), + sfn_client=sfn_client, + sfn_role_arn=sfn_role_arn + ) + + execution_input = { + 'image_uri': '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3', + 'instance_count': 1, + 'entrypoint': ['python3', '/opt/ml/processing/input/code/preprocessor.py'], + 'role': sagemaker_role_arn, + 'volume_size_in_gb': 30, + 'max_runtime_in_seconds': 500, + 'container_arguments': ['--train-test-split-ratio', '0.2'] + } + + # Execute workflow + execution = workflow.execute(inputs=execution_input) + execution_output = execution.get_output(wait=True) + + # Check workflow output + assert execution_output.get("ProcessingJobStatus") == "Completed" + + # Cleanup + state_machine_delete_wait(sfn_client, workflow.state_machine_arn) diff --git a/tests/unit/test_sagemaker_steps.py b/tests/unit/test_sagemaker_steps.py index c643468..664a498 100644 --- a/tests/unit/test_sagemaker_steps.py +++ b/tests/unit/test_sagemaker_steps.py @@ -27,7 +27,9 @@ from unittest.mock import MagicMock, patch from stepfunctions.inputs import ExecutionInput, StepInput -from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, ProcessingStep +from stepfunctions.steps.fields import Field +from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep,\ + ProcessingStep from stepfunctions.steps.sagemaker import tuning_config from tests.unit.utils import mock_boto_api_call @@ -962,3 +964,136 @@ def test_processing_step_creation(sklearn_processor): 'Resource': 'arn:aws:states:::sagemaker:createProcessingJob.sync', 'End': True } + + +def test_processing_step_creation_with_placeholders(sklearn_processor): + execution_input = ExecutionInput(schema={ + 'image_uri': str, + 'instance_count': int, + 'entrypoint': str, + 'output_kms_key': str, + 'role': str, + 'env': str, + 'volume_size_in_gb': int, + 'volume_kms_key': str, + 'max_runtime_in_seconds': int, + 'tags': [{str: str}], + 'container_arguments': [str] + }) + + step_input = StepInput(schema={ + 'instance_type': str + }) + + parameters = { + 'AppSpecification': { + 'ContainerEntrypoint': execution_input['entrypoint'], + 'ImageUri': execution_input['image_uri'] + }, + 'Environment': execution_input['env'], + 'ProcessingOutputConfig': { + 'KmsKeyId': execution_input['output_kms_key'] + }, + 'ProcessingResources': { + 'ClusterConfig': { + 'InstanceCount': execution_input['instance_count'], + 'InstanceType': step_input['instance_type'], + 'VolumeKmsKeyId': execution_input['volume_kms_key'], + 'VolumeSizeInGB': execution_input['volume_size_in_gb'] + } + }, + 'RoleArn': execution_input['role'], + 'StoppingCondition': { + 'MaxRuntimeInSeconds': execution_input['max_runtime_in_seconds'] + }, + 'Tags': execution_input['tags'] + } + + inputs = [ProcessingInput(source='dataset.csv', destination='/opt/ml/processing/input')] + outputs = [ + ProcessingOutput(source='/opt/ml/processing/output/train'), + ProcessingOutput(source='/opt/ml/processing/output/validation'), + ProcessingOutput(source='/opt/ml/processing/output/test') + ] + step = ProcessingStep( + 'Feature Transformation', + sklearn_processor, + 'MyProcessingJob', + container_entrypoint=execution_input['entrypoint'], + container_arguments=execution_input['container_arguments'], + kms_key_id=execution_input['output_kms_key'], + inputs=inputs, + outputs=outputs, + parameters=parameters + ) + assert step.to_dict() == { + 'Type': 'Task', + 'Parameters': { + 'AppSpecification': { + 'ContainerArguments.$': "$$.Execution.Input['container_arguments']", + 'ContainerEntrypoint.$': "$$.Execution.Input['entrypoint']", + 'ImageUri.$': "$$.Execution.Input['image_uri']" + }, + 'Environment.$': "$$.Execution.Input['env']", + 'ProcessingInputs': [ + { + 'InputName': None, + 'AppManaged': False, + 'S3Input': { + 'LocalPath': '/opt/ml/processing/input', + 'S3CompressionType': 'None', + 'S3DataDistributionType': 'FullyReplicated', + 'S3DataType': 'S3Prefix', + 'S3InputMode': 'File', + 'S3Uri': 'dataset.csv' + } + } + ], + 'ProcessingOutputConfig': { + 'KmsKeyId.$': "$$.Execution.Input['output_kms_key']", + 'Outputs': [ + { + 'OutputName': None, + 'AppManaged': False, + 'S3Output': { + 'LocalPath': '/opt/ml/processing/output/train', + 'S3UploadMode': 'EndOfJob', + 'S3Uri': None + } + }, + { + 'OutputName': None, + 'AppManaged': False, + 'S3Output': { + 'LocalPath': '/opt/ml/processing/output/validation', + 'S3UploadMode': 'EndOfJob', + 'S3Uri': None + } + }, + { + 'OutputName': None, + 'AppManaged': False, + 'S3Output': { + 'LocalPath': '/opt/ml/processing/output/test', + 'S3UploadMode': 'EndOfJob', + 'S3Uri': None + } + } + ] + }, + 'ProcessingResources': { + 'ClusterConfig': { + 'InstanceCount.$': "$$.Execution.Input['instance_count']", + 'InstanceType.$': "$['instance_type']", + 'VolumeKmsKeyId.$': "$$.Execution.Input['volume_kms_key']", + 'VolumeSizeInGB.$': "$$.Execution.Input['volume_size_in_gb']" + } + }, + 'ProcessingJobName': 'MyProcessingJob', + 'RoleArn.$': "$$.Execution.Input['role']", + 'Tags.$': "$$.Execution.Input['tags']", + 'StoppingCondition': {'MaxRuntimeInSeconds.$': "$$.Execution.Input['max_runtime_in_seconds']"}, + }, + 'Resource': 'arn:aws:states:::sagemaker:createProcessingJob.sync', + 'End': True + } diff --git a/tests/unit/test_steps_utils.py b/tests/unit/test_steps_utils.py index 6eb0885..7e06e37 100644 --- a/tests/unit/test_steps_utils.py +++ b/tests/unit/test_steps_utils.py @@ -13,7 +13,7 @@ # Test if boto3 session can fetch correct aws partition info from test environment -from stepfunctions.steps.utils import get_aws_partition +from stepfunctions.steps.utils import get_aws_partition, merge_dicts from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn import boto3 from unittest.mock import patch @@ -51,3 +51,38 @@ def test_arn_builder_sagemaker_wait_completion(): IntegrationPattern.WaitForCompletion) assert arn == "arn:aws:states:::sagemaker:createTrainingJob.sync" + +def test_merge_dicts(): + d1 = { + 'a': { + 'aa': 1, + 'bb': 2, + 'cc': 3 + }, + 'b': 1 + } + + d2 = { + 'a': { + 'bb': { + 'aaa': 1, + 'bbb': 2 + } + }, + 'b': 2, + 'c': 3 + } + + merge_dicts(d1, d2) + assert d1 == { + 'a': { + 'aa': 1, + 'bb': { + 'aaa': 1, + 'bbb': 2 + }, + 'cc': 3 + }, + 'b': 2, + 'c': 3 + }