Skip to content

Commit

Permalink
Exposing public forward/backward methods to the C++ API
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Nov 30, 2020
1 parent f3f8469 commit f6782a0
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 40 deletions.
76 changes: 38 additions & 38 deletions torchvision/csrc/deform_conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
#include <ATen/autocast_mode.h>
#endif

namespace {

at::Tensor deform_conv2d(
const at::Tensor& input,
const at::Tensor& weight,
Expand Down Expand Up @@ -81,6 +79,44 @@ _deform_conv2d_backward(
use_mask);
}

#if defined(WITH_CUDA) || defined(WITH_HIP)
at::Tensor deform_conv2d_autocast(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return deform_conv2d(
at::autocast::cached_cast(at::kFloat, input),
at::autocast::cached_cast(at::kFloat, weight),
at::autocast::cached_cast(at::kFloat, offset),
at::autocast::cached_cast(at::kFloat, mask),
at::autocast::cached_cast(at::kFloat, bias),
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask)
.to(input.scalar_type());
}
#endif

namespace {

class DeformConv2dFunction
: public torch::autograd::Function<DeformConv2dFunction> {
public:
Expand Down Expand Up @@ -257,42 +293,6 @@ class DeformConv2dBackwardFunction

} // namespace

#if defined(WITH_CUDA) || defined(WITH_HIP)
at::Tensor deform_conv2d_autocast(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return deform_conv2d(
at::autocast::cached_cast(at::kFloat, input),
at::autocast::cached_cast(at::kFloat, weight),
at::autocast::cached_cast(at::kFloat, offset),
at::autocast::cached_cast(at::kFloat, mask),
at::autocast::cached_cast(at::kFloat, bias),
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask)
.to(input.scalar_type());
}
#endif

at::Tensor deform_conv2d_autograd(
const at::Tensor& input,
const at::Tensor& weight,
Expand Down
39 changes: 37 additions & 2 deletions torchvision/csrc/deform_conv2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,42 @@
#include "hip/deform_conv2d_kernel.h"
#endif

// Autocast Registration
// C++ Forward and Backward API
at::Tensor deform_conv2d(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask);

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_deform_conv2d_backward(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask);

// Autocast Forward
#if defined(WITH_CUDA) || defined(WITH_HIP)
at::Tensor deform_conv2d_autocast(
const at::Tensor& input,
Expand All @@ -28,7 +63,7 @@ at::Tensor deform_conv2d_autocast(
bool use_mask);
#endif

// Autograd Registration
// Autograd Forward and Backward
at::Tensor deform_conv2d_autograd(
const at::Tensor& input,
const at::Tensor& weight,
Expand Down

0 comments on commit f6782a0

Please sign in to comment.