Skip to content

Commit

Permalink
DeformConv code cleanup (#2905)
Browse files Browse the repository at this point in the history
* Clean up and refactor DeformConv implementation:
- Remove primitive const declaration from method names.
- Passing as const ref instead of value where possible.
- Aligning method names between cpu and cuda.

* Adding newline.

* Adding back include for cpu.

* Restoring method names of private methods to avoid conflicts.

* Restore include headers.
  • Loading branch information
datumbox authored Oct 30, 2020
1 parent 45e027c commit 0e5aee4
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 237 deletions.
130 changes: 69 additions & 61 deletions torchvision/csrc/DeformConv.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
#pragma once

#if defined(WITH_CUDA) || defined(WITH_HIP)
#include "cpu/vision_cpu.h"

#ifdef WITH_CUDA
#include "autocast.h"
#include "cuda/vision_cuda.h"
#endif
#ifdef WITH_HIP
#include "autocast.h"
#include "hip/vision_cuda.h"
#endif

// TODO: put this stuff in torchvision namespace
Expand All @@ -11,14 +18,14 @@ at::Tensor deform_conv2d(
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::deform_conv2d", "")
.typed<decltype(deform_conv2d)>();
Expand All @@ -43,14 +50,14 @@ at::Tensor DeformConv2d_autocast(
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return deform_conv2d(
at::autocast::cached_cast(at::kFloat, input),
Expand All @@ -76,14 +83,14 @@ _deform_conv2d_backward(
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_deform_conv2d_backward", "")
Expand All @@ -109,10 +116,10 @@ class DeformConv2dFunction
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
torch::autograd::Variable input,
torch::autograd::Variable weight,
torch::autograd::Variable offset,
torch::autograd::Variable bias,
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const torch::autograd::Variable& offset,
const torch::autograd::Variable& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
Expand All @@ -121,7 +128,7 @@ class DeformConv2dFunction
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
at::AutoNonVariableTypeMode g; // TODO_vv: check if necessary
at::AutoNonVariableTypeMode g;
auto output = deform_conv2d(
input,
weight,
Expand Down Expand Up @@ -153,7 +160,7 @@ class DeformConv2dFunction

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
const torch::autograd::variable_list& grad_output) {
auto saved = ctx->get_saved_variables();
auto input = saved[0];
auto weight = saved[1];
Expand Down Expand Up @@ -211,19 +218,19 @@ class DeformConv2dBackwardFunction
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
torch::autograd::Variable grad,
torch::autograd::Variable input,
torch::autograd::Variable weight,
torch::autograd::Variable offset,
torch::autograd::Variable bias,
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
const torch::autograd::Variable& grad,
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const torch::autograd::Variable& offset,
const torch::autograd::Variable& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
at::AutoNonVariableTypeMode g;
auto result = _deform_conv2d_backward(
grad,
Expand Down Expand Up @@ -255,7 +262,7 @@ class DeformConv2dBackwardFunction

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
const torch::autograd::variable_list& grad_output) {
TORCH_CHECK(0, "double backwards on deform_conv2d not supported");
}
};
Expand All @@ -265,14 +272,14 @@ at::Tensor DeformConv2d_autograd(
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
return DeformConv2dFunction::apply(
input,
weight,
Expand All @@ -295,14 +302,14 @@ DeformConv2d_backward_autograd(
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
auto result = DeformConv2dBackwardFunction::apply(
grad,
input,
Expand All @@ -317,5 +324,6 @@ DeformConv2d_backward_autograd(
dilation_w,
groups,
offset_groups);

return std::make_tuple(result[0], result[1], result[2], result[3]);
}
Loading

0 comments on commit 0e5aee4

Please sign in to comment.