forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Matmul.cpp
268 lines (237 loc) · 9.48 KB
/
Matmul.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Config.h>
#include <ATen/Context.h>
#include <ATen/native/mkldnn/Matmul.h>
#if !AT_MKLDNN_ENABLED()
namespace at {
namespace native {
void mkldnn_matmul(
const Tensor &mat1,
const Tensor &mat2,
const Tensor &result,
float beta,
float alpha) {
TORCH_CHECK(false, "mkldnn_matmul: ATen not compiled with MKLDNN support");
}
bool use_mkldnn_bf16_matmul(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result_opt){
return false;
}
bool mkldnn_bf16_gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
float alpha,
const c10::BFloat16 *a, int64_t lda,
const c10::BFloat16 *b, int64_t ldb,
float beta,
c10::BFloat16 *c, int64_t ldc) {
return false;
}
} // namespace native
} // namespace at
#else // AT_MKLDNN_ENABLED
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/native/mkldnn/Utils.h>
namespace at {
namespace native {
static bool use_mkldnn_bf16_matmul() {
return (
at::globalContext().userEnabledMkldnn() &&
mkldnn_bf16_device_check());
}
bool mkldnn_bf16_gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
float alpha,
const c10::BFloat16 *a_data, int64_t lda,
const c10::BFloat16 *b_data, int64_t ldb,
float beta,
c10::BFloat16 *c_data, int64_t ldc) {
if (!use_mkldnn_bf16_matmul() ||
(m * n * k <= 16 * 16 * 16) ||
(alpha == 0.0f)) {
return false;
}
ideep::attr_t op_attr;
// Use mkldnn post ops to perform the add.
if (beta != 0.0f) {
op_attr = ideep::attr_t::fuse_sum();
}
// NOTE: View as c-contiguous to avoid extra reordering in mkldnn
// Use identity: C = AB <=> C^T = B^T A^T
ideep::tensor::dims a_strides{{lda, 1}}, b_strides{{ldb, 1}}, c_strides{{ldc, 1}};
if (transa != TransposeType::NoTranspose) {
std::swap(a_strides[0], a_strides[1]);
}
if (transb != TransposeType::NoTranspose) {
std::swap(b_strides[0], b_strides[1]);
}
ideep::tensor a({
/*sizes=*/{k, m},
ideep::tensor::data_type::bf16,
/*strides=*/a_strides},
const_cast<c10::BFloat16*>(a_data));
ideep::tensor b({
/*sizes=*/{n, k},
ideep::tensor::data_type::bf16,
/*strides=*/b_strides},
const_cast<c10::BFloat16*>(b_data));
ideep::tensor c({
/*sizes=*/{n, m},
ideep::tensor::data_type::bf16,
/*strides=*/c_strides},
c_data);
ideep::matmul_forward::compute(
b, a, c, alpha, beta,
ideep::scale_t(), ideep::scale_t(), ideep::scale_t(), op_attr);
if (c.get_data_handle() != c_data){
// ideep will query onednn expect format of output
// if given output format is not expected, ideep will re-init an output buffer
// under this case, we need copy the re-inited buffer back to given buffer
ideep::tensor real_output({
/*sizes=*/{n, m},
ideep::tensor::data_type::bf16,
/*strides=*/c_strides},
c_data);
c.reorder_to(real_output);
}
return true;
}
void mkldnn_matmul(
const Tensor &mat1,
const Tensor &mat2,
const Tensor &result,
float beta,
float alpha) {
TORCH_CHECK((mat1.dim() == 2 && mat2.dim() == 2) || // aten::addmm
(mat1.dim() == 3 && mat2.dim() == 3) || // aten::bmm, aten::baddbmm
(mat1.dim() == 2 && mat2.dim() == 1) || // aten::mv
(mat1.dim() == 1 && mat2.dim() == 1), // aten::dot
"mkldnn_matmul: unsupported dims for mat and mat2");
#if defined(__aarch64__)
// oneDNN fast-maths mode (enabled by setting the environment variable ONEDNN_DEFAULT_FPMATH_MODE=BF16) will dispatch
// fp32 inputs to bf16 kernels where HW permits. So, both fp32 and bf16 inputs are permitted.
TORCH_CHECK((mat1.scalar_type() == mat2.scalar_type()) && (mat1.scalar_type() == result.scalar_type()) &&
((mat1.scalar_type() == at::kFloat) || (mat1.scalar_type() == at::kBFloat16)),
"mkldnn_matmul: only enabled for fp32 and bf16 path");
// device needs to support bf16 if the inputs are of bf16 type
if (mat1.scalar_type() == at::kBFloat16) {
TORCH_CHECK(mkldnn_bf16_device_check_arm(),
"mkldnn_matmul: mkldnn_matmul bf16 path needs a cpu with bf16 support");
}
#else
TORCH_CHECK(mkldnn_bf16_device_check(),
"mkldnn_matmul: mkldnn_matmul bf16 path needs the cpu support avx512bw, avx512vl and avx512dq, or AWS Graviton3");
TORCH_CHECK(mat1.scalar_type() == at::kBFloat16 &&
mat2.scalar_type() == at::kBFloat16 &&
result.scalar_type() == at::kBFloat16, "mkldnn_matmul: only enabled for bf16 path");
#endif
auto mat1_unsqueezed = mat1.dim() == 1 ? mat1.unsqueeze(0) : mat1;
auto mat2_unsqueezed = mat2.dim() == 1 ? mat2.unsqueeze(1) : mat2;
auto result_unsqueezed = result.dim() == 1 ? result.unsqueeze(1) : result;
ideep::attr_t op_attr;
// "addmm", "addbmm" "baddbmm" in pytorch allow bias to be 2-D or 3-D tensor
// but mkldnn matmul primitive only support bias be 1-D tensors
// to address their differences, we use mkldnn post ops to perform a fused "add" after matrix multiplication is over
if (beta != 0.0f) op_attr = ideep::attr_t::fuse_sum();
// If alpha = 0, dose not need actually do gemm computation
if (alpha == 0)
return;
auto is_mkldnn_optimized_format = [&](const Tensor& t) {
if (t.is_contiguous()) return true;
const auto sizes = t.sizes();
const auto strides = t.strides();
if (t.dim() == 2){
return strides[0] == 1 && strides[1] == sizes[0];
} else {
// dim = 3
return strides[0] == sizes[1] * sizes[2] && strides[1] == 1 && strides[2] == sizes[1];
}
};
// Mkldnn only optimized for contiguous or transposed (transpose last 2 dim if 3-D tensor) format now
// Will remove this "contiguous" after mkldnn have fully supported
Tensor mat1_ = is_mkldnn_optimized_format(mat1_unsqueezed) ? mat1_unsqueezed : mat1_unsqueezed.contiguous();
Tensor mat2_ = is_mkldnn_optimized_format(mat2_unsqueezed) ? mat2_unsqueezed : mat2_unsqueezed.contiguous();
// Make sure mat1 and mat2 have default contiguous strides if they are contiguous tensors for better performance.
auto mat1_sizes = mat1_.sizes();
IntArrayRef mat1_default_contiguous_strides = c10::contiguous_strides(mat1_sizes);
if (mat1_.is_contiguous() && mat1_.strides() != mat1_default_contiguous_strides) {
mat1_ = mat1_.as_strided(mat1_sizes, mat1_default_contiguous_strides);
}
auto mat2_sizes = mat2_.sizes();
IntArrayRef mat2_default_contiguous_strides = c10::contiguous_strides(mat2_sizes);
if (mat2_.is_contiguous() && mat2_.strides() != mat2_default_contiguous_strides) {
mat2_ = mat2_.as_strided(mat2_sizes, mat2_default_contiguous_strides);
}
// mkldnn_matmul only proceed CPU tensor
const ideep::tensor x = itensor_view_from_dense(mat1_);
const ideep::tensor w = itensor_view_from_dense(mat2_);
ideep::tensor y = itensor_view_from_dense(result_unsqueezed);
ideep::matmul_forward::compute(x, w, y, alpha, beta,
ideep::scale_t(), ideep::scale_t(), ideep::scale_t(), op_attr);
if (y.get_data_handle() != result.data_ptr()){
// ideep will query onednn expect format of output
// if given output format is not expected, ideep will re-init an output buffer
// under this case, we need copy the re-inited buffer back to given buffer
ideep::tensor public_y = itensor_view_from_dense(result);
y.reorder_to(public_y);
}
if (mat1.dim() == 1 && mat2.dim() == 1){
// aten::dot
result.squeeze_();
}
}
inline bool checksize(const Tensor& mat1, const Tensor& mat2){
// if dim = 2, mat1's size = (m * n), mat2's size = (n * k)
// else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
// else called from aten::mv, mat1.size = (m * n), mat2.size = (n)
// only m * n * b * k(if exist) are large enough we can get benefit from mkldnn optimized gemm kernel
static const int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
if (mat1.dim() == 1 && mat2.dim() == 1) {
// aten::dot
return mat1.size(0) > mkldnn_gemm_min_size;
} else if (mat1.dim() == 2 && mat2.dim() == 1) {
// aten::mv
return mat1.size(0) * mat1.size(1) > mkldnn_gemm_min_size;
} else if (mat2.dim() == 2 && mat2.dim() == 2) {
// aten::addmm
return mat1.size(0) * mat1.size(1) * mat2.size(1) > mkldnn_gemm_min_size;
} else {
// aten::bmm, aten::baddbmm
return mat1.size(0) * mat1.size(1) * mat1.size(2) * mat2.size(2) > mkldnn_gemm_min_size;
}
}
bool use_mkldnn_bf16_matmul(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result) {
#if defined(__aarch64__)
if (mkldnn_bf16_device_check_arm()) {
//onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. Arm Neoverse V1
//so, don't restrict the mkldnn_matmul only for bf16 inputs, allow it for float as well
return (
use_mkldnn_bf16_matmul() &&
(mat1.scalar_type() == mat2.scalar_type()) && (!result.defined() || (mat1.scalar_type() == result.scalar_type())) &&
((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)) &&
mat1.numel() != 0 &&
mat2.numel() != 0 &&
checksize(mat1, mat2));
} else
#endif
{
return (
use_mkldnn_bf16_matmul() &&
mat1.scalar_type() == kBFloat16 &&
mat2.scalar_type() == kBFloat16 &&
(!result.defined() || result.scalar_type() == kBFloat16) &&
mat1.numel() != 0 &&
mat2.numel() != 0 &&
checksize(mat1, mat2));
}
}
} // namespace native
} // namespace at
#endif // AT_MKLDNN_ENABLED