From fc2dbe023792301557df5d1d1d8d1d5bb7251d7a Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Sat, 8 May 2021 06:35:53 -0700 Subject: [PATCH] [BYOC][TensorRT] Add nn.batch_matmul, nn.layer_norm, erf (#8005) --- docs/deploy/tensorrt.rst | 6 ++ python/tvm/relay/op/contrib/tensorrt.py | 29 ++++++ src/runtime/contrib/tensorrt/tensorrt_ops.cc | 93 ++++++++++++++++++++ tests/python/contrib/test_tensorrt.py | 44 ++++++++- 4 files changed, 171 insertions(+), 1 deletion(-) diff --git a/docs/deploy/tensorrt.rst b/docs/deploy/tensorrt.rst index 7eb0e20d1590..6ec2eb469b44 100644 --- a/docs/deploy/tensorrt.rst +++ b/docs/deploy/tensorrt.rst @@ -189,6 +189,8 @@ Operator support +------------------------+------------------------------------+ | nn.batch_norm | | +------------------------+------------------------------------+ +| nn.layer_norm | | ++------------------------+------------------------------------+ | nn.softmax | | +------------------------+------------------------------------+ | nn.conv2d | | @@ -261,6 +263,8 @@ Operator support +------------------------+------------------------------------+ | nn.adaptive_avg_pool2d | | +------------------------+------------------------------------+ +| nn.batch_matmul | | ++------------------------+------------------------------------+ | clip | Requires TensorRT 5.1.5 or greater | +------------------------+------------------------------------+ | nn.leaky_relu | Requires TensorRT 5.1.5 or greater | @@ -285,6 +289,8 @@ Operator support +------------------------+------------------------------------+ | nn.conv3d_transpose | Requires TensorRT 6.0.1 or greater | +------------------------+------------------------------------+ +| erf | Requires TensorRT 7.0.0 or greater | ++------------------------+------------------------------------+ Adding a new operator diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index afdea9712342..817d7e0908e9 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -304,6 +304,7 @@ def _func_wrapper(attrs, args, op_name): _register_external_op_helper_with_checker("cos", trt_version_annotate_fn((5, 1, 5))) _register_external_op_helper_with_checker("atan", trt_version_annotate_fn((5, 1, 5))) _register_external_op_helper_with_checker("ceil", trt_version_annotate_fn((5, 1, 5))) +_register_external_op_helper_with_checker("erf", trt_version_annotate_fn((7, 0, 0))) @_register_external_dynamic_check_func("add") @@ -410,6 +411,34 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable return True +@_register_external_dynamic_check_func("nn.batch_matmul") +def batch_matmul_annotate_fn(expr): + """Check if dense is supported by TensorRT.""" + + if any([x.checked_type.dtype != "float32" for x in expr.args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if get_tensorrt_use_implicit_batch_mode() and len(expr.args[0].checked_type.shape) != len( + expr.args[1].checked_type.shape + ): + logger.info("nn.batch_matmul: requires use_implict_batch=False.") + return False + return True + + +@_register_external_dynamic_check_func("nn.layer_norm") +def layer_norm_annotate_fn(expr): + """Check if dense is supported by TensorRT.""" + + if any([x.checked_type.dtype != "float32" for x in expr.args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if get_tensorrt_use_implicit_batch_mode() and int(expr.attrs.axis) == 0: + logger.info("nn.layer_norm: requires use_implict_batch=False.") + return False + return True + + @_register_external_dynamic_check_func("nn.bias_add") def bias_add_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.bias_add is supported by TensorRT.""" diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc index 824178eaa619..c59591a87a22 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -463,6 +463,78 @@ class BatchNormOpConverter : public TensorRTOpConverter { } }; +class LayerNormOpConverter : public TensorRTOpConverter { + public: + LayerNormOpConverter() : TensorRTOpConverter({kTensor, kWeight, kWeight}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + auto gamma_input = params->inputs.at(1).weight; + auto beta_input = params->inputs.at(2).weight; + ICHECK_EQ(gamma_input.count, beta_input.count); + + const float epsilon = std::stof(params->node.GetAttr>("epsilon")[0]); + const bool scale = std::stoi(params->node.GetAttr>("scale")[0]); + const bool center = std::stoi(params->node.GetAttr>("center")[0]); + const int input_rank = input->getDimensions().nbDims; + const int original_axis = std::stoi(params->node.GetAttr>("axis")[0]); + const int axis = ConvertAxis(params, original_axis, input_rank); + + std::vector weight_shape(input_rank, 1); + weight_shape[axis] = gamma_input.count; + auto gamma = + params->network->addConstant(VectorToTrtDims(weight_shape), gamma_input)->getOutput(0); + auto beta = + params->network->addConstant(VectorToTrtDims(weight_shape), beta_input)->getOutput(0); + + // Compute mean + auto mean_layer = params->network->addReduce(*input, nvinfer1::ReduceOperation::kAVG, 1 << axis, + /*keepdims=*/true); + ICHECK(mean_layer != nullptr); + auto mean = mean_layer->getOutput(0); + // Compute variance + auto diff_layer = + params->network->addElementWise(*input, *mean, nvinfer1::ElementWiseOperation::kSUB); + ICHECK(diff_layer != nullptr); + auto square_layer = + params->network->addElementWise(*diff_layer->getOutput(0), *diff_layer->getOutput(0), + nvinfer1::ElementWiseOperation::kPROD); + ICHECK(square_layer != nullptr); + auto var_layer = params->network->addReduce( + *square_layer->getOutput(0), nvinfer1::ReduceOperation::kAVG, 1 << axis, /*keepdims=*/true); + ICHECK(var_layer != nullptr); + auto var = var_layer->getOutput(0); + // sqrt(var + epsilon) + auto epsilon_tensor = CreateScalar(params, epsilon, var->getDimensions()); + auto denom_add_layer = params->network->addElementWise(*var, *epsilon_tensor, + nvinfer1::ElementWiseOperation::kSUM); + ICHECK(denom_add_layer != nullptr); + auto denom_layer = + params->network->addUnary(*denom_add_layer->getOutput(0), nvinfer1::UnaryOperation::kSQRT); + ICHECK(denom_layer != nullptr); + // (input - mean) / sqrt(var + epsilon) + auto output_layer = + params->network->addElementWise(*diff_layer->getOutput(0), *denom_layer->getOutput(0), + nvinfer1::ElementWiseOperation::kDIV); + ICHECK(output_layer != nullptr); + auto output = output_layer->getOutput(0); + + if (scale) { + auto scale_layer = + params->network->addElementWise(*output, *gamma, nvinfer1::ElementWiseOperation::kPROD); + ICHECK(scale_layer != nullptr); + output = scale_layer->getOutput(0); + } + if (center) { + auto center_layer = + params->network->addElementWise(*output, *beta, nvinfer1::ElementWiseOperation::kSUM); + ICHECK(center_layer != nullptr); + output = center_layer->getOutput(0); + } + params->outputs.push_back(output); + } +}; + class BatchFlattenOpConverter : public TensorRTOpConverter { public: BatchFlattenOpConverter() : TensorRTOpConverter({kTensor}) {} @@ -686,6 +758,9 @@ class UnaryOpConverter : public TensorRTOpConverter { {"atan", nvinfer1::UnaryOperation::kATAN}, {"ceil", nvinfer1::UnaryOperation::kCEIL}, {"floor", nvinfer1::UnaryOperation::kFLOOR}, +#endif +#if TRT_VERSION_GE(7, 0, 0) + {"erf", nvinfer1::UnaryOperation::kERF}, #endif }; auto it = op_map.find(params->op_name); @@ -1039,6 +1114,19 @@ class AdaptivePoolingOpConverter : public TensorRTOpConverter { } }; +class BatchMatmulOpConverter : public TensorRTOpConverter { + public: + BatchMatmulOpConverter() : TensorRTOpConverter({kTensor, kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + nvinfer1::IMatrixMultiplyLayer* matmul_layer = params->network->addMatrixMultiply( + *params->inputs.at(0).tensor, nvinfer1::MatrixOperation::kNONE, + *params->inputs.at(1).tensor, nvinfer1::MatrixOperation::kTRANSPOSE); + ICHECK(matmul_layer != nullptr); + params->outputs.push_back(matmul_layer->getOutput(0)); + } +}; + const std::shared_ptr>> GetOpConverters() { static auto map = @@ -1048,6 +1136,7 @@ GetOpConverters() { map->emplace("sigmoid", std::make_shared()); map->emplace("tanh", std::make_shared()); map->emplace("nn.batch_norm", std::make_shared()); + map->emplace("nn.layer_norm", std::make_shared()); map->emplace("nn.softmax", std::make_shared()); map->emplace("nn.conv2d", std::make_shared()); map->emplace("nn.dense", std::make_shared()); @@ -1084,6 +1173,7 @@ GetOpConverters() { map->emplace("mean", std::make_shared()); map->emplace("nn.adaptive_max_pool2d", std::make_shared()); map->emplace("nn.adaptive_avg_pool2d", std::make_shared()); + map->emplace("nn.batch_matmul", std::make_shared()); #if TRT_VERSION_GE(5, 1, 5) map->emplace("clip", std::make_shared()); map->emplace("nn.leaky_relu", std::make_shared()); @@ -1100,6 +1190,9 @@ GetOpConverters() { map->emplace("nn.avg_pool3d", std::make_shared()); map->emplace("nn.conv3d_transpose", std::make_shared()); #endif // TRT_VERSION_GE(6, 0, 1) +#if TRT_VERSION_GE(7, 0, 0) + map->emplace("erf", std::make_shared()); +#endif // TRT_VERSION_GE(7, 0, 0) return map; } diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index 7005f8647809..e18fe7961941 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -227,7 +227,7 @@ def test_tensorrt_not_compatible(): x = relay.var("x", shape=(xshape), dtype=dtype) y = relay.add(x, x) - z = relay.erf(y) + z = relay.cast(relay.cast(y, "int32"), "float32") out = relay.nn.relu(z) f = relay.Function([x], out) mod = tvm.IRModule() @@ -461,6 +461,17 @@ def get_graph(x_shape=(1, 16), k_shape=(32, 16)): run_and_verify_func(get_graph(k_shape=(1, 16))) +def test_batch_matmul(): + def get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64)): + x = relay.var("x", shape=(x_shape), dtype="float32") + y = relay.var("y", shape=(y_shape), dtype="float32") + out = relay.nn.batch_matmul(x, y) + f = relay.Function([x, y], out) + return f, {"x": x_shape, "y": y_shape}, [] + + run_and_verify_func(get_graph()) + + def test_bias_add(): def get_graph(x_shape=(1, 16), channels=16): x = relay.var("x", shape=(x_shape), dtype="float32") @@ -821,6 +832,36 @@ def get_graph(x_shape, param_shape, axis=1, epsilon=1e-5): run_and_verify_func(get_graph((1, 3, 8), (8,), axis=2)) +def test_layer_norm(): + def get_graph(x_shape, param_shape, axis=1, epsilon=1e-5): + x = relay.var("x", shape=(x_shape), dtype="float32") + gamma = relay.var("gamma", shape=(param_shape), dtype="float32") + beta = relay.var("beta", shape=(param_shape), dtype="float32") + out = relay.nn.layer_norm( + x, + gamma=gamma, + beta=beta, + axis=axis, + epsilon=epsilon, + center=True, + scale=True, + ) + f = relay.Function([x, gamma, beta], out) + return ( + f, + { + "x": x_shape, + "beta": param_shape, + "gamma": param_shape, + }, + ["beta", "gamma"], + ) + + run_and_verify_func(get_graph((1, 32, 8, 8), (32,))) + run_and_verify_func(get_graph((1, 8, 8, 32), (32,), axis=3, epsilon=1.001e-05)) + run_and_verify_func(get_graph((1, 8), (8,), axis=1)) + + def test_unary(): def get_graph(op, x_shape=(1, 8, 3, 3)): x = relay.var("x", shape=(x_shape), dtype="float32") @@ -842,6 +883,7 @@ def get_graph(op, x_shape=(1, 8, 3, 3)): relay.atan, relay.ceil, relay.floor, + relay.erf, ]: run_and_verify_func(get_graph(op))