From 20c3f858d516515e140d8958682b76fcf4fc5477 Mon Sep 17 00:00:00 2001 From: Vasileios Vryniotis Date: Wed, 28 Oct 2020 17:09:27 +0000 Subject: [PATCH 1/6] Fixing types. --- torchvision/csrc/PSROIAlign.h | 24 ++++++++++++------------ torchvision/csrc/cpu/PSROIAlign_cpu.cpp | 24 ++++++++++++------------ torchvision/csrc/cpu/vision_cpu.h | 24 ++++++++++++------------ torchvision/csrc/cuda/PSROIAlign_cuda.cu | 24 ++++++++++++------------ torchvision/csrc/cuda/vision_cuda.h | 24 ++++++++++++------------ 5 files changed, 60 insertions(+), 60 deletions(-) diff --git a/torchvision/csrc/PSROIAlign.h b/torchvision/csrc/PSROIAlign.h index ce8e49363c0..2c52ae4dfb0 100644 --- a/torchvision/csrc/PSROIAlign.h +++ b/torchvision/csrc/PSROIAlign.h @@ -14,10 +14,10 @@ std::tuple PSROIAlign_forward( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio) { if (input.is_cuda()) { #if defined(WITH_CUDA) || defined(WITH_HIP) return PSROIAlign_forward_cuda( @@ -39,14 +39,14 @@ at::Tensor PSROIAlign_backward( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& mapping_channel, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio, - const int batch_size, - const int channels, - const int height, - const int width) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width) { if (grad.is_cuda()) { #if defined(WITH_CUDA) || defined(WITH_HIP) return PSROIAlign_backward_cuda( diff --git a/torchvision/csrc/cpu/PSROIAlign_cpu.cpp b/torchvision/csrc/cpu/PSROIAlign_cpu.cpp index e5eb051cb91..c0a15318f8f 100644 --- a/torchvision/csrc/cpu/PSROIAlign_cpu.cpp +++ b/torchvision/csrc/cpu/PSROIAlign_cpu.cpp @@ -301,10 +301,10 @@ void PSROIAlignBackwardCPU( std::tuple PSROIAlign_forward_cpu( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio) { // Check if input tensors are CPU tensors TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); @@ -361,14 +361,14 @@ at::Tensor PSROIAlign_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio, - const int batch_size, - const int channels, - const int height, - const int width) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width) { // Check if input tensors are CPU tensors TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index 69b1bbf555d..877c04b6c57 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -65,23 +65,23 @@ VISION_API at::Tensor PSROIPool_backward_cpu( VISION_API std::tuple PSROIAlign_forward_cpu( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio); + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio); VISION_API at::Tensor PSROIAlign_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& mapping_channel, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio, - const int batch_size, - const int channels, - const int height, - const int width); + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width); VISION_API at::Tensor nms_cpu( const at::Tensor& dets, diff --git a/torchvision/csrc/cuda/PSROIAlign_cuda.cu b/torchvision/csrc/cuda/PSROIAlign_cuda.cu index 709b0bda208..05e9982543a 100644 --- a/torchvision/csrc/cuda/PSROIAlign_cuda.cu +++ b/torchvision/csrc/cuda/PSROIAlign_cuda.cu @@ -295,10 +295,10 @@ __global__ void PSROIAlignBackwardCUDA( std::tuple PSROIAlign_forward_cuda( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio) { // Check if input tensors are CUDA tensors TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); @@ -369,14 +369,14 @@ at::Tensor PSROIAlign_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio, - const int batch_size, - const int channels, - const int height, - const int width) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width) { // Check if input tensors are CUDA tensors TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor"); TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index 2481cfc63c2..3dd43421bdf 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -65,23 +65,23 @@ VISION_API at::Tensor PSROIPool_backward_cuda( VISION_API std::tuple PSROIAlign_forward_cuda( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio); + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio); VISION_API at::Tensor PSROIAlign_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& mapping_channel, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio, - const int batch_size, - const int channels, - const int height, - const int width); + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width); VISION_API at::Tensor nms_cuda( const at::Tensor& dets, From 462739f5c8f101ba393530a739364e96b139d3d9 Mon Sep 17 00:00:00 2001 From: Vasileios Vryniotis Date: Wed, 28 Oct 2020 17:41:17 +0000 Subject: [PATCH 2/6] Dispatcher + Autocast. --- torchvision/csrc/PSROIAlign.h | 99 +++++++++++++---------------- torchvision/csrc/cpu/vision_cpu.h | 2 +- torchvision/csrc/cuda/vision_cuda.h | 2 +- torchvision/csrc/vision.cpp | 10 ++- 4 files changed, 54 insertions(+), 59 deletions(-) diff --git a/torchvision/csrc/PSROIAlign.h b/torchvision/csrc/PSROIAlign.h index 2c52ae4dfb0..a390d5d1f8c 100644 --- a/torchvision/csrc/PSROIAlign.h +++ b/torchvision/csrc/PSROIAlign.h @@ -2,43 +2,55 @@ #include "cpu/vision_cpu.h" -#ifdef WITH_CUDA -#include "cuda/vision_cuda.h" -#endif -#ifdef WITH_HIP -#include "hip/vision_cuda.h" +#if defined(WITH_CUDA) || defined(WITH_HIP) +#include "autocast.h" #endif #include -std::tuple PSROIAlign_forward( +// TODO: put this stuff in torchvision namespace + +std::tuple ps_roi_align( const at::Tensor& input, const at::Tensor& rois, const double spatial_scale, const int64_t pooled_height, const int64_t pooled_width, const int64_t sampling_ratio) { - if (input.is_cuda()) { -#if defined(WITH_CUDA) || defined(WITH_HIP) - return PSROIAlign_forward_cuda( - input, - rois, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio); -#else - TORCH_CHECK(false, "Not compiled with GPU support"); -#endif - } - return PSROIAlign_forward_cpu( + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::ps_roi_align", "") + .typed(); + return op.call( input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); } -at::Tensor PSROIAlign_backward( +#if defined(WITH_CUDA) || defined(WITH_HIP) +std::tuple PSROIAlign_autocast( + const at::Tensor& input, + const at::Tensor& rois, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + auto result = ps_roi_align( + at::autocast::cached_cast(at::kFloat, input), + at::autocast::cached_cast(at::kFloat, rois), + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio); + + return std::make_tuple( + std::get<0>(result).to(input.scalar_type()), + std::get<1>(result).to(input.scalar_type())); +} +#endif + +at::Tensor _ps_roi_align_backward( const at::Tensor& grad, const at::Tensor& rois, - const at::Tensor& mapping_channel, + const at::Tensor& channel_mapping, const double spatial_scale, const int64_t pooled_height, const int64_t pooled_width, @@ -47,28 +59,14 @@ at::Tensor PSROIAlign_backward( const int64_t channels, const int64_t height, const int64_t width) { - if (grad.is_cuda()) { -#if defined(WITH_CUDA) || defined(WITH_HIP) - return PSROIAlign_backward_cuda( - grad, - rois, - mapping_channel, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - batch_size, - channels, - height, - width); -#else - TORCH_CHECK(false, "Not compiled with GPU support"); -#endif - } - return PSROIAlign_backward_cpu( + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_ps_roi_align_backward", "") + .typed(); + return op.call( grad, rois, - mapping_channel, + channel_mapping, spatial_scale, pooled_height, pooled_width, @@ -95,7 +93,8 @@ class PSROIAlignFunction ctx->saved_data["pooled_width"] = pooled_width; ctx->saved_data["sampling_ratio"] = sampling_ratio; ctx->saved_data["input_shape"] = input.sizes(); - auto result = PSROIAlign_forward( + at::AutoNonVariableTypeMode g; + auto result = ps_roi_align( input, rois, spatial_scale, @@ -117,7 +116,7 @@ class PSROIAlignFunction auto rois = saved[0]; auto channel_mapping = saved[1]; auto input_shape = ctx->saved_data["input_shape"].toIntList(); - auto grad_in = PSROIAlign_backward( + auto grad_in = _ps_roi_align_backward( grad_output[0], rois, channel_mapping, @@ -137,15 +136,3 @@ class PSROIAlignFunction torch::autograd::Variable()}; } }; - -std::tuple ps_roi_align( - const at::Tensor& input, - const at::Tensor& rois, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t sampling_ratio) { - auto result = PSROIAlignFunction::apply( - input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); - return std::tuple(result[0], result[1]); -} diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index 877c04b6c57..265cff347d8 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -73,7 +73,7 @@ VISION_API std::tuple PSROIAlign_forward_cpu( VISION_API at::Tensor PSROIAlign_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, - const at::Tensor& mapping_channel, + const at::Tensor& channel_mapping, const double spatial_scale, const int64_t pooled_height, const int64_t pooled_width, diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index 3dd43421bdf..7eaef181dd3 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -73,7 +73,7 @@ VISION_API std::tuple PSROIAlign_forward_cuda( VISION_API at::Tensor PSROIAlign_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, - const at::Tensor& mapping_channel, + const at::Tensor& channel_mapping, const double spatial_scale, const int64_t pooled_height, const int64_t pooled_width, diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index f56a671d6e5..de0d6ae1c84 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -52,7 +52,10 @@ TORCH_LIBRARY(torchvision, m) { "_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor"); m.def("roi_pool", &roi_pool); m.def("_new_empty_tensor_op", &new_empty_tensor); - m.def("ps_roi_align", &ps_roi_align); + m.def( + "ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)"); + m.def( + "_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor mapping_channel, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor"); m.def("ps_roi_pool", &ps_roi_pool); m.def( "deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> Tensor"); @@ -67,6 +70,8 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) { m.impl("deform_conv2d", DeformConv2d_forward_cpu); m.impl("_deform_conv2d_backward", DeformConv2d_backward_cpu); m.impl("nms", nms_cpu); + m.impl("ps_roi_align", PSROIAlign_forward_cpu); + m.impl("_ps_roi_align_backward", PSROIAlign_backward_cpu); } // TODO: Place this in a hypothetical separate torchvision_cuda library @@ -77,6 +82,8 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { m.impl("deform_conv2d", DeformConv2d_forward_cuda); m.impl("_deform_conv2d_backward", DeformConv2d_backward_cuda); m.impl("nms", nms_cuda); + m.impl("ps_roi_align", PSROIAlign_forward_cuda); + m.impl("_ps_roi_align_backward", PSROIAlign_backward_cuda); } #endif @@ -86,6 +93,7 @@ TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { m.impl("roi_align", ROIAlign_autocast); m.impl("deform_conv2d", DeformConv2d_autocast); m.impl("nms", nms_autocast); + m.impl("ps_roi_align", PSROIAlign_autocast); } #endif From fedc2e57b516905faa091f82fb757a0c71fc525e Mon Sep 17 00:00:00 2001 From: Vasileios Vryniotis Date: Wed, 28 Oct 2020 17:42:59 +0000 Subject: [PATCH 3/6] + Autograd. --- torchvision/csrc/PSROIAlign.h | 80 +++++++++++++++++++++++++++++++++++ torchvision/csrc/vision.cpp | 4 +- 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/torchvision/csrc/PSROIAlign.h b/torchvision/csrc/PSROIAlign.h index a390d5d1f8c..b99d359fa61 100644 --- a/torchvision/csrc/PSROIAlign.h +++ b/torchvision/csrc/PSROIAlign.h @@ -136,3 +136,83 @@ class PSROIAlignFunction torch::autograd::Variable()}; } }; + +// TODO: There should be an easier way to do this +class PSROIAlignBackwardFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + torch::autograd::Variable grad, + torch::autograd::Variable rois, + torch::autograd::Variable channel_mapping, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width) { + at::AutoNonVariableTypeMode g; + auto grad_in = _ps_roi_align_backward( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + batch_size, + channels, + height, + width); + + return {grad_in}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + TORCH_CHECK(0, "double backwards on ps_roi_align not supported"); + } +}; + +std::tuple PSROIAlign_autograd( + const at::Tensor& input, + const at::Tensor& rois, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio) { + auto result = PSROIAlignFunction::apply( + input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); + + return std::make_tuple(result[0], result[1]); +} + +at::Tensor PSROIAlign_backward_autograd( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width) { + return PSROIAlignBackwardFunction::apply( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + batch_size, + channels, + height, + width)[0]; +} \ No newline at end of file diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index de0d6ae1c84..d4f5ca2c917 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -55,7 +55,7 @@ TORCH_LIBRARY(torchvision, m) { m.def( "ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)"); m.def( - "_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor mapping_channel, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor"); + "_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor"); m.def("ps_roi_pool", &ps_roi_pool); m.def( "deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> Tensor"); @@ -102,4 +102,6 @@ TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { m.impl("_roi_align_backward", ROIAlign_backward_autograd); m.impl("deform_conv2d", DeformConv2d_autograd); m.impl("_deform_conv2d_backward", DeformConv2d_backward_autograd); + m.impl("ps_roi_align", PSROIAlign_autograd); + m.impl("_ps_roi_align_backward", PSROIAlign_backward_autograd); } From 9bf9d5e2ae5ae9efe46f8004331cd1b99ce1b4e0 Mon Sep 17 00:00:00 2001 From: Vasileios Vryniotis Date: Wed, 28 Oct 2020 18:30:13 +0000 Subject: [PATCH 4/6] Clean up and refactor PSROIAlign implementation: - Remove primitive const declaration from method names. - Using references when possible. - Sync naming of internal methods with other ops. --- torchvision/csrc/PSROIAlign.h | 97 ++++++++++++------------ torchvision/csrc/cpu/PSROIAlign_cpu.cpp | 78 +++++++++---------- torchvision/csrc/cpu/vision_cpu.h | 24 +++--- torchvision/csrc/cuda/PSROIAlign_cuda.cu | 78 +++++++++---------- torchvision/csrc/cuda/vision_cuda.h | 24 +++--- 5 files changed, 152 insertions(+), 149 deletions(-) diff --git a/torchvision/csrc/PSROIAlign.h b/torchvision/csrc/PSROIAlign.h index b99d359fa61..84803f76e69 100644 --- a/torchvision/csrc/PSROIAlign.h +++ b/torchvision/csrc/PSROIAlign.h @@ -13,10 +13,10 @@ std::tuple ps_roi_align( const at::Tensor& input, const at::Tensor& rois, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t sampling_ratio) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("torchvision::ps_roi_align", "") .typed(); @@ -28,10 +28,10 @@ std::tuple ps_roi_align( std::tuple PSROIAlign_autocast( const at::Tensor& input, const at::Tensor& rois, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t sampling_ratio) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio) { c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); auto result = ps_roi_align( at::autocast::cached_cast(at::kFloat, input), @@ -51,14 +51,14 @@ at::Tensor _ps_roi_align_backward( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t sampling_ratio, - const int64_t batch_size, - const int64_t channels, - const int64_t height, - const int64_t width) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("torchvision::_ps_roi_align_backward", "") @@ -82,12 +82,12 @@ class PSROIAlignFunction public: static torch::autograd::variable_list forward( torch::autograd::AutogradContext* ctx, - torch::autograd::Variable input, - torch::autograd::Variable rois, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t sampling_ratio) { + const torch::autograd::Variable& input, + const torch::autograd::Variable& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio) { ctx->saved_data["spatial_scale"] = spatial_scale; ctx->saved_data["pooled_height"] = pooled_height; ctx->saved_data["pooled_width"] = pooled_width; @@ -101,16 +101,18 @@ class PSROIAlignFunction pooled_height, pooled_width, sampling_ratio); + auto output = std::get<0>(result); auto channel_mapping = std::get<1>(result); ctx->save_for_backward({rois, channel_mapping}); ctx->mark_non_differentiable({channel_mapping}); + return {output, channel_mapping}; } static torch::autograd::variable_list backward( torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_output) { + const torch::autograd::variable_list& grad_output) { // Use data saved in forward auto saved = ctx->get_saved_variables(); auto rois = saved[0]; @@ -128,6 +130,7 @@ class PSROIAlignFunction input_shape[1], input_shape[2], input_shape[3]); + return {grad_in, torch::autograd::Variable(), torch::autograd::Variable(), @@ -143,17 +146,17 @@ class PSROIAlignBackwardFunction public: static torch::autograd::variable_list forward( torch::autograd::AutogradContext* ctx, - torch::autograd::Variable grad, - torch::autograd::Variable rois, - torch::autograd::Variable channel_mapping, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t sampling_ratio, - const int64_t batch_size, - const int64_t channels, - const int64_t height, - const int64_t width) { + const torch::autograd::Variable& grad, + const torch::autograd::Variable& rois, + const torch::autograd::Variable& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { at::AutoNonVariableTypeMode g; auto grad_in = _ps_roi_align_backward( grad, @@ -173,7 +176,7 @@ class PSROIAlignBackwardFunction static torch::autograd::variable_list backward( torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_output) { + const torch::autograd::variable_list& grad_output) { TORCH_CHECK(0, "double backwards on ps_roi_align not supported"); } }; @@ -181,10 +184,10 @@ class PSROIAlignBackwardFunction std::tuple PSROIAlign_autograd( const at::Tensor& input, const at::Tensor& rois, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t sampling_ratio) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio) { auto result = PSROIAlignFunction::apply( input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); @@ -195,14 +198,14 @@ at::Tensor PSROIAlign_backward_autograd( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t sampling_ratio, - const int64_t batch_size, - const int64_t channels, - const int64_t height, - const int64_t width) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { return PSROIAlignBackwardFunction::apply( grad, rois, diff --git a/torchvision/csrc/cpu/PSROIAlign_cpu.cpp b/torchvision/csrc/cpu/PSROIAlign_cpu.cpp index c0a15318f8f..25eec07fe78 100644 --- a/torchvision/csrc/cpu/PSROIAlign_cpu.cpp +++ b/torchvision/csrc/cpu/PSROIAlign_cpu.cpp @@ -5,11 +5,11 @@ template T bilinear_interpolate( const T* input, - const int height, - const int width, + int height, + int width, T y, T x, - const int index /* index for debug only*/) { + int index /* index for debug only*/) { // deal with cases that inverse elements are out of feature map boundary if (y < -1.0 || y > height || x < -1.0 || x > width) { // empty @@ -57,18 +57,18 @@ T bilinear_interpolate( } template -void PSROIAlignForwardCPU( - const int nthreads, +void PSROIAlignForward( + int nthreads, const T* input, const T spatial_scale, - const int channels, - const int height, - const int width, - const int pooled_height, - const int pooled_width, - const int sampling_ratio, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int sampling_ratio, const T* rois, - const int channels_out, + int channels_out, T* output, int* channel_mapping) { int num_rois = nthreads / channels_out / pooled_width / pooled_height; @@ -139,8 +139,8 @@ void PSROIAlignForwardCPU( template void bilinear_interpolate_gradient( - const int height, - const int width, + int height, + int width, T y, T x, T& w1, @@ -151,7 +151,7 @@ void bilinear_interpolate_gradient( int& x_high, int& y_low, int& y_high, - const int index /* index for debug only*/) { + int index /* index for debug only*/) { // deal with cases that inverse elements are out of feature map boundary if (y < -1.0 || y > height || x < -1.0 || x > width) { // empty @@ -202,19 +202,19 @@ inline void add(T* address, const T& val) { } template -void PSROIAlignBackwardCPU( - const int nthreads, +void PSROIAlignBackward( + int nthreads, const T* grad_output, const int* channel_mapping, - const int num_rois, + int num_rois, const T spatial_scale, - const int channels, - const int height, - const int width, - const int pooled_height, - const int pooled_width, - const int sampling_ratio, - const int channels_out, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int sampling_ratio, + int channels_out, T* grad_input, const T* rois) { for (int index = 0; index < nthreads; index++) { @@ -301,10 +301,10 @@ void PSROIAlignBackwardCPU( std::tuple PSROIAlign_forward_cpu( const at::Tensor& input, const at::Tensor& rois, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t sampling_ratio) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio) { // Check if input tensors are CPU tensors TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); @@ -339,7 +339,7 @@ std::tuple PSROIAlign_forward_cpu( auto input_ = input.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( input.scalar_type(), "PSROIAlign_forward", [&] { - PSROIAlignForwardCPU( + PSROIAlignForward( output_size, input_.data_ptr(), spatial_scale, @@ -361,14 +361,14 @@ at::Tensor PSROIAlign_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t sampling_ratio, - const int64_t batch_size, - const int64_t channels, - const int64_t height, - const int64_t width) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { // Check if input tensors are CPU tensors TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); @@ -396,7 +396,7 @@ at::Tensor PSROIAlign_backward_cpu( auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad.scalar_type(), "PSROIAlign_backward", [&] { - PSROIAlignBackwardCPU( + PSROIAlignBackward( grad.numel(), grad_.data_ptr(), channel_mapping.data_ptr(), diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index 265cff347d8..248225479d3 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -65,23 +65,23 @@ VISION_API at::Tensor PSROIPool_backward_cpu( VISION_API std::tuple PSROIAlign_forward_cpu( const at::Tensor& input, const at::Tensor& rois, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t sampling_ratio); + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio); VISION_API at::Tensor PSROIAlign_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t sampling_ratio, - const int64_t batch_size, - const int64_t channels, - const int64_t height, - const int64_t width); + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); VISION_API at::Tensor nms_cpu( const at::Tensor& dets, diff --git a/torchvision/csrc/cuda/PSROIAlign_cuda.cu b/torchvision/csrc/cuda/PSROIAlign_cuda.cu index 05e9982543a..b3e04b533df 100644 --- a/torchvision/csrc/cuda/PSROIAlign_cuda.cu +++ b/torchvision/csrc/cuda/PSROIAlign_cuda.cu @@ -10,11 +10,11 @@ template __device__ T bilinear_interpolate( const T* input, - const int height, - const int width, + int height, + int width, T y, T x, - const int index /* index for debug only*/) { + int index /* index for debug only*/) { // deal with cases that inverse elements are out of feature map boundary if (y < -1.0 || y > height || x < -1.0 || x > width) { // empty @@ -62,18 +62,18 @@ __device__ T bilinear_interpolate( } template -__global__ void PSROIAlignForwardCUDA( - const int nthreads, +__global__ void PSROIAlignForward( + int nthreads, const T* input, const T spatial_scale, - const int channels, - const int height, - const int width, - const int pooled_height, - const int pooled_width, - const int sampling_ratio, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int sampling_ratio, const T* rois, - const int channels_out, + int channels_out, T* output, int* channel_mapping) { CUDA_1D_KERNEL_LOOP(index, nthreads) { @@ -137,8 +137,8 @@ __global__ void PSROIAlignForwardCUDA( template __device__ void bilinear_interpolate_gradient( - const int height, - const int width, + int height, + int width, T y, T x, T& w1, @@ -149,7 +149,7 @@ __device__ void bilinear_interpolate_gradient( int& x_high, int& y_low, int& y_high, - const int index /* index for debug only*/) { + int index /* index for debug only*/) { // deal with cases that inverse elements are out of feature map boundary if (y < -1.0 || y > height || x < -1.0 || x > width) { // empty @@ -195,19 +195,19 @@ __device__ void bilinear_interpolate_gradient( } template -__global__ void PSROIAlignBackwardCUDA( - const int nthreads, +__global__ void PSROIAlignBackward( + int nthreads, const T* grad_output, const int* channel_mapping, - const int num_rois, + int num_rois, const T spatial_scale, - const int channels, - const int height, - const int width, - const int pooled_height, - const int pooled_width, - const int sampling_ratio, - const int channels_out, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int sampling_ratio, + int channels_out, T* grad_input, const T* rois) { CUDA_1D_KERNEL_LOOP(index, nthreads) { @@ -295,10 +295,10 @@ __global__ void PSROIAlignBackwardCUDA( std::tuple PSROIAlign_forward_cuda( const at::Tensor& input, const at::Tensor& rois, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t sampling_ratio) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio) { // Check if input tensors are CUDA tensors TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); @@ -345,7 +345,7 @@ std::tuple PSROIAlign_forward_cuda( rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( input.scalar_type(), "PSROIAlign_forward", [&] { - PSROIAlignForwardCUDA<<>>( + PSROIAlignForward<<>>( output_size, input_.data_ptr(), spatial_scale, @@ -369,14 +369,14 @@ at::Tensor PSROIAlign_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t sampling_ratio, - const int64_t batch_size, - const int64_t channels, - const int64_t height, - const int64_t width) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { // Check if input tensors are CUDA tensors TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor"); TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); @@ -416,7 +416,7 @@ at::Tensor PSROIAlign_backward_cuda( rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad.scalar_type(), "PSROIAlign_backward", [&] { - PSROIAlignBackwardCUDA<<>>( + PSROIAlignBackward<<>>( grad.numel(), grad_.data_ptr(), channel_mapping.data_ptr(), diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index 7eaef181dd3..00db2cacf90 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -65,23 +65,23 @@ VISION_API at::Tensor PSROIPool_backward_cuda( VISION_API std::tuple PSROIAlign_forward_cuda( const at::Tensor& input, const at::Tensor& rois, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t sampling_ratio); + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio); VISION_API at::Tensor PSROIAlign_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t sampling_ratio, - const int64_t batch_size, - const int64_t channels, - const int64_t height, - const int64_t width); + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); VISION_API at::Tensor nms_cuda( const at::Tensor& dets, From 5238a18d1220efdd08431fcbed8f903e42955075 Mon Sep 17 00:00:00 2001 From: Vasileios Vryniotis Date: Wed, 28 Oct 2020 20:00:52 +0000 Subject: [PATCH 5/6] Restoring names of internal methods to avoid conflicts. --- torchvision/csrc/cpu/PSROIAlign_cpu.cpp | 8 ++++---- torchvision/csrc/cuda/PSROIAlign_cuda.cu | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/torchvision/csrc/cpu/PSROIAlign_cpu.cpp b/torchvision/csrc/cpu/PSROIAlign_cpu.cpp index 25eec07fe78..899dbb208b6 100644 --- a/torchvision/csrc/cpu/PSROIAlign_cpu.cpp +++ b/torchvision/csrc/cpu/PSROIAlign_cpu.cpp @@ -57,7 +57,7 @@ T bilinear_interpolate( } template -void PSROIAlignForward( +void PSROIAlignForwardCPU( int nthreads, const T* input, const T spatial_scale, @@ -202,7 +202,7 @@ inline void add(T* address, const T& val) { } template -void PSROIAlignBackward( +void PSROIAlignBackwardCPU( int nthreads, const T* grad_output, const int* channel_mapping, @@ -339,7 +339,7 @@ std::tuple PSROIAlign_forward_cpu( auto input_ = input.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( input.scalar_type(), "PSROIAlign_forward", [&] { - PSROIAlignForward( + PSROIAlignForwardCPU( output_size, input_.data_ptr(), spatial_scale, @@ -396,7 +396,7 @@ at::Tensor PSROIAlign_backward_cpu( auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad.scalar_type(), "PSROIAlign_backward", [&] { - PSROIAlignBackward( + PSROIAlignBackwardCPU( grad.numel(), grad_.data_ptr(), channel_mapping.data_ptr(), diff --git a/torchvision/csrc/cuda/PSROIAlign_cuda.cu b/torchvision/csrc/cuda/PSROIAlign_cuda.cu index b3e04b533df..e6912d8c7ee 100644 --- a/torchvision/csrc/cuda/PSROIAlign_cuda.cu +++ b/torchvision/csrc/cuda/PSROIAlign_cuda.cu @@ -62,7 +62,7 @@ __device__ T bilinear_interpolate( } template -__global__ void PSROIAlignForward( +__global__ void PSROIAlignForwardCUDA( int nthreads, const T* input, const T spatial_scale, @@ -195,7 +195,7 @@ __device__ void bilinear_interpolate_gradient( } template -__global__ void PSROIAlignBackward( +__global__ void PSROIAlignBackwardCUDA( int nthreads, const T* grad_output, const int* channel_mapping, @@ -345,7 +345,7 @@ std::tuple PSROIAlign_forward_cuda( rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( input.scalar_type(), "PSROIAlign_forward", [&] { - PSROIAlignForward<<>>( + PSROIAlignForwardCUDA<<>>( output_size, input_.data_ptr(), spatial_scale, @@ -416,7 +416,7 @@ at::Tensor PSROIAlign_backward_cuda( rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad.scalar_type(), "PSROIAlign_backward", [&] { - PSROIAlignBackward<<>>( + PSROIAlignBackwardCUDA<<>>( grad.numel(), grad_.data_ptr(), channel_mapping.data_ptr(), From ea8a36cecaa48f6fe7ad8577754de04670620320 Mon Sep 17 00:00:00 2001 From: Vasileios Vryniotis Date: Fri, 30 Oct 2020 10:33:45 +0000 Subject: [PATCH 6/6] Restore include headers. --- torchvision/csrc/PSROIAlign.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchvision/csrc/PSROIAlign.h b/torchvision/csrc/PSROIAlign.h index 84803f76e69..92f4390a0f9 100644 --- a/torchvision/csrc/PSROIAlign.h +++ b/torchvision/csrc/PSROIAlign.h @@ -2,8 +2,13 @@ #include "cpu/vision_cpu.h" -#if defined(WITH_CUDA) || defined(WITH_HIP) +#ifdef WITH_CUDA +#include "autocast.h" +#include "cuda/vision_cuda.h" +#endif +#ifdef WITH_HIP #include "autocast.h" +#include "hip/vision_cuda.h" #endif #include