Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CustomOps] TensorRT Gather Topk Ops #1033

Merged
merged 6 commits into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "gather_topk.hpp"

#include <assert.h>
#include <stdio.h>

#include <chrono>

#include "NvInferVersion.h"
#include "gather_topk_kernel.hpp"
#include "trt_serialize.hpp"

namespace mmdeploy {
namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *PLUGIN_NAME{"GatherTopk"};
} // namespace

GatherTopk::GatherTopk(const std::string &name) : TRTPluginBase(name) {}

GatherTopk::GatherTopk(const std::string name, const void *data, size_t length)
: TRTPluginBase(name) {}

nvinfer1::IPluginV2DynamicExt *GatherTopk::clone() const TRT_NOEXCEPT {
GatherTopk *plugin = new GatherTopk(mLayerName);
plugin->setPluginNamespace(getPluginNamespace());

return plugin;
}

nvinfer1::DimsExprs GatherTopk::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
assert(inputs[0].nbDims >= inputs[1].nbDims);
nvinfer1::DimsExprs ret;
ret.nbDims = inputs[0].nbDims;
for (int i = 0; i < inputs[1].nbDims; ++i) {
ret.d[i] = inputs[1].d[i];
}
for (int i = inputs[1].nbDims; i < inputs[0].nbDims; ++i) {
ret.d[i] = inputs[0].d[i];
}
return ret;
}

bool GatherTopk::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc,
int nbInputs, int nbOutputs) TRT_NOEXCEPT {
switch (pos) {
case 0:
// data
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) ||
(ioDesc[pos].type == nvinfer1::DataType::kINT32 &&
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR);
case 1:
// indices
return ioDesc[pos].type == nvinfer1::DataType::kINT32 &&
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR;
case 2:
// output
return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format;
default:
return true;
}
return true;
}

void GatherTopk::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs,
int nbOutputs) TRT_NOEXCEPT {}

size_t GatherTopk::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const TRT_NOEXCEPT {
return 0;
}

int GatherTopk::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workSpace, cudaStream_t stream) TRT_NOEXCEPT {
const int *dims = &(inputDesc[0].dims.d[0]);
const int *indices_dims = &(inputDesc[1].dims.d[0]);
int nbDims = inputDesc[0].dims.nbDims;
int indice_nbDims = inputDesc[1].dims.nbDims;

const void *data = inputs[0];
const void *indices = inputs[1];
void *output = outputs[0];

auto data_type = inputDesc[0].type;

switch (data_type) {
case nvinfer1::DataType::kFLOAT:
gather_topk_impl<float>((float *)data, (int *)indices, dims, nbDims, indices_dims,
indice_nbDims, (float *)output, stream);
break;

case nvinfer1::DataType::kINT32:
gather_topk_impl<int>((int *)data, (int *)indices, dims, nbDims, indices_dims, indice_nbDims,
(int *)output, stream);
break;
default:
break;
}

return 0;
}

nvinfer1::DataType GatherTopk::getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT {
return inputTypes[0];
}

// IPluginV2 Methods
const char *GatherTopk::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; }

const char *GatherTopk::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; }

int GatherTopk::getNbOutputs() const TRT_NOEXCEPT { return 1; }

size_t GatherTopk::getSerializationSize() const TRT_NOEXCEPT { return 0; }

void GatherTopk::serialize(void *buffer) const TRT_NOEXCEPT {}

GatherTopkCreator::GatherTopkCreator() {
mPluginAttributes.clear();
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}

const char *GatherTopkCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; }

const char *GatherTopkCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; }

nvinfer1::IPluginV2 *GatherTopkCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
auto *plugin = new GatherTopk(name);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}

nvinfer1::IPluginV2 *GatherTopkCreator::deserializePlugin(const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT {
auto plugin = new GatherTopk(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}

REGISTER_TENSORRT_PLUGIN(GatherTopkCreator);
} // namespace mmdeploy
64 changes: 64 additions & 0 deletions csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_SCATTERND_HPP
#define TRT_SCATTERND_HPP
#include <cublas_v2.h>

#include <memory>
#include <string>
#include <vector>

#include "trt_plugin_base.hpp"

namespace mmdeploy {
class GatherTopk : public TRTPluginBase {
public:
GatherTopk(const std::string &name);

GatherTopk(const std::string name, const void *data, size_t length);

GatherTopk() = delete;

// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs,
int nbInputs, nvinfer1::IExprBuilder &exprBuilder)
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override;

// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT override;

// IPluginV2 Methods
const char *getPluginType() const TRT_NOEXCEPT override;
const char *getPluginVersion() const TRT_NOEXCEPT override;
int getNbOutputs() const TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void *buffer) const TRT_NOEXCEPT override;
};

class GatherTopkCreator : public TRTPluginCreatorBase {
public:
GatherTopkCreator();

const char *getPluginName() const TRT_NOEXCEPT override;

const char *getPluginVersion() const TRT_NOEXCEPT override;
nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
TRT_NOEXCEPT override;

nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT override;
};
} // namespace mmdeploy
#endif // TRT_SCATTERND_HPP
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include <functional>
#include <numeric>
#include <vector>

#include "common_cuda_helper.hpp"
#include "gather_topk_kernel.hpp"
#include "trt_plugin_helper.hpp"

template <typename scalar_t>
__global__ void gather_topk_kernel(const scalar_t* input, const int* indices, scalar_t* output,
int batch, int num_input, int num_indices, int channel) {
CUDA_1D_KERNEL_LOOP(index, batch * num_indices * channel) {
const int b_id = index / (num_indices * channel);
const int n_id = (index / channel) % num_indices;
const int c_id = index % channel;

const int input_n_id = indices[b_id * num_indices + n_id];
const scalar_t value = input[b_id * num_input * channel + input_n_id * channel + c_id];
output[b_id * num_indices * channel + n_id * channel + c_id] = value;
}
}

template <typename scalar_t>
void gather_topk_impl(const scalar_t* input, const int* indices, const int* dims, int nbDims,
const int* indices_dims, int indice_nbDims, scalar_t* output,
cudaStream_t stream) {
int batch = 1;
for (int i = 0; i < indice_nbDims - 1; ++i) batch *= dims[i];
int num_input = dims[indice_nbDims - 1];
int num_indices = indices_dims[indice_nbDims - 1];
int channel = 1;
for (int i = indice_nbDims; i < nbDims; ++i) channel *= dims[i];
const int col_block = DIVUP(batch * num_indices * channel, THREADS_PER_BLOCK);
gather_topk_kernel<<<col_block, THREADS_PER_BLOCK, 0, stream>>>(input, indices, output, batch,
num_input, num_indices, channel);
}

template void gather_topk_impl<float>(const float* input, const int* indices, const int* dims,
int nbDims, const int* indices_dims, int indice_nbDims,
float* output, cudaStream_t stream);

template void gather_topk_impl<int32_t>(const int32_t* input, const int* indices, const int* dims,
int nbDims, const int* indices_dims, int indice_nbDims,
int32_t* output, cudaStream_t stream);
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_GRID_SAMPLER_KERNEL_HPP
#define TRT_GRID_SAMPLER_KERNEL_HPP
#include <cuda_runtime.h>

template <typename scalar_t>
void gather_topk_impl(const scalar_t* input, const int* indices, const int* dims, int nbDims,
const int* indices_dims, int indice_nbDims, scalar_t* output,
cudaStream_t stream);
#endif // TRT_GRID_SAMPLER_KERNEL_HPP
42 changes: 42 additions & 0 deletions docs/en/06-custom-ops/tensorrt.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@
- [Inputs](#inputs-9)
- [Outputs](#outputs-9)
- [Type Constraints](#type-constraints-9)
- [GatherTopk](#gathertopk)
- [Description](#description-10)
- [Parameters](#parameters-10)
- [Inputs](#inputs-10)
- [Outputs](#outputs-10)
- [Type Constraints](#type-constraints-10)

<!-- TOC -->

Expand Down Expand Up @@ -447,3 +453,39 @@ None
#### Type Constraints

- T:tensor(float32, Linear)

### GatherTopk

#### Description

TensorRT 8.2~8.4 would give unexpected result for multi-index gather.

```python
data[batch_index, bbox_index, ...]
```

Read [this](https://github.com/NVIDIA/TensorRT/issues/2299) for more details.

#### Parameters

None

#### Inputs

<dl>
<dt><tt>inputs[0]</tt>: T</dt>
<dd>Tensor to be gathered, with shape (A0, ..., An, G0, C0, ...).</dd>

<dt><tt>inputs[1]</tt>: tensor(int32, Linear)</dt>
<dd>Tensor of index. with shape (A0, ..., An, G1)</dd>

#### Outputs

<dl>
<dt><tt>outputs[0]</tt>: T</dt>
<dd>Tensor of output. With shape (A0, ..., An, G1, C0, ...)</dd>
</dl>

#### Type Constraints

- T:tensor(float32, Linear), tensor(int32, Linear)
42 changes: 42 additions & 0 deletions docs/zh_cn/06-custom-ops/tensorrt.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@
- [Inputs](#inputs-9)
- [Outputs](#outputs-9)
- [Type Constraints](#type-constraints-9)
- [GatherTopk](#gathertopk)
- [Description](#description-10)
- [Parameters](#parameters-10)
- [Inputs](#inputs-10)
- [Outputs](#outputs-10)
- [Type Constraints](#type-constraints-10)

<!-- TOC -->

Expand Down Expand Up @@ -447,3 +453,39 @@ None
#### Type Constraints

- T:tensor(float32, Linear)

### GatherTopk

#### Description

TensorRT 8.2~8.4 would give unexpected result for multi-index gather.

```python
data[batch_index, bbox_index, ...]
```

Read [this](https://github.com/NVIDIA/TensorRT/issues/2299) for more details.

#### Parameters

None

#### Inputs

<dl>
<dt><tt>inputs[0]</tt>: T</dt>
<dd>Tensor to be gathered, with shape (A0, ..., An, G0, C0, ...).</dd>

<dt><tt>inputs[1]</tt>: tensor(int32, Linear)</dt>
<dd>Tensor of index. with shape (A0, ..., An, G1)</dd>

#### Outputs

<dl>
<dt><tt>outputs[0]</tt>: T</dt>
<dd>Tensor of output. With shape (A0, ..., An, G1, C0, ...)</dd>
</dl>

#### Type Constraints

- T:tensor(float32, Linear), tensor(int32, Linear)
7 changes: 7 additions & 0 deletions mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,13 @@ def multiclass_nms_static(ctx,
pre_top_k, keep_top_k, iou_threshold,
score_threshold, -1)

# retain shape info
batch_size = boxes.size(0)

dets_shape = dets.shape
label_shape = labels.shape
dets = dets.reshape([batch_size, *dets_shape[1:]])
labels = labels.reshape([batch_size, *label_shape[1:]])
return dets, labels


Expand Down
Loading