Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add conv1d support in BYOC TRT by converting conv1d to conv2d #9324

Merged
merged 1 commit into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def partition_for_tensorrt(
transform.RemoveUnusedFunctions(),
transform.ConvertLayout(
{
"nn.conv1d": ["NCW", "default"],
"nn.conv2d": ["NCHW", "default"],
"nn.conv3d": ["NCDHW", "default"],
"nn.conv2d_transpose": ["NCHW", "default"],
Expand Down Expand Up @@ -374,6 +375,23 @@ def softmax_annotate_fn(expr): # pylint: disable=unused-variable
return True


@_register_external_dynamic_check_func("nn.conv1d")
def conv1d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv1d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if attrs.data_layout != "NCW":
logger.info("nn.conv1d: data_layout is %s but must be NCW.", attrs.data_layout)
return False
if attrs.kernel_layout != "OIW":
logger.info("nn.conv1d: kernel_layout is %s but must be OIW.", attrs.kernel_layout)
return False
return True


@_register_external_dynamic_check_func("nn.conv2d")
def conv2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv2d is supported by TensorRT."""
Expand Down Expand Up @@ -912,6 +930,7 @@ def __init__(self):
def visit_call(self, call):
compute_intensive_ops = set(
[
"nn.conv1d",
"nn.conv2d",
"nn.conv2d_transpose",
"nn.conv3d",
Expand Down
50 changes: 50 additions & 0 deletions src/runtime/contrib/tensorrt/tensorrt_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,55 @@ class ElementWiseBinaryOpConverter : public TensorRTOpConverter {
}
};

class Conv1DOpConverter : public TensorRTOpConverter {
public:
Conv1DOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {}

void Convert(TensorRTOpConverterParams* params) const {
auto input_tensor = params->inputs.at(0).tensor;
auto input_dims = TrtDimsToVector(input_tensor->getDimensions());
auto weight_shape = params->inputs.at(1).weight_shape;
ICHECK_EQ(params->node.GetAttr<std::vector<std::string>>("data_layout")[0], "NCW");
ICHECK_EQ(params->node.GetAttr<std::vector<std::string>>("kernel_layout")[0], "OIW");
auto str_strides = params->node.GetAttr<std::vector<std::string>>("strides");
auto str_dilation = params->node.GetAttr<std::vector<std::string>>("dilation");
auto str_padding = params->node.GetAttr<std::vector<std::string>>("padding");
int groups = std::stoi(params->node.GetAttr<std::vector<std::string>>("groups")[0]);
int channels = weight_shape[0];
if (params->node.HasAttr("channels") &&
!params->node.GetAttr<std::vector<std::string>>("channels")[0].empty()) {
channels = std::stoi(params->node.GetAttr<std::vector<std::string>>("channels")[0]);
}

auto shuffle_layer = params->network->addShuffle(*input_tensor);
std::vector<int> new_shape = {input_dims[0], input_dims[1], 1};
shuffle_layer->setReshapeDimensions(VectorToTrtDims(new_shape));
input_tensor = shuffle_layer->getOutput(0);

const auto kernel_size = nvinfer1::DimsHW(weight_shape[2], 1);
nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0};

auto conv_layer = params->network->addConvolution(*input_tensor, channels, kernel_size,
params->inputs.at(1).weight, bias);
ICHECK(conv_layer != nullptr);
conv_layer->setPadding(nvinfer1::DimsHW(std::stoi(str_padding[0]), 0));
ICHECK_EQ(str_strides.size(), 1);
const auto strides = nvinfer1::DimsHW(std::stoi(str_strides[0]), 1);
conv_layer->setStride(strides);
ICHECK_EQ(str_dilation.size(), 1);
const auto dilation = nvinfer1::DimsHW(std::stoi(str_dilation[0]), 1);
conv_layer->setDilation(dilation);
conv_layer->setNbGroups(groups);
input_tensor = conv_layer->getOutput(0);

auto conv_output_dims = TrtDimsToVector(input_tensor->getDimensions());
std::vector<int> back_shape = {0, 0};
auto shuffle_back_layer = params->network->addShuffle(*input_tensor);
shuffle_back_layer->setReshapeDimensions(VectorToTrtDims(back_shape));
params->outputs.push_back(shuffle_back_layer->getOutput(0));
}
};

class Conv2DOpConverter : public TensorRTOpConverter {
public:
Conv2DOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {}
Expand Down Expand Up @@ -1198,6 +1247,7 @@ GetOpConverters() {
map->emplace("nn.batch_norm", std::make_shared<BatchNormOpConverter>());
map->emplace("nn.layer_norm", std::make_shared<LayerNormOpConverter>());
map->emplace("nn.softmax", std::make_shared<SoftmaxOpConverter>());
map->emplace("nn.conv1d", std::make_shared<Conv1DOpConverter>());
map->emplace("nn.conv2d", std::make_shared<Conv2DOpConverter>());
map->emplace("nn.dense", std::make_shared<DenseOpConverter>());
map->emplace("nn.bias_add", std::make_shared<BiasAddOpConverter>());
Expand Down
28 changes: 28 additions & 0 deletions tests/python/contrib/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,34 @@ def load_vm():
assert_result_dict_holds(result_dict)


def test_conv1d(run_module):
def get_graph(
x_shape=((1, 3, 224)),
k_shape=(10, 3, 3),
groups=1,
padding=(1, 1),
strides=(1),
dilation=(1),
channels=None,
):
x = relay.var("x", shape=(x_shape), dtype="float32")
kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
out = relay.nn.conv1d(
x,
kernel,
kernel_size=k_shape[2:3],
groups=groups,
padding=padding,
strides=strides,
dilation=dilation,
channels=channels,
)
f = relay.Function([x, kernel], out)
return f, {"x": x_shape, "kernel": k_shape}, ["kernel"]

run_and_verify_func(get_graph(channels=10), run_module=run_module)


def test_conv2d(run_module):
def get_graph(
x_shape=(1, 32, 8, 8),
Expand Down