Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T5 Translation with torch.compile & TensorRT backend #3223

Merged
merged 7 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions examples/torch_tensorrt/torchcompile/T5/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# TorchServe inference with torch.compile with tensorrt backend

This example shows how to run TorchServe inference with T5 [Torch-TensorRT](https://github.com/pytorch/TensorRT) model



[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#inference) is an encode-decoder model used for a variety of text tasks out of the box by prepending a different text corresponding to each task. In this example, we use T5 for translation from English to German.

### Pre-requisites

- Verified to be working with `torch-tensorrt==2.3.0`
Installation instructions can be found in [pytorch/TensorRT](https://github.com/pytorch/TensorRT)

Change directory to examples directory `cd examples/torch_tensorrt/T5/torchcompile`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we revert the order and do examples/torch_compile/tensor_rt instead? Then we could move all examples/pt2 into there too so people find everything about compile in a single place? I assume there will be no other integration point for TRT in the future other than compile.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For tensorrt, onnx, we actually want the starting point to be tensorrt, onnx, since customers are looking specifically for these. There are github issues where customers are still looking for an onnx example. Will add an example on onnx next


### torch.compile config

To use `tensorrt` backend with `torch.compile`, we specify the following config in `model-config.yaml`

```
pt2:
compile:
enable: True
backend: tensorrt
```

### Download the model

```
python ../../../large_models/Huggingface_accelerate/Download_model.py --model_name google-t5/t5-small
```

### Create the model archive
```
mkdir model_store
torch-model-archiver --model-name t5-translation --version 1.0 --handler T5_handler.py --config-file model-config.yaml -r requirements.txt --archive-format no-archive --export-path model_store -f
mv model model_store/t5-translation/.
```

### Start TorchServe

```
torchserve --start --ncs --ts-config config.properties --model-store model_store --models t5-translation --disable-token
agunapal marked this conversation as resolved.
Show resolved Hide resolved
```

### Run Inference

```
curl -X POST http://127.0.0.1:8080/predictions/t5-translation -T sample_text.txt
```

results in

```
Das Haus ist wunderbar
```
142 changes: 142 additions & 0 deletions examples/torch_tensorrt/torchcompile/T5/T5_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import logging

import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer

from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)


class T5Handler(BaseHandler):
"""
Transformers handler class for sequence, token classification and question answering.
"""

def __init__(self):
super(T5Handler, self).__init__()
self.tokenizer = None
self.model = None
self.initialized = False

def initialize(self, ctx):
"""In this initialize function, the T5 model is loaded. It also has
the torch.compile calls for encoder and decoder.
Args:
ctx (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
"""
self.manifest = ctx.manifest
self.model_yaml_config = (
ctx.model_yaml_config
if ctx is not None and hasattr(ctx, "model_yaml_config")
else {}
)
properties = ctx.system_properties
model_dir = properties.get("model_dir")

self.device = torch.device(
"cuda:" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else "cpu"
)

# read configs for the mode, model_name, etc. from the handler config
model_path = self.model_yaml_config.get("handler", {}).get("model_path", None)
if not model_path:
logger.warning("Missing model path")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be an error

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done. Thanks


self.tokenizer = T5Tokenizer.from_pretrained(model_path)
self.model = T5ForConditionalGeneration.from_pretrained(model_path)
self.model.to(self.device)

self.model.eval()

pt2_value = self.model_yaml_config.get("pt2", {})
if "compile" in pt2_value:
compile_options = pt2_value["compile"]
if compile_options["enable"] == True:
del compile_options["enable"]

compile_options_str = ", ".join(
[f"{k} {v}" for k, v in compile_options.items()]
)
self.model.encoder = torch.compile(
self.model.encoder,
**compile_options,
)
self.model.decoder = torch.compile(
self.model.decoder,
**compile_options,
)
logger.info(f"Compiled model with {compile_options_str}")
logger.info("T5 model from path %s loaded successfully", model_dir)

self.initialized = True

def preprocess(self, requests):
"""
Basic text preprocessing, based on the user's choice of application mode.
Args:
requests (list): A list of dictionaries with a "data" or "body" field, each
containing the input text to be processed.
Returns:
inputs: A batched tensor of inputs: the batch of input ids and
attention masks.
"""

# Prefix for translation from English to German
task_prefix = "translate English to German: "
input_texts = [task_prefix + self.preprocess_requests(r) for r in requests]

logger.debug("Received texts: '%s'", input_texts)
inputs = self.tokenizer(
input_texts,
padding=True,
return_tensors="pt",
).to(self.device)

return inputs

def preprocess_requests(self, request):
"""
Preprocess request
Args:
request : Request to be decoded.
Returns:
str: Decoded input text
"""
input_text = request.get("data") or request.get("body")
if isinstance(input_text, (bytes, bytearray)):
input_text = input_text.decode("utf-8")
return input_text

@torch.inference_mode()
def inference(self, input_batch):
"""
Generates the translated text for the given input
Args:
input_batch : A tensors: the batch of input ids and attention masks, as returned by the
preprocess function.
Returns:
list: A list of strings with the translated text for each input text in the batch.
"""
outputs = self.model.generate(
**input_batch,
)

inferences = self.tokenizer.batch_decode(
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
)

logger.debug("Generated text: %s", inferences)
return inferences

def postprocess(self, inference_output):
"""Post Process Function converts the predicted response into Torchserve readable format.
Args:
inference_output (list): It contains the predicted response of the input text.
Returns:
(list): Returns a list of the Predictions.
"""
return inference_output
5 changes: 5 additions & 0 deletions examples/torch_tensorrt/torchcompile/T5/config.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
inference_address=http://127.0.0.1:8080
management_address=http://127.0.0.1:8081
metrics_address=http://127.0.0.1:8082
enable_envvars_config=true
install_py_dep_per_model=true
8 changes: 8 additions & 0 deletions examples/torch_tensorrt/torchcompile/T5/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
minWorkers: 1
maxWorkers: 1
handler:
model_path: model/models--google-t5--t5-small/snapshots/df1b051c49625cf57a3d0d8d3863ed4d13564fe4
pt2:
compile:
enable: True
backend: tensorrt
3 changes: 3 additions & 0 deletions examples/torch_tensorrt/torchcompile/T5/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
transformers>=4.41.2
sentencepiece>=0.2.0
torch-tensorrt>=2.3.0
1 change: 1 addition & 0 deletions examples/torch_tensorrt/torchcompile/T5/sample_text.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The house is wonderful
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ This example shows how to run TorchServe inference with [Torch-TensorRT](https:/

- Verified to be working with `torch-tensorrt==2.3.0`

Change directory to examples directory `cd examples/torch_tensorrt/torchcompile`
Change directory to examples directory `cd examples/torch_tensorrt/resnet50/torchcompile`

### torch.compile config

Expand Down
Loading