Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PSROIAlign + Dispatcher + Autocast + Code Cleanup #2928

Merged
merged 6 commits into from
Oct 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 139 additions & 64 deletions torchvision/csrc/PSROIAlign.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,72 +3,75 @@
#include "cpu/vision_cpu.h"

#ifdef WITH_CUDA
#include "autocast.h"
#include "cuda/vision_cuda.h"
#endif
#ifdef WITH_HIP
#include "autocast.h"
#include "hip/vision_cuda.h"
#endif

#include <iostream>

std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward(
// TODO: put this stuff in torchvision namespace

std::tuple<at::Tensor, at::Tensor> ps_roi_align(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio) {
if (input.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return PSROIAlign_forward_cuda(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);
#else
TORCH_CHECK(false, "Not compiled with GPU support");
#endif
}
return PSROIAlign_forward_cpu(
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::ps_roi_align", "")
.typed<decltype(ps_roi_align)>();
return op.call(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
}

at::Tensor PSROIAlign_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& mapping_channel,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const int batch_size,
const int channels,
const int height,
const int width) {
if (grad.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return PSROIAlign_backward_cuda(
grad,
rois,
mapping_channel,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width);
#else
TORCH_CHECK(false, "Not compiled with GPU support");
std::tuple<at::Tensor, at::Tensor> PSROIAlign_autocast(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
auto result = ps_roi_align(
at::autocast::cached_cast(at::kFloat, input),
at::autocast::cached_cast(at::kFloat, rois),
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);

return std::make_tuple(
std::get<0>(result).to(input.scalar_type()),
std::get<1>(result).to(input.scalar_type()));
}
#endif
}
return PSROIAlign_backward_cpu(

at::Tensor _ps_roi_align_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_ps_roi_align_backward", "")
.typed<decltype(_ps_roi_align_backward)>();
return op.call(
grad,
rois,
mapping_channel,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
Expand All @@ -84,40 +87,43 @@ class PSROIAlignFunction
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
torch::autograd::Variable input,
torch::autograd::Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
const torch::autograd::Variable& input,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["input_shape"] = input.sizes();
auto result = PSROIAlign_forward(
at::AutoNonVariableTypeMode g;
auto result = ps_roi_align(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);

auto output = std::get<0>(result);
auto channel_mapping = std::get<1>(result);
ctx->save_for_backward({rois, channel_mapping});
ctx->mark_non_differentiable({channel_mapping});

return {output, channel_mapping};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
const torch::autograd::variable_list& grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto channel_mapping = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = PSROIAlign_backward(
auto grad_in = _ps_roi_align_backward(
grad_output[0],
rois,
channel_mapping,
Expand All @@ -129,6 +135,7 @@ class PSROIAlignFunction
input_shape[1],
input_shape[2],
input_shape[3]);

return {grad_in,
torch::autograd::Variable(),
torch::autograd::Variable(),
Expand All @@ -138,14 +145,82 @@ class PSROIAlignFunction
}
};

std::tuple<at::Tensor, at::Tensor> ps_roi_align(
// TODO: There should be an easier way to do this
class PSROIAlignBackwardFunction
: public torch::autograd::Function<PSROIAlignBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& grad,
const torch::autograd::Variable& rois,
const torch::autograd::Variable& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
at::AutoNonVariableTypeMode g;
auto grad_in = _ps_roi_align_backward(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width);

return {grad_in};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
TORCH_CHECK(0, "double backwards on ps_roi_align not supported");
}
};

std::tuple<at::Tensor, at::Tensor> PSROIAlign_autograd(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio) {
auto result = PSROIAlignFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
return std::tuple<at::Tensor, at::Tensor>(result[0], result[1]);

return std::make_tuple(result[0], result[1]);
}

at::Tensor PSROIAlign_backward_autograd(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
return PSROIAlignBackwardFunction::apply(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width)[0];
}
datumbox marked this conversation as resolved.
Show resolved Hide resolved
Loading