forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathConvUtils.h
217 lines (191 loc) · 7.91 KB
/
ConvUtils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
#pragma once
#include <ATen/detail/CUDAHooksInterface.h>
#include <c10/util/env.h>
#include <c10/util/irange.h>
namespace at { namespace native {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct ConvParams {
std::vector<int64_t> stride;
std::vector<int64_t> padding;
std::vector<int64_t> dilation;
bool transposed;
std::vector<int64_t> output_padding;
int groups;
bool benchmark;
bool deterministic;
bool cudnn_enabled;
bool allow_tf32;
bool is_strided() const;
bool is_dilated() const;
bool is_padded() const;
bool is_output_padding_neg() const;
bool is_output_padding_big() const;
bool is_padding_neg() const;
bool is_stride_nonpos() const;
void view1d_as_2d();
bool use_cpu_depthwise3x3_winograd(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias) const;
bool needs_64bit_indexing_no_split(const at::Tensor& input, const at::Tensor& weight) const;
bool use_cudnn(const at::Tensor& input, const at::Tensor& weight) const;
bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const;
bool use_miopen(const at::Tensor& input, const at::Tensor& weight, bool bias_defined) const;
bool use_mkldnn(const at::Tensor& input, const at::Tensor& weight) const;
bool use_nnpack(const at::Tensor& input, const at::Tensor& weight) const;
bool use_xnnpack(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias) const;
bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const;
};
enum class ConvBackend {
CudaDepthwise2d,
CudaDepthwise3d,
Cudnn,
CudnnTranspose,
Empty,
Miopen,
MiopenDepthwise,
MiopenTranspose,
Mkldnn,
MkldnnEmpty,
NnpackSpatial,
Overrideable,
Slow2d,
Slow3d,
SlowDilated2d,
SlowDilated3d,
SlowTranspose2d,
SlowTranspose3d,
Winograd3x3Depthwise,
Xnnpack2d
};
// Function to select the convolution backend based on the inputs and params.
// This overload is used within the convolution internals but not exposed to python.
TORCH_API ConvBackend select_conv_backend(
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
const ConvParams& params);
// Overload for selecting the convolution backend from the full set of convolution inputs.
// This overload is exposed to python for testing, etc.
TORCH_API ConvBackend select_conv_backend(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
bool transposed, IntArrayRef output_padding, int64_t groups);
// ---------------------------------------------------------------------
//
// Math
//
// ---------------------------------------------------------------------
constexpr int input_batch_size_dim = 0; // also grad_input
constexpr int input_channels_dim = 1;
constexpr int output_batch_size_dim = 0; // also grad_output
constexpr int output_channels_dim = 1;
constexpr int weight_output_channels_dim = 0;
constexpr int weight_input_channels_dim = 1;
// Often written as 2 + max_dim (extra dims for batch size and channels)
constexpr int max_dim = 3;
// NB: conv_output_size and conv_input_size are not bijections,
// as conv_output_size loses information; this is why conv_input_size
// takes an extra output_padding argument to resolve the ambiguity.
static inline std::vector<int64_t> conv_output_size(
IntArrayRef input_size, IntArrayRef weight_size,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef()
) {
// ASSERT(input_size.size() > 2)
// ASSERT(input_size.size() == weight_size.size())
bool has_dilation = dilation.size() > 0;
auto dim = input_size.size();
std::vector<int64_t> output_size(dim);
output_size[0] = input_size[input_batch_size_dim];
output_size[1] = weight_size[weight_output_channels_dim];
for (const auto d : c10::irange(2, dim)) {
auto dilation_ = has_dilation ? dilation[d - 2] : 1;
auto kernel = dilation_ * (weight_size[d] - 1) + 1;
output_size[d] = (input_size[d] + (2 * padding[d - 2]) - kernel) / stride[d - 2] + 1;
}
return output_size;
}
static inline std::vector<int64_t> conv_input_size(
IntArrayRef output_size, IntArrayRef weight_size,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
) {
// ASSERT(output_size.size() > 2)
// ASSERT(output_size.size() == weight_size.size())
auto dim = output_size.size();
std::vector<int64_t> input_size(dim);
input_size[0] = output_size[output_batch_size_dim];
input_size[1] = weight_size[weight_input_channels_dim] * groups;
for (const auto d : c10::irange(2, dim)) {
int kernel = dilation[d - 2] * (weight_size[d] - 1) + 1;
input_size[d] = (output_size[d] - 1) * stride[d - 2] - (2 * padding[d - 2]) +
kernel + output_padding[d - 2];
}
return input_size;
}
static inline std::vector<int64_t> conv_weight_size(
IntArrayRef input_size, IntArrayRef output_size,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
) {
auto dim = input_size.size();
std::vector<int64_t> weight_size(dim);
weight_size[0] = output_size[1];
weight_size[1] = input_size[1] / groups;
for (const auto d : c10::irange(2, dim)) {
int kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2]
+ 2 * padding[d - 2] - output_padding[d - 2];
weight_size[d] = (kernel - 1) / dilation[d - 2] + 1;
}
return weight_size;
}
static inline Tensor reshape_bias(int64_t dim, const Tensor& bias) {
std::vector<int64_t> shape(dim, 1);
shape[1] = -1;
return bias.reshape(shape);
}
static inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) {
// disable NHWC for float64 input.
if (!at::detail::getCUDAHooks().compiledWithCuDNN() ||
input.scalar_type() == at::kDouble ||
weight.scalar_type() == at::kDouble) {
return at::MemoryFormat::Contiguous;
}
long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
auto input_memory_format = input.suggest_memory_format();
auto weight_memory_format = weight.suggest_memory_format();
auto weight_ndim = weight.ndimension();
bool can_use_cudnn_channels_last_2d = (cudnn_version >= 7603) && (weight_ndim == 4) && (
(input_memory_format == at::MemoryFormat::ChannelsLast) ||
(weight_memory_format == at::MemoryFormat::ChannelsLast)
);
if (can_use_cudnn_channels_last_2d) {
return at::MemoryFormat::ChannelsLast;
}
bool can_use_cudnn_channels_last_3d = (cudnn_version >= 8005) && (weight_ndim == 5) && (
(input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
(weight_memory_format == at::MemoryFormat::ChannelsLast3d)
);
if (can_use_cudnn_channels_last_3d) {
return at::MemoryFormat::ChannelsLast3d;
}
return at::MemoryFormat::Contiguous;
}
static inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
// disable NHWC for float64 input.
if (!at::detail::getCUDAHooks().compiledWithMIOpen() ||
input.scalar_type() == at::kDouble ||
weight.scalar_type() == at::kDouble) {
return false;
}
bool can_use_miopen_channels_last_2d = false;
#if defined(USE_ROCM) && (ROCM_VERSION >= 40300)
// TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
// See #64427
static c10::optional<bool> PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC");
auto input_memory_format = input.suggest_memory_format();
auto weight_memory_format = weight.suggest_memory_format();
can_use_miopen_channels_last_2d = PYTORCH_MIOPEN_SUGGEST_NHWC && *PYTORCH_MIOPEN_SUGGEST_NHWC && (
( (input_memory_format == at::MemoryFormat::ChannelsLast) ||
(weight_memory_format == at::MemoryFormat::ChannelsLast) )
);
#endif
bool can_use_miopen_channels_last_3d = false;
return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
}
}} // namespace at::native