Skip to content

Commit

Permalink
feat: InferenceSpec support for MMS and testing (#4763)
Browse files Browse the repository at this point in the history
* feat: InferenceSpec support for MMS and testing

* Fix formatting

* CR Fixes for InferenceSpec MMS

* remove code

* Changes to environment, avoid duplicates

* Remove loggers and add docstring updates

* Changes for unit tests in transformers build

* formatting changes

* Add secret_key to endpoint mode

* get_model, docstring, and if changes

* pre-push fixes

* integ test edits

* formatting fixes

* format changes

* updated value error

* formatting changes for value error update

---------

Co-authored-by: Bryannah Hernandez <brymh@amazon.com>
  • Loading branch information
bryannahm1 and Bryannah Hernandez authored Jul 10, 2024
1 parent b7621dc commit 6789b61
Show file tree
Hide file tree
Showing 7 changed files with 321 additions and 7 deletions.
77 changes: 73 additions & 4 deletions src/sagemaker/serve/builder/transformers_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
"""Transformers build logic with model builder"""
from __future__ import absolute_import
import logging
import os
from abc import ABC, abstractmethod
from typing import Type
from pathlib import Path
from packaging.version import Version

from sagemaker.model import Model
Expand All @@ -26,7 +28,12 @@
from sagemaker.huggingface import HuggingFaceModel
from sagemaker.serve.model_server.multi_model_server.prepare import (
_create_dir_structure,
prepare_for_mms,
)
from sagemaker.serve.detector.image_detector import (
auto_detect_container,
)
from sagemaker.serve.detector.pickler import save_pkl
from sagemaker.serve.utils.optimize_utils import _is_optimized
from sagemaker.serve.utils.predictors import TransformersLocalModePredictor
from sagemaker.serve.utils.types import ModelServer
Expand Down Expand Up @@ -73,6 +80,8 @@ def __init__(self):
self.pytorch_version = None
self.instance_type = None
self.schema_builder = None
self.inference_spec = None
self.shared_libs = None

@abstractmethod
def _prepare_for_mode(self):
Expand Down Expand Up @@ -110,7 +119,7 @@ def _get_hf_metadata_create_model(self) -> Type[Model]:
"""

hf_model_md = get_huggingface_model_metadata(
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
)
hf_config = image_uris.config_for_framework("huggingface").get("inference")
config = hf_config["versions"]
Expand Down Expand Up @@ -245,18 +254,22 @@ def _build_transformers_env(self):

_create_dir_structure(self.model_path)
if not hasattr(self, "pysdk_model"):
self.env_vars.update({"HF_MODEL_ID": self.model})

if self.inference_spec is not None:
self.env_vars.update({"HF_MODEL_ID": self.inference_spec.get_model()})
else:
self.env_vars.update({"HF_MODEL_ID": self.model})

logger.info(self.env_vars)

# TODO: Move to a helper function
if hasattr(self.env_vars, "HF_API_TOKEN"):
self.hf_model_config = _get_model_config_properties_from_hf(
self.model, self.env_vars.get("HF_API_TOKEN")
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HF_API_TOKEN")
)
else:
self.hf_model_config = _get_model_config_properties_from_hf(
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
)

self.pysdk_model = self._create_transformers_model()
Expand Down Expand Up @@ -292,6 +305,42 @@ def _get_supported_version(self, hf_config, hugging_face_version, base_fw):
versions_to_return.append(base_fw_version)
return sorted(versions_to_return, reverse=True)[0]

def _auto_detect_container(self):
"""Set image_uri by detecting container via model name or inference spec"""
# Auto detect the container image uri
if self.image_uri:
logger.info(
"Skipping auto detection as the image uri is provided %s",
self.image_uri,
)
return

if self.model:
logger.info(
"Auto detect container url for the provided model and on instance %s",
self.instance_type,
)
self.image_uri = auto_detect_container(
self.model, self.sagemaker_session.boto_region_name, self.instance_type
)

elif self.inference_spec:
# TODO: this won't work for larger image.
# Fail and let the customer include the image uri
logger.warning(
"model_path provided with no image_uri. Attempting to autodetect the image\
by loading the model using inference_spec.load()..."
)
self.image_uri = auto_detect_container(
self.inference_spec.load(self.model_path),
self.sagemaker_session.boto_region_name,
self.instance_type,
)
else:
raise ValueError(
"Cannot detect and set image_uri. Please pass model or inference spec."
)

def _build_for_transformers(self):
"""Method that triggers model build
Expand All @@ -300,6 +349,26 @@ def _build_for_transformers(self):
self.secret_key = None
self.model_server = ModelServer.MMS

if self.inference_spec:

os.makedirs(self.model_path, exist_ok=True)

code_path = Path(self.model_path).joinpath("code")

save_pkl(code_path, (self.inference_spec, self.schema_builder))
logger.info("PKL file saved to file: %s", code_path)

self._auto_detect_container()

self.secret_key = prepare_for_mms(
model_path=self.model_path,
shared_libs=self.shared_libs,
dependencies=self.dependencies,
session=self.sagemaker_session,
image_uri=self.image_uri,
inference_spec=self.inference_spec,
)

self._build_transformers_env()

if self.role_arn:
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/serve/mode/sagemaker_endpoint_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def prepare(
model_path=model_path,
sagemaker_session=sagemaker_session,
s3_model_data_url=s3_model_data_url,
secret_key=secret_key,
image=image,
should_upload_artifacts=should_upload_artifacts,
)
Expand Down
100 changes: 100 additions & 0 deletions src/sagemaker/serve/model_server/multi_model_server/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""This module is for SageMaker inference.py."""

from __future__ import absolute_import
import os
import io
import cloudpickle
import shutil
import platform
from pathlib import Path
from functools import partial
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.validations.check_integrity import perform_integrity_check
import logging

logger = logging.getLogger(__name__)

inference_spec = None
schema_builder = None
SHARED_LIBS_DIR = Path(__file__).parent.parent.joinpath("shared_libs")
SERVE_PATH = Path(__file__).parent.joinpath("serve.pkl")
METADATA_PATH = Path(__file__).parent.joinpath("metadata.json")


def model_fn(model_dir):
"""Overrides default method for loading a model"""
shared_libs_path = Path(model_dir + "/shared_libs")

if shared_libs_path.exists():
# before importing, place dynamic linked libraries in shared lib path
shutil.copytree(shared_libs_path, "/lib", dirs_exist_ok=True)

serve_path = Path(__file__).parent.joinpath("serve.pkl")
with open(str(serve_path), mode="rb") as file:
global inference_spec, schema_builder
obj = cloudpickle.load(file)
if isinstance(obj[0], InferenceSpec):
inference_spec, schema_builder = obj

if inference_spec:
return partial(inference_spec.invoke, model=inference_spec.load(model_dir))


def input_fn(input_data, content_type):
"""Deserializes the bytes that were received from the model server"""
try:
if hasattr(schema_builder, "custom_input_translator"):
return schema_builder.custom_input_translator.deserialize(
io.BytesIO(input_data), content_type
)
else:
return schema_builder.input_deserializer.deserialize(
io.BytesIO(input_data), content_type[0]
)
except Exception as e:
logger.error("Encountered error: %s in deserialize_response." % e)
raise Exception("Encountered error in deserialize_request.") from e


def predict_fn(input_data, predict_callable):
"""Invokes the model that is taken in by model server"""
return predict_callable(input_data)


def output_fn(predictions, accept_type):
"""Prediction is serialized to bytes and sent back to the customer"""
try:
if hasattr(schema_builder, "custom_output_translator"):
return schema_builder.custom_output_translator.serialize(predictions, accept_type)
else:
return schema_builder.output_serializer.serialize(predictions)
except Exception as e:
logger.error("Encountered error: %s in serialize_response." % e)
raise Exception("Encountered error in serialize_response.") from e


def _run_preflight_diagnostics():
_py_vs_parity_check()
_pickle_file_integrity_check()


def _py_vs_parity_check():
container_py_vs = platform.python_version()
local_py_vs = os.getenv("LOCAL_PYTHON")

if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]:
logger.warning(
f"The local python version {local_py_vs} differs from the python version "
f"{container_py_vs} on the container. Please align the two to avoid unexpected behavior"
)


def _pickle_file_integrity_check():
with open(SERVE_PATH, "rb") as f:
buffer = f.read()

perform_integrity_check(buffer=buffer, metadata_path=METADATA_PATH)


# on import, execute
_run_preflight_diagnostics()
68 changes: 66 additions & 2 deletions src/sagemaker/serve/model_server/multi_model_server/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,23 @@

from __future__ import absolute_import
import logging
from pathlib import Path
from typing import List

from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts
from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage

from pathlib import Path
import shutil
from typing import List

from sagemaker.session import Session
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.detector.dependency_manager import capture_dependencies
from sagemaker.serve.validations.check_integrity import (
generate_secret_key,
compute_hash,
)
from sagemaker.remote_function.core.serialization import _MetaData

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -63,3 +74,56 @@ def prepare_mms_js_resources(
model_path, code_dir = _create_dir_structure(model_path)

return _copy_jumpstart_artifacts(model_data, js_id, code_dir)


def prepare_for_mms(
model_path: str,
shared_libs: List[str],
dependencies: dict,
session: Session,
image_uri: str,
inference_spec: InferenceSpec = None,
) -> str:
"""Prepares for InferenceSpec using model_path, writes inference.py, and captures dependencies to generate secret_key.
Args:to
model_path (str) : Argument
shared_libs (List[]) : Argument
dependencies (dict) : Argument
session (Session) : Argument
inference_spec (InferenceSpec, optional) : Argument
(default is None)
Returns:
( str ) : secret_key
"""
model_path = Path(model_path)
if not model_path.exists():
model_path.mkdir()
elif not model_path.is_dir():
raise Exception("model_dir is not a valid directory")

if inference_spec:
inference_spec.prepare(str(model_path))

code_dir = model_path.joinpath("code")
code_dir.mkdir(exist_ok=True)

shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir)

logger.info("Finished writing inference.py to code directory")

shared_libs_dir = model_path.joinpath("shared_libs")
shared_libs_dir.mkdir(exist_ok=True)
for shared_lib in shared_libs:
shutil.copy2(Path(shared_lib), shared_libs_dir)

capture_dependencies(dependencies=dependencies, work_dir=code_dir)

secret_key = generate_secret_key()
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
buffer = f.read()
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
metadata.write(_MetaData(hash_value).to_json())

return secret_key
25 changes: 24 additions & 1 deletion src/sagemaker/serve/model_server/multi_model_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import requests
import logging
import platform
from pathlib import Path
from sagemaker import Session, fw_utils
from sagemaker.serve.utils.exceptions import LocalModelInvocationException
Expand Down Expand Up @@ -31,6 +32,17 @@ def _start_serving(
env_vars: dict,
):
"""Placeholder docstring"""
env = {
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
if env_vars:
env_vars.update(env)
else:
env_vars = env

self.container = client.containers.run(
image,
"serve",
Expand All @@ -43,7 +55,7 @@ def _start_serving(
"mode": "rw",
},
},
environment=_update_env_vars(env_vars),
environment=env_vars,
)

def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str):
Expand Down Expand Up @@ -81,6 +93,7 @@ class SageMakerMultiModelServer:
def _upload_server_artifacts(
self,
model_path: str,
secret_key: str,
sagemaker_session: Session,
s3_model_data_url: str = None,
image: str = None,
Expand Down Expand Up @@ -127,6 +140,16 @@ def _upload_server_artifacts(
else None
)

if secret_key:
env_vars = {
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
"SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
"LOCAL_PYTHON": platform.python_version(),
}

return model_data, _update_env_vars(env_vars)


Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/serve/spec/inference_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ def invoke(self, input_object: object, model: object):

def prepare(self, *args, **kwargs):
"""Custom prepare function"""

def get_model(self):
"""Return HuggingFace model name for inference spec"""
Loading

0 comments on commit 6789b61

Please sign in to comment.