diff --git a/torchvision/csrc/ROIPool.h b/torchvision/csrc/ROIPool.h index 38748c7f57b..7950005f1bd 100644 --- a/torchvision/csrc/ROIPool.h +++ b/torchvision/csrc/ROIPool.h @@ -3,59 +3,64 @@ #include "cpu/vision_cpu.h" #ifdef WITH_CUDA +#include "autocast.h" #include "cuda/vision_cuda.h" #endif #ifdef WITH_HIP +#include "autocast.h" #include "hip/vision_cuda.h" #endif -std::tuple ROIPool_forward( +// TODO: put this stuff in torchvision namespace + +std::tuple roi_pool( const at::Tensor& input, const at::Tensor& rois, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width) { - if (input.is_cuda()) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::roi_pool", "") + .typed(); + return op.call(input, rois, spatial_scale, pooled_height, pooled_width); +} + #if defined(WITH_CUDA) || defined(WITH_HIP) - return ROIPool_forward_cuda( - input, rois, spatial_scale, pooled_height, pooled_width); -#else - TORCH_CHECK(false, "Not compiled with GPU support"); -#endif - } - return ROIPool_forward_cpu( - input, rois, spatial_scale, pooled_height, pooled_width); +std::tuple ROIPool_autocast( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + auto result = roi_pool( + at::autocast::cached_cast(at::kFloat, input), + at::autocast::cached_cast(at::kFloat, rois), + spatial_scale, + pooled_height, + pooled_width); + + return std::make_tuple( + std::get<0>(result).to(input.scalar_type()), + std::get<1>(result).to(input.scalar_type())); } +#endif -at::Tensor ROIPool_backward( +at::Tensor _roi_pool_backward( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& argmax, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width) { - if (grad.is_cuda()) { -#if defined(WITH_CUDA) || defined(WITH_HIP) - return ROIPool_backward_cuda( - grad, - rois, - argmax, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width); -#else - TORCH_CHECK(false, "Not compiled with GPU support"); -#endif - } - return ROIPool_backward_cpu( + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_roi_pool_backward", "") + .typed(); + return op.call( grad, rois, argmax, @@ -72,33 +77,36 @@ class ROIPoolFunction : public torch::autograd::Function { public: static torch::autograd::variable_list forward( torch::autograd::AutogradContext* ctx, - torch::autograd::Variable input, - torch::autograd::Variable rois, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width) { + const torch::autograd::Variable& input, + const torch::autograd::Variable& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { ctx->saved_data["spatial_scale"] = spatial_scale; ctx->saved_data["pooled_height"] = pooled_height; ctx->saved_data["pooled_width"] = pooled_width; ctx->saved_data["input_shape"] = input.sizes(); - auto result = ROIPool_forward( - input, rois, spatial_scale, pooled_height, pooled_width); + at::AutoNonVariableTypeMode g; + auto result = + roi_pool(input, rois, spatial_scale, pooled_height, pooled_width); + auto output = std::get<0>(result); auto argmax = std::get<1>(result); ctx->save_for_backward({rois, argmax}); ctx->mark_non_differentiable({argmax}); + return {output, argmax}; } static torch::autograd::variable_list backward( torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_output) { + const torch::autograd::variable_list& grad_output) { // Use data saved in forward auto saved = ctx->get_saved_variables(); auto rois = saved[0]; auto argmax = saved[1]; auto input_shape = ctx->saved_data["input_shape"].toIntList(); - auto grad_in = ROIPool_backward( + auto grad_in = _roi_pool_backward( grad_output[0], rois, argmax, @@ -109,6 +117,7 @@ class ROIPoolFunction : public torch::autograd::Function { input_shape[1], input_shape[2], input_shape[3]); + return {grad_in, torch::autograd::Variable(), torch::autograd::Variable(), @@ -117,13 +126,77 @@ class ROIPoolFunction : public torch::autograd::Function { } }; -std::tuple roi_pool( +// TODO: There should be an easier way to do this +class ROIPoolBackwardFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& grad, + const torch::autograd::Variable& rois, + const torch::autograd::Variable& argmax, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + at::AutoNonVariableTypeMode g; + auto grad_in = _roi_pool_backward( + grad, + rois, + argmax, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); + + return {grad_in}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + TORCH_CHECK(0, "double backwards on roi_pool not supported"); + } +}; + +std::tuple ROIPool_autograd( const at::Tensor& input, const at::Tensor& rois, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { auto result = ROIPoolFunction::apply( input, rois, spatial_scale, pooled_height, pooled_width); - return std::tuple(result[0], result[1]); + + return std::make_tuple(result[0], result[1]); +} + +at::Tensor ROIPool_backward_autograd( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + return ROIPoolBackwardFunction::apply( + grad, + rois, + argmax, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width)[0]; } diff --git a/torchvision/csrc/cpu/ROIPool_cpu.cpp b/torchvision/csrc/cpu/ROIPool_cpu.cpp index b13f1de6646..34da4f1d1cc 100644 --- a/torchvision/csrc/cpu/ROIPool_cpu.cpp +++ b/torchvision/csrc/cpu/ROIPool_cpu.cpp @@ -12,13 +12,13 @@ template void RoIPoolForward( const T* input, const T spatial_scale, - const int channels, - const int height, - const int width, - const int pooled_height, - const int pooled_width, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, const T* rois, - const int num_rois, + int num_rois, T* output, int* argmax_data) { for (int n = 0; n < num_rois; ++n) { @@ -81,18 +81,18 @@ template void RoIPoolBackward( const T* grad_output, const int* argmax_data, - const int num_rois, - const int channels, - const int height, - const int width, - const int pooled_height, - const int pooled_width, + int num_rois, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, T* grad_input, const T* rois, - const int n_stride, - const int c_stride, - const int h_stride, - const int w_stride) { + int n_stride, + int c_stride, + int h_stride, + int w_stride) { for (int n = 0; n < num_rois; ++n) { const T* offset_rois = rois + n * 5; int roi_batch_ind = offset_rois[0]; @@ -123,9 +123,9 @@ void RoIPoolBackward( std::tuple ROIPool_forward_cpu( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); @@ -172,13 +172,13 @@ at::Tensor ROIPool_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& argmax, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { // Check if input tensors are CPU tensors TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index 69b1bbf555d..4f887202d39 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -5,21 +5,21 @@ VISION_API std::tuple ROIPool_forward_cpu( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width); + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width); VISION_API at::Tensor ROIPool_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& argmax, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width); + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); VISION_API at::Tensor ROIAlign_forward_cpu( const at::Tensor& input, diff --git a/torchvision/csrc/cuda/ROIPool_cuda.cu b/torchvision/csrc/cuda/ROIPool_cuda.cu index a35dabbeb39..3131b9eea7e 100644 --- a/torchvision/csrc/cuda/ROIPool_cuda.cu +++ b/torchvision/csrc/cuda/ROIPool_cuda.cu @@ -8,14 +8,14 @@ template __global__ void RoIPoolForward( - const int nthreads, + int nthreads, const T* input, const T spatial_scale, - const int channels, - const int height, - const int width, - const int pooled_height, - const int pooled_width, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, const T* rois, T* output, int* argmax_data) { @@ -73,22 +73,22 @@ __global__ void RoIPoolForward( template __global__ void RoIPoolBackward( - const int nthreads, + int nthreads, const T* grad_output, const int* argmax_data, - const int num_rois, + int num_rois, const T spatial_scale, - const int channels, - const int height, - const int width, - const int pooled_height, - const int pooled_width, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, T* grad_input, const T* rois, - const int n_stride, - const int c_stride, - const int h_stride, - const int w_stride) { + int n_stride, + int c_stride, + int h_stride, + int w_stride) { CUDA_1D_KERNEL_LOOP(index, nthreads) { // (n, c, ph, pw) is an element in the pooled output int pw = index % pooled_width; @@ -118,9 +118,9 @@ __global__ void RoIPoolBackward( std::tuple ROIPool_forward_cuda( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); TORCH_CHECK( @@ -182,13 +182,13 @@ at::Tensor ROIPool_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& argmax, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { // Check if input tensors are CUDA tensors TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor"); TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index 2481cfc63c2..f25aac11fdd 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -5,43 +5,43 @@ VISION_API at::Tensor ROIAlign_forward_cuda( const at::Tensor& input, const at::Tensor& rois, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t sampling_ratio, - const bool aligned); + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned); VISION_API at::Tensor ROIAlign_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t batch_size, - const int64_t channels, - const int64_t height, - const int64_t width, - const int64_t sampling_ratio, - const bool aligned); + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned); VISION_API std::tuple ROIPool_forward_cuda( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width); + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width); VISION_API at::Tensor ROIPool_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& argmax, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width); + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width); VISION_API std::tuple PSROIPool_forward_cuda( const at::Tensor& input, diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index f56a671d6e5..eea1cf2ec9c 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -50,7 +50,10 @@ TORCH_LIBRARY(torchvision, m) { "roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor"); m.def( "_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor"); - m.def("roi_pool", &roi_pool); + m.def( + "roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"); + m.def( + "_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"); m.def("_new_empty_tensor_op", &new_empty_tensor); m.def("ps_roi_align", &ps_roi_align); m.def("ps_roi_pool", &ps_roi_pool); @@ -64,6 +67,8 @@ TORCH_LIBRARY(torchvision, m) { TORCH_LIBRARY_IMPL(torchvision, CPU, m) { m.impl("roi_align", ROIAlign_forward_cpu); m.impl("_roi_align_backward", ROIAlign_backward_cpu); + m.impl("roi_pool", ROIPool_forward_cpu); + m.impl("_roi_pool_backward", ROIPool_backward_cpu); m.impl("deform_conv2d", DeformConv2d_forward_cpu); m.impl("_deform_conv2d_backward", DeformConv2d_backward_cpu); m.impl("nms", nms_cpu); @@ -74,6 +79,8 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) { TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { m.impl("roi_align", ROIAlign_forward_cuda); m.impl("_roi_align_backward", ROIAlign_backward_cuda); + m.impl("roi_pool", ROIPool_forward_cuda); + m.impl("_roi_pool_backward", ROIPool_backward_cuda); m.impl("deform_conv2d", DeformConv2d_forward_cuda); m.impl("_deform_conv2d_backward", DeformConv2d_backward_cuda); m.impl("nms", nms_cuda); @@ -84,6 +91,7 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { #if defined(WITH_CUDA) || defined(WITH_HIP) TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { m.impl("roi_align", ROIAlign_autocast); + m.impl("roi_pool", ROIPool_autocast); m.impl("deform_conv2d", DeformConv2d_autocast); m.impl("nms", nms_autocast); } @@ -92,6 +100,8 @@ TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { m.impl("roi_align", ROIAlign_autograd); m.impl("_roi_align_backward", ROIAlign_backward_autograd); + m.impl("roi_pool", ROIPool_autograd); + m.impl("_roi_pool_backward", ROIPool_backward_autograd); m.impl("deform_conv2d", DeformConv2d_autograd); m.impl("_deform_conv2d_backward", DeformConv2d_backward_autograd); }