From 20c3f858d516515e140d8958682b76fcf4fc5477 Mon Sep 17 00:00:00 2001 From: Vasileios Vryniotis Date: Wed, 28 Oct 2020 17:09:27 +0000 Subject: [PATCH 1/3] Fixing types. --- torchvision/csrc/PSROIAlign.h | 24 ++++++++++++------------ torchvision/csrc/cpu/PSROIAlign_cpu.cpp | 24 ++++++++++++------------ torchvision/csrc/cpu/vision_cpu.h | 24 ++++++++++++------------ torchvision/csrc/cuda/PSROIAlign_cuda.cu | 24 ++++++++++++------------ torchvision/csrc/cuda/vision_cuda.h | 24 ++++++++++++------------ 5 files changed, 60 insertions(+), 60 deletions(-) diff --git a/torchvision/csrc/PSROIAlign.h b/torchvision/csrc/PSROIAlign.h index ce8e49363c0..2c52ae4dfb0 100644 --- a/torchvision/csrc/PSROIAlign.h +++ b/torchvision/csrc/PSROIAlign.h @@ -14,10 +14,10 @@ std::tuple PSROIAlign_forward( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio) { if (input.is_cuda()) { #if defined(WITH_CUDA) || defined(WITH_HIP) return PSROIAlign_forward_cuda( @@ -39,14 +39,14 @@ at::Tensor PSROIAlign_backward( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& mapping_channel, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio, - const int batch_size, - const int channels, - const int height, - const int width) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width) { if (grad.is_cuda()) { #if defined(WITH_CUDA) || defined(WITH_HIP) return PSROIAlign_backward_cuda( diff --git a/torchvision/csrc/cpu/PSROIAlign_cpu.cpp b/torchvision/csrc/cpu/PSROIAlign_cpu.cpp index e5eb051cb91..c0a15318f8f 100644 --- a/torchvision/csrc/cpu/PSROIAlign_cpu.cpp +++ b/torchvision/csrc/cpu/PSROIAlign_cpu.cpp @@ -301,10 +301,10 @@ void PSROIAlignBackwardCPU( std::tuple PSROIAlign_forward_cpu( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio) { // Check if input tensors are CPU tensors TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); @@ -361,14 +361,14 @@ at::Tensor PSROIAlign_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio, - const int batch_size, - const int channels, - const int height, - const int width) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width) { // Check if input tensors are CPU tensors TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index 69b1bbf555d..877c04b6c57 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -65,23 +65,23 @@ VISION_API at::Tensor PSROIPool_backward_cpu( VISION_API std::tuple PSROIAlign_forward_cpu( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio); + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio); VISION_API at::Tensor PSROIAlign_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& mapping_channel, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio, - const int batch_size, - const int channels, - const int height, - const int width); + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width); VISION_API at::Tensor nms_cpu( const at::Tensor& dets, diff --git a/torchvision/csrc/cuda/PSROIAlign_cuda.cu b/torchvision/csrc/cuda/PSROIAlign_cuda.cu index 709b0bda208..05e9982543a 100644 --- a/torchvision/csrc/cuda/PSROIAlign_cuda.cu +++ b/torchvision/csrc/cuda/PSROIAlign_cuda.cu @@ -295,10 +295,10 @@ __global__ void PSROIAlignBackwardCUDA( std::tuple PSROIAlign_forward_cuda( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio) { // Check if input tensors are CUDA tensors TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); @@ -369,14 +369,14 @@ at::Tensor PSROIAlign_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio, - const int batch_size, - const int channels, - const int height, - const int width) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width) { // Check if input tensors are CUDA tensors TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor"); TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index 2481cfc63c2..3dd43421bdf 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -65,23 +65,23 @@ VISION_API at::Tensor PSROIPool_backward_cuda( VISION_API std::tuple PSROIAlign_forward_cuda( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio); + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio); VISION_API at::Tensor PSROIAlign_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& mapping_channel, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio, - const int batch_size, - const int channels, - const int height, - const int width); + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width); VISION_API at::Tensor nms_cuda( const at::Tensor& dets, From 462739f5c8f101ba393530a739364e96b139d3d9 Mon Sep 17 00:00:00 2001 From: Vasileios Vryniotis Date: Wed, 28 Oct 2020 17:41:17 +0000 Subject: [PATCH 2/3] Dispatcher + Autocast. --- torchvision/csrc/PSROIAlign.h | 99 +++++++++++++---------------- torchvision/csrc/cpu/vision_cpu.h | 2 +- torchvision/csrc/cuda/vision_cuda.h | 2 +- torchvision/csrc/vision.cpp | 10 ++- 4 files changed, 54 insertions(+), 59 deletions(-) diff --git a/torchvision/csrc/PSROIAlign.h b/torchvision/csrc/PSROIAlign.h index 2c52ae4dfb0..a390d5d1f8c 100644 --- a/torchvision/csrc/PSROIAlign.h +++ b/torchvision/csrc/PSROIAlign.h @@ -2,43 +2,55 @@ #include "cpu/vision_cpu.h" -#ifdef WITH_CUDA -#include "cuda/vision_cuda.h" -#endif -#ifdef WITH_HIP -#include "hip/vision_cuda.h" +#if defined(WITH_CUDA) || defined(WITH_HIP) +#include "autocast.h" #endif #include -std::tuple PSROIAlign_forward( +// TODO: put this stuff in torchvision namespace + +std::tuple ps_roi_align( const at::Tensor& input, const at::Tensor& rois, const double spatial_scale, const int64_t pooled_height, const int64_t pooled_width, const int64_t sampling_ratio) { - if (input.is_cuda()) { -#if defined(WITH_CUDA) || defined(WITH_HIP) - return PSROIAlign_forward_cuda( - input, - rois, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio); -#else - TORCH_CHECK(false, "Not compiled with GPU support"); -#endif - } - return PSROIAlign_forward_cpu( + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::ps_roi_align", "") + .typed(); + return op.call( input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); } -at::Tensor PSROIAlign_backward( +#if defined(WITH_CUDA) || defined(WITH_HIP) +std::tuple PSROIAlign_autocast( + const at::Tensor& input, + const at::Tensor& rois, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + auto result = ps_roi_align( + at::autocast::cached_cast(at::kFloat, input), + at::autocast::cached_cast(at::kFloat, rois), + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio); + + return std::make_tuple( + std::get<0>(result).to(input.scalar_type()), + std::get<1>(result).to(input.scalar_type())); +} +#endif + +at::Tensor _ps_roi_align_backward( const at::Tensor& grad, const at::Tensor& rois, - const at::Tensor& mapping_channel, + const at::Tensor& channel_mapping, const double spatial_scale, const int64_t pooled_height, const int64_t pooled_width, @@ -47,28 +59,14 @@ at::Tensor PSROIAlign_backward( const int64_t channels, const int64_t height, const int64_t width) { - if (grad.is_cuda()) { -#if defined(WITH_CUDA) || defined(WITH_HIP) - return PSROIAlign_backward_cuda( - grad, - rois, - mapping_channel, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - batch_size, - channels, - height, - width); -#else - TORCH_CHECK(false, "Not compiled with GPU support"); -#endif - } - return PSROIAlign_backward_cpu( + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_ps_roi_align_backward", "") + .typed(); + return op.call( grad, rois, - mapping_channel, + channel_mapping, spatial_scale, pooled_height, pooled_width, @@ -95,7 +93,8 @@ class PSROIAlignFunction ctx->saved_data["pooled_width"] = pooled_width; ctx->saved_data["sampling_ratio"] = sampling_ratio; ctx->saved_data["input_shape"] = input.sizes(); - auto result = PSROIAlign_forward( + at::AutoNonVariableTypeMode g; + auto result = ps_roi_align( input, rois, spatial_scale, @@ -117,7 +116,7 @@ class PSROIAlignFunction auto rois = saved[0]; auto channel_mapping = saved[1]; auto input_shape = ctx->saved_data["input_shape"].toIntList(); - auto grad_in = PSROIAlign_backward( + auto grad_in = _ps_roi_align_backward( grad_output[0], rois, channel_mapping, @@ -137,15 +136,3 @@ class PSROIAlignFunction torch::autograd::Variable()}; } }; - -std::tuple ps_roi_align( - const at::Tensor& input, - const at::Tensor& rois, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t sampling_ratio) { - auto result = PSROIAlignFunction::apply( - input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); - return std::tuple(result[0], result[1]); -} diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index 877c04b6c57..265cff347d8 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -73,7 +73,7 @@ VISION_API std::tuple PSROIAlign_forward_cpu( VISION_API at::Tensor PSROIAlign_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, - const at::Tensor& mapping_channel, + const at::Tensor& channel_mapping, const double spatial_scale, const int64_t pooled_height, const int64_t pooled_width, diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index 3dd43421bdf..7eaef181dd3 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -73,7 +73,7 @@ VISION_API std::tuple PSROIAlign_forward_cuda( VISION_API at::Tensor PSROIAlign_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, - const at::Tensor& mapping_channel, + const at::Tensor& channel_mapping, const double spatial_scale, const int64_t pooled_height, const int64_t pooled_width, diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index f56a671d6e5..de0d6ae1c84 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -52,7 +52,10 @@ TORCH_LIBRARY(torchvision, m) { "_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor"); m.def("roi_pool", &roi_pool); m.def("_new_empty_tensor_op", &new_empty_tensor); - m.def("ps_roi_align", &ps_roi_align); + m.def( + "ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)"); + m.def( + "_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor mapping_channel, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor"); m.def("ps_roi_pool", &ps_roi_pool); m.def( "deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> Tensor"); @@ -67,6 +70,8 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) { m.impl("deform_conv2d", DeformConv2d_forward_cpu); m.impl("_deform_conv2d_backward", DeformConv2d_backward_cpu); m.impl("nms", nms_cpu); + m.impl("ps_roi_align", PSROIAlign_forward_cpu); + m.impl("_ps_roi_align_backward", PSROIAlign_backward_cpu); } // TODO: Place this in a hypothetical separate torchvision_cuda library @@ -77,6 +82,8 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { m.impl("deform_conv2d", DeformConv2d_forward_cuda); m.impl("_deform_conv2d_backward", DeformConv2d_backward_cuda); m.impl("nms", nms_cuda); + m.impl("ps_roi_align", PSROIAlign_forward_cuda); + m.impl("_ps_roi_align_backward", PSROIAlign_backward_cuda); } #endif @@ -86,6 +93,7 @@ TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { m.impl("roi_align", ROIAlign_autocast); m.impl("deform_conv2d", DeformConv2d_autocast); m.impl("nms", nms_autocast); + m.impl("ps_roi_align", PSROIAlign_autocast); } #endif From fedc2e57b516905faa091f82fb757a0c71fc525e Mon Sep 17 00:00:00 2001 From: Vasileios Vryniotis Date: Wed, 28 Oct 2020 17:42:59 +0000 Subject: [PATCH 3/3] + Autograd. --- torchvision/csrc/PSROIAlign.h | 80 +++++++++++++++++++++++++++++++++++ torchvision/csrc/vision.cpp | 4 +- 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/torchvision/csrc/PSROIAlign.h b/torchvision/csrc/PSROIAlign.h index a390d5d1f8c..b99d359fa61 100644 --- a/torchvision/csrc/PSROIAlign.h +++ b/torchvision/csrc/PSROIAlign.h @@ -136,3 +136,83 @@ class PSROIAlignFunction torch::autograd::Variable()}; } }; + +// TODO: There should be an easier way to do this +class PSROIAlignBackwardFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + torch::autograd::Variable grad, + torch::autograd::Variable rois, + torch::autograd::Variable channel_mapping, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width) { + at::AutoNonVariableTypeMode g; + auto grad_in = _ps_roi_align_backward( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + batch_size, + channels, + height, + width); + + return {grad_in}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + TORCH_CHECK(0, "double backwards on ps_roi_align not supported"); + } +}; + +std::tuple PSROIAlign_autograd( + const at::Tensor& input, + const at::Tensor& rois, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio) { + auto result = PSROIAlignFunction::apply( + input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); + + return std::make_tuple(result[0], result[1]); +} + +at::Tensor PSROIAlign_backward_autograd( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width) { + return PSROIAlignBackwardFunction::apply( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + batch_size, + channels, + height, + width)[0]; +} \ No newline at end of file diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index de0d6ae1c84..d4f5ca2c917 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -55,7 +55,7 @@ TORCH_LIBRARY(torchvision, m) { m.def( "ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)"); m.def( - "_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor mapping_channel, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor"); + "_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor"); m.def("ps_roi_pool", &ps_roi_pool); m.def( "deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> Tensor"); @@ -102,4 +102,6 @@ TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { m.impl("_roi_align_backward", ROIAlign_backward_autograd); m.impl("deform_conv2d", DeformConv2d_autograd); m.impl("_deform_conv2d_backward", DeformConv2d_backward_autograd); + m.impl("ps_roi_align", PSROIAlign_autograd); + m.impl("_ps_roi_align_backward", PSROIAlign_backward_autograd); }