diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index 0c192084ec..9666ce828f 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -92,7 +92,9 @@ class JumpStartTag(str, Enum): MODEL_ID = "sagemaker-sdk:jumpstart-model-id" MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version" MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type" - MODEL_CONFIG_NAME = "sagemaker-sdk:jumpstart-model-config-name" + + INFERENCE_CONFIG_NAME = "sagemaker-sdk:jumpstart-inference-config-name" + TRAINING_CONFIG_NAME = "sagemaker-sdk:jumpstart-training-config-name" class SerializerType(str, Enum): diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index b20a5c1ead..4939be4041 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -733,7 +733,7 @@ def attach( config_name = None if model_id is None: - model_id, model_version, config_name = get_model_info_from_training_job( + model_id, model_version, _, config_name = get_model_info_from_training_job( training_job_name=training_job_name, sagemaker_session=sagemaker_session ) @@ -1139,7 +1139,9 @@ def set_training_config(self, config_name: str) -> None: Args: config_name (str): The name of the config. """ - self.__init__(**self.init_kwargs, config_name=config_name) + self.__init__( + model_id=self.model_id, model_version=self.model_version, config_name=config_name + ) def __str__(self) -> str: """Overriding str(*) method to make more human-readable.""" diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 2d5b29b52f..604b20bc81 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -61,7 +61,7 @@ JumpStartModelInitKwargs, ) from sagemaker.jumpstart.utils import ( - add_jumpstart_model_id_version_tags, + add_jumpstart_model_info_tags, get_eula_message, update_dict_if_key_not_present, resolve_estimator_sagemaker_config_field, @@ -477,11 +477,12 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima ).version if kwargs.sagemaker_session.settings.include_jumpstart_tags: - kwargs.tags = add_jumpstart_model_id_version_tags( + kwargs.tags = add_jumpstart_model_info_tags( kwargs.tags, kwargs.model_id, full_model_version, config_name=kwargs.config_name, + scope=JumpStartScriptScope.TRAINING, ) return kwargs diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 53508b56f3..54301973e8 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -44,7 +44,7 @@ JumpStartModelRegisterKwargs, ) from sagemaker.jumpstart.utils import ( - add_jumpstart_model_id_version_tags, + add_jumpstart_model_info_tags, update_dict_if_key_not_present, resolve_model_sagemaker_config_field, verify_model_region_and_return_specs, @@ -495,8 +495,13 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: ).version if kwargs.sagemaker_session.settings.include_jumpstart_tags: - kwargs.tags = add_jumpstart_model_id_version_tags( - kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type, kwargs.config_name + kwargs.tags = add_jumpstart_model_info_tags( + kwargs.tags, + kwargs.model_id, + full_model_version, + kwargs.model_type, + config_name=kwargs.config_name, + scope=JumpStartScriptScope.INFERENCE, ) return kwargs diff --git a/src/sagemaker/jumpstart/session_utils.py b/src/sagemaker/jumpstart/session_utils.py index 0fa7722f91..7953b67913 100644 --- a/src/sagemaker/jumpstart/session_utils.py +++ b/src/sagemaker/jumpstart/session_utils.py @@ -17,7 +17,7 @@ from typing import Optional, Tuple from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn +from sagemaker.jumpstart.utils import get_jumpstart_model_info_from_resource_arn from sagemaker.session import Session from sagemaker.utils import aws_partition @@ -26,7 +26,7 @@ def get_model_info_from_endpoint( endpoint_name: str, inference_component_name: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[str, str, Optional[str], Optional[str]]: +) -> Tuple[str, str, Optional[str], Optional[str], Optional[str]]: """Optionally inference component names, return the model ID, version and config name. Infers the model ID and version based on the resource tags. Returns a tuple of the model ID @@ -46,7 +46,8 @@ def get_model_info_from_endpoint( ( model_id, model_version, - config_name, + inference_config_name, + training_config_name, ) = _get_model_info_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301 inference_component_name, sagemaker_session ) @@ -55,17 +56,29 @@ def get_model_info_from_endpoint( ( model_id, model_version, - config_name, + inference_config_name, + training_config_name, inference_component_name, ) = _get_model_info_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301 endpoint_name, sagemaker_session ) else: - model_id, model_version, config_name = _get_model_info_from_model_based_endpoint( + ( + model_id, + model_version, + inference_config_name, + training_config_name, + ) = _get_model_info_from_model_based_endpoint( endpoint_name, inference_component_name, sagemaker_session ) - return model_id, model_version, inference_component_name, config_name + return ( + model_id, + model_version, + inference_component_name, + inference_config_name, + training_config_name, + ) def _get_model_info_from_inference_component_endpoint_without_inference_component_name( @@ -125,9 +138,12 @@ def _get_model_info_from_inference_component_endpoint_with_inference_component_n f"inference-component/{inference_component_name}" ) - model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn( - inference_component_arn, sagemaker_session - ) + ( + model_id, + model_version, + inference_config_name, + training_config_name, + ) = get_jumpstart_model_info_from_resource_arn(inference_component_arn, sagemaker_session) if not model_id: raise ValueError( @@ -136,14 +152,14 @@ def _get_model_info_from_inference_component_endpoint_with_inference_component_n "when retrieving default predictor for this inference component." ) - return model_id, model_version, config_name + return model_id, model_version, inference_config_name, training_config_name def _get_model_info_from_model_based_endpoint( endpoint_name: str, inference_component_name: Optional[str], sagemaker_session: Session, -) -> Tuple[str, str, Optional[str]]: +) -> Tuple[str, str, Optional[str], Optional[str]]: """Returns the model ID, version and config name inferred from a model-based endpoint. Raises: @@ -163,9 +179,12 @@ def _get_model_info_from_model_based_endpoint( endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}" - model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn( - endpoint_arn, sagemaker_session - ) + ( + model_id, + model_version, + inference_config_name, + training_config_name, + ) = get_jumpstart_model_info_from_resource_arn(endpoint_arn, sagemaker_session) if not model_id: raise ValueError( @@ -174,13 +193,13 @@ def _get_model_info_from_model_based_endpoint( "predictor for this endpoint." ) - return model_id, model_version, config_name + return model_id, model_version, inference_config_name, training_config_name def get_model_info_from_training_job( training_job_name: str, sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[str, str, Optional[str]]: +) -> Tuple[str, str, Optional[str], Optional[str]]: """Returns the model ID and version and config name inferred from a training job. Raises: @@ -199,8 +218,9 @@ def get_model_info_from_training_job( ( model_id, inferred_model_version, - config_name, - ) = get_jumpstart_model_id_version_from_resource_arn(training_job_arn, sagemaker_session) + inference_config_name, + trainig_config_name, + ) = get_jumpstart_model_info_from_resource_arn(training_job_arn, sagemaker_session) model_version = inferred_model_version or None @@ -211,4 +231,4 @@ def get_model_info_from_training_job( "for this training job." ) - return model_id, model_version, config_name + return model_id, model_version, inference_config_name, trainig_config_name diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index a8c4bd7c21..d2a0a396b5 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -320,7 +320,8 @@ def add_single_jumpstart_tag( tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags) or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags) or tag_key_in_array(enums.JumpStartTag.MODEL_TYPE, curr_tags) - or tag_key_in_array(enums.JumpStartTag.MODEL_CONFIG_NAME, curr_tags) + or tag_key_in_array(enums.JumpStartTag.INFERENCE_CONFIG_NAME, curr_tags) + or tag_key_in_array(enums.JumpStartTag.TRAINING_CONFIG_NAME, curr_tags) ) if is_uri else False @@ -351,12 +352,13 @@ def get_jumpstart_base_name_if_jumpstart_model( return None -def add_jumpstart_model_id_version_tags( +def add_jumpstart_model_info_tags( tags: Optional[List[TagsDict]], model_id: str, model_version: str, model_type: Optional[enums.JumpStartModelType] = None, config_name: Optional[str] = None, + scope: enums.JumpStartScriptScope = None, ) -> List[TagsDict]: """Add custom model ID and version tags to JumpStart related resources.""" if model_id is None or model_version is None: @@ -380,10 +382,17 @@ def add_jumpstart_model_id_version_tags( tags, is_uri=False, ) - if config_name: + if config_name and scope == enums.JumpStartScriptScope.INFERENCE: tags = add_single_jumpstart_tag( config_name, - enums.JumpStartTag.MODEL_CONFIG_NAME, + enums.JumpStartTag.INFERENCE_CONFIG_NAME, + tags, + is_uri=False, + ) + if config_name and scope == enums.JumpStartScriptScope.TRAINING: + tags = add_single_jumpstart_tag( + config_name, + enums.JumpStartTag.TRAINING_CONFIG_NAME, tags, is_uri=False, ) @@ -840,10 +849,10 @@ def _extract_value_from_list_of_tags( return resolved_value -def get_jumpstart_model_id_version_from_resource_arn( +def get_jumpstart_model_info_from_resource_arn( resource_arn: str, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[Optional[str], Optional[str], Optional[str]]: +) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: """Returns the JumpStart model ID, version and config name if in resource tags. Returns 'None' if model ID or version or config name cannot be inferred from tags. @@ -853,7 +862,8 @@ def get_jumpstart_model_id_version_from_resource_arn( model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS] model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS] - model_config_name_keys = [enums.JumpStartTag.MODEL_CONFIG_NAME] + inference_config_name_keys = [enums.JumpStartTag.INFERENCE_CONFIG_NAME] + training_config_name_keys = [enums.JumpStartTag.TRAINING_CONFIG_NAME] model_id: Optional[str] = _extract_value_from_list_of_tags( tag_keys=model_id_keys, @@ -869,14 +879,21 @@ def get_jumpstart_model_id_version_from_resource_arn( resource_arn=resource_arn, ) - config_name: Optional[str] = _extract_value_from_list_of_tags( - tag_keys=model_config_name_keys, + inference_config_name: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=inference_config_name_keys, + list_tags_result=list_tags_result, + resource_name="inference config name", + resource_arn=resource_arn, + ) + + training_config_name: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=training_config_name_keys, list_tags_result=list_tags_result, - resource_name="model config name", + resource_name="training config name", resource_arn=resource_arn, ) - return model_id, model_version, config_name + return model_id, model_version, inference_config_name, training_config_name def get_region_fallback( diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index c5d08eb8f4..780a1a56c8 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -82,6 +82,7 @@ def retrieve_default( inferred_model_version, inferred_inference_component_name, inferred_config_name, + _, ) = get_model_info_from_endpoint(endpoint_name, inference_component_name, sagemaker_session) if not inferred_model_id: diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 20701edb01..89db48ffd8 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -39,7 +39,7 @@ import botocore from botocore.utils import merge_dicts from six.moves.urllib import parse -import pandas as pd +from six import viewitems from sagemaker import deprecations from sagemaker.config import validate_sagemaker_config @@ -1605,44 +1605,80 @@ def can_model_package_source_uri_autopopulate(source_uri: str): ) -def flatten_dict(source_dict: Dict[str, Any], sep: str = "^") -> Dict[str, Any]: - """Flatten a nested dictionary. +def flatten_dict( + d: Dict[str, Any], + max_flatten_depth=None, +) -> Dict[str, Any]: + """Flatten a dictionary object. - Args: - source_dict (dict): The dictionary to be flattened. - sep (str): The separator to be used in the flattened dictionary. - Returns: - transformed_dict: The flattened dictionary. + d (Dict[str, Any]): + The dict that will be flattened. + max_flatten_depth (Optional[int]): + Maximum depth to merge. """ - flat_dict_list = pd.json_normalize(source_dict, sep=sep).to_dict(orient="records") - if flat_dict_list: - return flat_dict_list[0] - return {} + def tuple_reducer(k1, k2): + if k1 is None: + return (k2,) + return k1 + (k2,) -def unflatten_dict(source_dict: Dict[str, Any], sep: str = "^") -> Dict[str, Any]: - """Unflatten a flattened dictionary back into a nested dictionary. + # check max_flatten_depth + if max_flatten_depth is not None and max_flatten_depth < 1: + raise ValueError("max_flatten_depth should not be less than 1.") - Args: - source_dict (dict): The input flattened dictionary. - sep (str): The separator used in the flattened keys. + reducer = tuple_reducer - Returns: - transformed_dict: The reconstructed nested dictionary. + flat_dict = {} + + def _flatten(_d, depth, parent=None): + key_value_iterable = viewitems(_d) + has_item = False + for key, value in key_value_iterable: + has_item = True + flat_key = reducer(parent, key) + if isinstance(value, dict) and (max_flatten_depth is None or depth < max_flatten_depth): + has_child = _flatten(value, depth=depth + 1, parent=flat_key) + if has_child: + continue + + if flat_key in flat_dict: + raise ValueError("duplicated key '{}'".format(flat_key)) + flat_dict[flat_key] = value + + return has_item + + _flatten(d, depth=1) + return flat_dict + + +def nested_set_dict(d: Dict[str, Any], keys: List[str], value: Any) -> None: + """Set a value to a sequence of nested keys.""" + + key = keys[0] + + if len(keys) == 1: + d[key] = value + return + if not d: + return + + d = d.setdefault(key, {}) + nested_set_dict(d, keys[1:], value) + + +def unflatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: + """Unflatten dict-like object. + + d (Dict[str, Any]) : + The dict that will be unflattened. """ - if not source_dict: - return {} - - result = {} - for key, value in source_dict.items(): - keys = key.split(sep) - current = result - for k in keys[:-1]: - if k not in current: - current[k] = {} - current = current[k] if current[k] is not None else current - current[keys[-1]] = value - return result + + unflattened_dict = {} + for flat_key, value in viewitems(d): + key_tuple = flat_key + nested_set_dict(unflattened_dict, key_tuple, value) + + return unflattened_dict def deep_override_dict( diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 7fbfffcc1a..2202b15ece 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1036,6 +1036,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( "js-trainable-model-prepacked", "1.0.0", None, + None, ) mock_get_model_specs.side_effect = get_special_model_spec @@ -1900,7 +1901,59 @@ def test_estimator_initialization_with_config_name( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, - {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-training"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "neuron-training"}, + ], + enable_network_isolation=False, + ) + + estimator.fit() + + mock_estimator_fit.assert_called_once_with(wait=True) + + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_set_config_name( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id) + + estimator.set_training_config(config_name="neuron-training") + + mock_estimator_init.assert_called_with( + instance_type="ml.p2.xlarge", + instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3", + model_uri="s3://jumpstart-cache-prod-us-west-2/artifacts/meta-textgeneration-llama-2-7b/" + "neuron-training/model/", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" + "transfer_learning/eqa/v1.0.0/sourcedir.tar.gz", + entry_point="transfer_learning.py", + hyperparameters={"epochs": "3", "adam-learning-rate": "2e-05", "batch-size": "4"}, + role="fake role! do not use!", + sagemaker_session=sagemaker_session, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "neuron-training"}, ], enable_network_isolation=False, ) diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index cd11d950d5..f32687bd99 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -1563,7 +1563,7 @@ def test_model_initialization_with_config_name( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, - {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"}, ], wait=True, endpoint_logging=False, @@ -1627,7 +1627,7 @@ def test_model_set_deployment_config( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, - {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"}, ], wait=True, endpoint_logging=False, @@ -1645,7 +1645,7 @@ def test_model_set_deployment_config( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, - {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"}, ], wait=True, endpoint_logging=False, diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 1cc8f292f0..a3425a7b90 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -97,7 +97,7 @@ def test_proprietary_predictor_support( def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( patched_get_model_specs, patched_verify_model_region_and_return_specs, - patched_get_jumpstart_model_id_version_from_endpoint, + patched_get_model_info_from_endpoint, patched_get_default_predictor, patched_predictor, ): @@ -105,20 +105,19 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_get_jumpstart_model_id_version_from_endpoint.return_value = ( + patched_get_model_info_from_endpoint.return_value = ( "predictor-specs-model", "1.2.3", None, None, + None, ) mock_session = Mock() predictor.retrieve_default(endpoint_name="blah", sagemaker_session=mock_session) - patched_get_jumpstart_model_id_version_from_endpoint.assert_called_once_with( - "blah", None, mock_session - ) + patched_get_model_info_from_endpoint.assert_called_once_with("blah", None, mock_session) patched_get_default_predictor.assert_called_once_with( predictor=patched_predictor.return_value, diff --git a/tests/unit/sagemaker/jumpstart/test_session_utils.py b/tests/unit/sagemaker/jumpstart/test_session_utils.py index 9dc8acb32a..ce06a189bd 100644 --- a/tests/unit/sagemaker/jumpstart/test_session_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_session_utils.py @@ -12,61 +12,63 @@ ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") def test_get_model_info_from_training_job_happy_case( - mock_get_jumpstart_model_id_version_from_resource_arn, + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", None, + None, ) retval = get_model_info_from_training_job("bLaH", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", None) + assert retval == ("model_id", "model_version", None, None) - mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( + mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:training-job/bLaH", mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") def test_get_model_info_from_training_job_config_name( - mock_get_jumpstart_model_id_version_from_resource_arn, + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", - "config_name", + None, + "training_config_name", ) retval = get_model_info_from_training_job("bLaH", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", "config_name") + assert retval == ("model_id", "model_version", None, "training_config_name") - mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( + mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:training-job/bLaH", mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") def test_get_model_info_from_training_job_no_model_id_inferred( - mock_get_jumpstart_model_id_version_from_resource_arn, + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( None, None, ) @@ -75,43 +77,45 @@ def test_get_model_info_from_training_job_no_model_id_inferred( get_model_info_from_training_job("blah", sagemaker_session=mock_sm_session) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") def test_get_model_info_from_model_based_endpoint_happy_case( - mock_get_jumpstart_model_id_version_from_resource_arn, + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", None, + None, ) retval = _get_model_info_from_model_based_endpoint( "bLaH", inference_component_name=None, sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version", None) + assert retval == ("model_id", "model_version", None, None) - mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( + mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:endpoint/blah", mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") def test_get_model_info_from_model_based_endpoint_inference_component_supplied( - mock_get_jumpstart_model_id_version_from_resource_arn, + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", None, + None, ) with pytest.raises(ValueError): @@ -120,15 +124,16 @@ def test_get_model_info_from_model_based_endpoint_inference_component_supplied( ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") def test_get_model_info_from_model_based_endpoint_no_model_id_inferred( - mock_get_jumpstart_model_id_version_from_resource_arn, + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( + None, None, None, ) @@ -139,40 +144,42 @@ def test_get_model_info_from_model_based_endpoint_no_model_id_inferred( ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") def test_get_model_info_from_inference_component_endpoint_with_inference_component_name_happy_case( - mock_get_jumpstart_model_id_version_from_resource_arn, + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", None, + None, ) retval = _get_model_info_from_inference_component_endpoint_with_inference_component_name( "bLaH", sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version", None) + assert retval == ("model_id", "model_version", None, None) - mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( + mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:inference-component/bLaH", mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") def test_get_model_info_from_inference_component_endpoint_with_inference_component_name_no_model_id_inferred( - mock_get_jumpstart_model_id_version_from_resource_arn, + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( + None, None, None, None, @@ -272,11 +279,12 @@ def test_get_model_info_from_endpoint_non_inference_component_endpoint( "model_id", "model_version", None, + None, ) retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", None, None) + assert retval == ("model_id", "model_version", None, None, None) mock_get_model_info_from_model_based_endpoint.assert_called_once_with( "blah", None, mock_sm_session ) @@ -296,13 +304,14 @@ def test_get_model_info_from_endpoint_inference_component_endpoint_with_inferenc "model_id", "model_version", None, + None, ) retval = get_model_info_from_endpoint( "blah", inference_component_name="icname", sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version", "icname", None) + assert retval == ("model_id", "model_version", "icname", None, None) mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.assert_called_once_with( "icname", mock_sm_session ) @@ -322,12 +331,13 @@ def test_get_model_info_from_endpoint_inference_component_endpoint_without_infer "model_id", "model_version", None, + None, "inferred-icname", ) retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", "inferred-icname", None) + assert retval == ("model_id", "model_version", "inferred-icname", None, None) mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() @@ -335,7 +345,7 @@ def test_get_model_info_from_endpoint_inference_component_endpoint_without_infer "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_" "endpoint_without_inference_component_name" ) -def test_get_model_info_from_endpoint_inference_component_endpoint_with_config_name( +def test_get_model_info_from_endpoint_inference_component_endpoint_with_inference_config_name( mock_get_model_info_from_inference_component_endpoint_without_inference_component_name, ): mock_sm_session = Mock() @@ -343,11 +353,35 @@ def test_get_model_info_from_endpoint_inference_component_endpoint_with_config_n mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( "model_id", "model_version", - "config_name", + "inference_config_name", + None, + "inferred-icname", + ) + + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) + + assert retval == ("model_id", "model_version", "inferred-icname", "inference_config_name", None) + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() + + +@patch( + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_" + "endpoint_without_inference_component_name" +) +def test_get_model_info_from_endpoint_inference_component_endpoint_with_training_config_name( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name, +): + mock_sm_session = Mock() + mock_sm_session.is_inference_component_based_endpoint.return_value = True + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( + "model_id", + "model_version", + None, + "training_config_name", "inferred-icname", ) retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", "inferred-icname", "config_name") + assert retval == ("model_id", "model_version", "inferred-icname", None, "training_config_name") mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index f7458a29e9..5b30a94dd6 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -210,16 +210,16 @@ def test_is_jumpstart_model_uri(): assert utils.is_jumpstart_model_uri(random_jumpstart_s3_uri("random_key")) -def test_add_jumpstart_model_id_version_tags(): +def test_add_jumpstart_model_info_tags(): tags = None model_id = "model_id" version = "version" + inference_config_name = "inference_config_name" + training_config_name = "training_config_name" assert [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version - ) + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) tags = [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"}, @@ -231,9 +231,7 @@ def test_add_jumpstart_model_id_version_tags(): assert [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"}, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version_2"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version - ) + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) tags = [ {"Key": "random key", "Value": "random_value"}, @@ -244,9 +242,7 @@ def test_add_jumpstart_model_id_version_tags(): {"Key": "random key", "Value": "random_value"}, {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version - ) + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) tags = [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"}, @@ -257,9 +253,7 @@ def test_add_jumpstart_model_id_version_tags(): assert [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"}, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version - ) + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) tags = [ {"Key": "random key", "Value": "random_value"}, @@ -268,8 +262,58 @@ def test_add_jumpstart_model_id_version_tags(): version = None assert [ {"Key": "random key", "Value": "random_value"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) + + tags = [ + {"Key": "random key", "Value": "random_value"}, + ] + model_id = "model_id" + version = "version" + assert [ + {"Key": "random key", "Value": "random_value"}, + {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, + {"Key": "sagemaker-sdk:jumpstart-inference-config-name", "Value": "inference_config_name"}, + ] == utils.add_jumpstart_model_info_tags( + tags=tags, + model_id=model_id, + model_version=version, + config_name=inference_config_name, + scope=JumpStartScriptScope.INFERENCE, + ) + + tags = [ + {"Key": "random key", "Value": "random_value"}, + ] + model_id = "model_id" + version = "version" + assert [ + {"Key": "random key", "Value": "random_value"}, + {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, + {"Key": "sagemaker-sdk:jumpstart-training-config-name", "Value": "training_config_name"}, + ] == utils.add_jumpstart_model_info_tags( + tags=tags, + model_id=model_id, + model_version=version, + config_name=training_config_name, + scope=JumpStartScriptScope.TRAINING, + ) + + tags = [ + {"Key": "random key", "Value": "random_value"}, + ] + model_id = "model_id" + version = "version" + assert [ + {"Key": "random key", "Value": "random_value"}, + {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, + ] == utils.add_jumpstart_model_info_tags( + tags=tags, + model_id=model_id, + model_version=version, + config_name=training_config_name, ) @@ -1322,10 +1366,8 @@ def test_no_model_id_no_version_found(self): mock_list_tags.return_value = [{"Key": "blah", "Value": "blah1"}] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, None, None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1339,10 +1381,8 @@ def test_model_id_no_version_found(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - ("model_id", None, None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + ("model_id", None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1356,10 +1396,8 @@ def test_no_model_id_version_found(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, "model_version", None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, "model_version", None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1370,27 +1408,54 @@ def test_no_config_name_found(self): mock_list_tags.return_value = [{"Key": "blah", "Value": "blah1"}] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, None, None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") - def test_config_name_found(self): + def test_inference_config_name_found(self): mock_list_tags = Mock() mock_sagemaker_session = Mock() mock_sagemaker_session.list_tags = mock_list_tags mock_list_tags.return_value = [ {"Key": "blah", "Value": "blah1"}, - {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "config_name"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "config_name"}, ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, None, "config_name"), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, "config_name", None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_training_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "config_name"}, + ] + + self.assertEquals( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, "config_name"), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_both_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "inference_config_name"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "training_config_name"}, + ] + + self.assertEquals( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, "inference_config_name", "training_config_name"), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1405,10 +1470,8 @@ def test_model_id_version_found(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - ("model_id", "model_version", None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + ("model_id", "model_version", None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1425,10 +1488,8 @@ def test_multiple_model_id_versions_found(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, None, None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1445,10 +1506,8 @@ def test_multiple_model_id_versions_found_aliases_consistent(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - ("model_id_1", "model_version_1", None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + ("model_id_1", "model_version_1", None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1465,10 +1524,8 @@ def test_multiple_model_id_versions_found_aliases_inconsistent(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, None, None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1480,15 +1537,13 @@ def test_multiple_config_names_found_aliases_inconsistent(self): {"Key": "blah", "Value": "blah1"}, {"Key": JumpStartTag.MODEL_ID, "Value": "model_id_1"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "model_version_1"}, - {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "config_name_1"}, - {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "config_name_2"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "config_name_1"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "config_name_2"}, ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - ("model_id_1", "model_version_1", None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + ("model_id_1", "model_version_1", None, None), ) mock_list_tags.assert_called_once_with("some-arn") diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 083e2dd09a..93abcfc7a8 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1819,7 +1819,13 @@ def test_can_model_package_source_uri_autopopulate(): class TestDeepMergeDict(TestCase): def test_flatten_dict_basic(self): nested_dict = {"a": 1, "b": {"x": 2, "y": {"p": 3, "q": 4}}, "c": 5} - flattened_dict = {"a": 1, "b^x": 2, "b^y^p": 3, "b^y^q": 4, "c": 5} + flattened_dict = { + ("a",): 1, + ("b", "x"): 2, + ("b", "y", "p"): 3, + ("b", "y", "q"): 4, + ("c",): 5, + } self.assertDictEqual(flatten_dict(nested_dict), flattened_dict) self.assertDictEqual(unflatten_dict(flattened_dict), nested_dict) @@ -1831,13 +1837,19 @@ def test_flatten_dict_empty(self): def test_flatten_dict_no_nested(self): nested_dict = {"a": 1, "b": 2, "c": 3} - flattened_dict = {"a": 1, "b": 2, "c": 3} + flattened_dict = {("a",): 1, ("b",): 2, ("c",): 3} self.assertDictEqual(flatten_dict(nested_dict), flattened_dict) self.assertDictEqual(unflatten_dict(flattened_dict), nested_dict) def test_flatten_dict_with_various_types(self): nested_dict = {"a": [1, 2, 3], "b": {"x": None, "y": {"p": [], "q": ""}}, "c": 9} - flattened_dict = {"a": [1, 2, 3], "b^x": None, "b^y^p": [], "b^y^q": "", "c": 9} + flattened_dict = { + ("a",): [1, 2, 3], + ("b", "x"): None, + ("b", "y", "p"): [], + ("b", "y", "q"): "", + ("c",): 9, + } self.assertDictEqual(flatten_dict(nested_dict), flattened_dict) self.assertDictEqual(unflatten_dict(flattened_dict), nested_dict)