diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index 78aa655e04..7541425868 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -77,6 +77,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default accept type for the model matching the given arguments. @@ -98,6 +99,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default accept type to use for the model. @@ -117,4 +119,5 @@ def retrieve_default( tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index 46d0361f67..627feca0d6 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -77,6 +77,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default content type for the model matching the given arguments. @@ -98,6 +99,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default content type to use for the model. @@ -117,6 +119,7 @@ def retrieve_default( tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 1a4be43897..02e61149ec 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -97,6 +97,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> BaseDeserializer: """Retrieves the default deserializer for the model matching the given arguments. @@ -118,6 +119,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: BaseDeserializer: The default deserializer to use for the model. @@ -138,4 +140,5 @@ def retrieve_default( tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index b67066fcde..8fa52c3ec8 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -36,6 +36,7 @@ def retrieve_default( sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE, + config_name: Optional[str] = None, ) -> Dict[str, str]: """Retrieves the default container environment variables for the model matching the arguments. @@ -65,6 +66,7 @@ def retrieve_default( variables specific for the instance type. script (JumpStartScriptScope): The JumpStart script for which to retrieve environment variables. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: The variables to use for the model. @@ -87,4 +89,5 @@ def retrieve_default( sagemaker_session=sagemaker_session, instance_type=instance_type, script=script, + config_name=config_name, ) diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index 5873e37b9f..5c22409c50 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -36,6 +36,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> Dict[str, str]: """Retrieves the default training hyperparameters for the model matching the given arguments. @@ -66,6 +67,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: The hyperparameters to use for the model. @@ -86,6 +88,7 @@ def retrieve_default( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 143ecc9bdb..97471f2c41 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -68,6 +68,7 @@ def retrieve( inference_tool=None, serverless_inference_config=None, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name=None, ) -> str: """Retrieves the ECR URI for the Docker image matching the given arguments. @@ -121,6 +122,7 @@ def retrieve( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The ECR URI for the corresponding SageMaker Docker image. @@ -160,6 +162,7 @@ def retrieve( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]): diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index 48aaab0ac8..c4af4b2036 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -36,6 +36,7 @@ def retrieve_default( sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default instance type for the model matching the given arguments. @@ -64,6 +65,7 @@ def retrieve_default( Optionally supply this to get a inference instance type conditioned on the training instance, to ensure compatability of training artifact to inference instance. (Default: None). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default instance type to use for the model. @@ -88,6 +90,7 @@ def retrieve_default( sagemaker_session=sagemaker_session, training_instance_type=training_instance_type, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index c28c27ed4e..fcb3ce3bf2 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -39,6 +39,7 @@ def _retrieve_default_environment_variables( sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE, + config_name: Optional[str] = None, ) -> Dict[str, str]: """Retrieves the inference environment variables for the model matching the given arguments. @@ -68,6 +69,7 @@ def _retrieve_default_environment_variables( environment variables specific for the instance type. script (JumpStartScriptScope): The JumpStart script for which to retrieve environment variables. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the inference environment variables to use for the model. """ @@ -84,6 +86,7 @@ def _retrieve_default_environment_variables( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) default_environment_variables: Dict[str, str] = {} @@ -121,6 +124,7 @@ def _retrieve_default_environment_variables( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, instance_type=instance_type, + config_name=config_name, ) ) @@ -167,6 +171,7 @@ def _retrieve_gated_model_uri_env_var_value( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> Optional[str]: """Retrieves the gated model env var URI matching the given arguments. @@ -190,6 +195,7 @@ def _retrieve_gated_model_uri_env_var_value( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get environment variables specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: Optional[str]: the s3 URI to use for the environment variable, or None if the model does not @@ -211,6 +217,7 @@ def _retrieve_gated_model_uri_env_var_value( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) s3_key: Optional[str] = ( diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index d19530ecfb..67db7d260f 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -36,6 +36,7 @@ def _retrieve_default_hyperparameters( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ): """Retrieves the training hyperparameters for the model matching the given arguments. @@ -66,6 +67,7 @@ def _retrieve_default_hyperparameters( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get hyperparameters specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the hyperparameters to use for the model. """ @@ -82,6 +84,7 @@ def _retrieve_default_hyperparameters( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) default_hyperparameters: Dict[str, str] = {} diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 9d19d5e069..72633320f5 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -46,6 +46,7 @@ def _retrieve_image_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ): """Retrieves the container image URI for JumpStart models. @@ -95,6 +96,7 @@ def _retrieve_image_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the ECR URI for the corresponding SageMaker Docker image. @@ -116,6 +118,7 @@ def _retrieve_image_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if image_scope == JumpStartScriptScope.INFERENCE: @@ -200,4 +203,5 @@ def _retrieve_image_uri( distribution=distribution, base_framework_version=base_framework_version_override or base_framework_version, training_compiler_config=training_compiler_config, + config_name=config_name, ) diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index 1b3c6f4b29..8bbe089354 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -33,6 +33,7 @@ def _model_supports_incremental_training( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> bool: """Returns True if the model supports incremental training. @@ -54,6 +55,7 @@ def _model_supports_incremental_training( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: bool: the support status for incremental training. """ @@ -70,6 +72,7 @@ def _model_supports_incremental_training( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.supports_incremental_training() diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index e7c9c5911d..f4bf212c1c 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -40,6 +40,7 @@ def _retrieve_default_instance_type( sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default instance type for the model. @@ -68,6 +69,7 @@ def _retrieve_default_instance_type( Optionally supply this to get a inference instance type conditioned on the training instance, to ensure compatability of training artifact to inference instance. (Default: None). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the default instance type to use for the model or None. @@ -89,6 +91,7 @@ def _retrieve_default_instance_type( tolerate_deprecated_model=tolerate_deprecated_model, model_type=model_type, sagemaker_session=sagemaker_session, + config_name=config_name, ) if scope == JumpStartScriptScope.INFERENCE: @@ -128,6 +131,7 @@ def _retrieve_instance_types( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> List[str]: """Retrieves the supported instance types for the model. @@ -156,6 +160,7 @@ def _retrieve_instance_types( Optionally supply this to get a inference instance type conditioned on the training instance, to ensure compatability of training artifact to inference instance. (Default: None). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: the supported instance types to use for the model or None. @@ -176,6 +181,7 @@ def _retrieve_instance_types( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if scope == JumpStartScriptScope.INFERENCE: diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index 9cd152b0bb..ceb88d9b26 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -37,6 +37,7 @@ def _retrieve_model_init_kwargs( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> dict: """Retrieves kwargs for `Model`. @@ -58,6 +59,7 @@ def _retrieve_model_init_kwargs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the kwargs to use for the use case. """ @@ -75,6 +77,7 @@ def _retrieve_model_init_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) kwargs = deepcopy(model_specs.model_kwargs) @@ -94,6 +97,7 @@ def _retrieve_model_deploy_kwargs( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> dict: """Retrieves kwargs for `Model.deploy`. @@ -117,6 +121,7 @@ def _retrieve_model_deploy_kwargs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the kwargs to use for the use case. @@ -135,6 +140,7 @@ def _retrieve_model_deploy_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) if volume_size_supported(instance_type) and model_specs.inference_volume_size is not None: @@ -151,6 +157,7 @@ def _retrieve_estimator_init_kwargs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> dict: """Retrieves kwargs for `Estimator`. @@ -174,6 +181,7 @@ def _retrieve_estimator_init_kwargs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the kwargs to use for the use case. """ @@ -190,6 +198,7 @@ def _retrieve_estimator_init_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) kwargs = deepcopy(model_specs.estimator_kwargs) @@ -210,6 +219,7 @@ def _retrieve_estimator_fit_kwargs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> dict: """Retrieves kwargs for `Estimator.fit`. @@ -231,6 +241,7 @@ def _retrieve_estimator_fit_kwargs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the kwargs to use for the use case. @@ -248,6 +259,7 @@ def _retrieve_estimator_fit_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.fit_kwargs diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index 57f66155c7..f23b66aed4 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -35,6 +35,7 @@ def _retrieve_default_training_metric_definitions( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> Optional[List[Dict[str, str]]]: """Retrieves the default training metric definitions for the model. @@ -58,6 +59,7 @@ def _retrieve_default_training_metric_definitions( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get metric definitions specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: the default training metric definitions to use for the model or None. """ @@ -74,6 +76,7 @@ def _retrieve_default_training_metric_definitions( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) default_metric_definitions = ( diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index aa22351771..12166b1a76 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -37,6 +37,7 @@ def _retrieve_model_package_arn( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> Optional[str]: """Retrieves associated model pacakge arn for the model. @@ -60,6 +61,7 @@ def _retrieve_model_package_arn( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the model package arn to use for the model or None. @@ -78,6 +80,7 @@ def _retrieve_model_package_arn( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) if scope == JumpStartScriptScope.INFERENCE: @@ -118,6 +121,7 @@ def _retrieve_model_package_model_artifact_s3_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> Optional[str]: """Retrieves s3 artifact uri associated with model package. @@ -141,6 +145,7 @@ def _retrieve_model_package_model_artifact_s3_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the model package artifact uri to use for the model or None. @@ -162,6 +167,7 @@ def _retrieve_model_package_model_artifact_s3_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if model_specs.training_model_package_artifact_uris is None: diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index 6bb2e576fc..00c6d8b9aa 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -95,6 +95,7 @@ def _retrieve_model_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ): """Retrieves the model artifact S3 URI for the model matching the given arguments. @@ -120,6 +121,8 @@ def _retrieve_model_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + Returns: str: the model artifact S3 URI for the corresponding model. @@ -141,6 +144,7 @@ def _retrieve_model_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) model_artifact_key: str @@ -182,6 +186,7 @@ def _model_supports_training_model_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> bool: """Returns True if the model supports training with model uri field. @@ -203,6 +208,7 @@ def _model_supports_training_model_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: bool: the support status for model uri with training. """ @@ -219,6 +225,7 @@ def _model_supports_training_model_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.use_training_model_artifact() diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 3359e32732..2f4a8bb0ac 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -37,6 +37,7 @@ def _retrieve_example_payloads( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> Optional[Dict[str, JumpStartSerializablePayload]]: """Returns example payloads. @@ -58,6 +59,7 @@ def _retrieve_example_payloads( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: Optional[Dict[str, JumpStartSerializablePayload]]: dictionary mapping payload aliases to the serializable payload object. @@ -76,6 +78,7 @@ def _retrieve_example_payloads( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) default_payloads = model_specs.default_payloads diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 4f6dfe1fe3..635f063e05 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -78,6 +78,7 @@ def _retrieve_default_deserializer( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> BaseDeserializer: """Retrieves the default deserializer for the model. @@ -98,6 +99,7 @@ def _retrieve_default_deserializer( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: BaseDeserializer: the default deserializer to use for the model. @@ -111,6 +113,7 @@ def _retrieve_default_deserializer( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) return _retrieve_deserializer_from_accept_type(MIMEType.from_suffixed_type(default_accept_type)) @@ -124,6 +127,7 @@ def _retrieve_default_serializer( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> BaseSerializer: """Retrieves the default serializer for the model. @@ -144,6 +148,7 @@ def _retrieve_default_serializer( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: BaseSerializer: the default serializer to use for the model. """ @@ -156,6 +161,7 @@ def _retrieve_default_serializer( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) return _retrieve_serializer_from_content_type(MIMEType.from_suffixed_type(default_content_type)) @@ -169,6 +175,7 @@ def _retrieve_deserializer_options( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> List[BaseDeserializer]: """Retrieves the supported deserializers for the model. @@ -189,6 +196,7 @@ def _retrieve_deserializer_options( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: List[BaseDeserializer]: the supported deserializers to use for the model. """ @@ -201,6 +209,7 @@ def _retrieve_deserializer_options( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) seen_classes: Set[Type] = set() @@ -227,6 +236,7 @@ def _retrieve_serializer_options( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> List[BaseSerializer]: """Retrieves the supported serializers for the model. @@ -247,6 +257,7 @@ def _retrieve_serializer_options( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: List[BaseSerializer]: the supported serializers to use for the model. """ @@ -258,6 +269,7 @@ def _retrieve_serializer_options( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) seen_classes: Set[Type] = set() @@ -285,6 +297,7 @@ def _retrieve_default_content_type( tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> str: """Retrieves the default content type for the model. @@ -305,6 +318,7 @@ def _retrieve_default_content_type( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the default content type to use for the model. """ @@ -322,6 +336,7 @@ def _retrieve_default_content_type( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) default_content_type = model_specs.predictor_specs.default_content_type @@ -336,6 +351,7 @@ def _retrieve_default_accept_type( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default accept type for the model. @@ -356,6 +372,7 @@ def _retrieve_default_accept_type( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the default accept type to use for the model. """ @@ -373,6 +390,7 @@ def _retrieve_default_accept_type( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) default_accept_type = model_specs.predictor_specs.default_accept_type @@ -388,6 +406,7 @@ def _retrieve_supported_accept_types( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> List[str]: """Retrieves the supported accept types for the model. @@ -408,6 +427,7 @@ def _retrieve_supported_accept_types( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: the supported accept types to use for the model. """ @@ -425,6 +445,7 @@ def _retrieve_supported_accept_types( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) supported_accept_types = model_specs.predictor_specs.supported_accept_types @@ -440,6 +461,7 @@ def _retrieve_supported_content_types( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> List[str]: """Retrieves the supported content types for the model. @@ -460,6 +482,7 @@ def _retrieve_supported_content_types( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: the supported content types to use for the model. """ @@ -477,6 +500,7 @@ def _retrieve_supported_content_types( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) supported_content_types = model_specs.predictor_specs.supported_content_types diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index cffd46d043..b4fdac770b 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -35,6 +35,8 @@ def _retrieve_resource_name_base( tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + scope: JumpStartScriptScope = JumpStartScriptScope.INFERENCE, + config_name: Optional[str] = None, ) -> bool: """Returns default resource name. @@ -56,6 +58,7 @@ def _retrieve_resource_name_base( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config. (Default: None). Returns: str: the default resource name. """ @@ -67,12 +70,13 @@ def _retrieve_resource_name_base( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, - scope=JumpStartScriptScope.INFERENCE, + scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, model_type=model_type, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.resource_name_base diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 369acac85f..49126da336 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -54,6 +54,7 @@ def _retrieve_default_resources( model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> ResourceRequirements: """Retrieves the default resource requirements for the model. @@ -79,6 +80,7 @@ def _retrieve_default_resources( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get host requirements specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default resource requirements to use for the model or None. @@ -102,6 +104,7 @@ def _retrieve_default_resources( tolerate_deprecated_model=tolerate_deprecated_model, model_type=model_type, sagemaker_session=sagemaker_session, + config_name=config_name, ) if scope == JumpStartScriptScope.INFERENCE: diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index f69732d2e0..97313ec626 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -37,6 +37,7 @@ def _retrieve_script_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ): """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -62,6 +63,7 @@ def _retrieve_script_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the model script URI for the corresponding model. @@ -83,6 +85,7 @@ def _retrieve_script_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if script_scope == JumpStartScriptScope.INFERENCE: @@ -108,6 +111,7 @@ def _model_supports_inference_script_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> bool: """Returns True if the model supports inference with script uri field. @@ -145,6 +149,7 @@ def _model_supports_inference_script_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.use_inference_script_uri() diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 88927ae931..cf9b720607 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -34,7 +34,9 @@ from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs from sagemaker.jumpstart.factory.model import get_default_predictor from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job +from sagemaker.jumpstart.types import JumpStartMetadataConfig from sagemaker.jumpstart.utils import ( + get_jumpstart_configs, validate_model_id_and_get_type, resolve_model_sagemaker_config_field, verify_model_region_and_return_specs, @@ -109,6 +111,7 @@ def __init__( container_arguments: Optional[List[str]] = None, disable_output_compression: Optional[bool] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + config_name: Optional[str] = None, ): """Initializes a ``JumpStartEstimator``. @@ -500,6 +503,8 @@ def __init__( to Amazon S3 without compression after training finishes. enable_remote_debug (bool or PipelineVariable): Optional. Specifies whether RemoteDebug is enabled for the training job + config_name (Optional[str]): + Name of the JumpStart Model config to apply. (Default: None). Raises: ValueError: If the model ID is not recognized by JumpStart. @@ -578,6 +583,7 @@ def _validate_model_id_and_get_type_hook(): disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, + config_name=config_name, ) self.model_id = estimator_init_kwargs.model_id @@ -591,6 +597,8 @@ def _validate_model_id_and_get_type_hook(): self.role = estimator_init_kwargs.role self.sagemaker_session = estimator_init_kwargs.sagemaker_session self._enable_network_isolation = estimator_init_kwargs.enable_network_isolation + self.config_name = estimator_init_kwargs.config_name + self.init_kwargs = estimator_init_kwargs.to_kwargs_dict(False) super(JumpStartEstimator, self).__init__(**estimator_init_kwargs.to_kwargs_dict()) @@ -665,6 +673,7 @@ def fit( tolerate_vulnerable_model=self.tolerate_vulnerable_model, tolerate_deprecated_model=self.tolerate_deprecated_model, sagemaker_session=self.sagemaker_session, + config_name=self.config_name, ) return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict()) @@ -1076,6 +1085,7 @@ def deploy( git_config=git_config, use_compiled_model=use_compiled_model, training_instance_type=self.instance_type, + config_name=self.config_name, ) predictor = super(JumpStartEstimator, self).deploy( @@ -1092,11 +1102,39 @@ def deploy( tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, + # config_name=self.config_name, ) # If a predictor class was passed, do not mutate predictor return predictor + def list_training_configs(self) -> List[JumpStartMetadataConfig]: + """Returns a list of configs associated with the estimator. + + Raises: + ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. + """ + configs_dict = get_jumpstart_configs( + model_id=self.model_id, + model_version=self.model_version, + model_type=self.model_type, + region=self.region, + scope=JumpStartScriptScope.TRAINING, + sagemaker_session=self.sagemaker_session, + ) + return list(configs_dict.values()) + + def set_training_config(self, config_name: str) -> None: + """Sets the config to apply to the model. + + Args: + config_name (str): The name of the config. + """ + self.__init__(**self.init_kwargs, config_name=config_name) + def __str__(self) -> str: """Overriding str(*) method to make more human-readable.""" return stringify_object(self) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 875ec9d003..926f313b68 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -130,6 +130,7 @@ def get_init_kwargs( disable_output_compression: Optional[bool] = None, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + config_name: Optional[str] = None, ) -> JumpStartEstimatorInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object.""" @@ -188,6 +189,7 @@ def get_init_kwargs( disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, + config_name=config_name, ) estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs) @@ -221,6 +223,7 @@ def get_fit_kwargs( tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, + config_name: Optional[str] = None, ) -> JumpStartEstimatorFitKwargs: """Returns kwargs required call `fit` on `sagemaker.estimator.Estimator` object.""" @@ -236,6 +239,7 @@ def get_fit_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) estimator_fit_kwargs = _add_model_version_to_kwargs(estimator_fit_kwargs) @@ -287,6 +291,7 @@ def get_deploy_kwargs( use_compiled_model: Optional[bool] = None, model_name: Optional[str] = None, training_instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> JumpStartEstimatorDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Estimator` object.""" @@ -314,6 +319,7 @@ def get_deploy_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) model_init_kwargs: JumpStartModelInitKwargs = model.get_init_kwargs( @@ -342,6 +348,7 @@ def get_deploy_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, training_instance_type=training_instance_type, disable_instance_type_logging=True, + config_name=config_name, ) estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs( @@ -386,6 +393,7 @@ def get_deploy_kwargs( tolerate_vulnerable_model=model_init_kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=model_init_kwargs.tolerate_deprecated_model, use_compiled_model=use_compiled_model, + config_name=config_name, ) return estimator_deploy_kwargs @@ -441,6 +449,7 @@ def _add_instance_type_and_count_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) kwargs.instance_count = kwargs.instance_count or 1 @@ -464,6 +473,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ).version if kwargs.sagemaker_session.settings.include_jumpstart_tags: @@ -486,6 +496,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) return kwargs @@ -511,6 +522,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE sagemaker_session=kwargs.sagemaker_session, region=kwargs.region, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) if ( @@ -523,6 +535,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) ): JUMPSTART_LOGGER.warning( @@ -558,6 +571,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, region=kwargs.region, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) return kwargs @@ -578,6 +592,7 @@ def _add_env_to_kwargs( sagemaker_session=kwargs.sagemaker_session, script=JumpStartScriptScope.TRAINING, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) model_package_artifact_uri = _retrieve_model_package_model_artifact_s3_uri( @@ -588,6 +603,7 @@ def _add_env_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) if model_package_artifact_uri: @@ -615,6 +631,7 @@ def _add_env_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) if model_specs.is_gated_model(): raise ValueError( @@ -644,9 +661,11 @@ def _add_training_job_name_to_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, region=kwargs.region, + scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) kwargs.job_name = kwargs.job_name or ( @@ -673,6 +692,7 @@ def _add_hyperparameters_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) for key, value in default_hyperparameters.items(): @@ -706,6 +726,7 @@ def _add_metric_definitions_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) or [] ) @@ -735,6 +756,7 @@ def _add_estimator_extra_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) for key, value in estimator_kwargs_to_add.items(): @@ -759,6 +781,7 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) for key, value in fit_kwargs_to_add.items(): diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 28746990e3..25a1d63215 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -72,6 +72,7 @@ def get_default_predictor( tolerate_deprecated_model: bool, sagemaker_session: Session, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> Predictor: """Converts predictor returned from ``Model.deploy()`` into a JumpStart-specific one. @@ -94,6 +95,7 @@ def get_default_predictor( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) predictor.deserializer = deserializers.retrieve_default( model_id=model_id, @@ -103,6 +105,7 @@ def get_default_predictor( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) predictor.accept = accept_types.retrieve_default( model_id=model_id, @@ -112,6 +115,7 @@ def get_default_predictor( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) predictor.content_type = content_types.retrieve_default( model_id=model_id, @@ -121,6 +125,7 @@ def get_default_predictor( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) return predictor @@ -184,7 +189,6 @@ def _add_instance_type_to_kwargs( """Sets instance type based on default or override, returns full kwargs.""" orig_instance_type = kwargs.instance_type - kwargs.instance_type = kwargs.instance_type or instance_types.retrieve_default( region=kwargs.region, model_id=kwargs.model_id, @@ -195,6 +199,7 @@ def _add_instance_type_to_kwargs( sagemaker_session=kwargs.sagemaker_session, training_instance_type=kwargs.training_instance_type, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) if not disable_instance_type_logging and orig_instance_type is None: @@ -226,6 +231,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) return kwargs @@ -247,6 +253,7 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"): @@ -287,6 +294,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ): source_dir = source_dir or script_uris.retrieve( script_scope=JumpStartScriptScope.INFERENCE, @@ -296,6 +304,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) kwargs.source_dir = source_dir @@ -319,6 +328,7 @@ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ): entry_point = entry_point or INFERENCE_ENTRY_POINT_SCRIPT_NAME @@ -350,6 +360,7 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw sagemaker_session=kwargs.sagemaker_session, script=JumpStartScriptScope.INFERENCE, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) for key, value in extra_env_vars.items(): @@ -380,6 +391,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) kwargs.model_package_arn = model_package_arn @@ -397,6 +409,7 @@ def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelI tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) for key, value in model_kwargs_to_add.items(): @@ -433,6 +446,7 @@ def _add_endpoint_name_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) kwargs.endpoint_name = kwargs.endpoint_name or ( @@ -455,6 +469,7 @@ def _add_model_name_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) kwargs.name = kwargs.name or ( @@ -476,6 +491,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: tolerate_deprecated_model=kwargs.tolerate_deprecated_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ).version if kwargs.sagemaker_session.settings.include_jumpstart_tags: @@ -498,6 +514,7 @@ def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any] tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) for key, value in deploy_kwargs_to_add.items(): @@ -520,6 +537,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) return kwargs @@ -555,6 +573,7 @@ def get_deploy_kwargs( resources: Optional[ResourceRequirements] = None, managed_instance_scaling: Optional[str] = None, endpoint_type: Optional[EndpointType] = None, + config_name: Optional[str] = None, ) -> JumpStartModelDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object.""" @@ -586,6 +605,7 @@ def get_deploy_kwargs( accept_eula=accept_eula, endpoint_logging=endpoint_logging, resources=resources, + config_name=config_name, ) deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs) @@ -639,6 +659,7 @@ def get_register_kwargs( data_input_configuration: Optional[str] = None, skip_model_validation: Optional[str] = None, source_uri: Optional[str] = None, + config_name: Optional[str] = None, ) -> JumpStartModelRegisterKwargs: """Returns kwargs required to call `register` on `sagemaker.estimator.Model` object.""" @@ -671,6 +692,7 @@ def get_register_kwargs( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + config_name=config_name, ) model_specs = verify_model_region_and_return_specs( @@ -681,6 +703,7 @@ def get_register_kwargs( sagemaker_session=sagemaker_session, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + config_name=config_name, ) register_kwargs.content_types = ( @@ -723,6 +746,7 @@ def get_init_kwargs( training_instance_type: Optional[str] = None, disable_instance_type_logging: bool = False, resources: Optional[ResourceRequirements] = None, + config_name: Optional[str] = None, ) -> JumpStartModelInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Model` object.""" @@ -754,6 +778,7 @@ def get_init_kwargs( model_package_arn=model_package_arn, training_instance_type=training_instance_type, resources=resources, + config_name=config_name, ) model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 4529bc11b9..ad70ffa805 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -92,6 +92,7 @@ def __init__( git_config: Optional[Dict[str, str]] = None, model_package_arn: Optional[str] = None, resources: Optional[ResourceRequirements] = None, + config_name: Optional[str] = None, ): """Initializes a ``JumpStartModel``. @@ -277,6 +278,8 @@ def __init__( for a model to be deployed to an endpoint. Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature. (Default: None). + config_name (Optional[str]): The name of the JumpStartConfig that can be + optionally applied to the model and override corresponding fields. Raises: ValueError: If the model ID is not recognized by JumpStart. """ @@ -326,6 +329,7 @@ def _validate_model_id_and_type(): git_config=git_config, model_package_arn=model_package_arn, resources=resources, + config_name=config_name, ) self.orig_predictor_cls = predictor_cls @@ -338,6 +342,7 @@ def _validate_model_id_and_type(): self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model self.region = model_init_kwargs.region self.sagemaker_session = model_init_kwargs.sagemaker_session + self.config_name = config_name if self.model_type == JumpStartModelType.PROPRIETARY: self.log_subscription_warning() @@ -345,6 +350,7 @@ def _validate_model_id_and_type(): super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) self.model_package_arn = model_init_kwargs.model_package_arn + self.init_kwargs = model_init_kwargs.to_kwargs_dict(False) def log_subscription_warning(self) -> None: """Log message prompting the customer to subscribe to the proprietary model.""" @@ -402,6 +408,18 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload: sagemaker_session=self.sagemaker_session, ) + def set_deployment_config(self, config_name: Optional[str]) -> None: + """Sets the deployment config to apply to the model. + + Args: + config_name (Optional[str]): + The name of the deployment config. Set to None to unset + any existing config that is applied to the model. + """ + self.__init__( + model_id=self.model_id, model_version=self.model_version, config_name=config_name + ) + def _create_sagemaker_model( self, instance_type=None, @@ -625,6 +643,7 @@ def deploy( managed_instance_scaling=managed_instance_scaling, endpoint_type=endpoint_type, model_type=self.model_type, + config_name=self.config_name, ) if ( self.model_type == JumpStartModelType.PROPRIETARY @@ -644,6 +663,7 @@ def deploy( model_type=self.model_type, scope=JumpStartScriptScope.INFERENCE, sagemaker_session=self.sagemaker_session, + config_name=self.config_name, ).model_subscription_link get_proprietary_model_subscription_error(e, subscription_link) raise @@ -659,6 +679,7 @@ def deploy( tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, model_type=self.model_type, + config_name=self.config_name, ) # If a predictor class was passed, do not mutate predictor @@ -769,6 +790,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + config_name=self.config_name, ) model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict()) diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 732493ce3b..52645f89f3 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -532,6 +532,7 @@ def get_model_url( model_version: str, region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> str: """Retrieve web url describing pretrained model. @@ -560,5 +561,6 @@ def get_model_url( sagemaker_session=sagemaker_session, scope=JumpStartScriptScope.INFERENCE, model_type=model_type, + config_name=config_name, ) return model_specs.url diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 8e53bf6f83..1de0f662da 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -744,12 +744,12 @@ def _get_regional_property( class JumpStartBenchmarkStat(JumpStartDataHolderType): - """Data class JumpStart benchmark stats.""" + """Data class JumpStart benchmark stat.""" __slots__ = ["name", "value", "unit"] def __init__(self, spec: Dict[str, Any]): - """Initializes a JumpStartBenchmarkStat object + """Initializes a JumpStartBenchmarkStat object. Args: spec (Dict[str, Any]): Dictionary representation of benchmark stat. @@ -858,7 +858,7 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "model_subscription_link", ] - def __init__(self, fields: Optional[Dict[str, Any]]): + def __init__(self, fields: Dict[str, Any]): """Initializes a JumpStartMetadataFields object. Args: @@ -877,7 +877,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.version: str = json_obj.get("version") self.min_sdk_version: str = json_obj.get("min_sdk_version") self.incremental_training_supported: bool = bool( - json_obj.get("incremental_training_supported") + json_obj.get("incremental_training_supported", False) ) self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = ( JumpStartECRSpecs(json_obj["hosting_ecr_specs"]) @@ -1038,7 +1038,7 @@ class JumpStartConfigComponent(JumpStartMetadataBaseFields): __slots__ = slots + JumpStartMetadataBaseFields.__slots__ - def __init__( # pylint: disable=super-init-not-called + def __init__( self, component_name: str, component: Optional[Dict[str, Any]], @@ -1049,7 +1049,10 @@ def __init__( # pylint: disable=super-init-not-called component_name (str): Name of the component. component (Dict[str, Any]): Dictionary representation of the config component. + Raises: + ValueError: If the component field is invalid. """ + super().__init__(component) self.component_name = component_name self.from_json(component) @@ -1080,7 +1083,7 @@ def __init__( self, base_fields: Dict[str, Any], config_components: Dict[str, JumpStartConfigComponent], - benchmark_metrics: Dict[str, JumpStartBenchmarkStat], + benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]], ): """Initializes a JumpStartMetadataConfig object from its json representation. @@ -1089,12 +1092,12 @@ def __init__( The default base fields that are used to construct the final resolved config. config_components (Dict[str, JumpStartConfigComponent]): The list of components that are used to construct the resolved config. - benchmark_metrics (Dict[str, JumpStartBenchmarkStat]): + benchmark_metrics (Dict[str, List[JumpStartBenchmarkStat]]): The dictionary of benchmark metrics with name being the key. """ self.base_fields = base_fields self.config_components: Dict[str, JumpStartConfigComponent] = config_components - self.benchmark_metrics: Dict[str, JumpStartBenchmarkStat] = benchmark_metrics + self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = benchmark_metrics self.resolved_metadata_config: Optional[Dict[str, Any]] = None def to_json(self) -> Dict[str, Any]: @@ -1104,7 +1107,7 @@ def to_json(self) -> Dict[str, Any]: @property def resolved_config(self) -> Dict[str, Any]: - """Returns the final config that is resolved from the list of components. + """Returns the final config that is resolved from the components map. Construct the final config by applying the list of configs from list index, and apply to the base default fields in the current model specs. @@ -1139,7 +1142,7 @@ def __init__( Args: configs (Dict[str, JumpStartMetadataConfig]): - List of configs that the current model has. + The map of JumpStartMetadataConfig object, with config name being the key. config_rankings (JumpStartConfigRanking): Config ranking class represents the ranking of the configs in the model. scope (JumpStartScriptScope): @@ -1158,19 +1161,30 @@ def get_top_config_from_ranking( self, ranking_name: str = JumpStartConfigRankingName.DEFAULT, instance_type: Optional[str] = None, - ) -> JumpStartMetadataConfig: - """Gets the best the config based on config ranking.""" + ) -> Optional[JumpStartMetadataConfig]: + """Gets the best the config based on config ranking. + + Args: + ranking_name (str): + The ranking name that config priority is based on. + instance_type (Optional[str]): + The instance type which the config selection is based on. + + Raises: + ValueError: If the config exists but missing config ranking. + NotImplementedError: If the scope is unrecognized. + """ if self.configs and ( not self.config_rankings or not self.config_rankings.get(ranking_name) ): - raise ValueError("Config exists but missing config ranking.") + raise ValueError(f"Config exists but missing config ranking {ranking_name}.") if self.scope == JumpStartScriptScope.INFERENCE: instance_type_attribute = "supported_inference_instance_types" elif self.scope == JumpStartScriptScope.TRAINING: instance_type_attribute = "supported_training_instance_types" else: - raise ValueError(f"Unknown script scope {self.scope}") + raise NotImplementedError(f"Unknown script scope {self.scope}") rankings = self.config_rankings.get(ranking_name) for config_name in rankings.rankings: @@ -1198,12 +1212,13 @@ class JumpStartModelSpecs(JumpStartMetadataBaseFields): __slots__ = JumpStartMetadataBaseFields.__slots__ + slots - def __init__(self, spec: Dict[str, Any]): # pylint: disable=super-init-not-called + def __init__(self, spec: Dict[str, Any]): """Initializes a JumpStartModelSpecs object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of spec. """ + super().__init__(spec) self.from_json(spec) if self.inference_configs and self.inference_configs.get_top_config_from_ranking(): super().from_json(self.inference_configs.get_top_config_from_ranking().resolved_config) @@ -1245,8 +1260,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: ), ( { - stat_name: JumpStartBenchmarkStat(stat) - for stat_name, stat in config.get("benchmark_metrics").items() + stat_name: [JumpStartBenchmarkStat(stat) for stat in stats] + for stat_name, stats in config.get("benchmark_metrics").items() } if config and config.get("benchmark_metrics") else None @@ -1297,8 +1312,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: ), ( { - stat_name: JumpStartBenchmarkStat(stat) - for stat_name, stat in config.get("benchmark_metrics").items() + stat_name: [JumpStartBenchmarkStat(stat) for stat in stats] + for stat_name, stats in config.get("benchmark_metrics").items() } if config and config.get("benchmark_metrics") else None @@ -1330,13 +1345,26 @@ def set_config( config_name (str): Name of the config. scope (JumpStartScriptScope, optional): Scope of the config. Defaults to JumpStartScriptScope.INFERENCE. + + Raises: + ValueError: If the scope is not supported, or cannot find config name. """ if scope == JumpStartScriptScope.INFERENCE: - super().from_json(self.inference_configs.configs[config_name].resolved_config) + metadata_configs = self.inference_configs elif scope == JumpStartScriptScope.TRAINING and self.training_supported: - super().from_json(self.training_configs.configs[config_name].resolved_config) + metadata_configs = self.training_configs else: - raise ValueError(f"Unknown Jumpstart Script scope {scope}.") + raise ValueError(f"Unknown Jumpstart script scope {scope}.") + + config_object = metadata_configs.configs.get(config_name) + if not config_object: + error_msg = f"Cannot find Jumpstart config name {config_name}. " + config_names = list(metadata_configs.configs.keys()) + if config_names: + error_msg += f"List of config names that is supported by the model: {config_names}" + raise ValueError(error_msg) + + super().from_json(config_object.resolved_config) def supports_prepacked_inference(self) -> bool: """Returns True if the model has a prepacked inference artifact.""" @@ -1437,11 +1465,11 @@ class JumpStartKwargs(JumpStartDataHolderType): SERIALIZATION_EXCLUSION_SET: Set[str] = set() - def to_kwargs_dict(self): + def to_kwargs_dict(self, exclude_keys: bool = True): """Serializes object to dictionary to be used for kwargs for method arguments.""" kwargs_dict = {} for field in self.__slots__: - if field not in self.SERIALIZATION_EXCLUSION_SET: + if exclude_keys and field not in self.SERIALIZATION_EXCLUSION_SET or not exclude_keys: att_value = getattr(self, field) if att_value is not None: kwargs_dict[field] = getattr(self, field) @@ -1479,6 +1507,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "model_package_arn", "training_instance_type", "resources", + "config_name", ] SERIALIZATION_EXCLUSION_SET = { @@ -1491,6 +1520,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "region", "model_package_arn", "training_instance_type", + "config_name", } def __init__( @@ -1522,6 +1552,7 @@ def __init__( model_package_arn: Optional[str] = None, training_instance_type: Optional[str] = None, resources: Optional[ResourceRequirements] = None, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartModelInitKwargs object.""" @@ -1552,6 +1583,7 @@ def __init__( self.model_package_arn = model_package_arn self.training_instance_type = training_instance_type self.resources = resources + self.config_name = config_name class JumpStartModelDeployKwargs(JumpStartKwargs): @@ -1587,6 +1619,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "endpoint_logging", "resources", "endpoint_type", + "config_name", ] SERIALIZATION_EXCLUSION_SET = { @@ -1598,6 +1631,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "tolerate_vulnerable_model", "sagemaker_session", "training_instance_type", + "config_name", } def __init__( @@ -1631,6 +1665,7 @@ def __init__( endpoint_logging: Optional[bool] = None, resources: Optional[ResourceRequirements] = None, endpoint_type: Optional[EndpointType] = None, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartModelDeployKwargs object.""" @@ -1663,6 +1698,7 @@ def __init__( self.endpoint_logging = endpoint_logging self.resources = resources self.endpoint_type = endpoint_type + self.config_name = config_name class JumpStartEstimatorInitKwargs(JumpStartKwargs): @@ -1723,6 +1759,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "disable_output_compression", "enable_infra_check", "enable_remote_debug", + "config_name", ] SERIALIZATION_EXCLUSION_SET = { @@ -1732,6 +1769,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_id", "model_version", "model_type", + "config_name", } def __init__( @@ -1790,6 +1828,7 @@ def __init__( disable_output_compression: Optional[bool] = None, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -1849,6 +1888,7 @@ def __init__( self.disable_output_compression = disable_output_compression self.enable_infra_check = enable_infra_check self.enable_remote_debug = enable_remote_debug + self.config_name = config_name class JumpStartEstimatorFitKwargs(JumpStartKwargs): @@ -1867,6 +1907,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "tolerate_deprecated_model", "tolerate_vulnerable_model", "sagemaker_session", + "config_name", ] SERIALIZATION_EXCLUSION_SET = { @@ -1877,6 +1918,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "tolerate_deprecated_model", "tolerate_vulnerable_model", "sagemaker_session", + "config_name", } def __init__( @@ -1893,6 +1935,7 @@ def __init__( tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -1908,6 +1951,7 @@ def __init__( self.tolerate_deprecated_model = tolerate_deprecated_model self.tolerate_vulnerable_model = tolerate_vulnerable_model self.sagemaker_session = sagemaker_session + self.config_name = config_name class JumpStartEstimatorDeployKwargs(JumpStartKwargs): @@ -1953,6 +1997,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): "tolerate_vulnerable_model", "model_name", "use_compiled_model", + "config_name", ] SERIALIZATION_EXCLUSION_SET = { @@ -1962,6 +2007,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): "model_id", "model_version", "sagemaker_session", + "config_name", } def __init__( @@ -2005,6 +2051,7 @@ def __init__( tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, use_compiled_model: bool = False, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -2047,6 +2094,7 @@ def __init__( self.tolerate_deprecated_model = tolerate_deprecated_model self.tolerate_vulnerable_model = tolerate_vulnerable_model self.use_compiled_model = use_compiled_model + self.config_name = config_name class JumpStartModelRegisterKwargs(JumpStartKwargs): @@ -2081,6 +2129,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "data_input_configuration", "skip_model_validation", "source_uri", + "config_name", ] SERIALIZATION_EXCLUSION_SET = { @@ -2090,6 +2139,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "model_id", "model_version", "sagemaker_session", + "config_name", } def __init__( @@ -2122,6 +2172,7 @@ def __init__( data_input_configuration: Optional[str] = None, skip_model_validation: Optional[str] = None, source_uri: Optional[str] = None, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartModelRegisterKwargs object.""" @@ -2154,3 +2205,4 @@ def __init__( self.data_input_configuration = data_input_configuration self.skip_model_validation = skip_model_validation self.source_uri = source_uri + self.config_name = config_name diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 63cfac0939..1459594faa 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -547,6 +547,7 @@ def verify_model_region_and_return_specs( tolerate_deprecated_model: bool = False, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> JumpStartModelSpecs: """Verifies that an acceptable model_id, version, scope, and region combination is provided. @@ -569,6 +570,7 @@ def verify_model_region_and_return_specs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Raises: NotImplementedError: If the scope is not supported. @@ -634,6 +636,9 @@ def verify_model_region_and_return_specs( scope=constants.JumpStartScriptScope.TRAINING, ) + if model_specs and config_name: + model_specs.set_config(config_name, scope) + return model_specs @@ -890,7 +895,11 @@ def get_config_names( scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, ) -> List[str]: - """Returns a list of config names for the given model ID and region.""" + """Returns a list of config names for the given model ID and region. + + Raises: + ValueError: If the script scope is not supported by JumpStart. + """ model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, @@ -905,7 +914,7 @@ def get_config_names( elif scope == enums.JumpStartScriptScope.TRAINING: metadata_configs = model_specs.training_configs else: - raise ValueError(f"Unknown script scope {scope}.") + raise ValueError(f"Unknown script scope: {scope}.") return list(metadata_configs.configs.keys()) if metadata_configs else [] @@ -919,7 +928,11 @@ def get_benchmark_stats( scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, ) -> Dict[str, List[JumpStartBenchmarkStat]]: - """Returns benchmark stats for the given model ID and region.""" + """Returns benchmark stats for the given model ID and region. + + Raises: + ValueError: If the script scope is not supported by JumpStart. + """ model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, @@ -934,7 +947,7 @@ def get_benchmark_stats( elif scope == enums.JumpStartScriptScope.TRAINING: metadata_configs = model_specs.training_configs else: - raise ValueError(f"Unknown script scope {scope}.") + raise ValueError(f"Unknown script scope: {scope}.") if not config_names: config_names = metadata_configs.configs.keys() if metadata_configs else [] @@ -942,7 +955,7 @@ def get_benchmark_stats( benchmark_stats = {} for config_name in config_names: if config_name not in metadata_configs.configs: - raise ValueError(f"Unknown config name: '{config_name}'") + raise ValueError(f"Unknown config name: {config_name}") benchmark_stats[config_name] = metadata_configs.configs.get(config_name).benchmark_metrics return benchmark_stats @@ -956,8 +969,12 @@ def get_jumpstart_configs( sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, -) -> Dict[str, List[JumpStartMetadataConfig]]: - """Returns metadata configs for the given model ID and region.""" +) -> Dict[str, JumpStartMetadataConfig]: + """Returns metadata configs for the given model ID and region. + + Raises: + ValueError: If the script scope is not supported by JumpStart. + """ model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, @@ -972,7 +989,7 @@ def get_jumpstart_configs( elif scope == enums.JumpStartScriptScope.TRAINING: metadata_configs = model_specs.training_configs else: - raise ValueError(f"Unknown script scope {scope}.") + raise ValueError(f"Unknown script scope: {scope}.") if not config_names: config_names = metadata_configs.configs.keys() if metadata_configs else [] diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index c7098a1185..bcb0365f7b 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -171,6 +171,7 @@ def validate_hyperparameters( sagemaker_session: Optional[session.Session] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + config_name: Optional[str] = None, ) -> None: """Validate hyperparameters for JumpStart models. @@ -193,6 +194,7 @@ def validate_hyperparameters( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Raises: JumpStartHyperparametersError: If the hyperparameters are not formatted correctly, @@ -218,6 +220,7 @@ def validate_hyperparameters( sagemaker_session=sagemaker_session, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + config_name=config_name, ) hyperparameters_specs = model_specs.hyperparameters diff --git a/src/sagemaker/metric_definitions.py b/src/sagemaker/metric_definitions.py index 71dd26db45..0c066ff801 100644 --- a/src/sagemaker/metric_definitions.py +++ b/src/sagemaker/metric_definitions.py @@ -33,6 +33,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> Optional[List[Dict[str, str]]]: """Retrieves the default training metric definitions for the model matching the given arguments. @@ -56,6 +57,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: The default metric definitions to use for the model or None. @@ -76,4 +78,5 @@ def retrieve_default( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 937180bd44..122647e536 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -34,6 +34,7 @@ def retrieve( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> str: """Retrieves the model artifact Amazon S3 URI for the model matching the given arguments. @@ -57,6 +58,8 @@ def retrieve( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + Returns: str: The model artifact S3 URI for the corresponding model. @@ -81,4 +84,5 @@ def retrieve( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) diff --git a/src/sagemaker/resource_requirements.py b/src/sagemaker/resource_requirements.py index df14ac558f..7808d0172a 100644 --- a/src/sagemaker/resource_requirements.py +++ b/src/sagemaker/resource_requirements.py @@ -37,6 +37,7 @@ def retrieve_default( model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> ResourceRequirements: """Retrieves the default resource requirements for the model matching the given arguments. @@ -62,6 +63,7 @@ def retrieve_default( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get host requirements specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default resource requirements to use for the model. @@ -87,4 +89,5 @@ def retrieve_default( model_type=model_type, sagemaker_session=sagemaker_session, instance_type=instance_type, + config_name=config_name, ) diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index 9a1c4933d2..6e10785498 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -33,6 +33,7 @@ def retrieve( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> str: """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -55,6 +56,7 @@ def retrieve( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The model script URI for the corresponding model. @@ -78,4 +80,5 @@ def retrieve( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index aefb52bd97..d197df731c 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -45,6 +45,7 @@ def retrieve_options( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> List[BaseSerializer]: """Retrieves the supported serializers for the model matching the given arguments. @@ -66,6 +67,7 @@ def retrieve_options( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: List[SimpleBaseSerializer]: The supported serializers to use for the model. @@ -85,6 +87,7 @@ def retrieve_options( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) @@ -96,6 +99,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> BaseSerializer: """Retrieves the default serializer for the model matching the given arguments. @@ -117,6 +121,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: SimpleBaseSerializer: The default serializer to use for the model. @@ -137,4 +142,5 @@ def retrieve_default( tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 2b6856b1f3..f165a513a9 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -6270,6 +6270,10 @@ "framework_version": "1.5.0", "py_version": "py3", }, + "default_inference_instance_type": "ml.p2.xlarge", + "supported_inference_instance_type": ["ml.p2.xlarge", "ml.p3.xlarge"], + "default_training_instance_type": "ml.p2.xlarge", + "supported_training_instance_type": ["ml.p2.xlarge", "ml.p3.xlarge"], "hosting_artifact_key": "pytorch-infer/infer-pytorch-eqa-bert-base-cased.tar.gz", "hosting_script_key": "source-directory-tarballs/pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz", "inference_vulnerable": False, @@ -7658,25 +7662,25 @@ "inference_configs": { "neuron-inference": { "benchmark_metrics": { - "ml.inf2.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"} + "ml.inf2.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] }, - "component_names": ["neuron-base"], + "component_names": ["neuron-inference"], }, "neuron-inference-budget": { "benchmark_metrics": { - "ml.inf2.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"} + "ml.inf2.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] }, "component_names": ["neuron-base"], }, "gpu-inference-budget": { "benchmark_metrics": { - "ml.p3.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"} + "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] }, "component_names": ["gpu-inference-budget"], }, "gpu-inference": { "benchmark_metrics": { - "ml.p3.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"} + "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] }, "component_names": ["gpu-inference"], }, @@ -7686,7 +7690,13 @@ "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"] }, "neuron-inference": { + "default_inference_instance_type": "ml.inf2.xlarge", "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"], + "hosting_ecr_specs": { + "framework": "huggingface-llm-neuronx", + "framework_version": "0.0.17", + "py_version": "py310", + }, "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-inference/model/", "hosting_instance_type_variants": { "regional_aliases": { @@ -7738,27 +7748,27 @@ "training_configs": { "neuron-training": { "benchmark_metrics": { - "ml.tr1n1.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"}, - "ml.tr1n1.4xlarge": {"name": "Latency", "value": "50", "unit": "Tokens/S"}, + "ml.tr1n1.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}], + "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], }, "component_names": ["neuron-training"], }, "neuron-training-budget": { "benchmark_metrics": { - "ml.tr1n1.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"}, - "ml.tr1n1.4xlarge": {"name": "Latency", "value": "50", "unit": "Tokens/S"}, + "ml.tr1n1.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}], + "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], }, "component_names": ["neuron-training-budget"], }, "gpu-training": { "benchmark_metrics": { - "ml.p3.2xlarge": {"name": "Latency", "value": "200", "unit": "Tokens/S"}, + "ml.p3.2xlarge": [{"name": "Latency", "value": "200", "unit": "Tokens/S"}], }, "component_names": ["gpu-training"], }, "gpu-training-budget": { "benchmark_metrics": { - "ml.p3.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"} + "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] }, "component_names": ["gpu-training-budget"], }, diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index ce5f15b287..ae79bb8b55 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -46,6 +46,8 @@ from sagemaker.model import Model from sagemaker.predictor import Predictor from tests.unit.sagemaker.jumpstart.utils import ( + get_prototype_manifest, + get_prototype_spec_with_configs, get_special_model_spec, overwrite_dictionary, ) @@ -1109,6 +1111,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): "region", "tolerate_vulnerable_model", "tolerate_deprecated_model", + "config_name", } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -1366,6 +1369,7 @@ def test_incremental_training_with_unsupported_model_logs_warning( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, + config_name=None, ) @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @@ -1417,6 +1421,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, + config_name=None, ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @@ -1848,6 +1853,55 @@ def test_jumpstart_estimator_session( assert len(s3_clients) == 1 assert list(s3_clients)[0] == session.s3_client + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_initialization_with_config_name( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, config_name="neuron-training") + + mock_estimator_init.assert_called_once_with( + instance_type="ml.p2.xlarge", + instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3", + model_uri="s3://jumpstart-cache-prod-us-west-2/artifacts/meta-textgeneration-llama-2-7b/" + "neuron-training/model/", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" + "transfer_learning/eqa/v1.0.0/sourcedir.tar.gz", + entry_point="transfer_learning.py", + hyperparameters={"epochs": "3", "adam-learning-rate": "2e-05", "batch-size": "4"}, + role="fake role! do not use!", + sagemaker_session=sagemaker_session, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + ], + enable_network_isolation=False, + ) + + estimator.fit() + + mock_estimator_fit.assert_called_once_with(wait=True) + def test_jumpstart_estimator_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 8b00eb5bcd..cb7b602fbf 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -40,6 +40,7 @@ from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from tests.unit.sagemaker.jumpstart.utils import ( + get_prototype_spec_with_configs, get_spec_from_base_spec, get_special_model_spec, overwrite_dictionary, @@ -715,6 +716,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): "tolerate_deprecated_model", "instance_type", "model_package_arn", + "config_name", } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -786,6 +788,7 @@ def test_no_predictor_returns_default_predictor( tolerate_vulnerable_model=False, sagemaker_session=model.sagemaker_session, model_type=JumpStartModelType.OPEN_WEIGHTS, + config_name=None, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @@ -1414,6 +1417,153 @@ def test_model_local_mode( endpoint_logging=False, ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_initialization_with_config_name( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id, config_name="neuron-inference") + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.inf2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + ], + wait=True, + endpoint_logging=False, + ) + + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_set_deployment_config( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_model_spec + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.p2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + ], + wait=True, + endpoint_logging=False, + ) + + mock_get_model_specs.reset_mock() + mock_model_deploy.reset_mock() + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + model.set_deployment_config("neuron-inference") + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.inf2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + ], + wait=True, + endpoint_logging=False, + ) + + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_unset_deployment_config( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id, config_name="neuron-inference") + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.inf2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + ], + wait=True, + endpoint_logging=False, + ) + + mock_get_model_specs.reset_mock() + mock_model_deploy.reset_mock() + mock_get_model_specs.side_effect = get_prototype_model_spec + model.set_deployment_config(None) + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.p2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + ], + wait=True, + endpoint_logging=False, + ) + def test_jumpstart_model_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 3048bbc320..5ca01c3c52 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import copy +import pytest from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.types import ( JumpStartBenchmarkStat, @@ -934,9 +935,9 @@ def test_inference_configs_parsing(): assert specs1.incremental_training_supported assert specs1.hosting_ecr_specs == JumpStartECRSpecs( { - "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", + "framework": "huggingface-llm-neuronx", + "framework_version": "0.0.17", + "py_version": "py310", } ) assert specs1.training_ecr_specs == JumpStartECRSpecs( @@ -946,7 +947,10 @@ def test_inference_configs_parsing(): "py_version": "py3", } ) - assert specs1.hosting_artifact_key == "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz" + assert ( + specs1.hosting_artifact_key + == "artifacts/meta-textgeneration-llama-2-7b/neuron-inference/model/" + ) assert specs1.training_artifact_key == "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz" assert ( specs1.hosting_script_key @@ -1019,16 +1023,58 @@ def test_inference_configs_parsing(): config = specs1.inference_configs.get_top_config_from_ranking() assert config.benchmark_metrics == { - "ml.inf2.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ) + "ml.inf2.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ] } assert len(config.config_components) == 1 - assert config.config_components["neuron-base"] == JumpStartConfigComponent( - "neuron-base", - {"supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"]}, + assert config.config_components["neuron-inference"] == JumpStartConfigComponent( + "neuron-inference", + { + "default_inference_instance_type": "ml.inf2.xlarge", + "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"], + "hosting_ecr_specs": { + "framework": "huggingface-llm-neuronx", + "framework_version": "0.0.17", + "py_version": "py310", + }, + "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-inference/model/", + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-hosting:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + } + }, + "variants": {"inf2": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, + }, + }, ) - assert list(config.config_components.keys()) == ["neuron-base"] + assert list(config.config_components.keys()) == ["neuron-inference"] + + +def test_set_inference_configs(): + spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} + specs1 = JumpStartModelSpecs(spec) + + assert list(specs1.inference_config_components.keys()) == [ + "neuron-base", + "neuron-inference", + "neuron-budget", + "gpu-inference", + "gpu-inference-budget", + ] + + with pytest.raises(ValueError) as error: + specs1.set_config("invalid_name") + assert "Cannot find Jumpstart config name invalid_name." + "List of config names that is supported by the model: " + "['neuron-inference', 'neuron-inference-budget', " + "'gpu-inference-budget', 'gpu-inference']" in str(error.value) + + assert specs1.supported_inference_instance_types == ["ml.inf2.xlarge", "ml.inf2.2xlarge"] + specs1.set_config("gpu-inference") + assert specs1.supported_inference_instance_types == ["ml.p2.xlarge", "ml.p3.2xlarge"] def test_training_configs_parsing(): @@ -1133,12 +1179,12 @@ def test_training_configs_parsing(): config = specs1.training_configs.get_top_config_from_ranking() assert config.benchmark_metrics == { - "ml.tr1n1.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ), - "ml.tr1n1.4xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "50", "unit": "Tokens/S"} - ), + "ml.tr1n1.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ], + "ml.tr1n1.4xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "50", "unit": "Tokens/S"}) + ], } assert len(config.config_components) == 1 assert config.config_components["neuron-training"] == JumpStartConfigComponent( @@ -1192,3 +1238,13 @@ def test_set_training_config(): specs1.training_artifact_key == "artifacts/meta-textgeneration-llama-2-7b/gpu-training-budget/model/" ) + + with pytest.raises(ValueError) as error: + specs1.set_config("invalid_name", scope=JumpStartScriptScope.TRAINING) + assert "Cannot find Jumpstart config name invalid_name." + "List of config names that is supported by the model: " + "['neuron-training', 'neuron-training-budget', " + "'gpu-training-budget', 'gpu-training']" in str(error.value) + + with pytest.raises(ValueError) as error: + specs1.set_config("invalid_name", scope="unknown scope") diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index e7a7d522c3..c1ea8abcb8 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1598,24 +1598,24 @@ def test_get_jumpstart_benchmark_stats_full_list( "mock-region", "mock-model", "mock-model-version", config_names=None ) == { "neuron-inference": { - "ml.inf2.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ) + "ml.inf2.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ] }, "neuron-inference-budget": { - "ml.inf2.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ) + "ml.inf2.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ] }, "gpu-inference-budget": { - "ml.p3.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ) + "ml.p3.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ] }, "gpu-inference": { - "ml.p3.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ) + "ml.p3.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ] }, } @@ -1633,14 +1633,14 @@ def test_get_jumpstart_benchmark_stats_partial_list( config_names=["neuron-inference-budget", "gpu-inference-budget"], ) == { "neuron-inference-budget": { - "ml.inf2.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ) + "ml.inf2.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ] }, "gpu-inference-budget": { - "ml.p3.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ) + "ml.p3.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ] }, } @@ -1658,9 +1658,9 @@ def test_get_jumpstart_benchmark_stats_single_stat( config_names=["neuron-inference-budget"], ) == { "neuron-inference-budget": { - "ml.inf2.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ) + "ml.inf2.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ] } } @@ -1695,16 +1695,16 @@ def test_get_jumpstart_benchmark_stats_training( config_names=["neuron-training", "gpu-training-budget"], ) == { "neuron-training": { - "ml.tr1n1.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ), - "ml.tr1n1.4xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "50", "unit": "Tokens/S"} - ), + "ml.tr1n1.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ], + "ml.tr1n1.4xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "50", "unit": "Tokens/S"}) + ], }, "gpu-training-budget": { - "ml.p3.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ) + "ml.p3.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ] }, } diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index e102251060..aee1497ec9 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -222,6 +222,23 @@ def get_base_spec_with_prototype_configs( return JumpStartModelSpecs(spec) +def get_prototype_spec_with_configs( + region: str = None, + model_id: str = None, + version: str = None, + s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, +) -> JumpStartModelSpecs: + spec = copy.deepcopy(PROTOTYPICAL_MODEL_SPECS_DICT[model_id]) + inference_configs = {**INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} + training_configs = {**TRAINING_CONFIGS, **TRAINING_CONFIG_RANKINGS} + + spec.update(inference_configs) + spec.update(training_configs) + + return JumpStartModelSpecs(spec) + + def patched_retrieval_function( _modelCacheObj: JumpStartModelsCache, key: JumpStartCachedS3ContentKey,