-
Notifications
You must be signed in to change notification settings - Fork 863
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
T5 Translation with torch.compile & TensorRT backend (#3223)
* Added T5 TensorRT example with torch.compile * Added T5 TensorRT example with torch.compile * lint check * Update examples/torch_tensorrt/torchcompile/T5/README.md Co-authored-by: Matthias Reso <13337103+mreso@users.noreply.github.com> * Update T5_handler.py review comments --------- Co-authored-by: Matthias Reso <13337103+mreso@users.noreply.github.com>
- Loading branch information
Showing
10 changed files
with
216 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# TorchServe inference with torch.compile with tensorrt backend | ||
|
||
This example shows how to run TorchServe inference with T5 [Torch-TensorRT](https://github.com/pytorch/TensorRT) model | ||
|
||
|
||
|
||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#inference) is an encode-decoder model used for a variety of text tasks out of the box by prepending a different text corresponding to each task. In this example, we use T5 for translation from English to German. | ||
|
||
### Pre-requisites | ||
|
||
- Verified to be working with `torch-tensorrt==2.3.0` | ||
Installation instructions can be found in [pytorch/TensorRT](https://github.com/pytorch/TensorRT) | ||
|
||
Change directory to examples directory `cd examples/torch_tensorrt/T5/torchcompile` | ||
|
||
### torch.compile config | ||
|
||
To use `tensorrt` backend with `torch.compile`, we specify the following config in `model-config.yaml` | ||
|
||
``` | ||
pt2: | ||
compile: | ||
enable: True | ||
backend: tensorrt | ||
``` | ||
|
||
### Download the model | ||
|
||
``` | ||
python ../../../large_models/Huggingface_accelerate/Download_model.py --model_name google-t5/t5-small | ||
``` | ||
|
||
### Create the model archive | ||
``` | ||
mkdir model_store | ||
torch-model-archiver --model-name t5-translation --version 1.0 --handler T5_handler.py --config-file model-config.yaml -r requirements.txt --archive-format no-archive --export-path model_store -f | ||
mv model model_store/t5-translation/. | ||
``` | ||
|
||
### Start TorchServe | ||
|
||
``` | ||
torchserve --start --ncs --ts-config config.properties --model-store model_store --models t5-translation --disable-token-auth | ||
``` | ||
|
||
### Run Inference | ||
|
||
``` | ||
curl -X POST http://127.0.0.1:8080/predictions/t5-translation -T sample_text.txt | ||
``` | ||
|
||
results in | ||
|
||
``` | ||
Das Haus ist wunderbar | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import logging | ||
|
||
import torch | ||
from transformers import T5ForConditionalGeneration, T5Tokenizer | ||
|
||
from ts.torch_handler.base_handler import BaseHandler | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class T5Handler(BaseHandler): | ||
""" | ||
Transformers handler class for sequence, token classification and question answering. | ||
""" | ||
|
||
def __init__(self): | ||
super(T5Handler, self).__init__() | ||
self.tokenizer = None | ||
self.model = None | ||
self.initialized = False | ||
|
||
def initialize(self, ctx): | ||
"""In this initialize function, the T5 model is loaded. It also has | ||
the torch.compile calls for encoder and decoder. | ||
Args: | ||
ctx (context): It is a JSON Object containing information | ||
pertaining to the model artifacts parameters. | ||
""" | ||
self.manifest = ctx.manifest | ||
self.model_yaml_config = ( | ||
ctx.model_yaml_config | ||
if ctx is not None and hasattr(ctx, "model_yaml_config") | ||
else {} | ||
) | ||
properties = ctx.system_properties | ||
model_dir = properties.get("model_dir") | ||
|
||
self.device = torch.device( | ||
"cuda:" + str(properties.get("gpu_id")) | ||
if torch.cuda.is_available() and properties.get("gpu_id") is not None | ||
else "cpu" | ||
) | ||
|
||
# read configs for the mode, model_name, etc. from the handler config | ||
model_path = self.model_yaml_config.get("handler", {}).get("model_path", None) | ||
if not model_path: | ||
logger.error("Missing model path") | ||
|
||
self.tokenizer = T5Tokenizer.from_pretrained(model_path) | ||
self.model = T5ForConditionalGeneration.from_pretrained(model_path) | ||
self.model.to(self.device) | ||
|
||
self.model.eval() | ||
|
||
pt2_value = self.model_yaml_config.get("pt2", {}) | ||
if "compile" in pt2_value: | ||
compile_options = pt2_value["compile"] | ||
if compile_options["enable"] == True: | ||
del compile_options["enable"] | ||
|
||
compile_options_str = ", ".join( | ||
[f"{k} {v}" for k, v in compile_options.items()] | ||
) | ||
self.model.encoder = torch.compile( | ||
self.model.encoder, | ||
**compile_options, | ||
) | ||
self.model.decoder = torch.compile( | ||
self.model.decoder, | ||
**compile_options, | ||
) | ||
logger.info(f"Compiled model with {compile_options_str}") | ||
logger.info("T5 model from path %s loaded successfully", model_dir) | ||
|
||
self.initialized = True | ||
|
||
def preprocess(self, requests): | ||
""" | ||
Basic text preprocessing, based on the user's choice of application mode. | ||
Args: | ||
requests (list): A list of dictionaries with a "data" or "body" field, each | ||
containing the input text to be processed. | ||
Returns: | ||
inputs: A batched tensor of inputs: the batch of input ids and | ||
attention masks. | ||
""" | ||
|
||
# Prefix for translation from English to German | ||
task_prefix = "translate English to German: " | ||
input_texts = [task_prefix + self.preprocess_requests(r) for r in requests] | ||
|
||
logger.debug("Received texts: '%s'", input_texts) | ||
inputs = self.tokenizer( | ||
input_texts, | ||
padding=True, | ||
return_tensors="pt", | ||
).to(self.device) | ||
|
||
return inputs | ||
|
||
def preprocess_requests(self, request): | ||
""" | ||
Preprocess request | ||
Args: | ||
request : Request to be decoded. | ||
Returns: | ||
str: Decoded input text | ||
""" | ||
input_text = request.get("data") or request.get("body") | ||
if isinstance(input_text, (bytes, bytearray)): | ||
input_text = input_text.decode("utf-8") | ||
return input_text | ||
|
||
@torch.inference_mode() | ||
def inference(self, input_batch): | ||
""" | ||
Generates the translated text for the given input | ||
Args: | ||
input_batch : A tensors: the batch of input ids and attention masks, as returned by the | ||
preprocess function. | ||
Returns: | ||
list: A list of strings with the translated text for each input text in the batch. | ||
""" | ||
outputs = self.model.generate( | ||
**input_batch, | ||
) | ||
|
||
inferences = self.tokenizer.batch_decode( | ||
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False | ||
) | ||
|
||
logger.debug("Generated text: %s", inferences) | ||
return inferences | ||
|
||
def postprocess(self, inference_output): | ||
"""Post Process Function converts the predicted response into Torchserve readable format. | ||
Args: | ||
inference_output (list): It contains the predicted response of the input text. | ||
Returns: | ||
(list): Returns a list of the Predictions. | ||
""" | ||
return inference_output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
inference_address=http://127.0.0.1:8080 | ||
management_address=http://127.0.0.1:8081 | ||
metrics_address=http://127.0.0.1:8082 | ||
enable_envvars_config=true | ||
install_py_dep_per_model=true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
minWorkers: 1 | ||
maxWorkers: 1 | ||
handler: | ||
model_path: model/models--google-t5--t5-small/snapshots/df1b051c49625cf57a3d0d8d3863ed4d13564fe4 | ||
pt2: | ||
compile: | ||
enable: True | ||
backend: tensorrt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
transformers>=4.41.2 | ||
sentencepiece>=0.2.0 | ||
torch-tensorrt>=2.3.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
The house is wonderful |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.