Skip to content

Commit

Permalink
[Fix] Fix pybind exporting bug in fused_bias_leakyrelu and upfirdn2d (#…
Browse files Browse the repository at this point in the history
…1005)

* fix export bug in pybind

* fix type bug in fused_bias_leakyrelu backward
  • Loading branch information
luopeichao authored Apr 29, 2021
1 parent ab973df commit 3c7321c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
2 changes: 2 additions & 0 deletions mmcv/ops/csrc/parrots/roi_pool_parrots.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "roi_pool_pytorch.h"
using namespace parrots;

#ifdef MMCV_WITH_CUDA
void roi_pool_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
Expand Down Expand Up @@ -62,3 +63,4 @@ PARROTS_EXTENSION_REGISTER(roi_pool_backward)
.output(1)
.apply(roi_pool_backward_cuda_parrots)
.done();
#endif
9 changes: 7 additions & 2 deletions mmcv/ops/csrc/pytorch/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,14 @@ void roi_align_rotated_backward(Tensor grad_output, Tensor rois,
int sample_num, bool aligned, bool clockwise);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
py::arg("down_y"), py::arg("pad_x0"), py::arg("pad_x1"),
py::arg("pad_y0"), py::arg("pad_y1"));
m.def("fused_bias_leakyrelu", &fused_bias_leakyrelu,
"fused_bias_leakyrelu (CUDA)");
"fused_bias_leakyrelu (CUDA)", py::arg("input"), py::arg("bias"),
py::arg("empty"), py::arg("act"), py::arg("grad"), py::arg("alpha"),
py::arg("scale"));
m.def("get_compiler_version", &get_compiler_version, "get_compiler_version");
m.def("get_compiling_cuda_version", &get_compiling_cuda_version,
"get_compiling_cuda_version");
Expand Down
2 changes: 1 addition & 1 deletion mmcv/ops/fused_bias_leakyrelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def backward(ctx, gradgrad_input, gradgrad_bias):
# which is similar with the first order deviation in implementation.
gradgrad_out = ext_module.fused_bias_leakyrelu(
gradgrad_input,
gradgrad_bias,
gradgrad_bias.to(out.dtype),
out,
act=3,
grad=1,
Expand Down

0 comments on commit 3c7321c

Please sign in to comment.