Skip to content

Commit

Permalink
[fbsync] Minor cleanup of roi_align_forward_kernel_impl (#3619)
Browse files Browse the repository at this point in the history
Summary:
* minor clean up

* do same for ps_roialign

Reviewed By: NicolasHug

Differential Revision: D27706957

fbshipit-source-id: 3320466f6a8b12445f4c901460d3b6f39e6760ea

Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
  • Loading branch information
2 people authored and facebook-github-bot committed Apr 13, 2021
1 parent 7a3176a commit 250d4ab
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 10 deletions.
8 changes: 3 additions & 5 deletions torchvision/csrc/ops/cpu/ps_roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ T bilinear_interpolate(

template <typename T>
void ps_roi_align_forward_kernel_impl(
int nthreads,
int num_rois,
const T* input,
const T spatial_scale,
int channels,
Expand All @@ -75,7 +75,6 @@ void ps_roi_align_forward_kernel_impl(
int channels_out,
T* output,
int* channel_mapping) {
int num_rois = nthreads / channels_out / pooled_width / pooled_height;
for (int n = 0; n < num_rois; n++) {
// [start, end) interval for spatial sampling
const T* offset_rois = rois + n * 5;
Expand Down Expand Up @@ -335,16 +334,15 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_kernel(
auto channel_mapping =
at::zeros(output.sizes(), input.options().dtype(at::kInt));

auto output_size = output.numel();
if (output_size == 0) {
if (output.numel() == 0) {
return std::make_tuple(output, channel_mapping);
}

auto input_ = input.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "ps_roi_align_forward_kernel", [&] {
ps_roi_align_forward_kernel_impl<scalar_t>(
output_size,
num_rois,
input_.data_ptr<scalar_t>(),
spatial_scale,
channels,
Expand Down
7 changes: 2 additions & 5 deletions torchvision/csrc/ops/cpu/roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ void pre_calc_for_bilinear_interpolate(

template <typename T>
void roi_align_forward_kernel_impl(
int nthreads,
int n_rois,
const T* input,
const T& spatial_scale,
int channels,
Expand All @@ -129,7 +129,6 @@ void roi_align_forward_kernel_impl(
bool aligned,
const T* rois,
T* output) {
int n_rois = nthreads / channels / pooled_width / pooled_height;
// (n, c, ph, pw) is an element in the pooled output
// can be parallelized using omp
// #pragma omp parallel for num_threads(32)
Expand Down Expand Up @@ -414,16 +413,14 @@ at::Tensor roi_align_forward_kernel(
at::Tensor output = at::zeros(
{num_rois, channels, pooled_height, pooled_width}, input.options());

auto output_size = num_rois * pooled_height * pooled_width * channels;

if (output.numel() == 0)
return output;

auto input_ = input.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "roi_align_forward_kernel", [&] {
roi_align_forward_kernel_impl<scalar_t>(
output_size,
num_rois,
input_.data_ptr<scalar_t>(),
spatial_scale,
channels,
Expand Down

0 comments on commit 250d4ab

Please sign in to comment.