This is a branch of torch2trt with dynamic input support.
Here are some examples
from torch2trt_dynamic import module2trt, BuildEngineConfig
import torch
from torchvision.models import resnet18
# create some regular pytorch model...
model = resnet18().cuda().eval()
# create example data
x = torch.ones((1, 3, 224, 224)).cuda()
# convert to TensorRT feeding sample data as input
config = BuildEngineConfig(
shape_ranges=dict(
x=dict(
min=(1, 3, 224, 224),
opt=(2, 3, 224, 224),
max=(4, 3, 224, 224),
)
))
trt_model = module2trt(
model,
args=[x],
config=config)
We can execute the returned TRTModule
just like the original PyTorch model
x = torch.rand(1, 3, 224, 224).cuda()
with torch.no_grad():
y = model(x)
y_trt = trt_model(x)
# check the output against PyTorch
torch.testing.assert_close(y, y_trt)
We can save the model as a state_dict
.
torch.save(trt_model.state_dict(), 'my_engine.pth')
We can load the saved model into a TRTModule
from torch2trt_dynamic import TRTModule
trt_model = TRTModule()
trt_model.load_state_dict(torch.load('my_engine.pth'))
To install without compiling plugins, call the following
git clone https://github.com/grimoire/torch2trt_dynamic.git torch2trt_dynamic
cd torch2trt_dynamic
pip install .
Some layers such as GN
need c++ plugins. Install the plugin project below
DO NOT FORGET to export the environment variable AMIRSTAN_LIBRARY_PATH
Here we show how to add a converter for the ReLU
module using the TensorRT Python API.
import tensorrt as trt
from torch2trt_dynamic import tensorrt_converter
@tensorrt_converter('torch.nn.ReLU.forward')
def convert_ReLU(ctx):
input = ctx.method_args[1]
output = ctx.method_return
layer = ctx.network.add_activation(input=input._trt, type=trt.ActivationType.RELU)
output._trt = layer.get_output(0)
The converter takes one argument, a ConversionContext
, which will contain
the following
-
ctx.network
- The TensorRT network that is being constructed. -
ctx.method_args
- Positional arguments that were passed to the specified PyTorch function. The_trt
attribute is set for relevant input tensors. -
ctx.method_kwargs
- Keyword arguments that were passed to the specified PyTorch function. -
ctx.method_return
- The value returned by the specified PyTorch function. The converter must set the_trt
attribute where relevant.
Please see this folder for more examples.