Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Pull latest tei container for sentence similiarity models on HuggingFace hub #4686

Merged
merged 31 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
2e00238
Update: Pull latest tei container for sentence similiarity models
samruds May 15, 2024
43ce1ba
Fix formatting
samruds May 15, 2024
6211227
Address PR comments
samruds May 15, 2024
0441436
Fix formatting
samruds May 15, 2024
4973f8f
Fix check
samruds May 16, 2024
f8cd864
Switch sentence similarity to be deployed on tgi
samruds May 16, 2024
a5fa0e9
Fix formatting
samruds May 16, 2024
e524134
Fix formatting
samruds May 16, 2024
4263a44
Fix formatting
samruds May 16, 2024
eb3b6d3
Fix formatting
samruds May 16, 2024
2b9ba2a
Introduce TEI builder with TGI server
samruds May 16, 2024
33d5b04
Fix formmatting
samruds May 16, 2024
20687f0
Add integ test
samruds May 16, 2024
d85425f
Fix formatting
samruds May 16, 2024
bbdff4c
Add integ test
samruds May 16, 2024
a526416
Add integ test
samruds May 16, 2024
1e49f88
Add integ test
samruds May 16, 2024
af78426
Add integ test
samruds May 16, 2024
a5e665a
Add integ test
samruds May 16, 2024
e58f622
Fix formatting
samruds May 16, 2024
4c336dd
Merge branch 'master' into master
samruds May 16, 2024
ea900bf
Move to G5 for integ test
samruds May 16, 2024
cffe46a
Fix formatting
samruds May 16, 2024
48205ad
Integ test updates
samruds May 17, 2024
312d837
Integ test updates
samruds May 17, 2024
29ea1c5
Integ test updates
samruds May 17, 2024
f6f8116
Fix formatting
samruds May 17, 2024
166e570
Integ test updates
samruds May 17, 2024
4bb5522
Move back to generate for ping
samruds May 17, 2024
17645f7
Integ test updates
samruds May 17, 2024
e8341c2
Integ test updates
samruds May 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from sagemaker.serve.detector.pickler import save_pkl, save_xgboost
from sagemaker.serve.builder.serve_settings import _ServeSettings
from sagemaker.serve.builder.djl_builder import DJL
from sagemaker.serve.builder.tei_builder import TEI
from sagemaker.serve.builder.tgi_builder import TGI
from sagemaker.serve.builder.jumpstart_builder import JumpStart
from sagemaker.serve.builder.transformers_builder import Transformers
Expand Down Expand Up @@ -95,9 +96,9 @@
}


# pylint: disable=attribute-defined-outside-init, disable=E1101, disable=R0901
# pylint: disable=attribute-defined-outside-init, disable=E1101, disable=R0901, disable=R1705
@dataclass
class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing):
class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, TEI):
"""Class that builds a deployable model.

Args:
Expand Down Expand Up @@ -753,7 +754,7 @@ def build( # pylint: disable=R0911
model_task = self.model_metadata.get("HF_TASK")
if self._is_jumpstart_model_id():
return self._build_for_jumpstart()
if self._is_djl(): # pylint: disable=R1705
if self._is_djl():
return self._build_for_djl()
else:
hf_model_md = get_huggingface_model_metadata(
Expand All @@ -764,8 +765,10 @@ def build( # pylint: disable=R0911
model_task = hf_model_md.get("pipeline_tag")
if self.schema_builder is None and model_task is not None:
self._hf_schema_builder_init(model_task)
if model_task == "text-generation": # pylint: disable=R1705
if model_task == "text-generation":
return self._build_for_tgi()
if model_task == "sentence-similarity":
return self._build_for_tei()
elif self._can_fit_on_single_gpu():
return self._build_for_transformers()
elif (
Expand Down
222 changes: 222 additions & 0 deletions src/sagemaker/serve/builder/tei_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Holds mixin logic to support deployment of Model ID"""
from __future__ import absolute_import
import logging
from typing import Type
from abc import ABC, abstractmethod

from sagemaker import image_uris
from sagemaker.model import Model
from sagemaker.djl_inference.model import _get_model_config_properties_from_hf

from sagemaker.huggingface import HuggingFaceModel
from sagemaker.serve.utils.local_hardware import (
_get_nb_instance,
)
from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure
from sagemaker.serve.utils.predictors import TgiLocalModePredictor
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
from sagemaker.base_predictor import PredictorBase

logger = logging.getLogger(__name__)

_CODE_FOLDER = "code"


class TEI(ABC):
"""TEI build logic for ModelBuilder()"""

def __init__(self):
self.model = None
self.serve_settings = None
self.sagemaker_session = None
self.model_path = None
self.dependencies = None
self.modes = None
self.mode = None
self.model_server = None
self.image_uri = None
self._is_custom_image_uri = False
self.image_config = None
self.vpc_config = None
self._original_deploy = None
self.hf_model_config = None
self._default_tensor_parallel_degree = None
self._default_data_type = None
self._default_max_tokens = None
self.pysdk_model = None
self.schema_builder = None
self.env_vars = None
self.nb_instance_type = None
self.ram_usage_model_load = None
self.secret_key = None
self.jumpstart = None
self.role_arn = None

@abstractmethod
def _prepare_for_mode(self):
"""Placeholder docstring"""

@abstractmethod
def _get_client_translators(self):
"""Placeholder docstring"""

def _set_to_tgi(self):
samruds marked this conversation as resolved.
Show resolved Hide resolved
"""Placeholder docstring"""
if self.model_server != ModelServer.TGI:
messaging = (
"HuggingFace Model ID support on model server: "
f"{self.model_server} is not currently supported. "
f"Defaulting to {ModelServer.TGI}"
)
logger.warning(messaging)
self.model_server = ModelServer.TGI

def _create_tei_model(self, **kwargs) -> Type[Model]:
"""Placeholder docstring"""
if self.nb_instance_type and "instance_type" not in kwargs:
kwargs.update({"instance_type": self.nb_instance_type})

if not self.image_uri:
self.image_uri = image_uris.retrieve(
"huggingface-tei",
image_scope="inference",
instance_type=kwargs.get("instance_type"),
region=self.sagemaker_session.boto_region_name,
)

pysdk_model = HuggingFaceModel(
image_uri=self.image_uri,
image_config=self.image_config,
vpc_config=self.vpc_config,
env=self.env_vars,
role=self.role_arn,
sagemaker_session=self.sagemaker_session,
)

logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)

self._original_deploy = pysdk_model.deploy
pysdk_model.deploy = self._tei_model_builder_deploy_wrapper
return pysdk_model

@_capture_telemetry("tei.deploy")
def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
"""Placeholder docstring"""
timeout = kwargs.get("model_data_download_timeout")
if timeout:
self.pysdk_model.env.update({"MODEL_LOADING_TIMEOUT": str(timeout)})

if "mode" in kwargs and kwargs.get("mode") != self.mode:
overwrite_mode = kwargs.get("mode")
# mode overwritten by customer during model.deploy()
logger.warning(
"Deploying in %s Mode, overriding existing configurations set for %s mode",
overwrite_mode,
self.mode,
)

if overwrite_mode == Mode.SAGEMAKER_ENDPOINT:
self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT
elif overwrite_mode == Mode.LOCAL_CONTAINER:
self._prepare_for_mode()
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
else:
raise ValueError("Mode %s is not supported!" % overwrite_mode)

serializer = self.schema_builder.input_serializer
deserializer = self.schema_builder._output_deserializer
if self.mode == Mode.LOCAL_CONTAINER:
timeout = kwargs.get("model_data_download_timeout")

predictor = TgiLocalModePredictor(
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
)

self.modes[str(Mode.LOCAL_CONTAINER)].create_server(
self.image_uri,
timeout if timeout else 1800,
None,
predictor,
self.pysdk_model.env,
jumpstart=False,
)

return predictor

if "mode" in kwargs:
del kwargs["mode"]
if "role" in kwargs:
self.pysdk_model.role = kwargs.get("role")
del kwargs["role"]

# set model_data to uncompressed s3 dict
self.pysdk_model.model_data, env_vars = self._prepare_for_mode()
self.env_vars.update(env_vars)
self.pysdk_model.env.update(self.env_vars)

# if the weights have been cached via local container mode -> set to offline
if str(Mode.LOCAL_CONTAINER) in self.modes:
self.pysdk_model.env.update({"TRANSFORMERS_OFFLINE": "1"})
else:
# if has not been built for local container we must use cache
# that hosting has write access to.
self.pysdk_model.env["TRANSFORMERS_CACHE"] = "/tmp"
self.pysdk_model.env["HUGGINGFACE_HUB_CACHE"] = "/tmp"

if "endpoint_logging" not in kwargs:
kwargs["endpoint_logging"] = True

if not self.nb_instance_type and "instance_type" not in kwargs:
raise ValueError(
"Instance type must be provided when deploying " "to SageMaker Endpoint mode."
)

if "initial_instance_count" not in kwargs:
kwargs.update({"initial_instance_count": 1})

predictor = self._original_deploy(*args, **kwargs)

predictor.serializer = serializer
predictor.deserializer = deserializer
return predictor

def _build_for_hf_tei(self):
"""Placeholder docstring"""
self.nb_instance_type = _get_nb_instance()

_create_dir_structure(self.model_path)
if not hasattr(self, "pysdk_model"):
self.env_vars.update({"HF_MODEL_ID": self.model})
self.hf_model_config = _get_model_config_properties_from_hf(
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
)

self.pysdk_model = self._create_tei_model()

if self.mode == Mode.LOCAL_CONTAINER:
self._prepare_for_mode()

return self.pysdk_model

def _build_for_tei(self):
"""Placeholder docstring"""
self.secret_key = None

self._set_to_tgi()

self.pysdk_model = self._build_for_hf_tei()
return self.pysdk_model
123 changes: 123 additions & 0 deletions tests/integ/sagemaker/serve/test_serve_tei.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import pytest
from sagemaker.serve.builder.schema_builder import SchemaBuilder
from sagemaker.serve.builder.model_builder import ModelBuilder, Mode

from tests.integ.sagemaker.serve.constants import (
HF_DIR,
PYTHON_VERSION_IS_NOT_310,
SERVE_SAGEMAKER_ENDPOINT_TIMEOUT,
)

from tests.integ.timeout import timeout
from tests.integ.utils import cleanup_model_resources
import logging

logger = logging.getLogger(__name__)

sample_input = {
"inputs": "The man worked as a [MASK].",
}

loaded_response = [
{
"score": 0.0974755585193634,
"token": 10533,
"token_str": "carpenter",
"sequence": "the man worked as a carpenter.",
},
{
"score": 0.052383411675691605,
"token": 15610,
"token_str": "waiter",
"sequence": "the man worked as a waiter.",
},
{
"score": 0.04962712526321411,
"token": 13362,
"token_str": "barber",
"sequence": "the man worked as a barber.",
},
{
"score": 0.0378861166536808,
"token": 15893,
"token_str": "mechanic",
"sequence": "the man worked as a mechanic.",
},
{
"score": 0.037680838257074356,
"token": 18968,
"token_str": "salesman",
"sequence": "the man worked as a salesman.",
},
]


@pytest.fixture
def model_input():
return {"inputs": "The man worked as a [MASK]."}


@pytest.fixture
def model_builder_model_schema_builder():
return ModelBuilder(
model_path=HF_DIR,
model="BAAI/bge-m3",
schema_builder=SchemaBuilder(sample_input, loaded_response),
model_metadata={
"HF_TASK": "sentence-similarity",
},
)


@pytest.fixture
def model_builder(request):
return request.getfixturevalue(request.param)


@pytest.mark.skipif(
PYTHON_VERSION_IS_NOT_310,
reason="Testing feature needs latest metadata",
)
@pytest.mark.parametrize("model_builder", ["model_builder_model_schema_builder"], indirect=True)
def test_tei_sagemaker_endpoint(sagemaker_session, model_builder, model_input):
logger.info("Running in SAGEMAKER_ENDPOINT mode...")
caught_ex = None

iam_client = sagemaker_session.boto_session.client("iam")
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]

model = model_builder.build(
mode=Mode.SAGEMAKER_ENDPOINT, role_arn=role_arn, sagemaker_session=sagemaker_session
)

with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
try:
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
predictor = model.deploy(instance_type="ml.g5.2xlarge", initial_instance_count=1)
predictor.predict(model_input)
assert predictor is not None
except Exception as e:
caught_ex = e
finally:
cleanup_model_resources(
sagemaker_session=model_builder.sagemaker_session,
model_name=model.name,
endpoint_name=model.endpoint_name,
)
if caught_ex:
logger.exception(caught_ex)
assert False, f"{caught_ex} was thrown when running tei sagemaker endpoint test"
Loading