Skip to content

Commit

Permalink
merge lr_triangular
Browse files Browse the repository at this point in the history
  • Loading branch information
gengenkai committed May 10, 2021
2 parents d7c684a + 934b549 commit 9d8709c
Show file tree
Hide file tree
Showing 19 changed files with 970 additions and 12 deletions.
112 changes: 112 additions & 0 deletions docs/onnxruntime_custom_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,24 @@
- [Inputs](#inputs-3)
- [Outputs](#outputs-3)
- [Type Constraints](#type-constraints-3)
- [CornerPool](#cornerpool)
- [Description](#description-4)
- [Parameters](#parameters-4)
- [Inputs](#inputs-4)
- [Outputs](#outputs-4)
- [Type Constraints](#type-constraints-4)
- [cummax](#cummax)
- [Description](#description-5)
- [Parameters](#parameters-5)
- [Inputs](#inputs-5)
- [Outputs](#outputs-5)
- [Type Constraints](#type-constraints-5)
- [cummin](#cummin)
- [Description](#description-6)
- [Parameters](#parameters-6)
- [Inputs](#inputs-6)
- [Outputs](#outputs-6)
- [Type Constraints](#type-constraints-6)

<!-- TOC -->

Expand Down Expand Up @@ -171,3 +189,97 @@ Perform sample from `input` with pixel locations from `grid`.
### Type Constraints

- T:tensor(float32, Linear)

## CornerPool

### Description

Perform CornerPool on `input` features. Read [CornerNet -- Detecting Objects as Paired Keypoints](https://arxiv.org/abs/1808.01244) for more details.

### Parameters

| Type | Parameter | Description |
| ------- | --------------- | ---------------------------------------------------------------- |
| `int` | `mode` | corner pool mode, (0: `top`, 1: `bottom`, 2: `left`, 3: `right`) |

### Inputs

<dl>
<dt><tt>input</tt>: T</dt>
<dd>Input features. 4-D tensor of shape (N, C, H, W). N is the batch size.</dd>
</dl>

### Outputs

<dl>
<dt><tt>output</tt>: T</dt>
<dd>Output the pooled features. 4-D tensor of shape (N, C, H, W).</dd>
</dl>

### Type Constraints

- T:tensor(float32)

## cummax

### Description

Returns a tuple (`values`, `indices`) where `values` is the cumulative maximum elements of `input` in the dimension `dim`. And `indices` is the index location of each maximum value found in the dimension `dim`. Read [torch.cummax](https://pytorch.org/docs/stable/generated/torch.cummax.html) for more details.

### Parameters

| Type | Parameter | Description |
| ------- | --------------- | ---------------------------------------------------------------- |
| `int` | `dim` | the dimension to do the operation over |

### Inputs

<dl>
<dt><tt>input</tt>: T</dt>
<dd>The input tensor with various shapes. Tensor with empty element is also supported.</dd>
</dl>

### Outputs

<dl>
<dt><tt>output</tt>: T</dt>
<dd>Output the cumulative maximum elements of `input` in the dimension `dim`, with the same shape and dtype as `input`.</dd>
<dt><tt>indices</tt>: tensor(int64)</dt>
<dd>Output the index location of each cumulative maximum value found in the dimension `dim`, with the same shape as `input`.</dd>
</dl>

### Type Constraints

- T:tensor(float32)

## cummin

### Description

Returns a tuple (`values`, `indices`) where `values` is the cumulative minimum elements of `input` in the dimension `dim`. And `indices` is the index location of each minimum value found in the dimension `dim`. Read [torch.cummin](https://pytorch.org/docs/stable/generated/torch.cummin.html) for more details.

### Parameters

| Type | Parameter | Description |
| ------- | --------------- | ---------------------------------------------------------------- |
| `int` | `dim` | the dimension to do the operation over |

### Inputs

<dl>
<dt><tt>input</tt>: T</dt>
<dd>The input tensor with various shapes. Tensor with empty element is also supported.</dd>
</dl>

### Outputs

<dl>
<dt><tt>output</tt>: T</dt>
<dd>Output the cumulative minimum elements of `input` in the dimension `dim`, with the same shape and dtype as `input`.</dd>
<dt><tt>indices</tt>: tensor(int64)</dt>
<dd>Output the index location of each cumulative minimum value found in the dimension `dim`, with the same shape as `input`.</dd>
</dl>

### Type Constraints

- T:tensor(float32)
7 changes: 6 additions & 1 deletion docs/onnxruntime_op.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
| [RoIAlign](onnxruntime_custom_ops.md#roialign) | Y | N | 1.2.5 |
| [NMS](onnxruntime_custom_ops.md#nms) | Y | N | 1.2.7 |
| [grid_sampler](onnxruntime_custom_ops.md#grid_sampler) | Y | N | master |
| [CornerPool](onnxruntime_custom_ops.md#cornerpool) | Y | N | master |
| [cummax](onnxruntime_custom_ops.md#cummax) | Y | N | master |
| [cummin](onnxruntime_custom_ops.md#cummin) | Y | N | master |

## How to build custom operators for ONNX Runtime

Expand Down Expand Up @@ -114,7 +117,9 @@ Take custom operator `soft_nms` for example.

## Known Issues

- None
- "RuntimeError: tuple appears in op that does not forward tuples, unsupported kind: `prim::PythonOp`."
1. Note generally `cummax` or `cummin` is exportable to ONNX as long as the torch version >= 1.5.0, since `torch.cummax` is only supported with torch >= 1.5.0. But when `cummax` or `cummin` serves as an intermediate component whose outputs is used as inputs for another modules, it's expected that torch version must be >= 1.7.0. Otherwise the above error might arise, when running exported ONNX model with onnxruntime.
2. Solution: update the torch version to 1.7.0 or higher.

## References

Expand Down
3 changes: 2 additions & 1 deletion mmcv/engine/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def single_gpu_test(model, data_loader):
result = model(return_loss=False, **data)
results.extend(result)

# Assume result has the same length of batch_size, refer to https://github.com/open-mmlab/mmcv/issues/985
# Assume result has the same length of batch_size
# refer to https://github.com/open-mmlab/mmcv/issues/985
batch_size = len(result)
for _ in range(batch_size):
prog_bar.update()
Expand Down
12 changes: 12 additions & 0 deletions mmcv/onnx/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,16 @@ def grid_sampler(g,
align_corners_i=align_corners)


@parse_args('v', 'i')
def cummax(g, input, dim):
return g.op('mmcv::cummax', input, dim_i=dim, outputs=2)


@parse_args('v', 'i')
def cummin(g, input, dim):
return g.op('mmcv::cummin', input, dim_i=dim, outputs=2)


def register_extra_symbolics(opset=11):
register_op('one_hot', one_hot, '', opset)
register_op('im2col', im2col, '', opset)
Expand All @@ -421,3 +431,5 @@ def register_extra_symbolics(opset=11):
register_op('upsample_bicubic2d', upsample_bicubic2d, '', opset)
register_op('new_full', new_full, '', opset)
register_op('grid_sampler', grid_sampler, '', opset)
register_op('cummax', cummax, '', opset)
register_op('cummin', cummin, '', opset)
35 changes: 35 additions & 0 deletions mmcv/ops/corner_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,17 @@
'right_pool_forward', 'right_pool_backward'
])

_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3}


class TopPoolFunction(Function):

@staticmethod
def symbolic(g, input):
output = g.op(
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top']))
return output

@staticmethod
def forward(ctx, input):
output = ext_module.top_pool_forward(input)
Expand All @@ -28,6 +36,12 @@ def backward(ctx, grad_output):

class BottomPoolFunction(Function):

@staticmethod
def symbolic(g, input):
output = g.op(
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom']))
return output

@staticmethod
def forward(ctx, input):
output = ext_module.bottom_pool_forward(input)
Expand All @@ -43,6 +57,12 @@ def backward(ctx, grad_output):

class LeftPoolFunction(Function):

@staticmethod
def symbolic(g, input):
output = g.op(
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left']))
return output

@staticmethod
def forward(ctx, input):
output = ext_module.left_pool_forward(input)
Expand All @@ -58,6 +78,12 @@ def backward(ctx, grad_output):

class RightPoolFunction(Function):

@staticmethod
def symbolic(g, input):
output = g.op(
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right']))
return output

@staticmethod
def forward(ctx, input):
output = ext_module.right_pool_forward(input)
Expand Down Expand Up @@ -114,6 +140,15 @@ def __init__(self, mode):

def forward(self, x):
if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0':
if torch.onnx.is_in_onnx_export():
assert torch.__version__ >= '1.7.0', \
'When `cummax` serves as an intermediate component whose '\
'outputs is used as inputs for another modules, it\'s '\
'expected that pytorch version must be >= 1.7.0, '\
'otherwise Error appears like: `RuntimeError: tuple '\
'appears in op that does not forward tuples, unsupported '\
'kind: prim::PythonOp`.'

dim, flip = self.cummax_dim_flip[self.mode]
if flip:
x = x.flip(dim)
Expand Down
45 changes: 45 additions & 0 deletions mmcv/ops/csrc/onnxruntime/corner_pool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#ifndef ONNXRUNTIME_CORNER_POOL_H
#define ONNXRUNTIME_CORNER_POOL_H

#include <assert.h>
#include <onnxruntime_cxx_api.h>

struct MMCVCornerPoolKernel {
public:
MMCVCornerPoolKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info)
: ort_(ort) {
mode_ = ort_.KernelInfoGetAttribute<int64_t>(info, "mode");
}

void Compute(OrtKernelContext* context);

private:
Ort::CustomOpApi ort_;

int64_t mode_;
};

struct MMCVCornerPoolCustomOp
: Ort::CustomOpBase<MMCVCornerPoolCustomOp, MMCVCornerPoolKernel> {
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) {
return new MMCVCornerPoolKernel(api, info);
}

const char* GetName() const { return "MMCVCornerPool"; }

size_t GetInputTypeCount() const { return 1; }
ONNXTensorElementDataType GetInputType(size_t) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}

size_t GetOutputTypeCount() const { return 1; }
ONNXTensorElementDataType GetOutputType(size_t) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}

// force cpu
const char* GetExecutionProviderType() const {
return "CPUExecutionProvider";
}
};
#endif // ONNXRUNTIME_CORNER_POOL_H
Loading

0 comments on commit 9d8709c

Please sign in to comment.