Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support config_name in all JumpStart interfaces #4583

Merged
merged 12 commits into from
Apr 22, 2024
3 changes: 3 additions & 0 deletions src/sagemaker/accept_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def retrieve_default(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
config_name: Optional[str] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the future, we should think of a way to consolidate all these JS related fields perhaps into a dataclass or kwargs, so we don't need to update all these function prototypes whenever a new feature is added.

) -> str:
"""Retrieves the default accept type for the model matching the given arguments.

Expand All @@ -98,6 +99,7 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check with @judyheflin about how to describe this

Returns:
str: The default accept type to use for the model.

Expand All @@ -117,4 +119,5 @@ def retrieve_default(
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
config_name=config_name,
)
3 changes: 3 additions & 0 deletions src/sagemaker/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def retrieve_default(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
config_name: Optional[str] = None,
) -> str:
"""Retrieves the default content type for the model matching the given arguments.

Expand All @@ -98,6 +99,7 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: The default content type to use for the model.

Expand All @@ -117,6 +119,7 @@ def retrieve_default(
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
config_name=config_name,
)


Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def retrieve_default(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
config_name: Optional[str] = None,
) -> BaseDeserializer:
"""Retrieves the default deserializer for the model matching the given arguments.

Expand All @@ -118,6 +119,7 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
BaseDeserializer: The default deserializer to use for the model.

Expand All @@ -138,4 +140,5 @@ def retrieve_default(
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
config_name=config_name,
)
3 changes: 3 additions & 0 deletions src/sagemaker/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def retrieve_default(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
config_name: Optional[str] = None,
) -> Dict[str, str]:
"""Retrieves the default container environment variables for the model matching the arguments.

Expand Down Expand Up @@ -65,6 +66,7 @@ def retrieve_default(
variables specific for the instance type.
script (JumpStartScriptScope): The JumpStart script for which to retrieve environment
variables.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
dict: The variables to use for the model.

Expand All @@ -87,4 +89,5 @@ def retrieve_default(
sagemaker_session=sagemaker_session,
instance_type=instance_type,
script=script,
config_name=config_name,
)
3 changes: 3 additions & 0 deletions src/sagemaker/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def retrieve_default(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name: Optional[str] = None,
) -> Dict[str, str]:
"""Retrieves the default training hyperparameters for the model matching the given arguments.

Expand Down Expand Up @@ -66,6 +67,7 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
dict: The hyperparameters to use for the model.

Expand All @@ -86,6 +88,7 @@ def retrieve_default(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)


Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def retrieve(
inference_tool=None,
serverless_inference_config=None,
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name=None,
) -> str:
"""Retrieves the ECR URI for the Docker image matching the given arguments.

Expand Down Expand Up @@ -121,6 +122,7 @@ def retrieve(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).

Returns:
str: The ECR URI for the corresponding SageMaker Docker image.
Expand Down Expand Up @@ -160,6 +162,7 @@ def retrieve(
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]):
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/instance_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def retrieve_default(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
training_instance_type: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
config_name: Optional[str] = None,
) -> str:
"""Retrieves the default instance type for the model matching the given arguments.

Expand Down Expand Up @@ -64,6 +65,7 @@ def retrieve_default(
Optionally supply this to get a inference instance type conditioned
on the training instance, to ensure compatability of training artifact to inference
instance. (Default: None).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: The default instance type to use for the model.

Expand All @@ -88,6 +90,7 @@ def retrieve_default(
sagemaker_session=sagemaker_session,
training_instance_type=training_instance_type,
model_type=model_type,
config_name=config_name,
)


Expand Down
7 changes: 7 additions & 0 deletions src/sagemaker/jumpstart/artifacts/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _retrieve_default_environment_variables(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
config_name: Optional[str] = None,
) -> Dict[str, str]:
"""Retrieves the inference environment variables for the model matching the given arguments.

Expand Down Expand Up @@ -68,6 +69,7 @@ def _retrieve_default_environment_variables(
environment variables specific for the instance type.
script (JumpStartScriptScope): The JumpStart script for which to retrieve
environment variables.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
dict: the inference environment variables to use for the model.
"""
Expand All @@ -84,6 +86,7 @@ def _retrieve_default_environment_variables(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

default_environment_variables: Dict[str, str] = {}
Expand Down Expand Up @@ -121,6 +124,7 @@ def _retrieve_default_environment_variables(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
instance_type=instance_type,
config_name=config_name,
)
)

Expand Down Expand Up @@ -167,6 +171,7 @@ def _retrieve_gated_model_uri_env_var_value(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
config_name: Optional[str] = None,
) -> Optional[str]:
"""Retrieves the gated model env var URI matching the given arguments.

Expand All @@ -190,6 +195,7 @@ def _retrieve_gated_model_uri_env_var_value(
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
instance_type (str): An instance type to optionally supply in order to get
environment variables specific for the instance type.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).

Returns:
Optional[str]: the s3 URI to use for the environment variable, or None if the model does not
Expand All @@ -211,6 +217,7 @@ def _retrieve_gated_model_uri_env_var_value(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

s3_key: Optional[str] = (
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/artifacts/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def _retrieve_default_hyperparameters(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
config_name: Optional[str] = None,
):
"""Retrieves the training hyperparameters for the model matching the given arguments.

Expand Down Expand Up @@ -66,6 +67,7 @@ def _retrieve_default_hyperparameters(
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
instance_type (str): An instance type to optionally supply in order to get hyperparameters
specific for the instance type.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
dict: the hyperparameters to use for the model.
"""
Expand All @@ -82,6 +84,7 @@ def _retrieve_default_hyperparameters(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

default_hyperparameters: Dict[str, str] = {}
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/artifacts/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def _retrieve_image_uri(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name: Optional[str] = None,
):
"""Retrieves the container image URI for JumpStart models.

Expand Down Expand Up @@ -95,6 +96,7 @@ def _retrieve_image_uri(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: the ECR URI for the corresponding SageMaker Docker image.

Expand All @@ -116,6 +118,7 @@ def _retrieve_image_uri(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

if image_scope == JumpStartScriptScope.INFERENCE:
Expand Down Expand Up @@ -200,4 +203,5 @@ def _retrieve_image_uri(
distribution=distribution,
base_framework_version=base_framework_version_override or base_framework_version,
training_compiler_config=training_compiler_config,
config_name=config_name,
)
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/artifacts/incremental_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _model_supports_incremental_training(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name: Optional[str] = None,
) -> bool:
"""Returns True if the model supports incremental training.

Expand All @@ -54,6 +55,7 @@ def _model_supports_incremental_training(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
bool: the support status for incremental training.
"""
Expand All @@ -70,6 +72,7 @@ def _model_supports_incremental_training(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

return model_specs.supports_incremental_training()
6 changes: 6 additions & 0 deletions src/sagemaker/jumpstart/artifacts/instance_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _retrieve_default_instance_type(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
training_instance_type: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
config_name: Optional[str] = None,
) -> str:
"""Retrieves the default instance type for the model.

Expand Down Expand Up @@ -68,6 +69,7 @@ def _retrieve_default_instance_type(
Optionally supply this to get a inference instance type conditioned
on the training instance, to ensure compatability of training artifact to inference
instance. (Default: None).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: the default instance type to use for the model or None.

Expand All @@ -89,6 +91,7 @@ def _retrieve_default_instance_type(
tolerate_deprecated_model=tolerate_deprecated_model,
model_type=model_type,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

if scope == JumpStartScriptScope.INFERENCE:
Expand Down Expand Up @@ -128,6 +131,7 @@ def _retrieve_instance_types(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
training_instance_type: Optional[str] = None,
config_name: Optional[str] = None,
) -> List[str]:
"""Retrieves the supported instance types for the model.

Expand Down Expand Up @@ -156,6 +160,7 @@ def _retrieve_instance_types(
Optionally supply this to get a inference instance type conditioned
on the training instance, to ensure compatability of training artifact to inference
instance. (Default: None).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
list: the supported instance types to use for the model or None.

Expand All @@ -176,6 +181,7 @@ def _retrieve_instance_types(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

if scope == JumpStartScriptScope.INFERENCE:
Expand Down
Loading