From 9ba1f76005f7ad34e9c3be751bde8f03022a9feb Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Thu, 11 Mar 2021 15:41:47 +0800 Subject: [PATCH] [Feature] : Add Deformable Conv2d TensorRT Plugin (#858) * add dcn tensorrt plugin * prepare for fp16 support * fix for lint * limit column buffer * add docstring to memcpyPermute --- mmcv/ops/csrc/deform_conv_cuda_kernel.cuh | 9 +- .../csrc/tensorrt/plugins/trt_cuda_helper.cu | 66 ++++ .../csrc/tensorrt/plugins/trt_deform_conv.cpp | 320 ++++++++++++++++++ .../plugins/trt_deform_conv_kernel.cu | 161 +++++++++ mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp | 2 + mmcv/ops/csrc/tensorrt/trt_cuda_helper.cuh | 14 + mmcv/ops/csrc/tensorrt/trt_deform_conv.hpp | 116 +++++++ mmcv/ops/deform_conv.py | 2 +- tests/test_ops/test_tensorrt.py | 76 +++++ 9 files changed, 763 insertions(+), 3 deletions(-) create mode 100644 mmcv/ops/csrc/tensorrt/plugins/trt_cuda_helper.cu create mode 100644 mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv.cpp create mode 100644 mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv_kernel.cu create mode 100644 mmcv/ops/csrc/tensorrt/trt_deform_conv.hpp diff --git a/mmcv/ops/csrc/deform_conv_cuda_kernel.cuh b/mmcv/ops/csrc/deform_conv_cuda_kernel.cuh index b6ddf34c99..1c4d684202 100644 --- a/mmcv/ops/csrc/deform_conv_cuda_kernel.cuh +++ b/mmcv/ops/csrc/deform_conv_cuda_kernel.cuh @@ -66,11 +66,16 @@ #ifndef DEFORM_CONV_CUDA_KERNEL_CUH #define DEFORM_CONV_CUDA_KERNEL_CUH +#include +#ifdef MMCV_WITH_TRT +#include "common_cuda_helper.hpp" +#else // MMCV_WITH_TRT #ifdef MMCV_USE_PARROTS #include "parrots_cuda_helper.hpp" -#else +#else // MMCV_USE_PARROTS #include "pytorch_cuda_helper.hpp" -#endif +#endif // MMCV_USE_PARROTS +#endif // MMCV_WITH_TRT template __device__ T deformable_im2col_bilinear(const T *input, const int data_width, diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_cuda_helper.cu b/mmcv/ops/csrc/tensorrt/plugins/trt_cuda_helper.cu new file mode 100644 index 0000000000..5b85a4e567 --- /dev/null +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_cuda_helper.cu @@ -0,0 +1,66 @@ +#include "common_cuda_helper.hpp" +#include "trt_cuda_helper.cuh" +#include "trt_plugin_helper.hpp" + +using mmcv::TensorDesc; + +template +__global__ void copy_permute_kernel(scalar_t *dst, const scalar_t *src, int n, + TensorDesc ts_src_stride, + TensorDesc ts_dst_stride, + TensorDesc ts_permute) { + const int src_dim = ts_src_stride.dim; + int *src_stride = &(ts_src_stride.stride[0]); + int *dst_stride = &(ts_dst_stride.stride[0]); + int *permute = &(ts_permute.shape[0]); + CUDA_1D_KERNEL_LOOP(index, n) { + size_t dst_index = index; + size_t src_index = 0; + for (int i = 0; i < src_dim; ++i) { + int dim_index = dst_index / dst_stride[i]; + dst_index = dst_index % dst_stride[i]; + src_index += dim_index * src_stride[permute[i]]; + } + dst[index] = src[src_index]; + } +} + +template +void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size, + int *permute, int src_dim, cudaStream_t stream) { + size_t copy_size = 1; + TensorDesc ts_permute; + memcpy(&(ts_permute.shape[0]), permute, src_dim * sizeof(int)); + + TensorDesc ts_src_stride; + TensorDesc ts_dst_stride; + ts_src_stride.dim = src_dim; + ts_dst_stride.dim = src_dim; + int *src_stride = &(ts_src_stride.stride[0]); + int *dst_stride = &(ts_dst_stride.stride[0]); + int *dst_size = &(ts_dst_stride.shape[0]); + src_stride[src_dim - 1] = 1; + dst_stride[src_dim - 1] = 1; + + for (int i = src_dim - 1; i >= 0; --i) { + dst_size[i] = src_size[permute[i]]; + if (i < src_dim - 1) { + src_stride[i] = src_stride[i + 1] * src_size[i + 1]; + } + } + + for (int i = src_dim - 1; i >= 0; --i) { + copy_size *= dst_size[i]; + if (i < src_dim - 1) { + dst_stride[i] = dst_stride[i + 1] * dst_size[i + 1]; + } + } + + copy_permute_kernel + <<>>( + dst, src, copy_size, ts_src_stride, ts_dst_stride, ts_permute); +} + +template void memcpyPermute(float *dst, const float *src, int *src_size, + int *permute, int src_dim, + cudaStream_t stream); diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv.cpp new file mode 100644 index 0000000000..988e9bc46e --- /dev/null +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv.cpp @@ -0,0 +1,320 @@ +#include "trt_deform_conv.hpp" + +#include + +#include + +#include "trt_serialize.hpp" + +void DeformConvForwardCUDAKernelLauncher_float( + const float *input, const float *weight, const float *offset, float *output, + void *workspace, int batchSize, int nInputPlane, int inputHeight, + int inputWidth, int nOutputPlane, int kW, int kH, int dW, int dH, int padW, + int padH, int dilationW, int dilationH, int group, int deformable_group, + int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream); + +namespace { +static const char *PLUGIN_VERSION{"1"}; +static const char *PLUGIN_NAME{"MMCVDeformConv2d"}; +} // namespace + +nvinfer1::PluginFieldCollection DeformableConvPluginDynamicCreator::mFC{}; +std::vector + DeformableConvPluginDynamicCreator::mPluginAttributes; + +DeformableConvPluginDynamic::DeformableConvPluginDynamic( + const std::string &name, const nvinfer1::Dims &stride, + const nvinfer1::Dims &padding, const nvinfer1::Dims &dilation, + const int deformableGroup, const int group, int im2colStep) + : mLayerName(name), + mStride(stride), + mPadding(padding), + mDilation(dilation), + mDeformableGroup(deformableGroup), + mGroup(group), + mIm2colStep(im2colStep) { + cublasCreate(&m_cublas_handle); +} + +DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string name, + const void *data, + size_t length) + : mLayerName(name) { + deserialize_value(&data, &length, &mStride); + deserialize_value(&data, &length, &mPadding); + deserialize_value(&data, &length, &mDilation); + deserialize_value(&data, &length, &mDeformableGroup); + deserialize_value(&data, &length, &mGroup); + deserialize_value(&data, &length, &mIm2colStep); + cublasCreate(&m_cublas_handle); +} +DeformableConvPluginDynamic::~DeformableConvPluginDynamic() { + // destroy cublas handle + cublasDestroy(m_cublas_handle); +} + +nvinfer1::IPluginV2DynamicExt *DeformableConvPluginDynamic::clone() const { + DeformableConvPluginDynamic *plugin = + new DeformableConvPluginDynamic(mLayerName, mStride, mPadding, mDilation, + mDeformableGroup, mGroup, mIm2colStep); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; +} + +nvinfer1::DimsExprs DeformableConvPluginDynamic::getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, + nvinfer1::IExprBuilder &exprBuilder) { + nvinfer1::DimsExprs ret; + ret.nbDims = 4; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[2].d[0]; + + ret.d[2] = inputs[1].d[2]; + ret.d[3] = inputs[1].d[3]; + + return ret; +} + +bool DeformableConvPluginDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs, + int nbOutputs) { + if (pos == 0) { + return (inOut[pos].type == nvinfer1::DataType::kFLOAT && + inOut[pos].format == nvinfer1::TensorFormat::kLINEAR); + + } else { + return inOut[pos].type == inOut[0].type && + inOut[pos].format == inOut[0].format; + } +} + +void DeformableConvPluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {} + +size_t DeformableConvPluginDynamic::getWorkspaceSize( + const nvinfer1::PluginTensorDesc *inputs, int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const { + int sizeof_dtype = mmcv::getElementSize(outputs[0].type); + + int batch_size = inputs[0].dims.d[0]; + int nInputPlane = inputs[0].dims.d[1]; + int inputHeight = inputs[0].dims.d[2]; + int inputWidth = inputs[0].dims.d[3]; + + int nOutputPlane = outputs[0].dims.d[1]; + int outputHeight = outputs[0].dims.d[2]; + int outputWidth = outputs[0].dims.d[3]; + + int kW = inputs[2].dims.d[2]; + int kH = inputs[2].dims.d[3]; + int im2col_step = std::min(batch_size, mIm2colStep); + + size_t col_size = + mmcv::getAlignedSize(nInputPlane * kW * kH * im2col_step * outputHeight * + outputWidth * sizeof_dtype); + + size_t out_size = 0; + if (im2col_step != 1) + out_size = mmcv::getAlignedSize(batch_size * nOutputPlane * outputHeight * + outputWidth * sizeof_dtype); + + return col_size + out_size; +} + +int DeformableConvPluginDynamic::enqueue( + const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, + void *const *outputs, void *workSpace, cudaStream_t stream) { + if (m_cuda_stream != stream) { + cublasSetStream(m_cublas_handle, stream); + m_cuda_stream = stream; + } + + int batch_size = inputDesc[0].dims.d[0]; + int inputChannel = inputDesc[0].dims.d[1]; + int inputHeight = inputDesc[0].dims.d[2]; + int inputWidth = inputDesc[0].dims.d[3]; + int outputChannel = outputDesc[0].dims.d[1]; + int kernelHeight = inputDesc[2].dims.d[2]; + int kernelWidth = inputDesc[2].dims.d[3]; + + const void *x = inputs[0]; + const void *offset = inputs[1]; + const void *weight = inputs[2]; + void *output = outputs[0]; + int im2col_step = std::min(batch_size, mIm2colStep); + + // TODO: add fp16 support + auto data_type = inputDesc[0].type; + switch (data_type) { + case nvinfer1::DataType::kFLOAT: + DeformConvForwardCUDAKernelLauncher_float( + (float *)x, (float *)weight, (float *)offset, (float *)output, + workSpace, batch_size, inputChannel, inputHeight, inputWidth, + outputChannel, kernelWidth, kernelHeight, mStride.d[0], mStride.d[1], + mPadding.d[0], mPadding.d[1], mDilation.d[0], mDilation.d[1], mGroup, + mDeformableGroup, im2col_step, m_cublas_handle, stream); + break; + default: + return 1; + break; + } + + return 0; +} + +nvinfer1::DataType DeformableConvPluginDynamic::getOutputDataType( + int index, const nvinfer1::DataType *inputTypes, int nbInputs) const { + return inputTypes[0]; +} + +// IPluginV2 Methods +const char *DeformableConvPluginDynamic::getPluginType() const { + return PLUGIN_NAME; +} + +const char *DeformableConvPluginDynamic::getPluginVersion() const { + return PLUGIN_VERSION; +} + +int DeformableConvPluginDynamic::getNbOutputs() const { return 1; } + +int DeformableConvPluginDynamic::initialize() { return 0; } + +void DeformableConvPluginDynamic::terminate() {} + +size_t DeformableConvPluginDynamic::getSerializationSize() const { + return sizeof(mStride) + sizeof(mPadding) + sizeof(mDilation) + + sizeof(mDeformableGroup) + sizeof(mGroup) + sizeof(mIm2colStep); +} + +void DeformableConvPluginDynamic::serialize(void *buffer) const { + serialize_value(&buffer, mStride); + serialize_value(&buffer, mPadding); + serialize_value(&buffer, mDilation); + serialize_value(&buffer, mDeformableGroup); + serialize_value(&buffer, mGroup); + serialize_value(&buffer, mIm2colStep); +} + +void DeformableConvPluginDynamic::destroy() { + // This gets called when the network containing plugin is destroyed + delete this; +} + +void DeformableConvPluginDynamic::setPluginNamespace(const char *libNamespace) { + mNamespace = libNamespace; +} + +const char *DeformableConvPluginDynamic::getPluginNamespace() const { + return mNamespace.c_str(); +} + +////////////////////// creator ///////////////////////////// + +DeformableConvPluginDynamicCreator::DeformableConvPluginDynamicCreator() { + mPluginAttributes.emplace_back(nvinfer1::PluginField("stride")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("padding")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("dilation")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("groups")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("deform_groups")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("bias")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("im2col_step")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +const char *DeformableConvPluginDynamicCreator::getPluginName() const { + return PLUGIN_NAME; +} + +const char *DeformableConvPluginDynamicCreator::getPluginVersion() const { + return PLUGIN_VERSION; +} + +const nvinfer1::PluginFieldCollection * +DeformableConvPluginDynamicCreator::getFieldNames() { + return &mFC; +} + +nvinfer1::IPluginV2 *DeformableConvPluginDynamicCreator::createPlugin( + const char *name, const nvinfer1::PluginFieldCollection *fc) { + nvinfer1::Dims stride{2, {1, 1}}; + nvinfer1::Dims padding{2, {0, 0}}; + nvinfer1::Dims dilation{2, {1, 1}}; + int deformableGroup = 1; + int group = 1; + int im2col_step = 32; + + for (int i = 0; i < fc->nbFields; i++) { + if (fc->fields[i].data == nullptr) { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("stride") == 0) { + stride.nbDims = 2; + stride.d[0] = static_cast(fc->fields[i].data)[0]; + if (fc->fields[i].length == 1) { + stride.d[1] = stride.d[0]; + } else { + stride.d[1] = static_cast(fc->fields[i].data)[1]; + } + } + + if (field_name.compare("padding") == 0) { + padding.nbDims = 2; + padding.d[0] = static_cast(fc->fields[i].data)[0]; + if (fc->fields[i].length == 1) { + padding.d[1] = padding.d[0]; + } else { + padding.d[1] = static_cast(fc->fields[i].data)[1]; + } + } + + if (field_name.compare("dilation") == 0) { + dilation.nbDims = 2; + dilation.d[0] = static_cast(fc->fields[i].data)[0]; + if (fc->fields[i].length == 1) { + dilation.d[1] = dilation.d[0]; + } else { + dilation.d[1] = static_cast(fc->fields[i].data)[1]; + } + } + + if (field_name.compare("deformable_group") == 0) { + deformableGroup = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("group") == 0) { + group = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("im2col_step") == 0) { + im2col_step = static_cast(fc->fields[i].data)[0]; + } + } + + DeformableConvPluginDynamic *plugin = new DeformableConvPluginDynamic( + name, stride, padding, dilation, deformableGroup, group, im2col_step); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} + +nvinfer1::IPluginV2 *DeformableConvPluginDynamicCreator::deserializePlugin( + const char *name, const void *serialData, size_t serialLength) { + auto plugin = new DeformableConvPluginDynamic(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} + +void DeformableConvPluginDynamicCreator::setPluginNamespace( + const char *libNamespace) { + mNamespace = libNamespace; +} + +const char *DeformableConvPluginDynamicCreator::getPluginNamespace() const { + return mNamespace.c_str(); +} diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv_kernel.cu b/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv_kernel.cu new file mode 100644 index 0000000000..36a63dea9d --- /dev/null +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv_kernel.cu @@ -0,0 +1,161 @@ +#include +#include + +#include "common_cuda_helper.hpp" +#include "deform_conv_cuda_kernel.cuh" +#include "trt_cuda_helper.cuh" +#include "trt_plugin_helper.hpp" + +template +void trt_deformable_im2col(const T* data_input, const T* data_offset, + const int channels, const int height, + const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + T* data_col, cudaStream_t stream) { + int height_col = + (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = + (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_input, data_offset, height, width, ksize_h, ksize_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, parallel_imgs, channels, + deformable_group, height_col, width_col, data_col); + + cudaCheckError(); +} + +// used to switch gemm between fp32 and fp16 +template +cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const scalar_t* alpha, const scalar_t* A, int lda, + const scalar_t* B, int ldb, const scalar_t* beta, + scalar_t* C, int ldc) { + return CUBLAS_STATUS_INTERNAL_ERROR; +} + +template <> +cublasStatus_t cublasGemmWrap(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, int m, int n, + int k, const float* alpha, const float* A, + int lda, const float* B, int ldb, + const float* beta, float* C, int ldc) { + cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); +} + +template <> +cublasStatus_t cublasGemmWrap(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, int m, int n, + int k, const half* alpha, const half* A, + int lda, const half* B, int ldb, + const half* beta, half* C, int ldc) { + cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); +} + +template +void DeformConvForwardCUDAKernelLauncher( + const scalar_t* input, const scalar_t* weight, const scalar_t* offset, + scalar_t* output, void* workspace, int batchSize, int nInputPlane, + int inputHeight, int inputWidth, int nOutputPlane, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, int im2col_step, cublasHandle_t cublas_handle, + cudaStream_t stream) { + size_t word_size = sizeof(scalar_t); + + im2col_step = std::min(int(batchSize), im2col_step); + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + long long columns_size = + mmcv::getAlignedSize(nInputPlane * kW * kH * im2col_step * outputHeight * + outputWidth * word_size); + + // column buffer for img2col + scalar_t* columns = (scalar_t*)workspace; + workspace = workspace + columns_size; + + scalar_t* output_buffer; + long long output_buffer_size = 0; + if (im2col_step == 1) { + output_buffer = output; + } else { + // output need permute when im2col_step!=1 + output_buffer = (scalar_t*)workspace; + output_buffer_size = batchSize * nOutputPlane * outputWidth * outputHeight; + } + + long long input_elt_step = + im2col_step * nInputPlane * inputHeight * inputWidth; + long long offset_elt_step = + im2col_step * deformable_group * 2 * kH * kW * outputHeight * outputWidth; + long long out_buffer_step = + nOutputPlane * im2col_step * outputHeight * outputWidth; + long long col_g_step = + nInputPlane * kW * kH / group * im2col_step * outputHeight * outputWidth; + long long weight_g_step = + nOutputPlane / group * nInputPlane / group * kH * kW; + long long out_buffer_g_step = + nOutputPlane / group * im2col_step * outputHeight * outputWidth; + int m = nOutputPlane / group; + int n = im2col_step * outputHeight * outputWidth; + int k = nInputPlane / group * kH * kW; + scalar_t alpha = 1.; + scalar_t beta = 0.; + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + const scalar_t* input_start = input + elt * input_elt_step; + const scalar_t* offset_start = offset + elt * offset_elt_step; + + trt_deformable_im2col(input_start, offset_start, nInputPlane, + inputHeight, inputWidth, kH, kW, padH, padW, + dH, dW, dilationH, dilationW, im2col_step, + deformable_group, columns, stream); + + for (int g = 0; g < group; ++g) { + const scalar_t* weight_start = weight + g * weight_g_step; + scalar_t* col_start = columns + g * col_g_step; + scalar_t* out_buffer_start = + output_buffer + elt * out_buffer_step + g * out_buffer_g_step; + + cublasGemmWrap(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, + &alpha, col_start, n, weight_start, k, &beta, + out_buffer_start, n); + cudaCheckError(); + } + } + + if (im2col_step != 1) { + int output_buffer_shape[5] = {batchSize / im2col_step, nOutputPlane, + im2col_step, outputHeight, outputWidth}; + int output_buffer_permute[5] = {0, 2, 1, 3, 4}; + memcpyPermute(output, output_buffer, &output_buffer_shape[0], + &output_buffer_permute[0], 5, stream); + } +} + +void DeformConvForwardCUDAKernelLauncher_float( + const float* input, const float* weight, const float* offset, float* output, + void* workspace, int batchSize, int nInputPlane, int inputHeight, + int inputWidth, int nOutputPlane, int kW, int kH, int dW, int dH, int padW, + int padH, int dilationW, int dilationH, int group, int deformable_group, + int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream) { + DeformConvForwardCUDAKernelLauncher( + input, weight, offset, output, workspace, batchSize, nInputPlane, + inputHeight, inputWidth, nOutputPlane, kW, kH, dW, dH, padW, padH, + dilationW, dilationH, group, deformable_group, im2col_step, cublas_handle, + stream); +} diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp index 6f8489d1a8..4cccc22a76 100644 --- a/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp @@ -1,9 +1,11 @@ #include "trt_plugin.hpp" +#include "trt_deform_conv.hpp" #include "trt_nms.hpp" #include "trt_roi_align.hpp" #include "trt_scatternd.hpp" +REGISTER_TENSORRT_PLUGIN(DeformableConvPluginDynamicCreator); REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator); REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator); REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator); diff --git a/mmcv/ops/csrc/tensorrt/trt_cuda_helper.cuh b/mmcv/ops/csrc/tensorrt/trt_cuda_helper.cuh index fd73f99d61..a4635dcdd5 100644 --- a/mmcv/ops/csrc/tensorrt/trt_cuda_helper.cuh +++ b/mmcv/ops/csrc/tensorrt/trt_cuda_helper.cuh @@ -13,4 +13,18 @@ } \ } +/** + * Returns a view of the original tensor with its dimensions permuted. + * + * @param[out] dst pointer to the destination tensor + * @param[in] src pointer to the source tensor + * @param[in] src_size shape of the src tensor + * @param[in] permute The desired ordering of dimensions + * @param[in] src_dim dim of src tensor + * @param[in] stream cuda stream handle + */ +template +void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size, + int *permute, int src_dim, cudaStream_t stream = 0); + #endif // TRT_CUDA_HELPER_HPP diff --git a/mmcv/ops/csrc/tensorrt/trt_deform_conv.hpp b/mmcv/ops/csrc/tensorrt/trt_deform_conv.hpp new file mode 100644 index 0000000000..b8762f7868 --- /dev/null +++ b/mmcv/ops/csrc/tensorrt/trt_deform_conv.hpp @@ -0,0 +1,116 @@ +#ifndef TRT_DEFORM_CONV_HPP +#define TRT_DEFORM_CONV_HPP +#include + +#include +#include +#include + +#include "trt_plugin_helper.hpp" + +class DeformableConvPluginDynamic : public nvinfer1::IPluginV2DynamicExt { + public: + DeformableConvPluginDynamic(const std::string &name, + const nvinfer1::Dims &stride, + const nvinfer1::Dims &padding, + const nvinfer1::Dims &dilation, + const int deformableGroup, const int group, + int im2colStep); + + DeformableConvPluginDynamic(const std::string name, const void *data, + size_t length); + + DeformableConvPluginDynamic() = delete; + + ~DeformableConvPluginDynamic(); + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt *clone() const override; + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, + nvinfer1::IExprBuilder &exprBuilder) override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc *inOut, + int nbInputs, int nbOutputs) override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, + int nbOutputs) override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, + int nbOutputs) const override; + int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, + const void *const *inputs, void *const *outputs, void *workspace, + cudaStream_t stream) override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType *inputTypes, + int nbInputs) const override; + + // IPluginV2 Methods + const char *getPluginType() const override; + const char *getPluginVersion() const override; + int getNbOutputs() const override; + int initialize() override; + void terminate() override; + size_t getSerializationSize() const override; + void serialize(void *buffer) const override; + void destroy() override; + void setPluginNamespace(const char *pluginNamespace) override; + const char *getPluginNamespace() const override; + + private: + const std::string mLayerName; + std::string mNamespace; + + nvinfer1::Dims mStride; + nvinfer1::Dims mPadding; + nvinfer1::Dims mDilation; + int mDeformableGroup; + int mGroup; + int mIm2colStep; + + cublasHandle_t m_cublas_handle; + cudaStream_t m_cuda_stream; + + protected: + // To prevent compiler warnings. + using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch; + using nvinfer1::IPluginV2DynamicExt::configurePlugin; + using nvinfer1::IPluginV2DynamicExt::enqueue; + using nvinfer1::IPluginV2DynamicExt::getOutputDimensions; + using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize; + using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch; + using nvinfer1::IPluginV2DynamicExt::supportsFormat; +}; + +class DeformableConvPluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + DeformableConvPluginDynamicCreator(); + + const char *getPluginName() const override; + + const char *getPluginVersion() const override; + + const nvinfer1::PluginFieldCollection *getFieldNames() override; + + nvinfer1::IPluginV2 *createPlugin( + const char *name, const nvinfer1::PluginFieldCollection *fc) override; + + nvinfer1::IPluginV2 *deserializePlugin(const char *name, + const void *serialData, + size_t serialLength) override; + + void setPluginNamespace(const char *pluginNamespace) override; + + const char *getPluginNamespace() const override; + + private: + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; +#endif // TRT_DEFORM_CONV_HPP diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py index 250e096a59..2029de2bf6 100644 --- a/mmcv/ops/deform_conv.py +++ b/mmcv/ops/deform_conv.py @@ -32,7 +32,7 @@ def symbolic(g, bias=False, im2col_step=32): return g.op( - 'MMCVDeformConv2d', + 'mmcv::MMCVDeformConv2d', input, offset, weight, diff --git a/tests/test_ops/test_tensorrt.py b/tests/test_ops/test_tensorrt.py index 09f706f206..95a7b86e3d 100644 --- a/tests/test_ops/test_tensorrt.py +++ b/tests/test_ops/test_tensorrt.py @@ -326,3 +326,79 @@ def func(data): if os.path.exists(trt_file): os.remove(trt_file) assert torch.allclose(pytorch_results, trt_results) + + +def test_deform_conv(): + try: + from mmcv.ops import DeformConv2dPack + except (ImportError, ModuleNotFoundError): + pytest.skip('test requires compilation') + + input = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]] + offset_weight = [[[0.1, 0.4, 0.6, 0.1]], [[0.3, 0.2, 0.1, 0.3]], + [[0.5, 0.5, 0.2, 0.8]], [[0.8, 0.3, 0.9, 0.1]], + [[0.3, 0.1, 0.2, 0.5]], [[0.3, 0.7, 0.5, 0.3]], + [[0.6, 0.2, 0.5, 0.3]], [[0.4, 0.1, 0.8, 0.4]]] + offset_bias = [0.7, 0.1, 0.8, 0.5, 0.6, 0.5, 0.4, 0.7] + deform_weight = [[[0.4, 0.2, 0.1, 0.9]]] + + c_in = 1 + c_out = 1 + x = torch.Tensor(input).cuda() + x.requires_grad = True + model = DeformConv2dPack(c_in, c_out, 2, stride=1, padding=0) + model.conv_offset.weight.data = torch.nn.Parameter( + torch.Tensor(offset_weight).reshape(8, 1, 2, 2)) + model.conv_offset.bias.data = torch.nn.Parameter( + torch.Tensor(offset_bias).reshape(8)) + model.weight.data = torch.nn.Parameter( + torch.Tensor(deform_weight).reshape(1, 1, 2, 2)) + model.cuda().eval() + + input_names = ['input'] + output_names = ['output'] + + with torch.no_grad(): + torch.onnx.export( + model, (x.clone(), ), + onnx_file, + export_params=True, + keep_initializers_as_inputs=True, + input_names=input_names, + output_names=output_names, + opset_version=11) + + onnx_model = onnx.load(onnx_file) + + # create trt engine and wraper + opt_shape_dict = { + 'input': [list(x.shape), list(x.shape), + list(x.shape)], + } + # trt config + fp16_mode = False + max_workspace_size = 1 << 30 + + trt_engine = onnx2trt( + onnx_model, + opt_shape_dict, + fp16_mode=fp16_mode, + max_workspace_size=max_workspace_size) + + save_trt_engine(trt_engine, trt_file) + trt_model = TRTWraper(trt_file, input_names, output_names) + + with torch.no_grad(): + trt_outputs = trt_model({'input': x.clone()}) + trt_results = trt_outputs['output'] + + # compute pytorch_output + with torch.no_grad(): + pytorch_results = model(x.clone()) + + # allclose + if os.path.exists(onnx_file): + os.remove(onnx_file) + if os.path.exists(trt_file): + os.remove(trt_file) + assert torch.allclose(pytorch_results, trt_results)