Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DONOTMERGE] Port PSROIAlign to use the Dispatcher and support Autocast #2927

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 125 additions & 58 deletions torchvision/csrc/PSROIAlign.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,73 +2,71 @@

#include "cpu/vision_cpu.h"

#ifdef WITH_CUDA
#include "cuda/vision_cuda.h"
#endif
#ifdef WITH_HIP
#include "hip/vision_cuda.h"
#if defined(WITH_CUDA) || defined(WITH_HIP)
#include "autocast.h"
#endif

#include <iostream>

std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward(
// TODO: put this stuff in torchvision namespace

std::tuple<at::Tensor, at::Tensor> ps_roi_align(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio) {
if (input.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return PSROIAlign_forward_cuda(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);
#else
TORCH_CHECK(false, "Not compiled with GPU support");
#endif
}
return PSROIAlign_forward_cpu(
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::ps_roi_align", "")
.typed<decltype(ps_roi_align)>();
return op.call(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
}

at::Tensor PSROIAlign_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& mapping_channel,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const int batch_size,
const int channels,
const int height,
const int width) {
if (grad.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return PSROIAlign_backward_cuda(
grad,
rois,
mapping_channel,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width);
#else
TORCH_CHECK(false, "Not compiled with GPU support");
std::tuple<at::Tensor, at::Tensor> PSROIAlign_autocast(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
auto result = ps_roi_align(
at::autocast::cached_cast(at::kFloat, input),
at::autocast::cached_cast(at::kFloat, rois),
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);

return std::make_tuple(
std::get<0>(result).to(input.scalar_type()),
std::get<1>(result).to(input.scalar_type()));
}
#endif
}
return PSROIAlign_backward_cpu(

at::Tensor _ps_roi_align_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_ps_roi_align_backward", "")
.typed<decltype(_ps_roi_align_backward)>();
return op.call(
grad,
rois,
mapping_channel,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
Expand All @@ -95,7 +93,8 @@ class PSROIAlignFunction
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["input_shape"] = input.sizes();
auto result = PSROIAlign_forward(
at::AutoNonVariableTypeMode g;
auto result = ps_roi_align(
input,
rois,
spatial_scale,
Expand All @@ -117,7 +116,7 @@ class PSROIAlignFunction
auto rois = saved[0];
auto channel_mapping = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = PSROIAlign_backward(
auto grad_in = _ps_roi_align_backward(
grad_output[0],
rois,
channel_mapping,
Expand All @@ -138,7 +137,48 @@ class PSROIAlignFunction
}
};

std::tuple<at::Tensor, at::Tensor> ps_roi_align(
// TODO: There should be an easier way to do this
class PSROIAlignBackwardFunction
: public torch::autograd::Function<PSROIAlignBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
torch::autograd::Variable grad,
torch::autograd::Variable rois,
torch::autograd::Variable channel_mapping,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width) {
at::AutoNonVariableTypeMode g;
auto grad_in = _ps_roi_align_backward(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width);

return {grad_in};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
TORCH_CHECK(0, "double backwards on ps_roi_align not supported");
}
};

std::tuple<at::Tensor, at::Tensor> PSROIAlign_autograd(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
Expand All @@ -147,5 +187,32 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align(
const int64_t sampling_ratio) {
auto result = PSROIAlignFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
return std::tuple<at::Tensor, at::Tensor>(result[0], result[1]);

return std::make_tuple(result[0], result[1]);
}

at::Tensor PSROIAlign_backward_autograd(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width) {
return PSROIAlignBackwardFunction::apply(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width)[0];
}
24 changes: 12 additions & 12 deletions torchvision/csrc/cpu/PSROIAlign_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,10 @@ void PSROIAlignBackwardCPU(
std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio) {
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
// Check if input tensors are CPU tensors
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
Expand Down Expand Up @@ -361,14 +361,14 @@ at::Tensor PSROIAlign_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const int batch_size,
const int channels,
const int height,
const int width) {
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width) {
// Check if input tensors are CPU tensors
TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor");
TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
Expand Down
26 changes: 13 additions & 13 deletions torchvision/csrc/cpu/vision_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,23 +65,23 @@ VISION_API at::Tensor PSROIPool_backward_cpu(
VISION_API std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio);
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio);

VISION_API at::Tensor PSROIAlign_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& mapping_channel,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const int batch_size,
const int channels,
const int height,
const int width);
const at::Tensor& channel_mapping,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width);

VISION_API at::Tensor nms_cpu(
const at::Tensor& dets,
Expand Down
24 changes: 12 additions & 12 deletions torchvision/csrc/cuda/PSROIAlign_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,10 @@ __global__ void PSROIAlignBackwardCUDA(
std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio) {
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
// Check if input tensors are CUDA tensors
TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor");
Expand Down Expand Up @@ -369,14 +369,14 @@ at::Tensor PSROIAlign_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const int batch_size,
const int channels,
const int height,
const int width) {
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width) {
// Check if input tensors are CUDA tensors
TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor");
TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor");
Expand Down
26 changes: 13 additions & 13 deletions torchvision/csrc/cuda/vision_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,23 +65,23 @@ VISION_API at::Tensor PSROIPool_backward_cuda(
VISION_API std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio);
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio);

VISION_API at::Tensor PSROIAlign_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& mapping_channel,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const int batch_size,
const int channels,
const int height,
const int width);
const at::Tensor& channel_mapping,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width);

VISION_API at::Tensor nms_cuda(
const at::Tensor& dets,
Expand Down
Loading