Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support mmcv ext with DIOPI impl #2790

Merged
merged 29 commits into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
a3d6413
Support mmcv ext+dipu
CokeDong May 5, 2023
ac359bc
Fix
CokeDong May 8, 2023
7fd3844
Fix
CokeDong May 10, 2023
f1f3bc7
Fix lint
CokeDong May 10, 2023
e1a0514
Fix cuda lint
CokeDong May 10, 2023
4cf8def
Fix dipu
CokeDong May 11, 2023
641a368
Add support for ops nms, sigmoid_focal_loss and voxelize
CokeDong May 18, 2023
83837db
Fix cpp lint
CokeDong May 18, 2023
388e547
Add support for op modulated_deform_conv(diopi and fallback)
CokeDong May 18, 2023
2498e3c
Fix
CokeDong May 18, 2023
f3b1351
Fix and refactot testcase
CokeDong May 18, 2023
2ab2e30
Refactor test_roi_align
CokeDong May 23, 2023
61dedd7
Merge branch 'main' into support_dipu
CokeDong May 23, 2023
b3d2c66
Fix lint
CokeDong May 23, 2023
6cea160
Fix testcase
CokeDong May 23, 2023
b1f17ba
Add dipu test for voxelization
CokeDong May 24, 2023
fc79564
Fix nms for none output
CokeDong May 24, 2023
f0b059d
Fix lint
CokeDong May 24, 2023
444035d
Seperate testcses
CokeDong May 24, 2023
4fa0404
Fix mlu dipu device focalloss error
CokeDong May 26, 2023
5865135
Fix setup.py
CokeDong May 26, 2023
0b34562
Bugfix for roi_align
CokeDong Jun 7, 2023
f7f21d4
fix for focal_loss
CokeDong Jun 8, 2023
e53f4ae
fix for focal_loss2
CokeDong Jun 8, 2023
9189735
Fix reviews :remove unused codes and clear diopi impl functions
CokeDong Jun 12, 2023
bb9a00e
Add diopi support since mmengine 0.7.4
CokeDong Jun 12, 2023
346094c
Fix setup according to dipu
CokeDong Jun 13, 2023
2712d69
Fix version
CokeDong Jun 13, 2023
5fb42ba
add comment for test_voxelization.py
CokeDong Jun 13, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions mmcv/ops/csrc/pytorch/bbox_overlaps.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
#ifdef MMCV_WITH_DIOPI
#include <diopi/diopirt.h>
#include <diopi/functions.h>
#include <diopi/functions_mmcv.h>

#include "csrc_dipu/diopirt/diopirt_impl.h"

using dipu::diopi_helper::toDiopiScalar;
using dipu::diopi_helper::toDiopiTensorHandle;
#endif

void bbox_overlaps_impl(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
const int mode, const bool aligned, const int offset) {
Expand All @@ -10,5 +20,30 @@ void bbox_overlaps_impl(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,

void bbox_overlaps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
const int mode, const bool aligned, const int offset) {
#ifdef MMCV_WITH_DIOPI
auto bboxes1_p = toDiopiTensorHandle(bboxes1);
diopiDevice_t device;
diopiGetTensorDevice(bboxes1_p, &device);
if (device == diopi_host) {
bbox_overlaps_impl(bboxes1, bboxes2, ious, mode, aligned, offset);
return;
}
diopiContext ctx(dipu::getCurrentDIPUStream().rawstream());
diopiContextHandle_t ch = &ctx;
auto bboxes2_p = toDiopiTensorHandle(bboxes2);
auto ious_p = toDiopiTensorHandle(ious);
if (reinterpret_cast<void *>(diopiBboxOverlapsMmcv) != nullptr) {
auto ret = diopiBboxOverlapsMmcv(ch, ious_p, bboxes1_p, bboxes2_p, mode,
offset, aligned);
if (ret == diopiSuccess) return;
}
LOG(WARNING) << "Fallback to cpu: mmcv ext op bbox_overlaps";
auto bboxes1_cpu = bboxes1.cpu();
auto bboxes2_cpu = bboxes2.cpu();
auto ious_cpu = ious.cpu();
bbox_overlaps_impl(bboxes1_cpu, bboxes2_cpu, ious_cpu, mode, aligned, offset);
ious.copy_(ious_cpu);
#else
bbox_overlaps_impl(bboxes1, bboxes2, ious, mode, aligned, offset);
#endif
}
72 changes: 72 additions & 0 deletions mmcv/ops/csrc/pytorch/focal_loss.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
#ifdef MMCV_WITH_DIOPI
#include <diopi/diopirt.h>
#include <diopi/functions.h>
#include <diopi/functions_mmcv.h>

#include "csrc_dipu/diopirt/diopirt_impl.h"

using dipu::diopi_helper::toDiopiScalar;
using dipu::diopi_helper::toDiopiTensorHandle;
#endif

void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
Expand Down Expand Up @@ -31,13 +41,75 @@ void softmax_focal_loss_backward_impl(Tensor input, Tensor target,

void sigmoid_focal_loss_forward(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
#ifdef MMCV_WITH_DIOPI
auto input_p = toDiopiTensorHandle(input);
diopiDevice_t device;
diopiGetTensorDevice(input_p, &device);
if (device == diopi_host) {
sigmoid_focal_loss_forward_impl(input, target, weight, output, gamma,
alpha);
return;
}
diopiContext ctx(dipu::getCurrentDIPUStream().rawstream());
diopiContextHandle_t ch = &ctx;
auto target_p = toDiopiTensorHandle(target);
auto weight_p = toDiopiTensorHandle(weight);
auto output_p = toDiopiTensorHandle(output);
if (reinterpret_cast<void *>(diopiSigmoidFocalLossMmcv) != nullptr) {
auto ret = diopiSigmoidFocalLossMmcv(ch, output_p, input_p, target_p,
weight_p, gamma, alpha);
if (ret == diopiSuccess) return;
}
LOG(WARNING)
<< "Fallback to cpu: mmcv ext op sigmoid_focal_loss_forward_impl";
auto input_cpu = input.cpu();
auto target_cpu = target.cpu();
auto weight_cpu = weight.cpu();
auto output_cpu = output.cpu();
sigmoid_focal_loss_forward_impl(input_cpu, target_cpu, weight_cpu, output_cpu,
gamma, alpha);
output.copy_(output_cpu);
return;
#else
sigmoid_focal_loss_forward_impl(input, target, weight, output, gamma, alpha);
#endif
}

void sigmoid_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
Tensor grad_input, float gamma, float alpha) {
#ifdef MMCV_WITH_DIOPI
auto input_p = toDiopiTensorHandle(input);
diopiDevice_t device;
diopiGetTensorDevice(input_p, &device);
if (device == diopi_host) {
sigmoid_focal_loss_backward_impl(input, target, weight, grad_input, gamma,
alpha);
return;
}
diopiContext ctx(dipu::getCurrentDIPUStream().rawstream());
diopiContextHandle_t ch = &ctx;
auto target_p = toDiopiTensorHandle(target);
auto weight_p = toDiopiTensorHandle(weight);
auto grad_input_p = toDiopiTensorHandle(grad_input);
if (reinterpret_cast<void *>(diopiSigmoidFocalLossBackwardMmcv) != nullptr) {
auto ret = diopiSigmoidFocalLossBackwardMmcv(
ch, grad_input_p, input_p, target_p, weight_p, gamma, alpha);
if (ret == diopiSuccess) return;
}
LOG(WARNING)
<< "Fallback to cpu: mmcv ext op sigmoid_focal_loss_forward_impl";
auto input_cpu = input.cpu();
auto target_cpu = target.cpu();
auto weight_cpu = weight.cpu();
auto grad_input_cpu = grad_input.cpu();
sigmoid_focal_loss_backward_impl(input_cpu, target_cpu, weight_cpu,
grad_input_cpu, gamma, alpha);
grad_input.copy_(grad_input_cpu);
return;
#else
sigmoid_focal_loss_backward_impl(input, target, weight, grad_input, gamma,
alpha);
#endif
}

void softmax_focal_loss_forward(Tensor input, Tensor target, Tensor weight,
Expand Down
148 changes: 146 additions & 2 deletions mmcv/ops/csrc/pytorch/modulated_deform_conv.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
#ifdef MMCV_WITH_DIOPI
#include <diopi/diopirt.h>
#include <diopi/functions.h>
#include <diopi/functions_mmcv.h>

#include "csrc_dipu/diopirt/diopirt_impl.h"

using dipu::diopi_helper::toDiopiScalar;
using dipu::diopi_helper::toDiopiTensorHandle;
#endif

void modulated_deformable_im2col_impl(
const Tensor data_im, const Tensor data_offset, const Tensor data_mask,
Expand Down Expand Up @@ -45,7 +55,7 @@ void modulated_deformable_col2im_coord_impl(
dilation_w, deformable_group, grad_offset, grad_mask);
}

void modulated_deform_conv_forward(
void modulated_deform_conv_forward_fallthrough(
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
Expand Down Expand Up @@ -123,7 +133,63 @@ void modulated_deform_conv_forward(
}
}

void modulated_deform_conv_backward(
void modulated_deform_conv_forward(
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w, const int group,
const int deformable_group, const bool with_bias) {
#ifdef MMCV_WITH_DIOPI
auto input_p = toDiopiTensorHandle(input);
diopiDevice_t device;
diopiGetTensorDevice(input_p, &device);
if (device == diopi_host) {
modulated_deform_conv_forward_fallthrough(
input, weight, bias, ones, offset, mask, output, columns, kernel_h,
kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w,
group, deformable_group, with_bias);
return;
}
diopiContext ctx(dipu::getCurrentDIPUStream().rawstream());
diopiContextHandle_t ch = &ctx;
auto weight_p = toDiopiTensorHandle(weight);
auto bias_p = toDiopiTensorHandle(bias);
auto ones_p = toDiopiTensorHandle(ones);
auto offset_p = toDiopiTensorHandle(offset);
auto mask_p = toDiopiTensorHandle(mask);
auto output_p = toDiopiTensorHandle(output);
auto columns_p = toDiopiTensorHandle(columns);
if (reinterpret_cast<void*>(diopiModulatedDeformConvMmcv) != nullptr) {
auto ret = diopiModulatedDeformConvMmcv(
ch, output_p, columns_p, ones_p, input_p, weight_p, bias_p, offset_p,
mask_p, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, deformable_group, with_bias);
if (ret == diopiSuccess) return;
}
LOG(WARNING) << "Fallback to cpu: mmcv ext op modulated_deform_conv_forward";
auto input_cpu = input.cpu();
auto weight_cpu = weight.cpu();
auto bias_cpu = bias.cpu();
auto ones_cpu = ones.cpu();
auto offset_cpu = offset.cpu();
auto mask_cpu = mask.cpu();
auto output_cpu = output.cpu();
auto columns_cpu = columns.cpu();
modulated_deform_conv_forward_fallthrough(
input_cpu, weight_cpu, bias_cpu, ones_cpu, offset_cpu, mask_cpu,
output_cpu, columns_cpu, kernel_h, kernel_w, stride_h, stride_w, pad_h,
pad_w, dilation_h, dilation_w, group, deformable_group, with_bias);
output.copy_(output_cpu);
return;
#else
modulated_deform_conv_forward_fallthrough(
input, weight, bias, ones, offset, mask, output, columns, kernel_h,
kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
deformable_group, with_bias);
#endif
}

void modulated_deform_conv_backward_fallthrough(
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight,
Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output,
Expand Down Expand Up @@ -235,3 +301,81 @@ void modulated_deform_conv_backward(
grad_output.size(2), grad_output.size(3),
grad_output.size(4)});
}

void modulated_deform_conv_backward(
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight,
Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output,
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias) {
#ifdef MMCV_WITH_DIOPI
auto input_p = toDiopiTensorHandle(input);
diopiDevice_t device;
diopiGetTensorDevice(input_p, &device);
if (device == diopi_host) {
modulated_deform_conv_backward_fallthrough(
input, weight, bias, ones, offset, mask, columns, grad_input,
grad_weight, grad_bias, grad_offset, grad_mask, grad_output, kernel_h,
kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w,
group, deformable_group, with_bias);
return;
}
diopiContext ctx(dipu::getCurrentDIPUStream().rawstream());
diopiContextHandle_t ch = &ctx;
auto weight_p = toDiopiTensorHandle(weight);
auto bias_p = toDiopiTensorHandle(bias);
auto ones_p = toDiopiTensorHandle(ones);
auto offset_p = toDiopiTensorHandle(offset);
auto mask_p = toDiopiTensorHandle(mask);
auto columns_p = toDiopiTensorHandle(columns);
auto grad_input_p = toDiopiTensorHandle(grad_input);
auto grad_weight_p = toDiopiTensorHandle(grad_weight);
auto grad_bias_p = toDiopiTensorHandle(grad_bias);
auto grad_offset_p = toDiopiTensorHandle(grad_offset);
auto grad_mask_p = toDiopiTensorHandle(grad_mask);
auto grad_output_p = toDiopiTensorHandle(grad_output);

if (reinterpret_cast<void*>(diopiModulatedDeformConvBackwardMmcv) !=
nullptr) {
auto ret = diopiModulatedDeformConvBackwardMmcv(
ch, grad_input_p, grad_weight_p, grad_bias_p, grad_offset_p,
grad_mask_p, input_p, weight_p, bias_p, ones_p, offset_p, mask_p,
columns_p, grad_output_p, kernel_h, kernel_w, stride_h, stride_w, pad_h,
pad_w, dilation_h, dilation_w, group, deformable_group, with_bias);
if (ret == diopiSuccess) return;
}
LOG(WARNING) << "Fallback to cpu: mmcv ext op modulated_deform_conv_forward";
auto input_cpu = input.cpu();
auto weight_cpu = weight.cpu();
auto bias_cpu = bias.cpu();
auto ones_cpu = ones.cpu();
auto offset_cpu = offset.cpu();
auto mask_cpu = mask.cpu();
auto columns_cpu = columns.cpu();
auto grad_input_cpu = grad_input.cpu();
auto grad_weight_cpu = grad_weight.cpu();
auto grad_bias_cpu = grad_bias.cpu();
auto grad_offset_cpu = grad_offset.cpu();
auto grad_mask_cpu = grad_mask.cpu();
auto grad_output_cpu = grad_output.cpu();
modulated_deform_conv_backward_fallthrough(
input_cpu, weight_cpu, bias_cpu, ones_cpu, offset_cpu, mask_cpu,
columns_cpu, grad_input_cpu, grad_weight_cpu, grad_bias_cpu,
grad_offset_cpu, grad_mask_cpu, grad_output_cpu, kernel_h, kernel_w,
stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
deformable_group, with_bias);
grad_input.copy_(grad_input_cpu);
grad_weight.copy_(grad_weight_cpu);
grad_bias.copy_(grad_bias_cpu);
grad_offset.copy_(grad_offset_cpu);
grad_mask.copy_(grad_mask_cpu);
return;
#else
modulated_deform_conv_backward_fallthrough(
input, weight, bias, ones, offset, mask, columns, grad_input, grad_weight,
grad_bias, grad_offset, grad_mask, grad_output, kernel_h, kernel_w,
stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
deformable_group, with_bias);
#endif
}
37 changes: 37 additions & 0 deletions mmcv/ops/csrc/pytorch/nms.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
#ifdef MMCV_WITH_DIOPI
#include <diopi/diopirt.h>
#include <diopi/functions.h>
#include <diopi/functions_mmcv.h>

#include "csrc_dipu/diopirt/diopirt_impl.h"

using dipu::diopi_helper::toDiopiScalar;
using dipu::diopi_helper::toDiopiTensorHandle;
#endif

Tensor nms_impl(Tensor boxes, Tensor scores, float iou_threshold, int offset) {
return DISPATCH_DEVICE_IMPL(nms_impl, boxes, scores, iou_threshold, offset);
Expand All @@ -19,7 +29,34 @@ std::vector<std::vector<int> > nms_match_impl(Tensor dets,
}

Tensor nms(Tensor boxes, Tensor scores, float iou_threshold, int offset) {
#ifdef MMCV_WITH_DIOPI
auto boxes_p = toDiopiTensorHandle(boxes);
diopiDevice_t device;
diopiGetTensorDevice(boxes_p, &device);
if (device == diopi_host) {
return nms_impl(boxes, scores, iou_threshold, offset);
}
diopiContext ctx(dipu::getCurrentDIPUStream().rawstream());
diopiContextHandle_t ch = &ctx;
Tensor out;
auto outp = toDiopiTensorHandle(out);
diopiTensorHandle_t* outhandle = &outp;
auto scores_p = toDiopiTensorHandle(scores);
if (reinterpret_cast<void*>(diopiNmsMmcv) != nullptr) {
auto ret =
diopiNmsMmcv(ch, outhandle, boxes_p, scores_p, iou_threshold, offset);
if (ret == diopiSuccess) {
auto tensorhandle = reinterpret_cast<Tensor*>(*outhandle);
return *tensorhandle;
}
}
LOG(WARNING) << "Fallback to cpu: mmcv ext op nms";
auto boxes_cpu = boxes.cpu();
auto scores_cpu = scores.cpu();
return nms_impl(boxes_cpu, scores_cpu, iou_threshold, offset);
#else
return nms_impl(boxes, scores, iou_threshold, offset);
#endif
}

Tensor softnms(Tensor boxes, Tensor scores, Tensor dets, float iou_threshold,
Expand Down
Loading