diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 066846564e..58a5fabc2f 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -181,6 +181,7 @@ def __init__( container_arguments: Optional[List[str]] = None, disable_output_compression: bool = False, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, **kwargs, ): """Initialize an ``EstimatorBase`` instance. @@ -544,7 +545,9 @@ def __init__( enable_infra_check (bool or PipelineVariable): Optional. Specifies whether it is running Sagemaker built-in infra check jobs. enable_remote_debug (bool or PipelineVariable): Optional. - Specifies whether RemoteDebug is enabled for the training job + Specifies whether RemoteDebug is enabled for the training job. + enable_session_tag_chaining (bool or PipelineVariable): Optional. + Specifies whether SessionTagChaining is enabled for the training job. """ instance_count = renamed_kwargs( "train_instance_count", "instance_count", instance_count, kwargs @@ -785,6 +788,8 @@ def __init__( self._enable_remote_debug = enable_remote_debug + self._enable_session_tag_chaining = enable_session_tag_chaining + @abstractmethod def training_image_uri(self): """Return the Docker image to use for training. @@ -2318,6 +2323,14 @@ def get_remote_debug_config(self): else {"EnableRemoteDebug": self._enable_remote_debug} ) + def get_session_chaining_config(self): + """dict: Return the configuration of SessionChaining""" + return ( + None + if self._enable_session_tag_chaining is None + else {"EnableSessionTagChaining": self._enable_session_tag_chaining} + ) + def enable_remote_debug(self): """Enable remote debug for a training job.""" self._update_remote_debug(True) @@ -2574,6 +2587,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config): if estimator.get_remote_debug_config() is not None: train_args["remote_debug_config"] = estimator.get_remote_debug_config() + if estimator.get_session_chaining_config() is not None: + train_args["session_chaining_config"] = estimator.get_session_chaining_config() + return train_args @classmethod @@ -2766,6 +2782,7 @@ def __init__( disable_output_compression: bool = False, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, **kwargs, ): """Initialize an ``Estimator`` instance. @@ -3129,6 +3146,8 @@ def __init__( Specifies whether it is running Sagemaker built-in infra check jobs. enable_remote_debug (bool or PipelineVariable): Optional. Specifies whether RemoteDebug is enabled for the training job + enable_session_tag_chaining (bool or PipelineVariable): Optional. + Specifies whether SessionTagChaining is enabled for the training job """ self.image_uri = image_uri self._hyperparameters = hyperparameters.copy() if hyperparameters else {} @@ -3181,6 +3200,7 @@ def __init__( container_arguments=container_arguments, disable_output_compression=disable_output_compression, enable_remote_debug=enable_remote_debug, + enable_session_tag_chaining=enable_session_tag_chaining, **kwargs, ) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 88927ae931..bade834cc6 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -109,6 +109,7 @@ def __init__( container_arguments: Optional[List[str]] = None, disable_output_compression: Optional[bool] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, ): """Initializes a ``JumpStartEstimator``. @@ -500,6 +501,8 @@ def __init__( to Amazon S3 without compression after training finishes. enable_remote_debug (bool or PipelineVariable): Optional. Specifies whether RemoteDebug is enabled for the training job + enable_session_tag_chaining (bool or PipelineVariable): Optional. + Specifies whether SessionTagChaining is enabled for the training job Raises: ValueError: If the model ID is not recognized by JumpStart. @@ -578,6 +581,7 @@ def _validate_model_id_and_get_type_hook(): disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, + enable_session_tag_chaining=enable_session_tag_chaining, ) self.model_id = estimator_init_kwargs.model_id diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 875ec9d003..387a4a843c 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -130,6 +130,7 @@ def get_init_kwargs( disable_output_compression: Optional[bool] = None, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, ) -> JumpStartEstimatorInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object.""" @@ -188,6 +189,7 @@ def get_init_kwargs( disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, + enable_session_tag_chaining=enable_session_tag_chaining, ) estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 05c6a00961..dae879494e 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1751,6 +1751,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "disable_output_compression", "enable_infra_check", "enable_remote_debug", + "enable_session_tag_chaining", ] SERIALIZATION_EXCLUSION_SET = { @@ -1818,6 +1819,7 @@ def __init__( disable_output_compression: Optional[bool] = None, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -1877,6 +1879,7 @@ def __init__( self.disable_output_compression = disable_output_compression self.enable_infra_check = enable_infra_check self.enable_remote_debug = enable_remote_debug + self.enable_session_tag_chaining = enable_session_tag_chaining class JumpStartEstimatorFitKwargs(JumpStartKwargs): diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 9e593706c1..5ea3d5f8a1 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -758,6 +758,7 @@ def train( # noqa: C901 environment: Optional[Dict[str, str]] = None, retry_strategy=None, remote_debug_config=None, + session_chaining_config=None, ): """Create an Amazon SageMaker training job. @@ -877,6 +878,15 @@ def train( # noqa: C901 remote_debug_config = { "EnableRemoteDebug": True, } + session_chaining_config(dict): Configuration for SessionChaining. (default: ``None``) + The dict can contain 'EnableSessionTagChaining'(bool). + For example, + + .. code:: python + + session_chaining_config = { + "EnableSessionTagChaining": True, + } environment (dict[str, str]) : Environment variables to be set for use during training job (default: ``None``) retry_strategy(dict): Defines RetryStrategy for InternalServerFailures. @@ -970,6 +980,7 @@ def train( # noqa: C901 profiler_rule_configs=profiler_rule_configs, profiler_config=inferred_profiler_config, remote_debug_config=remote_debug_config, + session_chaining_config=session_chaining_config, environment=environment, retry_strategy=retry_strategy, ) @@ -1013,6 +1024,7 @@ def _get_train_request( # noqa: C901 profiler_rule_configs=None, profiler_config=None, remote_debug_config=None, + session_chaining_config=None, environment=None, retry_strategy=None, ): @@ -1133,6 +1145,15 @@ def _get_train_request( # noqa: C901 remote_debug_config = { "EnableRemoteDebug": True, } + session_chaining_config(dict): Configuration for SessionChaining. (default: ``None``) + The dict can contain 'EnableSessionTagChaining'(bool). + For example, + + .. code:: python + + session_chaining_config = { + "EnableSessionTagChaining": True, + } environment (dict[str, str]) : Environment variables to be set for use during training job (default: ``None``) retry_strategy(dict): Defines RetryStrategy for InternalServerFailures. @@ -1239,6 +1260,9 @@ def _get_train_request( # noqa: C901 if remote_debug_config is not None: train_request["RemoteDebugConfig"] = remote_debug_config + if session_chaining_config is not None: + train_request["SessionChainingConfig"] = session_chaining_config + if retry_strategy is not None: train_request["RetryStrategy"] = retry_strategy diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 382c48fde6..fd45601801 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -2089,6 +2089,41 @@ def test_framework_disable_remote_debug(sagemaker_session): assert len(args) == 2 +def test_framework_with_session_chaining_config(sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.m4.xlarge", 2), + ], + enable_session_tag_chaining=True, + ) + f.fit("s3://mydata") + sagemaker_session.train.assert_called_once() + _, args = sagemaker_session.train.call_args + assert args["session_chaining_config"]["EnableSessionTagChaining"] + assert f.get_session_chaining_config()["EnableSessionTagChaining"] + + +def test_framework_without_session_chaining_config(sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.m4.xlarge", 2), + ], + ) + f.fit("s3://mydata") + sagemaker_session.train.assert_called_once() + _, args = sagemaker_session.train.call_args + assert args.get("SessionTagChaining") is None + assert f.get_remote_debug_config() is None + + @patch("time.strftime", return_value=TIMESTAMP) def test_custom_code_bucket(time, sagemaker_session): code_bucket = "codebucket" diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 19f9d0ae3d..944f22acff 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2197,6 +2197,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): CONTAINER_ENTRY_POINT = ["bin/bash", "test.sh"] CONTAINER_ARGUMENTS = ["--arg1", "value1", "--arg2", "value2"] remote_debug_config = {"EnableRemoteDebug": True} + session_chaining_config = {"EnableSessionTagChaining": True} sagemaker_session.train( image_uri=IMAGE, @@ -2222,6 +2223,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): container_entry_point=CONTAINER_ENTRY_POINT, container_arguments=CONTAINER_ARGUMENTS, remote_debug_config=remote_debug_config, + session_chaining_config=session_chaining_config, ) _, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0] @@ -2245,6 +2247,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): ) assert actual_train_args["AlgorithmSpecification"]["ContainerArguments"] == CONTAINER_ARGUMENTS assert actual_train_args["RemoteDebugConfig"]["EnableRemoteDebug"] + assert actual_train_args["SessionChainingConfig"]["EnableSessionTagChaining"] def test_create_transform_job_with_sagemaker_config_injection(sagemaker_session):