Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support corner_pool related custom operators for onnxruntime in mmcv #997

Merged
merged 13 commits into from
May 1, 2021
Merged
36 changes: 36 additions & 0 deletions docs/onnxruntime_custom_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
- [Inputs](#inputs-3)
- [Outputs](#outputs-3)
- [Type Constraints](#type-constraints-3)
- [CornerPool](#cornerpool)
- [Description](#description-4)
- [Parameters](#parameters-4)
- [Inputs](#inputs-4)
- [Outputs](#outputs-4)
- [Type Constraints](#type-constraints-4)

<!-- TOC -->

Expand Down Expand Up @@ -171,3 +177,33 @@ Perform sample from `input` with pixel locations from `grid`.
### Type Constraints

- T:tensor(float32, Linear)

## CornerPool

### Description

Perform CornerPool on `input` features. Read [CornerNet -- Detecting Objects as Paired Keypoints](https://arxiv.org/abs/1808.01244) for more details.

### Parameters

| Type | Parameter | Description |
| ------- | --------------- | ---------------------------------------------------------------- |
| `int` | `mode` | corner pool mode, (0: `top`, 1: `bottom`, 2: `left`, 3: `right`) |

### Inputs

<dl>
<dt><tt>input</tt>: T</dt>
<dd>Input features. 4-D tensor of shape (N, C, H, W). N is the batch size.</dd>
</dl>

### Outputs

<dl>
<dt><tt>output</tt>: T</dt>
<dd>Output the pooled features. 4-D tensor of shape (N, C, H, W).</dd>
</dl>

### Type Constraints

- T:tensor(float32)
1 change: 1 addition & 0 deletions docs/onnxruntime_op.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
| [RoIAlign](onnxruntime_custom_ops.md#roialign) | Y | N | 1.2.5 |
| [NMS](onnxruntime_custom_ops.md#nms) | Y | N | 1.2.7 |
| [grid_sampler](onnxruntime_custom_ops.md#grid_sampler) | Y | N | master |
| [CornerPool](onnxruntime_custom_ops.md#cornerpool) | Y | N | master |

## How to build custom operators for ONNX Runtime

Expand Down
26 changes: 26 additions & 0 deletions mmcv/ops/corner_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,17 @@
'right_pool_forward', 'right_pool_backward'
])

_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3}


class TopPoolFunction(Function):

@staticmethod
def symbolic(g, input):
output = g.op(
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top']))
return output

@staticmethod
def forward(ctx, input):
output = ext_module.top_pool_forward(input)
Expand All @@ -28,6 +36,12 @@ def backward(ctx, grad_output):

class BottomPoolFunction(Function):

@staticmethod
def symbolic(g, input):
output = g.op(
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom']))
return output

@staticmethod
def forward(ctx, input):
output = ext_module.bottom_pool_forward(input)
Expand All @@ -43,6 +57,12 @@ def backward(ctx, grad_output):

class LeftPoolFunction(Function):

@staticmethod
def symbolic(g, input):
output = g.op(
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left']))
return output

@staticmethod
def forward(ctx, input):
output = ext_module.left_pool_forward(input)
Expand All @@ -58,6 +78,12 @@ def backward(ctx, grad_output):

class RightPoolFunction(Function):

@staticmethod
def symbolic(g, input):
output = g.op(
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right']))
return output

@staticmethod
def forward(ctx, input):
output = ext_module.right_pool_forward(input)
Expand Down
45 changes: 45 additions & 0 deletions mmcv/ops/csrc/onnxruntime/corner_pool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#ifndef ONNXRUNTIME_CORNER_POOL_H
#define ONNXRUNTIME_CORNER_POOL_H

#include <assert.h>
#include <onnxruntime_cxx_api.h>

struct MMCVCornerPoolKernel {
public:
MMCVCornerPoolKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info)
: ort_(ort) {
mode_ = ort_.KernelInfoGetAttribute<int64_t>(info, "mode");
}

void Compute(OrtKernelContext* context);

private:
Ort::CustomOpApi ort_;

int64_t mode_;
};

struct MMCVCornerPoolCustomOp
: Ort::CustomOpBase<MMCVCornerPoolCustomOp, MMCVCornerPoolKernel> {
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) {
return new MMCVCornerPoolKernel(api, info);
}

const char* GetName() const { return "MMCVCornerPool"; }

size_t GetInputTypeCount() const { return 1; }
ONNXTensorElementDataType GetInputType(size_t) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}

size_t GetOutputTypeCount() const { return 1; }
ONNXTensorElementDataType GetOutputType(size_t) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}

// force cpu
const char* GetExecutionProviderType() const {
return "CPUExecutionProvider";
}
};
#endif // ONNXRUNTIME_CORNER_POOL_H
122 changes: 122 additions & 0 deletions mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#include "corner_pool.h"

#include "../ort_mmcv_utils.h"

void TopPoolForwardCPU(const float *input, float *output, const int batch_size,
const int channels, const int height, const int width) {
for (int n = 0; n < batch_size; n++) {
int index_n = n * channels * width * height;
for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * width * height;
for (int w = 0; w < width; w++) {
// directly copy the most bottom value from input to output
output[index_n_c + (height - 1) * width + w] =
input[index_n_c + (height - 1) * width + w];
// do top_pool
for (int h = height - 2; h >= 0; h--) {
output[index_n_c + h * width + w] =
std::max(output[index_n_c + (h + 1) * width + w],
input[index_n_c + h * width + w]);
} // for h
} // for w
} // for c
} // for n
}

void BottomPoolForwardCPU(const float *input, float *output,
const int batch_size, const int channels,
const int height, const int width) {
for (int n = 0; n < batch_size; n++) {
int index_n = n * channels * width * height;
for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * width * height;
for (int w = 0; w < width; w++) {
// directly copy the most top value from input to output
output[index_n_c + w] = input[index_n_c + w];
// do top_pool
for (int h = 1; h < height; h++) {
output[index_n_c + h * width + w] =
std::max(output[index_n_c + (h - 1) * width + w],
input[index_n_c + h * width + w]);
} // for h
} // for w
} // for c
} // for n
}

void LeftPoolForwardCPU(const float *input, float *output, const int batch_size,
const int channels, const int height, const int width) {
for (int n = 0; n < batch_size; n++) {
int index_n = n * channels * width * height;
for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * width * height;
for (int h = 0; h < height; h++) {
// directly copy the most right value from input to output
output[index_n_c + h * width + width - 1] =
input[index_n_c + h * width + width - 1];
// do left_pool
for (int w = width - 2; w >= 0; w--) {
output[index_n_c + h * width + w] =
std::max(output[index_n_c + h * width + w + 1],
input[index_n_c + h * width + w]);
} // for w
} // for h
} // for c
} // for n
}

void RightPoolForwardCPU(const float *input, float *output,
const int batch_size, const int channels,
const int height, const int width) {
for (int n = 0; n < batch_size; n++) {
int index_n = n * channels * width * height;
for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * width * height;
for (int h = 0; h < height; h++) {
// directly copy the most left value from input to output
output[index_n_c + h * width] = input[index_n_c + h * width];
// do right_pool
for (int w = 1; w < width; w++) {
output[index_n_c + h * width + w] =
std::max(output[index_n_c + h * width + w - 1],
input[index_n_c + h * width + w]);
} // for w
} // for h
} // for c
} // for n
}

void MMCVCornerPoolKernel::Compute(OrtKernelContext *context) {
const int mode = int(mode_);
typedef float T;
const OrtValue *input = ort_.KernelContext_GetInput(context, 0);
const T *input_data =
reinterpret_cast<const float *>(ort_.GetTensorData<T>(input));

// get output memory
OrtTensorDimensions out_dimensions(ort_, input);
OrtValue *output = ort_.KernelContext_GetOutput(
context, 0, out_dimensions.data(), out_dimensions.size());
T *output_data = ort_.GetTensorMutableData<T>(output);

// 'top': 0, 'bottom': 1, 'left': 2, 'right':3
assert(mode == 0 || mode == 1 || mode == 2 || mode == 3);

// do corner_pool
int batch_size = out_dimensions.data()[0];
int input_channels = out_dimensions.data()[1];
int input_height = out_dimensions.data()[2];
int input_width = out_dimensions.data()[3];
if (mode == 0)
TopPoolForwardCPU(input_data, output_data, batch_size, input_channels,
input_height, input_width);
else if (mode == 1)
BottomPoolForwardCPU(input_data, output_data, batch_size, input_channels,
input_height, input_width);
else if (mode == 2)
LeftPoolForwardCPU(input_data, output_data, batch_size, input_channels,
input_height, input_width);
else
RightPoolForwardCPU(input_data, output_data, batch_size, input_channels,
input_height, input_width);
}
7 changes: 7 additions & 0 deletions mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "onnxruntime_register.h"

#include "corner_pool.h"
#include "grid_sample.h"
#include "nms.h"
#include "ort_mmcv_utils.h"
Expand All @@ -13,6 +14,7 @@ NmsOp c_NmsOp;
MMCVRoiAlignCustomOp c_MMCVRoiAlignCustomOp;
MMCVRoIAlignRotatedCustomOp c_MMCVRoIAlignRotatedCustomOp;
GridSampleOp c_GridSampleOp;
MMCVCornerPoolCustomOp c_MMCVCornerPoolCustomOp;

OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
const OrtApiBase *api) {
Expand Down Expand Up @@ -45,5 +47,10 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
return status;
}

if (auto status =
ortApi->CustomOpDomain_Add(domain, &c_MMCVCornerPoolCustomOp)) {
return status;
}

return ortApi->AddCustomOpDomain(options, domain);
}
46 changes: 46 additions & 0 deletions tests/test_ops/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,49 @@ def func(feat, scale_factor=2):
if os.path.exists(onnx_file):
os.remove(onnx_file)
assert np.allclose(pytorch_result, onnx_result, atol=1e-3)


@pytest.mark.parametrize('mode', ['top', 'bottom', 'left', 'right'])
def test_corner_pool(mode, opset=11):
if torch.__version__ == 'parrots':
pytest.skip('onnx is not supported in parrots directly')

from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
if not os.path.exists(ort_custom_op_path):
pytest.skip('custom ops for onnxruntime are not compiled.')

from mmcv.ops.corner_pool import CornerPool

def corner_pool_func(input):
corner_pool_module = CornerPool(mode)
return corner_pool_module.corner_pool.apply(input)

wrapped_model = WrapFunction(corner_pool_func).eval()

input = torch.rand((2, 3, 9, 12)) # (n,c,h,w)

with torch.no_grad():
torch.onnx.export(
wrapped_model,
input,
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['input'],
output_names=['output'],
opset_version=opset)

onnx_model = onnx.load(onnx_file)
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [node.name for node in onnx_model.graph.initializer]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 1)

session_options = rt.SessionOptions()
session_options.register_custom_ops_library(ort_custom_op_path)
sess = rt.InferenceSession(onnx_file, session_options)
ort_result = sess.run(None, {'input': input.detach().numpy()})
pytorch_results = wrapped_model(input.clone())
os.remove(onnx_file)
assert np.allclose(pytorch_results, ort_result, atol=1e-5)