forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
AveragePool2d.cpp
216 lines (186 loc) · 6.7 KB
/
AveragePool2d.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/ScalarOps.h>
#include <ATen/native/Pool.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/avg_pool2d_backward_native.h>
#include <ATen/ops/avg_pool2d_native.h>
#endif
namespace at::meta {
using namespace ::at::native;
TORCH_PRECOMPUTE_META_FUNC(avg_pool2d)
(const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
bool ceil_mode,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
// #20866, #22032: Guarantee this for the official C++ API?
TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
"avg_pool2d: kernel_size must either be a single int, or a tuple of two ints");
const int64_t kH = kernel_size[0];
const int64_t kW = kernel_size.size() == 1 ? kH : kernel_size[1];
TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 2,
"avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints");
const int64_t dH = stride.empty() ? kH : stride[0];
const int64_t dW = stride.empty() ? kW : stride.size() == 1 ? dH : stride[1];
TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
"avg_pool2d: padding must either be a single int, or a tuple of two ints");
const int64_t padH = padding[0];
const int64_t padW = padding.size() == 1 ? padH : padding[1];
TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0,
"divisor must be not zero");
const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1;
const int64_t nInputPlane = input.size(-3);
const int64_t inputHeight = input.size(-2);
const int64_t inputWidth = input.size(-1);
const int64_t outputHeight = pooling_output_shape<int64_t>(
inputHeight, kH, padH, dH, 1, ceil_mode);
const int64_t outputWidth =
pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
auto memory_format = input.suggest_memory_format();
pool2d_shape_check(
input,
kH,
kW,
dH,
dW,
padH,
padW,
1,
1,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
memory_format);
/* resize output */
if (input.ndimension() == 3) {
set_output_raw_strided(
0,
{nInputPlane,
outputHeight,
outputWidth},
{},
input.options());
}
else {
set_output_raw_strided(
0,
{nbatch,
nInputPlane,
outputHeight,
outputWidth},
{},
input.options().memory_format(memory_format));
}
return TORCH_PRECOMPUTE_STRUCT(avg_pool2d)().set_kH(kH).set_kW(kW).set_dH(dH).set_dW(dW).set_padH(padH).set_padW(padW);
}
TORCH_META_FUNC(avg_pool2d_backward) (
const Tensor& gradOutput_,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
bool ceil_mode,
bool count_include_pad,
c10::optional<int64_t> divisor_override
) {
// #20866, #22032: Guarantee this for the official C++ API?
TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
"avg_pool2d: kernel_size must either be a single int, or a tuple of two ints");
const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 2,
"avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints");
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
const int dW = stride.empty() ? kW :
stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
"avg_pool2d: padding must either be a single int, or a tuple of two ints");
const int padH = safe_downcast<int, int64_t>(padding[0]);
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0, "divisor must be not zero");
/* sizes */
const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1;
const int64_t nInputPlane = input.size(-3); // number of channels (or colors)
const int64_t inputHeight = input.size(-2);
const int64_t inputWidth = input.size(-1);
const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, 1, ceil_mode);
auto memory_format = input.suggest_memory_format();
avg_pool2d_backward_shape_check(
input,
gradOutput_,
nbatch,
kH, kW, dH, dW, padH, padW,
nInputPlane,
inputHeight, inputWidth,
outputHeight, outputWidth,
memory_format);
/* resize output */
set_output_raw_strided(0, input.sizes(), {}, input.options().memory_format(memory_format));
}
} // namespace at::meta
namespace at::native {
TORCH_IMPL_FUNC(avg_pool2d_out_cpu)
(const Tensor& input,
int64_t kH,
int64_t kW,
int64_t dH,
int64_t dW,
int64_t padH,
int64_t padW,
bool ceil_mode,
bool count_include_pad,
c10::optional<int64_t> divisor_override,
const Tensor& output) {
avg_pool2d_kernel(
kCPU,
output,
input,
kW,
kH,
dW,
dH,
padW,
padH,
count_include_pad,
divisor_override);
}
TORCH_IMPL_FUNC(avg_pool2d_backward_out_cpu) (
const Tensor& gradOutput,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
bool ceil_mode,
bool count_include_pad,
c10::optional<int64_t> divisor_override,
const Tensor& gradInput
) {
const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
const int dW = stride.empty() ? kW :
stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
const int padH = safe_downcast<int, int64_t>(padding[0]);
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0, "divisor must be not zero");
TORCH_CHECK(input.dtype() == gradOutput.dtype(),
"expected dtype ", input.dtype(), " for `gradOutput` but got dtype ", gradOutput.dtype());
/* zero the gradient */
gradInput.zero_();
avg_pool2d_backward_kernel(
kCPU, gradInput, gradOutput,
kW, kH, dW, dH, padW, padH,
count_include_pad, divisor_override);
}
DEFINE_DISPATCH(avg_pool2d_kernel);
DEFINE_DISPATCH(avg_pool2d_backward_kernel);
} // namespace at::native