Skip to content

Commit

Permalink
Merge 2c496b6 into 375605f
Browse files Browse the repository at this point in the history
  • Loading branch information
SFMDI authored Apr 10, 2021
2 parents 375605f + 2c496b6 commit d0cb25e
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 30 deletions.
12 changes: 6 additions & 6 deletions mmcv/ops/csrc/deform_conv_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ __device__ T deformable_im2col_bilinear(const T *input, const int data_width,
return 0;
}

int h_low = floor(h);
int w_low = floor(w);
int h_low = floorf(h);
int w_low = floorf(w);
int h_high = h_low + 1;
int w_high = w_low + 1;

Expand Down Expand Up @@ -122,8 +122,8 @@ __device__ T get_gradient_weight(T argmax_h, T argmax_w, const int h,
return 0;
}

int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_low = floorf(argmax_h);
int argmax_w_low = floorf(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;

Expand All @@ -149,8 +149,8 @@ __device__ T get_coordinate_weight(T argmax_h, T argmax_w, const int height,
return 0;
}

int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_low = floorf(argmax_h);
int argmax_w_low = floorf(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;

Expand Down
18 changes: 10 additions & 8 deletions mmcv/ops/csrc/deform_roi_pool_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ __global__ void deform_roi_pool_forward_cuda_kernel(
int roi_bin_grid_h =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceil(roi_height / pooled_height));
int roi_bin_grid_w = (sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceil(roi_width / pooled_width));
: static_cast<int>(ceilf(roi_height / pooled_height));
int roi_bin_grid_w =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_width / pooled_width));

// Compute roi offset
if (offset != NULL) {
Expand Down Expand Up @@ -113,10 +114,11 @@ __global__ void deform_roi_pool_backward_cuda_kernel(
int roi_bin_grid_h =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceil(roi_height / pooled_height));
int roi_bin_grid_w = (sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceil(roi_width / pooled_width));
: static_cast<int>(ceilf(roi_height / pooled_height));
int roi_bin_grid_w =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_width / pooled_width));

// Compute roi offset
if (offset != NULL) {
Expand Down
12 changes: 6 additions & 6 deletions mmcv/ops/csrc/modulated_deform_conv_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@
template <typename T>
__device__ T dmcn_im2col_bilinear(const T *input, const int data_width,
const int height, const int width, T h, T w) {
int h_low = floor(h);
int w_low = floor(w);
int h_low = floorf(h);
int w_low = floorf(w);
int h_high = h_low + 1;
int w_high = w_low + 1;

Expand Down Expand Up @@ -112,8 +112,8 @@ __device__ T dmcn_get_gradient_weight(T argmax_h, T argmax_w, const int h,
return 0;
}

int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_low = floorf(argmax_h);
int argmax_w_low = floorf(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;

Expand All @@ -140,8 +140,8 @@ __device__ T dmcn_get_coordinate_weight(T argmax_h, T argmax_w,
return 0;
}

int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_low = floorf(argmax_h);
int argmax_w_low = floorf(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;

Expand Down
13 changes: 7 additions & 6 deletions mmcv/ops/csrc/roi_align_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ __global__ void roi_align_forward_cuda_kernel(
int roi_bin_grid_h =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceil(roi_height / pooled_height));
int roi_bin_grid_w = (sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceil(roi_width / pooled_width));
: static_cast<int>(ceilf(roi_height / pooled_height));
int roi_bin_grid_w =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_width / pooled_width));

if (pool_mode == 0) {
// We do max pooling inside a bin
Expand Down Expand Up @@ -168,11 +169,11 @@ __global__ void roi_align_backward_cuda_kernel(
int roi_bin_grid_h =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceil(roi_height / pooled_height));
: static_cast<int>(ceilf(roi_height / pooled_height));
int roi_bin_grid_w =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceil(roi_width / pooled_width));
: static_cast<int>(ceilf(roi_width / pooled_width));

// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
Expand Down
8 changes: 4 additions & 4 deletions mmcv/ops/csrc/roi_pool_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ __global__ void roi_pool_forward_cuda_kernel(
T bin_size_h = roi_h / static_cast<T>(pooled_height);

// the corresponding bin region
int bin_x1 = floor(static_cast<T>(pw) * bin_size_w + roi_x1);
int bin_y1 = floor(static_cast<T>(ph) * bin_size_h + roi_y1);
int bin_x2 = ceil(static_cast<T>(pw + 1) * bin_size_w + roi_x1);
int bin_y2 = ceil(static_cast<T>(ph + 1) * bin_size_h + roi_y1);
int bin_x1 = floorf(static_cast<T>(pw) * bin_size_w + roi_x1);
int bin_y1 = floorf(static_cast<T>(ph) * bin_size_h + roi_y1);
int bin_x2 = ceilf(static_cast<T>(pw + 1) * bin_size_w + roi_x1);
int bin_y2 = ceilf(static_cast<T>(ph + 1) * bin_size_h + roi_y1);

// add roi offsets and clip to input boundaries
bin_x1 = min(max(bin_x1, 0), width);
Expand Down

0 comments on commit d0cb25e

Please sign in to comment.