forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
OpContext.cpp
47 lines (38 loc) · 1.26 KB
/
OpContext.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <ATen/native/mkldnn/ConvPrepack.h>
#include <ATen/native/mkldnn/OpContext.h>
#if AT_MKLDNN_ENABLED()
namespace at {
namespace native {
namespace mkldnn {
c10::intrusive_ptr<ConvOpContext> MkldnnConvOpContext::create_context(
at::Tensor&& weight,
c10::optional<at::Tensor>&& bias,
std::vector<int64_t>&& padding,
std::vector<int64_t>&& stride,
std::vector<int64_t>&& dilation,
int64_t groups,
std::vector<int64_t>&& input_size,
const ideep::attr_t& attr) {
auto op_context = mkldnn::internal::convolution::create(
weight, bias, padding, stride, dilation, groups, input_size, attr);
auto conv_op_context = c10::make_intrusive<MkldnnConvOpContext>(
std::move(weight),
std::move(bias),
std::move(padding),
std::move(stride),
std::move(dilation),
groups,
std::move(input_size),
std::move(op_context));
return conv_op_context;
}
Tensor MkldnnConvOpContext::run(const Tensor& input) {
return mkldnn::internal::convolution::run(op_context_, input);
}
void MkldnnConvOpContext::run(const Tensor& input, void* output) {
return mkldnn::internal::convolution::run(op_context_, input, output);
}
} // namespace mkldnn
} // namespace native
} // namespace at
#endif // AT_MKLDNN_ENABLED()