-
Notifications
You must be signed in to change notification settings - Fork 395
/
roi_pooling_layer.cu
188 lines (163 loc) · 7.45 KB
/
roi_pooling_layer.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
// ------------------------------------------------------------------
// Fast R-CNN
// Copyright (c) 2015 Microsoft
// Licensed under The MIT License [see fast-rcnn/LICENSE for details]
// Written by Ross Girshick
// ------------------------------------------------------------------
#include <cfloat>
#include "caffe/fast_rcnn_layers.hpp"
using std::max;
using std::min;
namespace caffe {
template <typename Dtype>
__global__ void ROIPoolForward(const int nthreads, const Dtype* bottom_data,
const Dtype spatial_scale, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const Dtype* bottom_rois, Dtype* top_data, int* argmax_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
bottom_rois += n * 5;
int roi_batch_ind = bottom_rois[0];
int roi_start_w = round(bottom_rois[1] * spatial_scale);
int roi_start_h = round(bottom_rois[2] * spatial_scale);
int roi_end_w = round(bottom_rois[3] * spatial_scale);
int roi_end_h = round(bottom_rois[4] * spatial_scale);
// Force malformed ROIs to be 1x1
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
Dtype bin_size_h = static_cast<Dtype>(roi_height)
/ static_cast<Dtype>(pooled_height);
Dtype bin_size_w = static_cast<Dtype>(roi_width)
/ static_cast<Dtype>(pooled_width);
int hstart = static_cast<int>(floor(static_cast<Dtype>(ph)
* bin_size_h));
int wstart = static_cast<int>(floor(static_cast<Dtype>(pw)
* bin_size_w));
int hend = static_cast<int>(ceil(static_cast<Dtype>(ph + 1)
* bin_size_h));
int wend = static_cast<int>(ceil(static_cast<Dtype>(pw + 1)
* bin_size_w));
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart + roi_start_h, 0), height);
hend = min(max(hend + roi_start_h, 0), height);
wstart = min(max(wstart + roi_start_w, 0), width);
wend = min(max(wend + roi_start_w, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
// Define an empty pooling region to be zero
Dtype maxval = is_empty ? 0 : -FLT_MAX;
// If nothing is pooled, argmax = -1 causes nothing to be backprop'd
int maxidx = -1;
bottom_data += (roi_batch_ind * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int bottom_index = h * width + w;
if (bottom_data[bottom_index] > maxval) {
maxval = bottom_data[bottom_index];
maxidx = bottom_index;
}
}
}
top_data[index] = maxval;
argmax_data[index] = maxidx;
}
}
template <typename Dtype>
void ROIPoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
const Dtype* bottom_rois = bottom[1]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
int* argmax_data = max_idx_.mutable_gpu_data();
int count = top[0]->count();
// NOLINT_NEXT_LINE(whitespace/operators)
ROIPoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, spatial_scale_, channels_, height_, width_,
pooled_height_, pooled_width_, bottom_rois, top_data, argmax_data);
CUDA_POST_KERNEL_CHECK;
}
template <typename Dtype>
__global__ void ROIPoolBackward(const int nthreads, const Dtype* top_diff,
const int* argmax_data, const int num_rois, const Dtype spatial_scale,
const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, Dtype* bottom_diff,
const Dtype* bottom_rois) {
CUDA_KERNEL_LOOP(index, nthreads) {
// (n, c, h, w) coords in bottom data
int w = index % width;
int h = (index / width) % height;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
Dtype gradient = 0;
// Accumulate gradient over all ROIs that pooled this element
for (int roi_n = 0; roi_n < num_rois; ++roi_n) {
const Dtype* offset_bottom_rois = bottom_rois + roi_n * 5;
int roi_batch_ind = offset_bottom_rois[0];
// Skip if ROI's batch index doesn't match n
if (n != roi_batch_ind) {
continue;
}
int roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
int roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
int roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
int roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
// Skip if ROI doesn't include (h, w)
const bool in_roi = (w >= roi_start_w && w <= roi_end_w &&
h >= roi_start_h && h <= roi_end_h);
if (!in_roi) {
continue;
}
int offset = (roi_n * channels + c) * pooled_height * pooled_width;
const Dtype* offset_top_diff = top_diff + offset;
const int* offset_argmax_data = argmax_data + offset;
// Compute feasible set of pooled units that could have pooled
// this bottom unit
// Force malformed ROIs to be 1x1
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
Dtype bin_size_h = static_cast<Dtype>(roi_height)
/ static_cast<Dtype>(pooled_height);
Dtype bin_size_w = static_cast<Dtype>(roi_width)
/ static_cast<Dtype>(pooled_width);
int phstart = floor(static_cast<Dtype>(h - roi_start_h) / bin_size_h);
int phend = ceil(static_cast<Dtype>(h - roi_start_h + 1) / bin_size_h);
int pwstart = floor(static_cast<Dtype>(w - roi_start_w) / bin_size_w);
int pwend = ceil(static_cast<Dtype>(w - roi_start_w + 1) / bin_size_w);
phstart = min(max(phstart, 0), pooled_height);
phend = min(max(phend, 0), pooled_height);
pwstart = min(max(pwstart, 0), pooled_width);
pwend = min(max(pwend, 0), pooled_width);
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
if (offset_argmax_data[ph * pooled_width + pw] == (h * width + w)) {
gradient += offset_top_diff[ph * pooled_width + pw];
}
}
}
}
bottom_diff[index] = gradient;
}
}
template <typename Dtype>
void ROIPoolingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
if (!propagate_down[0]) {
return;
}
const Dtype* bottom_rois = bottom[1]->gpu_data();
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
const int count = bottom[0]->count();
caffe_gpu_set(count, Dtype(0.), bottom_diff);
const int* argmax_data = max_idx_.gpu_data();
// NOLINT_NEXT_LINE(whitespace/operators)
ROIPoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, argmax_data, top[0]->num(), spatial_scale_, channels_,
height_, width_, pooled_height_, pooled_width_, bottom_diff, bottom_rois);
CUDA_POST_KERNEL_CHECK;
}
INSTANTIATE_LAYER_GPU_FUNCS(ROIPoolingLayer);
} // namespace caffe