Skip to content

Commit

Permalink
fix: allow for inf spec and server override to be passed (#4769)
Browse files Browse the repository at this point in the history
* fix: allow for just inf spec and server overide to pass

* fix formatting
  • Loading branch information
samruds authored Jul 3, 2024
1 parent ce10e01 commit 3d1a4f7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,8 +881,8 @@ def _build_for_model_server(self): # pylint: disable=R0911, R1710
if self.model_metadata:
mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH)

if not self.model and not mlflow_path:
raise ValueError("Missing required parameter `model` or 'ml_flow' path")
if not self.model and not mlflow_path and not self.inference_spec:
raise ValueError("Missing required parameter `model` or 'ml_flow' path or inf_spec")

if self.model_server == ModelServer.TORCHSERVE:
return self._build_for_torchserve()
Expand Down
18 changes: 16 additions & 2 deletions tests/unit/sagemaker/serve/builder/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_model_server_override_djl_without_model_or_mlflow(self, mock_serve_sett
)
self.assertRaisesRegex(
Exception,
"Missing required parameter `model` or 'ml_flow' path",
"Missing required parameter `model` or 'ml_flow' path or inf_spec",
builder.build,
Mode.SAGEMAKER_ENDPOINT,
mock_role_arn,
Expand All @@ -168,12 +168,26 @@ def test_model_server_override_torchserve_with_model(

mock_build_for_ts.assert_called_once()

@patch("sagemaker.serve.builder.model_builder._ServeSettings")
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_torchserve")
def test_model_server_override_torchserve_with_inf_spec(
self, mock_build_for_ts, mock_serve_settings
):
mock_setting_object = mock_serve_settings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

builder = ModelBuilder(model_server=ModelServer.TORCHSERVE, inference_spec="some value")
builder.build(sagemaker_session=mock_session)

mock_build_for_ts.assert_called_once()

@patch("sagemaker.serve.builder.model_builder._ServeSettings")
def test_model_server_override_torchserve_without_model_or_mlflow(self, mock_serve_settings):
builder = ModelBuilder(model_server=ModelServer.TORCHSERVE)
self.assertRaisesRegex(
Exception,
"Missing required parameter `model` or 'ml_flow' path",
"Missing required parameter `model` or 'ml_flow' path or inf_spec",
builder.build,
Mode.SAGEMAKER_ENDPOINT,
mock_role_arn,
Expand Down

0 comments on commit 3d1a4f7

Please sign in to comment.