Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: populate default config name to model #4617

Merged
merged 14 commits into from
Apr 26, 2024
6 changes: 4 additions & 2 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def retrieve_default(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
config_name: Optional[str] = None,
) -> Predictor:
"""Retrieves the default predictor for the model matching the given arguments.

Expand All @@ -65,6 +66,8 @@ def retrieve_default(
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
(exception not raised). False if these models should raise an exception.
(Default: False).
config_name (Optional[str]): The name of the configuration to use for the
predictor. (Default: None)
Returns:
Predictor: The default predictor to use for the model.

Expand All @@ -91,10 +94,9 @@ def retrieve_default(
model_id = inferred_model_id
model_version = model_version or inferred_model_version or "*"
inference_component_name = inference_component_name or inferred_inference_component_name
config_name = inferred_config_name or None
config_name = config_name or inferred_config_name or None
else:
model_version = model_version or "*"
config_name = None
Captainia marked this conversation as resolved.
Show resolved Hide resolved

predictor = Predictor(
endpoint_name=endpoint_name,
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/sagemaker/jumpstart/estimator/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,8 @@ def test_jumpstart_estimator_attach_eula_model(
additional_kwargs={
"model_id": "gemma-model",
"model_version": "*",
'tolerate_vulnerable_model': True,
'tolerate_deprecated_model': True,
"environment": {"accept_eula": "true"},
},
)
Expand Down Expand Up @@ -1056,6 +1058,8 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case(
additional_kwargs={
"model_id": "js-trainable-model-prepacked",
"model_version": "1.0.0",
'tolerate_vulnerable_model': True,
'tolerate_deprecated_model': True,
},
)

Expand Down
Loading