forked from aws/sagemaker-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat: Pull latest tei container for sentence similiarity models on Hu…
…ggingFace hub (aws#4686) * Update: Pull latest tei container for sentence similiarity models * Fix formatting * Address PR comments * Fix formatting * Fix check * Switch sentence similarity to be deployed on tgi * Fix formatting * Fix formatting * Fix formatting * Fix formatting * Introduce TEI builder with TGI server * Fix formmatting * Add integ test * Fix formatting * Add integ test * Add integ test * Add integ test * Add integ test * Add integ test * Fix formatting * Move to G5 for integ test * Fix formatting * Integ test updates * Integ test updates * Integ test updates * Fix formatting * Integ test updates * Move back to generate for ping * Integ test updates * Integ test updates
- Loading branch information
Showing
6 changed files
with
543 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""Holds mixin logic to support deployment of Model ID""" | ||
from __future__ import absolute_import | ||
import logging | ||
from typing import Type | ||
from abc import ABC, abstractmethod | ||
|
||
from sagemaker import image_uris | ||
from sagemaker.model import Model | ||
from sagemaker.djl_inference.model import _get_model_config_properties_from_hf | ||
|
||
from sagemaker.huggingface import HuggingFaceModel | ||
from sagemaker.serve.utils.local_hardware import ( | ||
_get_nb_instance, | ||
) | ||
from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure | ||
from sagemaker.serve.utils.predictors import TgiLocalModePredictor | ||
from sagemaker.serve.utils.types import ModelServer | ||
from sagemaker.serve.mode.function_pointers import Mode | ||
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry | ||
from sagemaker.base_predictor import PredictorBase | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
_CODE_FOLDER = "code" | ||
|
||
|
||
class TEI(ABC): | ||
"""TEI build logic for ModelBuilder()""" | ||
|
||
def __init__(self): | ||
self.model = None | ||
self.serve_settings = None | ||
self.sagemaker_session = None | ||
self.model_path = None | ||
self.dependencies = None | ||
self.modes = None | ||
self.mode = None | ||
self.model_server = None | ||
self.image_uri = None | ||
self._is_custom_image_uri = False | ||
self.image_config = None | ||
self.vpc_config = None | ||
self._original_deploy = None | ||
self.hf_model_config = None | ||
self._default_tensor_parallel_degree = None | ||
self._default_data_type = None | ||
self._default_max_tokens = None | ||
self.pysdk_model = None | ||
self.schema_builder = None | ||
self.env_vars = None | ||
self.nb_instance_type = None | ||
self.ram_usage_model_load = None | ||
self.secret_key = None | ||
self.jumpstart = None | ||
self.role_arn = None | ||
|
||
@abstractmethod | ||
def _prepare_for_mode(self): | ||
"""Placeholder docstring""" | ||
|
||
@abstractmethod | ||
def _get_client_translators(self): | ||
"""Placeholder docstring""" | ||
|
||
def _set_to_tgi(self): | ||
"""Placeholder docstring""" | ||
if self.model_server != ModelServer.TGI: | ||
messaging = ( | ||
"HuggingFace Model ID support on model server: " | ||
f"{self.model_server} is not currently supported. " | ||
f"Defaulting to {ModelServer.TGI}" | ||
) | ||
logger.warning(messaging) | ||
self.model_server = ModelServer.TGI | ||
|
||
def _create_tei_model(self, **kwargs) -> Type[Model]: | ||
"""Placeholder docstring""" | ||
if self.nb_instance_type and "instance_type" not in kwargs: | ||
kwargs.update({"instance_type": self.nb_instance_type}) | ||
|
||
if not self.image_uri: | ||
self.image_uri = image_uris.retrieve( | ||
"huggingface-tei", | ||
image_scope="inference", | ||
instance_type=kwargs.get("instance_type"), | ||
region=self.sagemaker_session.boto_region_name, | ||
) | ||
|
||
pysdk_model = HuggingFaceModel( | ||
image_uri=self.image_uri, | ||
image_config=self.image_config, | ||
vpc_config=self.vpc_config, | ||
env=self.env_vars, | ||
role=self.role_arn, | ||
sagemaker_session=self.sagemaker_session, | ||
) | ||
|
||
logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri) | ||
|
||
self._original_deploy = pysdk_model.deploy | ||
pysdk_model.deploy = self._tei_model_builder_deploy_wrapper | ||
return pysdk_model | ||
|
||
@_capture_telemetry("tei.deploy") | ||
def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: | ||
"""Placeholder docstring""" | ||
timeout = kwargs.get("model_data_download_timeout") | ||
if timeout: | ||
self.pysdk_model.env.update({"MODEL_LOADING_TIMEOUT": str(timeout)}) | ||
|
||
if "mode" in kwargs and kwargs.get("mode") != self.mode: | ||
overwrite_mode = kwargs.get("mode") | ||
# mode overwritten by customer during model.deploy() | ||
logger.warning( | ||
"Deploying in %s Mode, overriding existing configurations set for %s mode", | ||
overwrite_mode, | ||
self.mode, | ||
) | ||
|
||
if overwrite_mode == Mode.SAGEMAKER_ENDPOINT: | ||
self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT | ||
elif overwrite_mode == Mode.LOCAL_CONTAINER: | ||
self._prepare_for_mode() | ||
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER | ||
else: | ||
raise ValueError("Mode %s is not supported!" % overwrite_mode) | ||
|
||
serializer = self.schema_builder.input_serializer | ||
deserializer = self.schema_builder._output_deserializer | ||
if self.mode == Mode.LOCAL_CONTAINER: | ||
timeout = kwargs.get("model_data_download_timeout") | ||
|
||
predictor = TgiLocalModePredictor( | ||
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer | ||
) | ||
|
||
self.modes[str(Mode.LOCAL_CONTAINER)].create_server( | ||
self.image_uri, | ||
timeout if timeout else 1800, | ||
None, | ||
predictor, | ||
self.pysdk_model.env, | ||
jumpstart=False, | ||
) | ||
|
||
return predictor | ||
|
||
if "mode" in kwargs: | ||
del kwargs["mode"] | ||
if "role" in kwargs: | ||
self.pysdk_model.role = kwargs.get("role") | ||
del kwargs["role"] | ||
|
||
# set model_data to uncompressed s3 dict | ||
self.pysdk_model.model_data, env_vars = self._prepare_for_mode() | ||
self.env_vars.update(env_vars) | ||
self.pysdk_model.env.update(self.env_vars) | ||
|
||
# if the weights have been cached via local container mode -> set to offline | ||
if str(Mode.LOCAL_CONTAINER) in self.modes: | ||
self.pysdk_model.env.update({"TRANSFORMERS_OFFLINE": "1"}) | ||
else: | ||
# if has not been built for local container we must use cache | ||
# that hosting has write access to. | ||
self.pysdk_model.env["TRANSFORMERS_CACHE"] = "/tmp" | ||
self.pysdk_model.env["HUGGINGFACE_HUB_CACHE"] = "/tmp" | ||
|
||
if "endpoint_logging" not in kwargs: | ||
kwargs["endpoint_logging"] = True | ||
|
||
if not self.nb_instance_type and "instance_type" not in kwargs: | ||
raise ValueError( | ||
"Instance type must be provided when deploying " "to SageMaker Endpoint mode." | ||
) | ||
|
||
if "initial_instance_count" not in kwargs: | ||
kwargs.update({"initial_instance_count": 1}) | ||
|
||
predictor = self._original_deploy(*args, **kwargs) | ||
|
||
predictor.serializer = serializer | ||
predictor.deserializer = deserializer | ||
return predictor | ||
|
||
def _build_for_hf_tei(self): | ||
"""Placeholder docstring""" | ||
self.nb_instance_type = _get_nb_instance() | ||
|
||
_create_dir_structure(self.model_path) | ||
if not hasattr(self, "pysdk_model"): | ||
self.env_vars.update({"HF_MODEL_ID": self.model}) | ||
self.hf_model_config = _get_model_config_properties_from_hf( | ||
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") | ||
) | ||
|
||
self.pysdk_model = self._create_tei_model() | ||
|
||
if self.mode == Mode.LOCAL_CONTAINER: | ||
self._prepare_for_mode() | ||
|
||
return self.pysdk_model | ||
|
||
def _build_for_tei(self): | ||
"""Placeholder docstring""" | ||
self.secret_key = None | ||
|
||
self._set_to_tgi() | ||
|
||
self.pysdk_model = self._build_for_hf_tei() | ||
return self.pysdk_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
import pytest | ||
from sagemaker.serve.builder.schema_builder import SchemaBuilder | ||
from sagemaker.serve.builder.model_builder import ModelBuilder, Mode | ||
|
||
from tests.integ.sagemaker.serve.constants import ( | ||
HF_DIR, | ||
PYTHON_VERSION_IS_NOT_310, | ||
SERVE_SAGEMAKER_ENDPOINT_TIMEOUT, | ||
) | ||
|
||
from tests.integ.timeout import timeout | ||
from tests.integ.utils import cleanup_model_resources | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
sample_input = { | ||
"inputs": "The man worked as a [MASK].", | ||
} | ||
|
||
loaded_response = [ | ||
{ | ||
"score": 0.0974755585193634, | ||
"token": 10533, | ||
"token_str": "carpenter", | ||
"sequence": "the man worked as a carpenter.", | ||
}, | ||
{ | ||
"score": 0.052383411675691605, | ||
"token": 15610, | ||
"token_str": "waiter", | ||
"sequence": "the man worked as a waiter.", | ||
}, | ||
{ | ||
"score": 0.04962712526321411, | ||
"token": 13362, | ||
"token_str": "barber", | ||
"sequence": "the man worked as a barber.", | ||
}, | ||
{ | ||
"score": 0.0378861166536808, | ||
"token": 15893, | ||
"token_str": "mechanic", | ||
"sequence": "the man worked as a mechanic.", | ||
}, | ||
{ | ||
"score": 0.037680838257074356, | ||
"token": 18968, | ||
"token_str": "salesman", | ||
"sequence": "the man worked as a salesman.", | ||
}, | ||
] | ||
|
||
|
||
@pytest.fixture | ||
def model_input(): | ||
return {"inputs": "The man worked as a [MASK]."} | ||
|
||
|
||
@pytest.fixture | ||
def model_builder_model_schema_builder(): | ||
return ModelBuilder( | ||
model_path=HF_DIR, | ||
model="BAAI/bge-m3", | ||
schema_builder=SchemaBuilder(sample_input, loaded_response), | ||
model_metadata={ | ||
"HF_TASK": "sentence-similarity", | ||
}, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def model_builder(request): | ||
return request.getfixturevalue(request.param) | ||
|
||
|
||
@pytest.mark.skipif( | ||
PYTHON_VERSION_IS_NOT_310, | ||
reason="Testing feature needs latest metadata", | ||
) | ||
@pytest.mark.parametrize("model_builder", ["model_builder_model_schema_builder"], indirect=True) | ||
def test_tei_sagemaker_endpoint(sagemaker_session, model_builder, model_input): | ||
logger.info("Running in SAGEMAKER_ENDPOINT mode...") | ||
caught_ex = None | ||
|
||
iam_client = sagemaker_session.boto_session.client("iam") | ||
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] | ||
|
||
model = model_builder.build( | ||
mode=Mode.SAGEMAKER_ENDPOINT, role_arn=role_arn, sagemaker_session=sagemaker_session | ||
) | ||
|
||
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): | ||
try: | ||
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") | ||
predictor = model.deploy(instance_type="ml.g5.2xlarge", initial_instance_count=1) | ||
predictor.predict(model_input) | ||
assert predictor is not None | ||
except Exception as e: | ||
caught_ex = e | ||
finally: | ||
cleanup_model_resources( | ||
sagemaker_session=model_builder.sagemaker_session, | ||
model_name=model.name, | ||
endpoint_name=model.endpoint_name, | ||
) | ||
if caught_ex: | ||
logger.exception(caught_ex) | ||
assert False, f"{caught_ex} was thrown when running tei sagemaker endpoint test" |
Oops, something went wrong.