-
Notifications
You must be signed in to change notification settings - Fork 871
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
T5 Translation with torch.compile & TensorRT backend #3223
Merged
Merged
Changes from 5 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
e98ab05
Added T5 TensorRT example with torch.compile
agunapal 9daf269
Added T5 TensorRT example with torch.compile
agunapal 8cd3edc
Merge branch 'master' into examples/t5_tensorrt
agunapal 34d1e01
lint check
agunapal 6c43368
Merge remote-tracking branch 'origin/master' into examples/t5_tensorrt
mreso 8fe32b3
Update examples/torch_tensorrt/torchcompile/T5/README.md
agunapal 55ef609
Update T5_handler.py
agunapal File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# TorchServe inference with torch.compile with tensorrt backend | ||
|
||
This example shows how to run TorchServe inference with T5 [Torch-TensorRT](https://github.com/pytorch/TensorRT) model | ||
|
||
|
||
|
||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#inference) is an encode-decoder model used for a variety of text tasks out of the box by prepending a different text corresponding to each task. In this example, we use T5 for translation from English to German. | ||
|
||
### Pre-requisites | ||
|
||
- Verified to be working with `torch-tensorrt==2.3.0` | ||
Installation instructions can be found in [pytorch/TensorRT](https://github.com/pytorch/TensorRT) | ||
|
||
Change directory to examples directory `cd examples/torch_tensorrt/T5/torchcompile` | ||
|
||
### torch.compile config | ||
|
||
To use `tensorrt` backend with `torch.compile`, we specify the following config in `model-config.yaml` | ||
|
||
``` | ||
pt2: | ||
compile: | ||
enable: True | ||
backend: tensorrt | ||
``` | ||
|
||
### Download the model | ||
|
||
``` | ||
python ../../../large_models/Huggingface_accelerate/Download_model.py --model_name google-t5/t5-small | ||
``` | ||
|
||
### Create the model archive | ||
``` | ||
mkdir model_store | ||
torch-model-archiver --model-name t5-translation --version 1.0 --handler T5_handler.py --config-file model-config.yaml -r requirements.txt --archive-format no-archive --export-path model_store -f | ||
mv model model_store/t5-translation/. | ||
``` | ||
|
||
### Start TorchServe | ||
|
||
``` | ||
torchserve --start --ncs --ts-config config.properties --model-store model_store --models t5-translation --disable-token | ||
agunapal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
``` | ||
|
||
### Run Inference | ||
|
||
``` | ||
curl -X POST http://127.0.0.1:8080/predictions/t5-translation -T sample_text.txt | ||
``` | ||
|
||
results in | ||
|
||
``` | ||
Das Haus ist wunderbar | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import logging | ||
|
||
import torch | ||
from transformers import T5ForConditionalGeneration, T5Tokenizer | ||
|
||
from ts.torch_handler.base_handler import BaseHandler | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class T5Handler(BaseHandler): | ||
""" | ||
Transformers handler class for sequence, token classification and question answering. | ||
""" | ||
|
||
def __init__(self): | ||
super(T5Handler, self).__init__() | ||
self.tokenizer = None | ||
self.model = None | ||
self.initialized = False | ||
|
||
def initialize(self, ctx): | ||
"""In this initialize function, the T5 model is loaded. It also has | ||
the torch.compile calls for encoder and decoder. | ||
Args: | ||
ctx (context): It is a JSON Object containing information | ||
pertaining to the model artifacts parameters. | ||
""" | ||
self.manifest = ctx.manifest | ||
self.model_yaml_config = ( | ||
ctx.model_yaml_config | ||
if ctx is not None and hasattr(ctx, "model_yaml_config") | ||
else {} | ||
) | ||
properties = ctx.system_properties | ||
model_dir = properties.get("model_dir") | ||
|
||
self.device = torch.device( | ||
"cuda:" + str(properties.get("gpu_id")) | ||
if torch.cuda.is_available() and properties.get("gpu_id") is not None | ||
else "cpu" | ||
) | ||
|
||
# read configs for the mode, model_name, etc. from the handler config | ||
model_path = self.model_yaml_config.get("handler", {}).get("model_path", None) | ||
if not model_path: | ||
logger.warning("Missing model path") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be an error There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. Thanks |
||
|
||
self.tokenizer = T5Tokenizer.from_pretrained(model_path) | ||
self.model = T5ForConditionalGeneration.from_pretrained(model_path) | ||
self.model.to(self.device) | ||
|
||
self.model.eval() | ||
|
||
pt2_value = self.model_yaml_config.get("pt2", {}) | ||
if "compile" in pt2_value: | ||
compile_options = pt2_value["compile"] | ||
if compile_options["enable"] == True: | ||
del compile_options["enable"] | ||
|
||
compile_options_str = ", ".join( | ||
[f"{k} {v}" for k, v in compile_options.items()] | ||
) | ||
self.model.encoder = torch.compile( | ||
self.model.encoder, | ||
**compile_options, | ||
) | ||
self.model.decoder = torch.compile( | ||
self.model.decoder, | ||
**compile_options, | ||
) | ||
logger.info(f"Compiled model with {compile_options_str}") | ||
logger.info("T5 model from path %s loaded successfully", model_dir) | ||
|
||
self.initialized = True | ||
|
||
def preprocess(self, requests): | ||
""" | ||
Basic text preprocessing, based on the user's choice of application mode. | ||
Args: | ||
requests (list): A list of dictionaries with a "data" or "body" field, each | ||
containing the input text to be processed. | ||
Returns: | ||
inputs: A batched tensor of inputs: the batch of input ids and | ||
attention masks. | ||
""" | ||
|
||
# Prefix for translation from English to German | ||
task_prefix = "translate English to German: " | ||
input_texts = [task_prefix + self.preprocess_requests(r) for r in requests] | ||
|
||
logger.debug("Received texts: '%s'", input_texts) | ||
inputs = self.tokenizer( | ||
input_texts, | ||
padding=True, | ||
return_tensors="pt", | ||
).to(self.device) | ||
|
||
return inputs | ||
|
||
def preprocess_requests(self, request): | ||
""" | ||
Preprocess request | ||
Args: | ||
request : Request to be decoded. | ||
Returns: | ||
str: Decoded input text | ||
""" | ||
input_text = request.get("data") or request.get("body") | ||
if isinstance(input_text, (bytes, bytearray)): | ||
input_text = input_text.decode("utf-8") | ||
return input_text | ||
|
||
@torch.inference_mode() | ||
def inference(self, input_batch): | ||
""" | ||
Generates the translated text for the given input | ||
Args: | ||
input_batch : A tensors: the batch of input ids and attention masks, as returned by the | ||
preprocess function. | ||
Returns: | ||
list: A list of strings with the translated text for each input text in the batch. | ||
""" | ||
outputs = self.model.generate( | ||
**input_batch, | ||
) | ||
|
||
inferences = self.tokenizer.batch_decode( | ||
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False | ||
) | ||
|
||
logger.debug("Generated text: %s", inferences) | ||
return inferences | ||
|
||
def postprocess(self, inference_output): | ||
"""Post Process Function converts the predicted response into Torchserve readable format. | ||
Args: | ||
inference_output (list): It contains the predicted response of the input text. | ||
Returns: | ||
(list): Returns a list of the Predictions. | ||
""" | ||
return inference_output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
inference_address=http://127.0.0.1:8080 | ||
management_address=http://127.0.0.1:8081 | ||
metrics_address=http://127.0.0.1:8082 | ||
enable_envvars_config=true | ||
install_py_dep_per_model=true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
minWorkers: 1 | ||
maxWorkers: 1 | ||
handler: | ||
model_path: model/models--google-t5--t5-small/snapshots/df1b051c49625cf57a3d0d8d3863ed4d13564fe4 | ||
pt2: | ||
compile: | ||
enable: True | ||
backend: tensorrt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
transformers>=4.41.2 | ||
sentencepiece>=0.2.0 | ||
torch-tensorrt>=2.3.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
The house is wonderful |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we revert the order and do examples/torch_compile/tensor_rt instead? Then we could move all examples/pt2 into there too so people find everything about compile in a single place? I assume there will be no other integration point for TRT in the future other than compile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For tensorrt, onnx, we actually want the starting point to be tensorrt, onnx, since customers are looking specifically for these. There are github issues where customers are still looking for an onnx example. Will add an example on onnx next