forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ValidateCompressedIndicesKernel.cpp
49 lines (40 loc) · 1.25 KB
/
ValidateCompressedIndicesKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include <ATen/native/sparse/ValidateCompressedIndicesCommon.h>
#include <ATen/native/cpu/Loops.h>
#ifdef AT_PER_OPERATOR_HEADERS
#include <ATen/ops/_validate_compressed_sparse_indices_native.h>
#endif
namespace at::native {
namespace {
template <typename func_t>
struct CPUKernel {
static void launch(TensorIteratorBase& iter, const func_t& f) {
cpu_kernel(iter, f);
}
};
template <typename func_t>
struct EmptyKernel {
static void launch(TensorIteratorBase& iter, const func_t& f) {
}
};
template <typename func_t, typename vec_func_t>
struct CPUVecKernel {
static void launch(TensorIteratorBase& iter, const func_t& f, const vec_func_t& vec_f) {
cpu_kernel_vec(iter, f, vec_f);
}
};
}
void _validate_compressed_sparse_indices_cpu(
const bool is_crow,
const Tensor& cidx,
const Tensor& idx,
const int64_t cdim,
const int64_t dim,
const int64_t nnz) {
// Call into
// compressed_index_invariance_checks_kernel<EmptyKernel, CPUVecKernel, Vectorized>
// to enable vectorized checks once all the conditions for that are met,
// see ATen/native/sparse/CompressedIndexChecksCommon.h for more details.
validate_compressed_sparse_indices_kernel<CPUKernel>(
is_crow, cidx, idx, cdim, dim, nnz);
}
} //namespace at::native