Skip to content

Commit

Permalink
add CPU inference support (#28)
Browse files Browse the repository at this point in the history
* cpu support

* add cpu support tests
add thread param

* improve support of CPU

* add quantization ORT

* fix quantization

* fix test

* fix imports

* use temp folder in unit tests

* refactoring

* add Pytorch quantization

* enable CPU inference on Triton server

* fix tests

* detect device if not set

* add commands to install pytorch-quantization

* add commands to install pytorch-quantization
  • Loading branch information
pommedeterresautee authored Dec 20, 2021
1 parent 6d5dbbb commit 2b369c1
Show file tree
Hide file tree
Showing 12 changed files with 259 additions and 132 deletions.
32 changes: 23 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,20 @@ To install this package locally, you need:
```shell
git clone git@github.com:ELS-RD/transformer-deploy.git
cd transformer-deploy
```

* for GPU support:

```shell
pip3 install ".[GPU]" -f https://download.pytorch.org/whl/cu113/torch_stable.html --extra-index-url https://pypi.ngc.nvidia.com
# if you want to perform GPU quantization (recommended)
pip3 install git+ssh://git@github.com/NVIDIA/TensorRT#egg=pytorch-quantization\&subdirectory=tools/pytorch-quantization/
```

* for CPU support:

```shell
pip3 install ".[CPU]" -f https://download.pytorch.org/whl/cpu/torch_stable.html
```

To build your own version of the Docker image:
Expand All @@ -86,19 +99,20 @@ With the single command below, you will:
* **generate** configuration files for Triton inference server

```shell
convert_model -m roberta-large-mnli --backend tensorrt onnx --seq-len 128 128 128 --batch-size 1 32 32
convert_model -m philschmid/MiniLM-L6-H384-uncased-sst2 --backend onnx --seq-len 128 128 128 --batch-size 1 32 32
# ...
# Inference done on NVIDIA GeForce RTX 3090
# latencies:
# [Pytorch (FP32)] mean=129.57ms, sd=8.73ms, min=119.10ms, max=192.44ms, median=129.81ms, 95p=137.11ms, 99p=173.64ms
# [Pytorch (FP16)] mean=82.68ms, sd=3.92ms, min=76.39ms, max=97.42ms, median=83.59ms, 95p=89.58ms, 99p=94.09ms
# [TensorRT (FP16)] mean=51.84ms, sd=2.66ms, min=46.42ms, max=59.03ms, median=52.10ms, 95p=56.18ms, 99p=57.68ms
# [ONNX Runtime (vanilla)] mean=116.98ms, sd=3.67ms, min=111.96ms, max=130.20ms, median=116.23ms, 95p=127.03ms, 99p=128.58ms
# [ONNX Runtime (optimized)] mean=55.14ms, sd=2.17ms, min=52.85ms, max=61.65ms, median=53.94ms, 95p=59.45ms, 99p=60.27ms
# [Pytorch (FP32)] mean=8.75ms, sd=0.30ms, min=8.60ms, max=11.20ms, median=8.68ms, 95p=9.15ms, 99p=10.77ms
# [Pytorch (FP16)] mean=6.75ms, sd=0.22ms, min=6.66ms, max=8.99ms, median=6.71ms, 95p=6.88ms, 99p=7.95ms
# [ONNX Runtime (FP32)] mean=8.10ms, sd=0.43ms, min=7.93ms, max=11.76ms, median=8.02ms, 95p=8.39ms, 99p=11.30ms
# [ONNX Runtime (optimized)] mean=3.66ms, sd=0.23ms, min=3.57ms, max=6.46ms, median=3.62ms, 95p=3.70ms, 99p=4.95ms
```

> **16 128 128** -> minimum, optimal, maximum sequence length, to help TensorRT better optimize your model
> **1 32 32** -> batch size, same as above
> **128 128 128** -> minimum, optimal, maximum sequence length, to help TensorRT better optimize your model.
> Better to have the same value for seq len to get best performances from TensorRT (ONNX Runtime has not this limitation).
>
> **1 32 32** -> batch size, same as above. Good idea to get 1 as minimum value. No impact on TensorRT performance.
* Launch Nvidia Triton inference server to play with both ONNX and TensorRT models:

Expand All @@ -112,7 +126,7 @@ docker run -it --rm --gpus all -p8000:8000 -p8001:8001 -p8002:8002 --shm-size 25
> This is of course a bad practice, you should make your own 2 lines Dockerfile with Transformers inside.
Right now, only TensorRT 8.0.3 backend is available in Triton.
Until the TensorRT 8.2 backend is available, we advise you to only use ONNX Runtime Triton backend.
Until the TensorRT 8.2 backend is available, we advise you to only use ONNX Runtime backend.

* Query the inference server:

Expand Down
48 changes: 26 additions & 22 deletions demo/quantization_end_to_end.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Quantization is one of the most effective and generic approach to make model inference faster.\n",
"\n",
"\n",
"zation is one of the most effective and generic approach to make model inference faster.\n",
"Basically, it replaces high precision float numbers in model tensors encoded in 32 or 16 bits by lower precision ones encoded in 8 bits or less:\n",
"\n",
"* it takes less memory\n",
Expand All @@ -23,21 +25,21 @@
"\n",
"**TL;DR, we benchmarked Pytorch and Nvidia TensorRT, on both CPU and GPU, with/without quantization, our methods provide the fastest inference by large margin**.\n",
"\n",
"| Framework | Precision | Latency (ms) | Accuracy | Speedup | Hardware |\n",
"|:---------------------------|-----------|--------------|----------|:-----------|:--------:|\n",
"| Pytorch | FP32 | 4000 | 86.8 % | X 0.02 | CPU |\n",
"| Pytorch | FP16 | 4005 | 86.8 % | X 0.02 | CPU |\n",
"| Pytorch | **INT-8** | 3670 | 86.8 % | X 0.02 | **CPU** |\n",
"| Pytorch | FP32 | 80 | 86.8 % | X 1 | GPU |\n",
"| Pytorch | FP16 | 58 | 86.8 % | X 1.38 | GPU |\n",
"| ONNX Runtime | FP32 | 74 | 86.8 % | X 1.08 | GPU |\n",
"| ONNX Runtime | FP16 | 34 | 86.8 % | X 2.35 | GPU |\n",
"| ONNX Runtime | FP32 | 3767 | 86.8 % | X 0.02 | CPU |\n",
"| ONNX Runtime | FP16 | 4607 | 86.8 % | X 0.02 | CPU |\n",
"| ONNX Runtime | **INT-8** | 3712 | 86.8 % | X 0.02 | **CPU** |\n",
"| TensorRT | FP16 | 30 | 86.8 % | X 2.67 | GPU |\n",
"| TensorRT (**our method 1**)| **INT-8** | 15 | 84.4 % | **X 5.33** | **GPU** |\n",
"| TensorRT (**our method 2**)| **INT-8** | 16 | 85.8 % | **X 5.00** | **GPU** |\n",
"| Framework | Precision | Latency (ms) | Accuracy | Speedup | Hardware |\n",
"|:----------------------------|-----------|--------------|----------|:-----------|:--------:|\n",
"| Pytorch | FP32 | 4000 | 86.8 % | X 0.02 | CPU |\n",
"| Pytorch | FP16 | 4005 | 86.8 % | X 0.02 | CPU |\n",
"| Pytorch | **INT-8** | 3670 | 86.8 % | X 0.02 | **CPU** |\n",
"| Pytorch | FP32 | 80 | 86.8 % | X 1 | GPU |\n",
"| Pytorch | FP16 | 58 | 86.8 % | X 1.38 | GPU |\n",
"| ONNX Runtime | FP32 | 74 | 86.8 % | X 1.08 | GPU |\n",
"| ONNX Runtime | FP16 | 34 | 86.8 % | X 2.35 | GPU |\n",
"| ONNX Runtime | FP32 | 3767 | 86.8 % | X 0.02 | CPU |\n",
"| ONNX Runtime | FP16 | 4607 | 86.8 % | X 0.02 | CPU |\n",
"| ONNX Runtime | **INT-8** | 3712 | 86.8 % | X 0.02 | **CPU** |\n",
"| TensorRT | FP16 | 30 | 86.8 % | X 2.67 | GPU |\n",
"| TensorRT (**our method 1**) | **INT-8** | 15 | 84.4 % | **X 5.33** | **GPU** |\n",
"| TensorRT (**our method 2**) | **INT-8** | 16 | 85.8 % | **X 5.00** | **GPU** |\n",
"\n",
"> measures done on a Nvidia RTX 3090 GPU + 12 cores i7 Intel CPU (support AVX-2 instructions)\n",
">\n",
Expand Down Expand Up @@ -127,7 +129,8 @@
"#! pip install git+https://github.com/ELS-RD/transformer-deploy\n",
"#! pip install sklearn datasets\n",
"#! pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com\n",
"# or install pytorch-quantization from https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization"
"# or install pytorch-quantization from https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization\n",
"# pip3 install git+ssh://git@github.com/NVIDIA/TensorRT#egg=pytorch-quantization\\&subdirectory=tools/pytorch-quantization/"
]
},
{
Expand Down Expand Up @@ -245,7 +248,7 @@
"from pytorch_quantization import calib\n",
"import logging\n",
"from datasets import DatasetDict\n",
"from transformer_deploy.backends.trt_utils import build_engine, get_binding_idxs, infer_tensorrt, load_engine\n",
"from transformer_deploy.backends.trt_utils import build_engine, get_binding_idxs, infer_tensorrt\n",
"from transformer_deploy.backends.ort_utils import convert_to_onnx\n",
"from collections import OrderedDict\n",
"from transformer_deploy.benchmarks.utils import track_infer_time, print_timings\n",
Expand Down Expand Up @@ -2219,16 +2222,17 @@
}
],
"source": [
"from transformer_deploy.backends.ort_utils import optimize_onnx, create_model_for_provider\n",
"from transformer_deploy.backends.ort_utils import optimize_onnx, create_model_for_provider, cpu_quantization\n",
"from onnxruntime.quantization import quantize_dynamic, QuantType\n",
"\n",
"optimize_onnx(\n",
" onnx_path=\"baseline.onnx\",\n",
" onnx_optim_fp16_path=\"baseline-optimized.onnx\",\n",
" onnx_optim_model_path=\"baseline-optimized.onnx\",\n",
" fp16=True,\n",
" use_cuda=True,\n",
")\n",
"onnx_model = create_model_for_provider(path=\"baseline-optimized.onnx\", provider_to_use=\"CUDAExecutionProvider\")\n",
"quantize_dynamic(\"baseline-optimized.onnx\", \"baseline-quantized.onnx\", weight_type=QuantType.QUInt8)"
"\n",
"cpu_quantization(input_model_path=\"baseline-optimized.onnx\", output_model_path=\"baseline-quantized.onnx\")"
]
},
{
Expand Down
6 changes: 4 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ known_third_party =
transformers
sentencepiece
tokenizers
onnxruntime-tools
onnxruntime
onnx
nvidia-pyindex
tritonclient[all]
tritonclient
tensorrt
pycuda
numpy
fastapi
requests
Expand Down
1 change: 0 additions & 1 deletion src/transformer_deploy/QDQModels/QDQRoberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
if position_ids is None:
# TODO here?
if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
Expand Down
36 changes: 27 additions & 9 deletions src/transformer_deploy/backends/ort_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,27 @@
from typing import Union

import torch
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
from onnxruntime import ExecutionMode, GraphOptimizationLevel, InferenceSession, SessionOptions
from onnxruntime.quantization import QuantType, quantize_dynamic
from onnxruntime.transformers import optimizer
from onnxruntime.transformers.fusion_options import FusionOptions
from onnxruntime.transformers.onnx_model_bert import BertOnnxModel
from torch.onnx import TrainingMode
from transformers import PreTrainedModel


def create_model_for_provider(path: str, provider_to_use: Union[str, List]) -> InferenceSession:
def create_model_for_provider(
path: str, provider_to_use: Union[str, List], nb_threads: int = multiprocessing.cpu_count(), nb_instances: int = 0
) -> InferenceSession:
options = SessionOptions()
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
if type(provider_to_use) != list:
provider_to_use = [provider_to_use]
if provider_to_use == ["CPUExecutionProvider"]:
options.intra_op_num_threads = multiprocessing.cpu_count()
options.execution_mode = ExecutionMode.ORT_SEQUENTIAL if nb_instances <= 1 else ExecutionMode.ORT_PARALLEL
options.intra_op_num_threads = nb_threads
if nb_instances > 1:
options.inter_op_num_threads = nb_instances
return InferenceSession(path, options, providers=provider_to_use)


Expand All @@ -61,19 +67,31 @@ def convert_to_onnx(
)


def optimize_onnx(onnx_path: str, onnx_optim_fp16_path: str, use_cuda: bool) -> None:
def optimize_onnx(onnx_path: str, onnx_optim_model_path: str, fp16: bool, use_cuda: bool) -> None:
optimization_options = FusionOptions("bert")
optimization_options.enable_gelu_approximation = True # additional optimization
optimization_options.enable_gelu_approximation = False # additional optimization
optimized_model: BertOnnxModel = optimizer.optimize_model(
input=onnx_path,
model_type="bert",
use_gpu=use_cuda,
opt_level=1,
num_heads=0, # automatic detection don't work with opset 13
num_heads=0, # automatic detection may not work with opset 13
hidden_size=0, # automatic detection
optimization_options=optimization_options,
)

optimized_model.convert_float_to_float16() # FP32 -> FP16
if fp16:
optimized_model.convert_float_to_float16() # FP32 -> FP16
logging.info(f"optimizations applied: {optimized_model.get_fused_operator_statistics()}")
optimized_model.save_model_to_file(onnx_optim_fp16_path)
optimized_model.save_model_to_file(onnx_optim_model_path)


def cpu_quantization(input_model_path: str, output_model_path: str):
quantize_dynamic(
model_input=input_model_path,
model_output=output_model_path,
op_types_to_quantize=["MatMul", "Attention"],
weight_type=QuantType.QInt8,
per_channel=True,
reduce_range=True,
extra_options={"WeightSymmetric": False, "MatMulConstBOnly": True},
)
38 changes: 0 additions & 38 deletions src/transformer_deploy/backends/trt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,44 +35,6 @@
)


class Calibrator(trt.IInt8Calibrator):
def __init__(self):
trt.IInt8Calibrator.__init__(self)
self.algorithm = trt.CalibrationAlgoType.MINMAX_CALIBRATION
self.batch_size = 32

input_list: List[ndarray] = [np.zeros((32, 512), dtype=np.int32) for _ in range(3)]
# allocate GPU memory for input tensors
self.device_inputs: List[DeviceAllocation] = [cuda.mem_alloc(tensor.nbytes) for tensor in input_list]
for h_input, d_input in zip(input_list, self.device_inputs):
cuda.memcpy_htod_async(d_input, h_input) # host to GPU
self.count = 0

def get_algorithm(self):
return trt.CalibrationAlgoType.MINMAX_CALIBRATION

def get_batch_size(self):
return self.batch_size

def get_batch(self, names, p_str=None):
self.count += 1
if self.count > 20:
return []
# return pointers to arrays
return [int(d) for d in self.device_inputs]

def read_calibration_cache(self):
return None

def write_calibration_cache(self, cache):
with open("calibration_cache.bin", "wb") as f:
f.write(cache)

def free(self):
for dinput in self.device_inputs:
dinput.free()


def setup_binding_shapes(
context: trt.IExecutionContext,
host_inputs: List[np.ndarray],
Expand Down
1 change: 1 addition & 0 deletions src/transformer_deploy/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def track_infer_time(buffer: [int]):
def generate_input(
seq_len: int, batch_size: int, include_token_ids: bool, device: str = "cuda"
) -> Tuple[Dict[str, torch.Tensor], Dict[str, np.ndarray]]:
assert device in ["cuda", "cpu"]
shape = (batch_size, seq_len)
inputs_pytorch: OrderedDict[str, torch.Tensor] = OrderedDict()
inputs_pytorch["input_ids"] = torch.randint(high=100, size=shape, dtype=torch.long, device=device)
Expand Down
Loading

0 comments on commit 2b369c1

Please sign in to comment.