diff --git a/src/sagemaker/serve/builder/transformers_builder.py b/src/sagemaker/serve/builder/transformers_builder.py index e618b54e44..e064564961 100644 --- a/src/sagemaker/serve/builder/transformers_builder.py +++ b/src/sagemaker/serve/builder/transformers_builder.py @@ -13,8 +13,10 @@ """Transformers build logic with model builder""" from __future__ import absolute_import import logging +import os from abc import ABC, abstractmethod from typing import Type +from pathlib import Path from packaging.version import Version from sagemaker.model import Model @@ -26,7 +28,12 @@ from sagemaker.huggingface import HuggingFaceModel from sagemaker.serve.model_server.multi_model_server.prepare import ( _create_dir_structure, + prepare_for_mms, ) +from sagemaker.serve.detector.image_detector import ( + auto_detect_container, +) +from sagemaker.serve.detector.pickler import save_pkl from sagemaker.serve.utils.optimize_utils import _is_optimized from sagemaker.serve.utils.predictors import TransformersLocalModePredictor from sagemaker.serve.utils.types import ModelServer @@ -73,6 +80,8 @@ def __init__(self): self.pytorch_version = None self.instance_type = None self.schema_builder = None + self.inference_spec = None + self.shared_libs = None @abstractmethod def _prepare_for_mode(self): @@ -110,7 +119,7 @@ def _get_hf_metadata_create_model(self) -> Type[Model]: """ hf_model_md = get_huggingface_model_metadata( - self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") + self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) hf_config = image_uris.config_for_framework("huggingface").get("inference") config = hf_config["versions"] @@ -245,18 +254,22 @@ def _build_transformers_env(self): _create_dir_structure(self.model_path) if not hasattr(self, "pysdk_model"): - self.env_vars.update({"HF_MODEL_ID": self.model}) + + if self.inference_spec is not None: + self.env_vars.update({"HF_MODEL_ID": self.inference_spec.get_model()}) + else: + self.env_vars.update({"HF_MODEL_ID": self.model}) logger.info(self.env_vars) # TODO: Move to a helper function if hasattr(self.env_vars, "HF_API_TOKEN"): self.hf_model_config = _get_model_config_properties_from_hf( - self.model, self.env_vars.get("HF_API_TOKEN") + self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HF_API_TOKEN") ) else: self.hf_model_config = _get_model_config_properties_from_hf( - self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") + self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) self.pysdk_model = self._create_transformers_model() @@ -292,6 +305,42 @@ def _get_supported_version(self, hf_config, hugging_face_version, base_fw): versions_to_return.append(base_fw_version) return sorted(versions_to_return, reverse=True)[0] + def _auto_detect_container(self): + """Set image_uri by detecting container via model name or inference spec""" + # Auto detect the container image uri + if self.image_uri: + logger.info( + "Skipping auto detection as the image uri is provided %s", + self.image_uri, + ) + return + + if self.model: + logger.info( + "Auto detect container url for the provided model and on instance %s", + self.instance_type, + ) + self.image_uri = auto_detect_container( + self.model, self.sagemaker_session.boto_region_name, self.instance_type + ) + + elif self.inference_spec: + # TODO: this won't work for larger image. + # Fail and let the customer include the image uri + logger.warning( + "model_path provided with no image_uri. Attempting to autodetect the image\ + by loading the model using inference_spec.load()..." + ) + self.image_uri = auto_detect_container( + self.inference_spec.load(self.model_path), + self.sagemaker_session.boto_region_name, + self.instance_type, + ) + else: + raise ValueError( + "Cannot detect and set image_uri. Please pass model or inference spec." + ) + def _build_for_transformers(self): """Method that triggers model build @@ -300,6 +349,26 @@ def _build_for_transformers(self): self.secret_key = None self.model_server = ModelServer.MMS + if self.inference_spec: + + os.makedirs(self.model_path, exist_ok=True) + + code_path = Path(self.model_path).joinpath("code") + + save_pkl(code_path, (self.inference_spec, self.schema_builder)) + logger.info("PKL file saved to file: %s", code_path) + + self._auto_detect_container() + + self.secret_key = prepare_for_mms( + model_path=self.model_path, + shared_libs=self.shared_libs, + dependencies=self.dependencies, + session=self.sagemaker_session, + image_uri=self.image_uri, + inference_spec=self.inference_spec, + ) + self._build_transformers_env() if self.role_arn: diff --git a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py index 6f9bf8307f..2f09d3d572 100644 --- a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py +++ b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py @@ -130,6 +130,7 @@ def prepare( model_path=model_path, sagemaker_session=sagemaker_session, s3_model_data_url=s3_model_data_url, + secret_key=secret_key, image=image, should_upload_artifacts=should_upload_artifacts, ) diff --git a/src/sagemaker/serve/model_server/multi_model_server/inference.py b/src/sagemaker/serve/model_server/multi_model_server/inference.py new file mode 100644 index 0000000000..1ee7b5e4dc --- /dev/null +++ b/src/sagemaker/serve/model_server/multi_model_server/inference.py @@ -0,0 +1,100 @@ +"""This module is for SageMaker inference.py.""" + +from __future__ import absolute_import +import os +import io +import cloudpickle +import shutil +import platform +from pathlib import Path +from functools import partial +from sagemaker.serve.spec.inference_spec import InferenceSpec +from sagemaker.serve.validations.check_integrity import perform_integrity_check +import logging + +logger = logging.getLogger(__name__) + +inference_spec = None +schema_builder = None +SHARED_LIBS_DIR = Path(__file__).parent.parent.joinpath("shared_libs") +SERVE_PATH = Path(__file__).parent.joinpath("serve.pkl") +METADATA_PATH = Path(__file__).parent.joinpath("metadata.json") + + +def model_fn(model_dir): + """Overrides default method for loading a model""" + shared_libs_path = Path(model_dir + "/shared_libs") + + if shared_libs_path.exists(): + # before importing, place dynamic linked libraries in shared lib path + shutil.copytree(shared_libs_path, "/lib", dirs_exist_ok=True) + + serve_path = Path(__file__).parent.joinpath("serve.pkl") + with open(str(serve_path), mode="rb") as file: + global inference_spec, schema_builder + obj = cloudpickle.load(file) + if isinstance(obj[0], InferenceSpec): + inference_spec, schema_builder = obj + + if inference_spec: + return partial(inference_spec.invoke, model=inference_spec.load(model_dir)) + + +def input_fn(input_data, content_type): + """Deserializes the bytes that were received from the model server""" + try: + if hasattr(schema_builder, "custom_input_translator"): + return schema_builder.custom_input_translator.deserialize( + io.BytesIO(input_data), content_type + ) + else: + return schema_builder.input_deserializer.deserialize( + io.BytesIO(input_data), content_type[0] + ) + except Exception as e: + logger.error("Encountered error: %s in deserialize_response." % e) + raise Exception("Encountered error in deserialize_request.") from e + + +def predict_fn(input_data, predict_callable): + """Invokes the model that is taken in by model server""" + return predict_callable(input_data) + + +def output_fn(predictions, accept_type): + """Prediction is serialized to bytes and sent back to the customer""" + try: + if hasattr(schema_builder, "custom_output_translator"): + return schema_builder.custom_output_translator.serialize(predictions, accept_type) + else: + return schema_builder.output_serializer.serialize(predictions) + except Exception as e: + logger.error("Encountered error: %s in serialize_response." % e) + raise Exception("Encountered error in serialize_response.") from e + + +def _run_preflight_diagnostics(): + _py_vs_parity_check() + _pickle_file_integrity_check() + + +def _py_vs_parity_check(): + container_py_vs = platform.python_version() + local_py_vs = os.getenv("LOCAL_PYTHON") + + if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]: + logger.warning( + f"The local python version {local_py_vs} differs from the python version " + f"{container_py_vs} on the container. Please align the two to avoid unexpected behavior" + ) + + +def _pickle_file_integrity_check(): + with open(SERVE_PATH, "rb") as f: + buffer = f.read() + + perform_integrity_check(buffer=buffer, metadata_path=METADATA_PATH) + + +# on import, execute +_run_preflight_diagnostics() diff --git a/src/sagemaker/serve/model_server/multi_model_server/prepare.py b/src/sagemaker/serve/model_server/multi_model_server/prepare.py index 7059d9026d..48cf5c878a 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/prepare.py +++ b/src/sagemaker/serve/model_server/multi_model_server/prepare.py @@ -14,12 +14,23 @@ from __future__ import absolute_import import logging -from pathlib import Path -from typing import List from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage +from pathlib import Path +import shutil +from typing import List + +from sagemaker.session import Session +from sagemaker.serve.spec.inference_spec import InferenceSpec +from sagemaker.serve.detector.dependency_manager import capture_dependencies +from sagemaker.serve.validations.check_integrity import ( + generate_secret_key, + compute_hash, +) +from sagemaker.remote_function.core.serialization import _MetaData + logger = logging.getLogger(__name__) @@ -63,3 +74,56 @@ def prepare_mms_js_resources( model_path, code_dir = _create_dir_structure(model_path) return _copy_jumpstart_artifacts(model_data, js_id, code_dir) + + +def prepare_for_mms( + model_path: str, + shared_libs: List[str], + dependencies: dict, + session: Session, + image_uri: str, + inference_spec: InferenceSpec = None, +) -> str: + """Prepares for InferenceSpec using model_path, writes inference.py, and captures dependencies to generate secret_key. + + Args:to + model_path (str) : Argument + shared_libs (List[]) : Argument + dependencies (dict) : Argument + session (Session) : Argument + inference_spec (InferenceSpec, optional) : Argument + (default is None) + Returns: + ( str ) : secret_key + """ + model_path = Path(model_path) + if not model_path.exists(): + model_path.mkdir() + elif not model_path.is_dir(): + raise Exception("model_dir is not a valid directory") + + if inference_spec: + inference_spec.prepare(str(model_path)) + + code_dir = model_path.joinpath("code") + code_dir.mkdir(exist_ok=True) + + shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir) + + logger.info("Finished writing inference.py to code directory") + + shared_libs_dir = model_path.joinpath("shared_libs") + shared_libs_dir.mkdir(exist_ok=True) + for shared_lib in shared_libs: + shutil.copy2(Path(shared_lib), shared_libs_dir) + + capture_dependencies(dependencies=dependencies, work_dir=code_dir) + + secret_key = generate_secret_key() + with open(str(code_dir.joinpath("serve.pkl")), "rb") as f: + buffer = f.read() + hash_value = compute_hash(buffer=buffer, secret_key=secret_key) + with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata: + metadata.write(_MetaData(hash_value).to_json()) + + return secret_key diff --git a/src/sagemaker/serve/model_server/multi_model_server/server.py b/src/sagemaker/serve/model_server/multi_model_server/server.py index 91d585b4cf..ccb73d8cb6 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/server.py +++ b/src/sagemaker/serve/model_server/multi_model_server/server.py @@ -4,6 +4,7 @@ import requests import logging +import platform from pathlib import Path from sagemaker import Session, fw_utils from sagemaker.serve.utils.exceptions import LocalModelInvocationException @@ -31,6 +32,17 @@ def _start_serving( env_vars: dict, ): """Placeholder docstring""" + env = { + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_SERVE_SECRET_KEY": secret_key, + "LOCAL_PYTHON": platform.python_version(), + } + if env_vars: + env_vars.update(env) + else: + env_vars = env + self.container = client.containers.run( image, "serve", @@ -43,7 +55,7 @@ def _start_serving( "mode": "rw", }, }, - environment=_update_env_vars(env_vars), + environment=env_vars, ) def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str): @@ -81,6 +93,7 @@ class SageMakerMultiModelServer: def _upload_server_artifacts( self, model_path: str, + secret_key: str, sagemaker_session: Session, s3_model_data_url: str = None, image: str = None, @@ -127,6 +140,16 @@ def _upload_server_artifacts( else None ) + if secret_key: + env_vars = { + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_SERVE_SECRET_KEY": secret_key, + "SAGEMAKER_REGION": sagemaker_session.boto_region_name, + "SAGEMAKER_CONTAINER_LOG_LEVEL": "10", + "LOCAL_PYTHON": platform.python_version(), + } + return model_data, _update_env_vars(env_vars) diff --git a/src/sagemaker/serve/spec/inference_spec.py b/src/sagemaker/serve/spec/inference_spec.py index b61d7d55ea..2598a38d01 100644 --- a/src/sagemaker/serve/spec/inference_spec.py +++ b/src/sagemaker/serve/spec/inference_spec.py @@ -30,3 +30,6 @@ def invoke(self, input_object: object, model: object): def prepare(self, *args, **kwargs): """Custom prepare function""" + + def get_model(self): + """Return HuggingFace model name for inference spec""" diff --git a/tests/unit/sagemaker/serve/model_server/multi_model_server/test_multi_model_server_prepare.py b/tests/unit/sagemaker/serve/model_server/multi_model_server/test_multi_model_server_prepare.py index 895ed3907f..e877c1e7e9 100644 --- a/tests/unit/sagemaker/serve/model_server/multi_model_server/test_multi_model_server_prepare.py +++ b/tests/unit/sagemaker/serve/model_server/multi_model_server/test_multi_model_server_prepare.py @@ -12,13 +12,67 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +from pathlib import PosixPath +import platform from unittest import TestCase from unittest.mock import Mock, patch +import numpy as np + from sagemaker.serve.model_server.multi_model_server.prepare import _create_dir_structure +from sagemaker.serve.model_server.multi_model_server.server import ( + LocalMultiModelServer, +) + +CPU_TF_IMAGE = ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/" + "huggingface-pytorch-inference:2.0.0-transformers4.28.1-cpu-py310-ubuntu20.04" +) +MODEL_PATH = "model_path" +MODEL_REPO = f"{MODEL_PATH}/1" +ENV_VAR = {"KEY": "VALUE"} +PAYLOAD = np.random.rand(3, 4).astype(dtype=np.float32) +DTYPE = "TYPE_FP32" +SECRET_KEY = "secret_key" +INFER_RESPONSE = {"outputs": [{"name": "output_name"}]} + class MultiModelServerPrepareTests(TestCase): + def test_start_invoke_destroy_local_multi_model_server(self): + mock_container = Mock() + mock_docker_client = Mock() + mock_docker_client.containers.run.return_value = mock_container + + local_multi_model_server = LocalMultiModelServer() + mock_schema_builder = Mock() + mock_schema_builder.input_serializer.serialize.return_value = PAYLOAD + local_multi_model_server.schema_builder = mock_schema_builder + + local_multi_model_server._start_serving( + client=mock_docker_client, + model_path=MODEL_PATH, + secret_key=SECRET_KEY, + env_vars=ENV_VAR, + image=CPU_TF_IMAGE, + ) + + mock_docker_client.containers.run.assert_called_once_with( + CPU_TF_IMAGE, + "serve", + network_mode="host", + detach=True, + auto_remove=True, + volumes={PosixPath("model_path/code"): {"bind": "/opt/ml/model/", "mode": "rw"}}, + environment={ + "KEY": "VALUE", + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_SERVE_SECRET_KEY": "secret_key", + "LOCAL_PYTHON": platform.python_version(), + }, + ) + @patch("sagemaker.serve.model_server.multi_model_server.prepare._check_disk_space") @patch("sagemaker.serve.model_server.multi_model_server.prepare._check_docker_disk_usage") @patch("sagemaker.serve.model_server.multi_model_server.prepare.Path")