Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[backends] Add functionality to TRT backend #1753

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 98 additions & 20 deletions torchbenchmark/util/backends/trt.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,68 @@
from typing import List
import torch
import argparse

from torchbenchmark.util.backends import create_backend
from torchbenchmark.util.backends import create_backend
from torchbenchmark.util.env_check import is_hf_model


def parse_torch_trt_args(backend_args: List[str]):
"""Parses CLI-provided backend arguments to extract Torch-TRT keywords

Returns kwargs dictionary and remainder arguments which were unrecognized
"""
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument(
"--truncate_long_and_double",
default=None,
action="store_true",
help="Whether to automatically truncate long and double operations",
)
arg_parser.add_argument(
"--workspace_size", type=int, help="Size of workspace allotted to TensorRT"
)
arg_parser.add_argument(
"--min_block_size",
type=int,
help="Minimum number of operations in an accelerated TRT block",
)
arg_parser.add_argument(
"--ir",
type=str,
help="Which internal representation to use: {'ts', 'dynamo_compile', 'fx_ts_compat', ...}",
)
args, unknown = arg_parser.parse_known_args(backend_args)

# Remove unspecified arguments from the args dictionary
# (Only pass through user-specified args)
parsed_args = vars(args)
for key in list(parsed_args.keys()):
if parsed_args[key] is None:
del parsed_args[key]

return parsed_args, unknown


@create_backend
def fx2trt(model: 'torchbenchmark.util.model.BenchmarkModel', backend_args: List[str]):
def fx2trt(model: "torchbenchmark.util.model.BenchmarkModel", backend_args: List[str]):
FP16 = True if model.dargs.precision == "fp16" else False
HF_MODEL = True if is_hf_model(model) else False
assert model.device == "cuda" and model.test == "eval", f"fx2trt only works on CUDA inference tests."
assert (
model.device == "cuda" and model.test == "eval"
), f"fx2trt only works on CUDA inference tests."

def _fx2trt():
from torch_tensorrt.fx import compile
from torch_tensorrt.fx.utils import LowerPrecision

module, example_inputs = model.get_module()
precision = LowerPrecision.FP16 if FP16 else LowerPrecision.FP32

if HF_MODEL:
from transformers.utils.fx import symbolic_trace as hf_symbolic_trace

traced_model = hf_symbolic_trace(
module,
batch_size = model.batch_size,
sequence_lenghth = model.max_length
module, batch_size=model.batch_size, sequence_lenghth=model.max_length
)
trt_model = compile(
traced_model,
Expand All @@ -31,27 +73,63 @@ def _fx2trt():
max_workspace_size=20 << 30,
)
else:
trt_model = compile(module=module,
input=example_inputs,
max_batch_size=model.batch_size,
lower_precision=precision)
trt_model = compile(
module=module,
input=example_inputs,
max_batch_size=model.batch_size,
lower_precision=precision,
)
model.set_module(trt_model)

return _fx2trt, backend_args


@create_backend
def torch_trt(model: 'torchbenchmark.util.model.BenchmarkModel', backend_args: List[str]):
def torch_trt(
model: "torchbenchmark.util.model.BenchmarkModel", backend_args: List[str]
):
"""Backend for Torch-TRT

Can be directly invoked from the command line, for example via:
python run.py resnet18 -d cuda -t eval --backend torch_trt --precision fp32 --truncate_long_and_double

Options include:
--truncate_long_and_double: Whether to automatically truncate long and double operations
--min_block_size: Minimum number of operations in an accelerated TRT block
--workspace_size: Size of workspace allotted to TensorRT
--ir: Which internal representation to use: {"ts", "dynamo_compile", "fx_ts_compat", ...}
"""
FP16 = True if model.dargs.precision == "fp16" else False
assert model.device == "cuda" and model.test == "eval", f"fx2trt only works on CUDA inference tests."
assert (
model.device == "cuda" and model.test == "eval"
), f"Torch-TRT only works on CUDA inference tests."

# Extract relevant Torch-TRT arguments from the provided CLI arguments
torch_trt_kwargs, backend_args = parse_torch_trt_args(backend_args)

def _torch_trt():
"""Helper function for invoking Torch-TRT"""
import torch_tensorrt

module, example_inputs = model.get_module()
if FP16:
torchtrt_dtype = torch_tensorrt.dtype.half
torch_dtype = torch.half
else:
torchtrt_dtype = torch_tensorrt.dtype.float
torch_dtype = torch.float32
trt_input = [torch_tensorrt.Input(shape=example_inputs[0].shape, dtype=torch_dtype)]
trt_module = torch_tensorrt.compile(module, inputs=trt_input, enabled_precisions=torchtrt_dtype)
torch_dtype_precision = torch.half if FP16 else torch.float32

trt_input = [
torch_tensorrt.Input(shape=input_.shape, dtype=input_.dtype)
for input_ in example_inputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to add check that the type of example_inputs is List[tensor]? Actually, many models have different types of inputs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One suggestion is to use pytree to traverse the input and cast them to torch_tensorrt.Input. Similar to this:

from torch.utils._pytree import tree_map

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the different ir choices for Torch-TRT process inputs differently, but all can handle Torch Tensor inputs, I passed the example inputs directly to the compiler instead, so the selected ir can handle the inputs/casting as necessary.

]

print(
f"Compiling {model.name} with batch size {model.batch_size}, precision {model.dargs.precision}, "
+ f"and {'default' if 'ir' not in torch_trt_kwargs else torch_trt_kwargs['ir']} IR"
)

trt_module = torch_tensorrt.compile(
module,
inputs=trt_input,
enabled_precisions={torch_dtype_precision},
**torch_trt_kwargs,
)
model.set_module(trt_module)

return _torch_trt, backend_args
Empty file.
223 changes: 223 additions & 0 deletions userbenchmark/torch_trt/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import argparse
import traceback
import torch

import numpy as np

import json
import os
import time
from datetime import datetime
from typing import List

from torchbenchmark import (
load_canary_model_by_name,
load_model_by_name,
list_models,
ModelNotFoundError,
)


def cli(args: List[str]):
"""Parse input arguments, extracting model specification and batch size"""
arg_parser = argparse.ArgumentParser(args)
arg_parser.add_argument(
"--model",
help="Full or partial name of a model to run. If partial, picks the first match.",
default="",
type=str,
)
arg_parser.add_argument(
"--bs",
help="Input batch size to test.",
default=1,
type=int,
)
arg_parser.add_argument(
"--num_warmup",
help="Number of inference warmup iterations.",
default=10,
type=int,
)
arg_parser.add_argument(
"--num_iter",
help="Number of inference iterations for benchmarking.",
default=100,
type=int,
)
parsed_args, unknown = arg_parser.parse_known_args()

return vars(parsed_args), unknown


def save_metrics(metrics):
"""Save metrics to a JSON file with formatted filename"""
metrics_json = {
"name": "torch_trt",
"environ": {
"metrics_version": "v0.1",
"pytorch_git_version": torch.version.git_version,
},
"metrics": metrics,
}

# Obtain target save directory for JSON metrics from current save directory
current_dir = os.path.dirname(os.path.abspath(__file__))
target_dir = os.path.normpath(
os.path.join(current_dir, "../../.userbenchmark/torch_trt/")
)

os.makedirs(target_dir, exist_ok=True)

# Format filename and path to save metrics
metrics_file = "metrics-{}.json".format(
datetime.fromtimestamp(time.time()).strftime("%Y%m%d%H%M%S")
)
metrics_save_path = os.path.join(target_dir, metrics_file)

with open(metrics_save_path, "w") as f:
json.dump(metrics_json, f, indent=4)


def run_single_model(
Model,
batch_size: int,
extra_args: List[str],
selected_ir: str,
num_warmup: int,
num_iter: int,
):
"""Run inference benchmarking on a single model"""
# Build TorchBench model instance, with backend having the userbenchmark name
# This invokes the torch_trt backend functionality directly
model = Model(
device="cuda",
test="eval",
jit=False,
batch_size=batch_size,
extra_args=[
"--backend",
]
+ extra_args,
)

metrics = run_one_step(model.invoke, model, num_warmup, num_iter, selected_ir)

# Print dynamo compilation metrics, if there are any.
try:
if model.pt2_compilation_time:
metrics[
f"{model.name}.bs_{model.batch_size}.precision_{model.dargs.precision}."
+ f"ir_{selected_ir}.pt2_compilation_time"
] = model.pt2_compilation_time
except:
pass

return metrics


def run_one_step(func, model, num_warmup, num_iter, selected_ir):
# Warmup model inference
for _ in range(num_warmup):
func()

result_summary = []

# Run inference for the specified number of iterations
for _ in range(num_iter):
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

# Collect time_ns() instead of time() which does not provide better precision than 1
# second according to https://docs.python.org/3/library/time.html#time.time.
t0 = time.time_ns()
start_event.record()
func()
end_event.record()
torch.cuda.synchronize()
t1 = time.time_ns()
result_summary.append(
(start_event.elapsed_time(end_event), (t1 - t0) / 1_000_000)
)

# Get median times for GPU and CPU Walltime
gpu_time = np.median(list(map(lambda x: x[0], result_summary)))
cpu_walltime = np.median(list(map(lambda x: x[1], result_summary)))

if hasattr(model, "NUM_BATCHES"):
median_gpu_time_per_batch = gpu_time / model.NUM_BATCHES
median_cpu_walltime_per_batch = cpu_walltime / model.NUM_BATCHES
else:
median_gpu_time_per_batch = gpu_time
median_cpu_walltime_per_batch = cpu_walltime

metrics = {
f"{model.name}.bs_{model.batch_size}.precision_{model.dargs.precision}."
+ f"ir_{selected_ir}.median_gpu_time_per_batch": median_gpu_time_per_batch,
f"{model.name}.bs_{model.batch_size}.precision_{model.dargs.precision}."
+ f"ir_{selected_ir}.median_cpu_walltime_per_batch": median_cpu_walltime_per_batch,
}

return metrics


def run(args: List[str]):
"""Run inference and extract requested metrics"""
parsed_args, unknown_args = cli(args)

# Attempt to extract specified IR for logging purposes
try:
ir_idx = unknown_args.index("--ir")
selected_ir = unknown_args[ir_idx + 1]
except (ValueError, IndexError):
selected_ir = "default"

# Parse model string if specified, otherwise run all models
# Adapted from benchmark/run.py
if parsed_args["model"]:
try:
Model = load_model_by_name(parsed_args["model"])
except ModuleNotFoundError:
traceback.print_exc()
exit(-1)
except ModelNotFoundError:
print(
f"Warning: The model {parsed_args['model']} cannot be found at core set."
)
if not Model:
try:
Model = load_canary_model_by_name(parsed_args["model"])
except ModuleNotFoundError:
traceback.print_exc()
exit(-1)
except ModelNotFoundError:
print(
f"Error: The model {parsed_args['model']} cannot be found at either core or canary model set."
)
exit(-1)

all_metrics = run_single_model(
Model,
parsed_args["bs"],
unknown_args,
selected_ir,
parsed_args["num_warmup"],
parsed_args["num_iter"],
)

else:
all_metrics = {}

for Model in list_models():
metrics = run_single_model(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code will work for running a single model. However, it won't work to run a batch of models.

This is because there is no isolation between running the models. For example, model 1 might set some global torch configuration that will model 2 to be very slow or even crash (for example, torch.cudnn.benchmark). Some models have benign "memory leak" that won't cause problem in model training, but it will cause problem in benchmarking multiple models in the same process.

We suggest using the ModelTask() approach used by the torch-nightly userbenchmark: https://github.com/pytorch/benchmark/blob/main/userbenchmark/torch-nightly/run.py#L163
It will run each model in an isolated process, and doesn't have the limits mentioned above.

Model,
parsed_args["bs"],
unknown_args,
selected_ir,
parsed_args["num_warmup"],
parsed_args["num_iter"],
)
all_metrics = {**all_metrics, **metrics}

save_metrics(all_metrics)