Skip to content

Commit

Permalink
Merge branch 'master' into grpc-token-auth-model-control
Browse files Browse the repository at this point in the history
  • Loading branch information
namannandan authored Jul 11, 2024
2 parents 7edb5b8 + 1125bb1 commit bb53231
Show file tree
Hide file tree
Showing 20 changed files with 455 additions and 11 deletions.
12 changes: 11 additions & 1 deletion .github/workflows/docker-ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,22 @@ jobs:
working-directory: docker
run: ./test_build_image_tagging.sh ${{ matrix.python-version }}

- name: Determine the branch name
id: branch-name
run: |
if [[ "${GITHUB_REF}" == refs/heads/* ]]; then
GITHUB_BRANCH=${GITHUB_REF#refs/heads/}
elif [[ "${GITHUB_REF}" == refs/pull/*/merge ]]; then
GITHUB_BRANCH=${GITHUB_HEAD_REF}
fi
echo "GITHUB_BRANCH=${GITHUB_BRANCH}" >> $GITHUB_OUTPUT
- name: Build Image for container test
id: image_build
working-directory: docker
run: |
IMAGE_TAG=test-image-${{ matrix.python-version }}
./build_image.sh -py "${{ matrix.python-version }}" -t "${IMAGE_TAG}" -b ${GITHUB_HEAD_REF} -s
./build_image.sh -py "${{ matrix.python-version }}" -t "${IMAGE_TAG}" -b ${{ steps.branch-name.outputs.GITHUB_BRANCH }} -s
echo "IMAGE_TAG=${IMAGE_TAG}" >> $GITHUB_OUTPUT
- name: Container Healthcheck
Expand Down
4 changes: 2 additions & 2 deletions docker/build_image.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ do
-g|--gpu)
MACHINE=gpu
DOCKER_TAG="pytorch/torchserve:latest-gpu"
BASE_IMAGE="nvidia/cuda:11.8.0-base-ubuntu20.04"
CUDA_VERSION="cu117"
BASE_IMAGE="nvidia/cuda:12.1.1-base-ubuntu20.04"
CUDA_VERSION="cu121"
shift
;;
-bi|--baseimage)
Expand Down
3 changes: 1 addition & 2 deletions examples/Huggingface_Transformers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ To register the model on TorchServe using the above model archive file, we run t
```
mkdir model_store
mv BERTSeqClassification.mar model_store/
torchserve --start --model-store model_store --models my_tc=BERTSeqClassification.mar --disable-token --ncs --disable-token-auth --enable-model-api
torchserve --start --model-store model_store --models my_tc=BERTSeqClassification.mar --ncs --disable-token-auth --enable-model-api
```

### Run an inference
Expand Down
86 changes: 86 additions & 0 deletions examples/large_models/trt_llm/llama/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Llama TensorRT-LLM Engine integration with TorchServe

[TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) provides users with an option to build TensorRT engines for LLMs that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.

## Pre-requisites

TRT-LLM requires Python 3.10
This example is tested with CUDA 12.1
Once TorchServe is installed, install TensorRT-LLM using the following.
This will downgrade the versions of PyTorch & Triton but this doesn't cause any issue.

```
pip install tensorrt_llm==0.10.0 --extra-index-url https://pypi.nvidia.com
pip install tensorrt-cu12==10.1.0
python -c "import tensorrt_llm"
```
shows
```
[TensorRT-LLM] TensorRT-LLM version: 0.10.0
```

## Download model from HuggingFace
```
huggingface-cli login
# or using an environment variable
huggingface-cli login --token $HUGGINGFACE_TOKEN
```
```
python ../../utils/Download_model.py --model_path model --model_name meta-llama/Meta-Llama-3-8B-Instruct
```

## Create TensorRT-LLM Engine
Clone TensorRT-LLM which will be used to create the TensorRT-LLM Engine

```
git clone -b v0.10.0 https://github.com/NVIDIA/TensorRT-LLM.git
```

Compile the model into a TensorRT engine with model weights and a model definition written in the TensorRT-LLM Python API.

```
python TensorRT-LLM/examples/llama/convert_checkpoint.py --model_dir model/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/ --output_dir ./tllm_checkpoint_1gpu_bf16 --dtype bfloat16
```
```
trtllm-build --checkpoint_dir tllm_checkpoint_1gpu_bf16 --gemm_plugin bfloat16 --gpt_attention_plugin bfloat16 --output_dir ./llama-3-8b-engine
```

You can test if TensorRT-LLM Engine has been compiled correctly by running the following
```
python TensorRT-LLM/examples/run.py --engine_dir ./llama-3-8b-engine --max_output_len 100 --tokenizer_dir model/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/ --input_text "How do I count to nine in French?"
```

You should see an output as follows
```
Input [Text 0]: "<|begin_of_text|>How do I count to nine in French?"
Output [Text 0 Beam 0]: " Counting to nine in French is easy and fun. Here's how you can do it:
One: Un
Two: Deux
Three: Trois
Four: Quatre
Five: Cinq
Six: Six
Seven: Sept
Eight: Huit
Nine: Neuf
That's it! You can now count to nine in French. Just remember that the numbers one to five are similar to their English counterparts, but the numbers six to nine have different pronunciations"
```

## Create model archive

```
mkdir model_store
torch-model-archiver --model-name llama3-8b --version 1.0 --handler trt_llm_handler.py --config-file model-config.yaml --archive-format no-archive --export-path model_store -f
mv model model_store/llama3-8b/.
mv llama-3-8b-engine model_store/llama3-8b/.
```

## Start TorchServe
```
torchserve --start --ncs --model-store model_store --models llama3-8b --disable-token-auth
```

## Run Inference
```
python ../../utils/test_llm_streaming_response.py -o 50 -t 2 -n 4 -m llama3-8b --prompt-text "@prompt.json" --prompt-json
```
12 changes: 12 additions & 0 deletions examples/large_models/trt_llm/llama/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# TorchServe frontend parameters
minWorkers: 1
maxWorkers: 1
maxBatchDelay: 100
responseTimeout: 1200
deviceType: "gpu"
asyncCommunication: true

handler:
tokenizer_dir: "model/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/"
trt_llm_engine_config:
engine_dir: "llama-3-8b-engine"
3 changes: 3 additions & 0 deletions examples/large_models/trt_llm/llama/prompt.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"prompt": "How is the climate in San Francisco?",
"temperature":0.5,
"max_new_tokens": 200}
118 changes: 118 additions & 0 deletions examples/large_models/trt_llm/llama/trt_llm_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import json
import logging
import time

import torch
from tensorrt_llm.runtime import ModelRunner
from transformers import AutoTokenizer

from ts.handler_utils.utils import send_intermediate_predict_response
from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)


class TRTLLMHandler(BaseHandler):
def __init__(self):
super().__init__()

self.trt_llm_engine = None
self.tokenizer = None
self.model = None
self.model_dir = None
self.lora_ids = {}
self.adapters = None
self.initialized = False

def initialize(self, ctx):
self.model_dir = ctx.system_properties.get("model_dir")

trt_llm_engine_config = ctx.model_yaml_config.get("handler").get(
"trt_llm_engine_config"
)

tokenizer_dir = ctx.model_yaml_config.get("handler").get("tokenizer_dir")
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir,
legacy=False,
padding_side="left",
truncation_side="left",
trust_remote_code=True,
use_fast=True,
)

if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

self.trt_llm_engine = ModelRunner.from_dir(**trt_llm_engine_config)
self.initialized = True

async def handle(self, data, context):
start_time = time.time()

metrics = context.metrics

data_preprocess = await self.preprocess(data)
output, input_batch = await self.inference(data_preprocess, context)
output = await self.postprocess(output, input_batch, context)

stop_time = time.time()
metrics.add_time(
"HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms"
)
return output

async def preprocess(self, requests):
input_batch = []
assert len(requests) == 1, "Expecting batch_size = 1"
for req_data in requests:
data = req_data.get("data") or req_data.get("body")
if isinstance(data, (bytes, bytearray)):
data = data.decode("utf-8")

prompt = data.get("prompt")
temperature = data.get("temperature", 1.0)
max_new_tokens = data.get("max_new_tokens", 50)
input_ids = self.tokenizer.encode(
prompt, add_special_tokens=True, truncation=True
)
input_batch.append(input_ids)

input_batch = [torch.tensor(x, dtype=torch.int32) for x in input_batch]

return (input_batch, temperature, max_new_tokens)

async def inference(self, input_batch, context):
input_ids_batch, temperature, max_new_tokens = input_batch

with torch.no_grad():
outputs = self.trt_llm_engine.generate(
batch_input_ids=input_ids_batch,
temperature=temperature,
max_new_tokens=max_new_tokens,
end_id=self.tokenizer.eos_token_id,
pad_id=self.tokenizer.pad_token_id,
output_sequence_lengths=True,
streaming=True,
return_dict=True,
)
return outputs, input_ids_batch

async def postprocess(self, inference_outputs, input_batch, context):
for inference_output in inference_outputs:
output_ids = inference_output["output_ids"]
sequence_lengths = inference_output["sequence_lengths"]

batch_size, _, _ = output_ids.size()
for batch_idx in range(batch_size):
output_end = sequence_lengths[batch_idx][0]
outputs = output_ids[batch_idx][0][output_end - 1 : output_end].tolist()
output_text = self.tokenizer.decode(outputs)
send_intermediate_predict_response(
[json.dumps({"text": output_text})],
context.request_ids,
"Result",
200,
context,
)
return [""] * len(input_batch)
6 changes: 3 additions & 3 deletions examples/large_models/utils/test_llm_streaming_response.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import argparse
import json
import random
import threading
from queue import Queue

import orjson
import requests

max_prompt_random_tokens = 20
Expand All @@ -27,7 +27,7 @@ def _predict(self):
combined_text = ""
for chunk in response.iter_content(chunk_size=None):
if chunk:
data = orjson.loads(chunk)
data = json.loads(chunk)
if self.args.demo_streaming:
print(data["text"], end="", flush=True)
else:
Expand All @@ -41,7 +41,7 @@ def _get_url(self):
def _format_payload(self):
prompt_input = _load_curl_like_data(self.args.prompt_text)
if self.args.prompt_json:
prompt_input = orjson.loads(prompt_input)
prompt_input = json.loads(prompt_input)
prompt = prompt_input.get("prompt", None)
assert prompt is not None
prompt_list = prompt.split(" ")
Expand Down
2 changes: 1 addition & 1 deletion examples/pt2/torch_compile_hpu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ PT_HPU_LAZY_MODE=0 torch-model-archiver --model-name resnet-50 --version 1.0 --m

Start the TorchServe server using the following command:
```bash
PT_HPU_LAZY_MODE=0 torchserve --start --ncs --disable-token --model-store model_store --models resnet-50.mar --disable-token-auth --enable-model-api
PT_HPU_LAZY_MODE=0 torchserve --start --ncs --model-store model_store --models resnet-50.mar --disable-token-auth --enable-model-api
```
`--disable-token` - this is an option that disables token authorization. This option is used here only for example purposes. Please refer to the torchserve [documentation](https://github.com/pytorch/serve/blob/master/docs/token_authorization_api.md), which describes the process of serving the model using tokens.

Expand Down
56 changes: 56 additions & 0 deletions examples/torch_tensorrt/torchcompile/T5/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# TorchServe inference with torch.compile with tensorrt backend

This example shows how to run TorchServe inference with T5 [Torch-TensorRT](https://github.com/pytorch/TensorRT) model



[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#inference) is an encode-decoder model used for a variety of text tasks out of the box by prepending a different text corresponding to each task. In this example, we use T5 for translation from English to German.

### Pre-requisites

- Verified to be working with `torch-tensorrt==2.3.0`
Installation instructions can be found in [pytorch/TensorRT](https://github.com/pytorch/TensorRT)

Change directory to examples directory `cd examples/torch_tensorrt/T5/torchcompile`

### torch.compile config

To use `tensorrt` backend with `torch.compile`, we specify the following config in `model-config.yaml`

```
pt2:
compile:
enable: True
backend: tensorrt
```

### Download the model

```
python ../../../large_models/Huggingface_accelerate/Download_model.py --model_name google-t5/t5-small
```

### Create the model archive
```
mkdir model_store
torch-model-archiver --model-name t5-translation --version 1.0 --handler T5_handler.py --config-file model-config.yaml -r requirements.txt --archive-format no-archive --export-path model_store -f
mv model model_store/t5-translation/.
```

### Start TorchServe

```
torchserve --start --ncs --ts-config config.properties --model-store model_store --models t5-translation --disable-token-auth
```

### Run Inference

```
curl -X POST http://127.0.0.1:8080/predictions/t5-translation -T sample_text.txt
```

results in

```
Das Haus ist wunderbar
```
Loading

0 comments on commit bb53231

Please sign in to comment.