Skip to content

Commit

Permalink
Add aten::_nnz/_values for SparseXPU and SparsCsrXPU dispatch key (#428)
Browse files Browse the repository at this point in the history
Signed-off-by: Feng Yuan <feng1.yuan@intel.com>
  • Loading branch information
fengyuan14 authored Jun 22, 2024
1 parent 01fc85f commit 31c4001
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/aten/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# ATen XPU sources

file(GLOB xpu_cpp "*.cpp")
file(GLOB xpu_cpp "*.cpp", "sparse/*.cpp")
file(GLOB xpu_sycl "sycl/*.cpp")

list(APPEND ATen_XPU_CPP_SRCS ${xpu_cpp})
Expand Down
23 changes: 23 additions & 0 deletions src/aten/sparse/SparseCsrTensor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Basic functions on sparse tensors
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <torch/library.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_nnz_native.h>
#endif

namespace at::native::xpu {

int64_t _nnz_csr(const Tensor& self) {
return at::native::_nnz_sparse_csr(self);
}

TORCH_LIBRARY_IMPL(aten, SparseCsrXPU, m) {
m.impl(TORCH_SELECTIVE_NAME("_nnz"), TORCH_FN(_nnz_csr));
}

} // namespace at::native::xpu
17 changes: 13 additions & 4 deletions src/aten/SparseTensor.cpp → src/aten/sparse/SparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_nnz_native.h>
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_native.h>
#endif

namespace at {
namespace native::xpu {
namespace at::native::xpu {

Tensor _sparse_coo_tensor_with_dims_and_tensors(
int64_t sparse_dim,
Expand Down Expand Up @@ -50,11 +50,20 @@ Tensor _sparse_coo_tensor_with_dims_and_tensors(
is_coalesced);
}

int64_t _nnz(const Tensor& self) {
return at::native::_nnz_sparse(self);
}

Tensor _values(const Tensor& self) {
return at::native::_values_sparse(self);
}

TORCH_LIBRARY_IMPL(aten, SparseXPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("_sparse_coo_tensor_with_dims_and_tensors"),
TORCH_FN(_sparse_coo_tensor_with_dims_and_tensors));
m.impl(TORCH_SELECTIVE_NAME("_nnz"), TORCH_FN(_nnz));
m.impl(TORCH_SELECTIVE_NAME("_values"), TORCH_FN(_values));
}

} // namespace native::xpu
} // namespace at
} // namespace at::native::xpu

0 comments on commit 31c4001

Please sign in to comment.