forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FlattenIndicesKernel.cpp
31 lines (24 loc) · 998 Bytes
/
FlattenIndicesKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/sparse/SparseStubs.h>
#include <ATen/native/sparse/FlattenIndicesCommon.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/AccumulateType.h>
namespace at::native {
namespace {
template <typename func_t>
struct CPUKernelLauncher {
static void launch(TensorIteratorBase& iter, const func_t& f) {
cpu_kernel(iter, f);
}
};
Tensor flatten_indices_cpu_kernel(const Tensor& indices, IntArrayRef size) {
return _flatten_indices<CPUKernelLauncher>(indices, size);
}
}
REGISTER_ARCH_DISPATCH(flatten_indices_stub, DEFAULT, &flatten_indices_cpu_kernel);
REGISTER_AVX512_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
REGISTER_AVX2_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
REGISTER_VSX_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
REGISTER_ZVECTOR_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
} // namespace at::native