diff --git a/CITATION.cff b/CITATION.cff index 37cd60dfa5..ac0501e0b6 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -64,6 +64,8 @@ preferred-citation: given-names: Achim - family-names: Streit given-names: Achim + - family-names: Vaithinathan Aravindan + given-names: Ashwath year: 2020 collection-title: 2020 IEEE International Conference on Big Data (IEEE Big Data 2020) collection-doi: 10.1109/BigData50022.2020.9378050 diff --git a/heat/sparse/__init__.py b/heat/sparse/__init__.py index 538cb9fd48..86e2929704 100644 --- a/heat/sparse/__init__.py +++ b/heat/sparse/__init__.py @@ -1,7 +1,7 @@ """add sparse heat function to the ht.sparse namespace""" from .arithmetics import * -from .dcsr_matrix import * +from .dcsx_matrix import * from .factories import * from ._operations import * from .manipulations import * diff --git a/heat/sparse/_operations.py b/heat/sparse/_operations.py index f480a15da8..9efc57ca62 100644 --- a/heat/sparse/_operations.py +++ b/heat/sparse/_operations.py @@ -3,11 +3,10 @@ import torch import numpy as np -from heat.sparse.dcsr_matrix import DCSR_matrix +from heat.sparse.dcsx_matrix import DCSC_matrix, DCSR_matrix, __DCSX_matrix from . import factories from ..core.communication import MPI -from ..core.dndarray import DNDarray from ..core import types from typing import Callable, Optional, Dict @@ -15,13 +14,14 @@ __all__ = [] -def __binary_op_csr( +def __binary_op_csx( operation: Callable, - t1: DCSR_matrix, - t2: DCSR_matrix, - out: Optional[DCSR_matrix] = None, + t1: __DCSX_matrix, + t2: __DCSX_matrix, + out: Optional[__DCSX_matrix] = None, + orientation: str = "row", fn_kwargs: Optional[Dict] = {}, -) -> DCSR_matrix: +) -> __DCSX_matrix: """ Generic wrapper for element-wise binary operations of two operands. Takes the operation function and the two operands involved in the operation as arguments. @@ -31,37 +31,60 @@ def __binary_op_csr( operation : PyTorch function The operation to be performed. Function that performs operation elements-wise on the involved tensors, e.g. add values from other to self - t1: DCSR_matrix + t1: __DCSX_matrix or scalar The first operand involved in the operation. - t2: DCSR_matrix + t2: __DCSX_matrix or scalar The second operand involved in the operation. - out: DCSR_matrix, optional + out: __DCSX_matrix, optional Output buffer in which the result is placed. If not provided, a freshly allocated matrix is returned. + orientation: str, optional + The orientation of the operation. Options: 'row' or 'col' + Default: 'row' fn_kwargs: Dict, optional keyword arguments used for the given operation Default: {} (empty dictionary) Returns ------- - result: ht.sparse.DCSR_matrix - A DCSR_matrix containing the results of element-wise operation. + result: ht.sparse.__DCSX_matrix + A __DCSX_matrix containing the results of element-wise operation. + + Raises + ------ + ValueError + If the orientation is invalid + ValueError + If the input types are not supported + ValueError + If the input shapes are not compatible + ValueError + If the output buffer shape is not compatible with the result """ - if not np.isscalar(t1) and not isinstance(t1, DCSR_matrix): + if orientation not in ["row", "col"]: + raise ValueError(f"Invalid orientation: '{orientation}'. Options: 'row' or 'col'") + + if not np.isscalar(t1) and not isinstance(t1, __DCSX_matrix): raise TypeError( f"Only Dcsr_matrices and numeric scalars are supported, but input was {type(t1)}" ) - if not np.isscalar(t2) and not isinstance(t2, DCSR_matrix): + if not np.isscalar(t2) and not isinstance(t2, __DCSX_matrix): raise TypeError( f"Only Dcsr_matrices and numeric scalars are supported, but input was {type(t2)}" ) - if not isinstance(t1, DCSR_matrix) and not isinstance(t2, DCSR_matrix): + if not isinstance(t1, __DCSX_matrix) and not isinstance(t2, __DCSX_matrix): raise TypeError( f"Operator only to be used with Dcsr_matrices, but input types were {type(t1)} and {type(t2)}" ) promoted_type = types.result_type(t1, t2).torch_type() + torch_constructor = torch.sparse_csr_tensor if orientation == "row" else torch.sparse_csc_tensor + factory_method = ( + factories.sparse_csr_matrix if orientation == "row" else factories.sparse_csc_matrix + ) + split_axis = 0 if orientation == "row" else 1 + # If one of the inputs is a scalar # just perform the operation on the data tensor # and create a new sparse matrix @@ -74,15 +97,15 @@ def __binary_op_csr( scalar = t1 res_values = operation(matrix.larray.values().to(promoted_type), scalar, **fn_kwargs) - res_torch_sparse_csr = torch.sparse_csr_tensor( + res_torch_sparse_csx = torch_constructor( matrix.lindptr, matrix.lindices, res_values, size=matrix.lshape, device=matrix.device.torch_device, ) - return factories.sparse_csr_matrix( - res_torch_sparse_csr, is_split=matrix.split, comm=matrix.comm, device=matrix.device + return factory_method( + res_torch_sparse_csx, is_split=matrix.split, comm=matrix.comm, device=matrix.device ) if t1.shape != t2.shape: @@ -93,10 +116,10 @@ def __binary_op_csr( if t1.split is not None or t2.split is not None: if t1.split is None: - t1 = factories.sparse_csr_matrix(t1.larray, split=0) + t1 = factory_method(t1.larray, split=split_axis) if t2.split is None: - t2 = factories.sparse_csr_matrix(t2.larray, split=0) + t2 = factory_method(t2.larray, split=split_axis) output_split = t1.split output_device = t1.device @@ -113,10 +136,10 @@ def __binary_op_csr( if out.split != output_split: if out.split is None: - out = factories.sparse_csr_matrix(out.larray, split=0) + out = factory_method(out.larray, split=split_axis) else: - out = factories.sparse_csr_matrix( - torch.sparse_csr_tensor( + out = factory_method( + torch_constructor( torch.tensor(out.indptr, dtype=torch.int64), torch.tensor(out.indices, dtype=torch.int64), torch.tensor(out.data), @@ -146,21 +169,38 @@ def __binary_op_csr( output_type = types.canonical_heat_type(result.dtype) if out is None: - return DCSR_matrix( - array=torch.sparse_csr_tensor( - result.crow_indices().to(torch.int64), - result.col_indices().to(torch.int64), - result.values(), - size=output_lshape, - ), - gnnz=output_gnnz, - gshape=output_shape, - dtype=output_type, - split=output_split, - device=output_device, - comm=output_comm, - balanced=output_balanced, - ) + if orientation == "row": + return DCSR_matrix( + array=torch_constructor( + result.crow_indices().to(torch.int64), + result.col_indices().to(torch.int64), + result.values(), + size=output_lshape, + ), + gnnz=output_gnnz, + gshape=output_shape, + dtype=output_type, + split=output_split, + device=output_device, + comm=output_comm, + balanced=output_balanced, + ) + else: + return DCSC_matrix( + array=torch_constructor( + result.ccol_indices().to(torch.int64), + result.row_indices().to(torch.int64), + result.values(), + size=output_lshape, + ), + gnnz=output_gnnz, + gshape=output_shape, + dtype=output_type, + split=output_split, + device=output_device, + comm=output_comm, + balanced=output_balanced, + ) out.larray.copy_(result) out.gnnz = output_gnnz diff --git a/heat/sparse/arithmetics.py b/heat/sparse/arithmetics.py index 6aef13c93c..6497b015e9 100644 --- a/heat/sparse/arithmetics.py +++ b/heat/sparse/arithmetics.py @@ -4,7 +4,7 @@ import torch -from .dcsr_matrix import DCSR_matrix +from .dcsx_matrix import DCSC_matrix, DCSR_matrix from . import _operations @@ -14,7 +14,7 @@ ] -def add(t1: DCSR_matrix, t2: DCSR_matrix) -> DCSR_matrix: +def add(t1: DCSR_matrix, t2: DCSR_matrix, orientation: str = "row") -> DCSR_matrix: """ Element-wise addition of values from two operands, commutative. Takes the first and second operand (scalar or :class:`~heat.sparse.DCSR_matrix`) whose elements are to be added @@ -26,6 +26,9 @@ def add(t1: DCSR_matrix, t2: DCSR_matrix) -> DCSR_matrix: The first operand involved in the addition t2: DCSR_matrix The second operand involved in the addition + orientation: str, optional + The orientation of the operation. Options: 'row' or 'col' + Default: 'row' Examples -------- @@ -43,16 +46,16 @@ def add(t1: DCSR_matrix, t2: DCSR_matrix) -> DCSR_matrix: DNDarray([[2., 0., 4.], [0., 0., 6.]], dtype=ht.float32, device=cpu:0, split=0) """ - return _operations.__binary_op_csr(torch.add, t1, t2) + return _operations.__binary_op_csx(torch.add, t1, t2, orientation=orientation) -DCSR_matrix.__add__ = lambda self, other: add(self, other) +DCSR_matrix.__add__ = lambda self, other: add(self, other, orientation="row") DCSR_matrix.__add__.__doc__ = add.__doc__ -DCSR_matrix.__radd__ = lambda self, other: add(self, other) +DCSR_matrix.__radd__ = lambda self, other: add(self, other, orientation="row") DCSR_matrix.__radd__.__doc__ = add.__doc__ -def mul(t1: DCSR_matrix, t2: DCSR_matrix) -> DCSR_matrix: +def mul(t1: DCSR_matrix, t2: DCSR_matrix, orientation: str = "row") -> DCSR_matrix: """ Element-wise multiplication (NOT matrix multiplication) of values from two operands, commutative. Takes the first and second operand (scalar or :class:`~heat.sparse.DCSR_matrix`) whose elements are to be @@ -64,6 +67,9 @@ def mul(t1: DCSR_matrix, t2: DCSR_matrix) -> DCSR_matrix: The first operand involved in the multiplication t2: DCSR_matrix The second operand involved in the multiplication + orientation: str, optional + The orientation of the operation. Options: 'row' or 'col' + Default: 'row' Examples -------- @@ -81,10 +87,10 @@ def mul(t1: DCSR_matrix, t2: DCSR_matrix) -> DCSR_matrix: DNDarray([[1., 0., 4.], [0., 0., 9.]], dtype=ht.float32, device=cpu:0, split=0) """ - return _operations.__binary_op_csr(torch.mul, t1, t2) + return _operations.__binary_op_csx(torch.mul, t1, t2, orientation=orientation) -DCSR_matrix.__mul__ = lambda self, other: mul(self, other) +DCSR_matrix.__mul__ = lambda self, other: mul(self, other, orientation="row") DCSR_matrix.__mul__.__doc__ = mul.__doc__ -DCSR_matrix.__rmul__ = lambda self, other: mul(self, other) +DCSR_matrix.__rmul__ = lambda self, other: mul(self, other, orientation="row") DCSR_matrix.__rmul__.__doc__ = mul.__doc__ diff --git a/heat/sparse/dcsr_matrix.py b/heat/sparse/dcsx_matrix.py similarity index 59% rename from heat/sparse/dcsr_matrix.py rename to heat/sparse/dcsx_matrix.py index a5d6187fd6..2d1f754a1d 100644 --- a/heat/sparse/dcsr_matrix.py +++ b/heat/sparse/dcsx_matrix.py @@ -11,19 +11,18 @@ from ..core.factories import array from ..core.types import datatype, canonical_heat_type -__all__ = ["DCSR_matrix"] +__all__ = ["DCSR_matrix", "DCSC_matrix"] Communication = TypeVar("Communication") -class DCSR_matrix: +class __DCSX_matrix: """ - Distributed Compressed Sparse Row Matrix. It is composed of - PyTorch sparse_csr_tensors local to each process. + Distributed Compressed Sparse Matrix. Base class for DCSR_matrix and DCSC_matrix. Parameters ---------- - array : torch.Tensor (layout ==> torch.sparse_csr) + array : torch.Tensor (layout ==> torch.sparse_csr | torch.sparse_csc) Local sparse array gnnz: int Total number of non-zero elements across all processes @@ -53,18 +52,18 @@ def __init__( comm: Communication, balanced: bool, ): - self.__array = array - self.__gnnz = gnnz - self.__gshape = gshape - self.__dtype = dtype - self.__split = split - self.__device = device - self.__comm = comm - self.__balanced = balanced + self._array = array + self._gnnz = gnnz + self._gshape = gshape + self._dtype = dtype + self._split = split + self._device = device + self._comm = comm + self._balanced = balanced def global_indptr(self) -> DNDarray: """ - Global indptr of the ``DCSR_matrix`` as a ``DNDarray`` + Global indptr of the ``__DCSX_matrix`` as a ``DNDarray`` """ if self.split is None: raise ValueError("This method works only for distributed matrices") @@ -92,41 +91,43 @@ def global_indptr(self) -> DNDarray: dtype=self.lindptr.dtype, device=self.device, comm=self.comm, - is_split=self.split, + is_split=( + 0 if self.split is not None else None + ), # is_split for the indptr is either 0 or None because it is always 1 dimensional ) @property def balanced(self) -> bool: """ - Boolean value indicating if the DCSR_matrix is balanced between the MPI processes + Boolean value indicating if the __DCSX_matrix is balanced between the MPI processes """ - return self.__balanced + return self._balanced @property def comm(self) -> Communication: """ - The :class:`~heat.core.communication.Communication` of the ``DCSR_matrix`` + The :class:`~heat.core.communication.Communication` of the ``__DCSX_matrix`` """ - return self.__comm + return self._comm @property def device(self) -> Device: """ - The :class:`~heat.core.devices.Device` of the ``DCSR_matrix`` + The :class:`~heat.core.devices.Device` of the ``__DCSX_matrix`` """ - return self.__device + return self._device @property def larray(self) -> torch.Tensor: """ - Local data of the ``DCSR_matrix`` + Local data of the ``__DCSX_matrix`` """ - return self.__array + return self._array @property def data(self) -> torch.Tensor: """ - Global data of the ``DCSR_matrix`` + Global data of the ``__DCSX_matrix`` """ if self.split is None: return self.ldata @@ -141,21 +142,21 @@ def data(self) -> torch.Tensor: @property def gdata(self) -> torch.Tensor: """ - Global data of the ``DCSR_matrix`` + Global data of the ``__DCSX_matrix`` """ return self.data @property def ldata(self) -> torch.Tensor: """ - Local data of the ``DCSR_matrix`` + Local data of the ``__DCSX_matrix`` """ - return self.__array.values() + return self._array.values() @property def indptr(self) -> torch.Tensor: """ - Global indptr of the ``DCSR_matrix`` + Global indptr of the ``__DCSX_matrix`` """ if self.split is None: return self.lindptr @@ -165,21 +166,21 @@ def indptr(self) -> torch.Tensor: @property def gindptr(self) -> torch.Tensor: """ - Global indptr of the ``DCSR_matrix`` + Global indptr of the ``__DCSX_matrix`` """ return self.indptr @property def lindptr(self) -> torch.Tensor: """ - Local indptr of the ``DCSR_matrix`` + Local indptr of the ``__DCSX_matrix`` """ - return self.__array.crow_indices() + raise NotImplementedError("Local indptr is not implemented for __DCSX_matrix") @property def indices(self) -> torch.Tensor: """ - Global indices of the ``DCSR_matrix`` + Global indices of the ``__DCSX_matrix`` """ if self.split is None: return self.lindices @@ -194,89 +195,89 @@ def indices(self) -> torch.Tensor: @property def gindices(self) -> torch.Tensor: """ - Global indices of the ``DCSR_matrix`` + Global indices of the ``__DCSX_matrix`` """ return self.indices @property def lindices(self) -> torch.Tensor: """ - Local indices of the ``DCSR_matrix`` + Local indices of the ``__DCSX_matrix`` """ - return self.__array.col_indices() + raise NotImplementedError("Local indices is not implemented for __DCSX_matrix") @property def ndim(self) -> int: """ - Number of dimensions of the ``DCSR_matrix`` + Number of dimensions of the ``__DCSX_matrix`` """ - return len(self.__gshape) + return len(self._gshape) @property def nnz(self) -> int: """ - Total number of non-zero elements of the ``DCSR_matrix`` + Total number of non-zero elements of the ``__DCSX_matrix`` """ - return self.__gnnz + return self._gnnz @property def gnnz(self) -> int: """ - Total number of non-zero elements of the ``DCSR_matrix`` + Total number of non-zero elements of the ``__DCSX_matrix`` """ return self.nnz @property def lnnz(self) -> int: """ - Number of non-zero elements on the local process of the ``DCSR_matrix`` + Number of non-zero elements on the local process of the ``__DCSX_matrix`` """ - return self.__array._nnz() + return self._array._nnz() @property def shape(self) -> Tuple[int, ...]: """ - Global shape of the ``DCSR_matrix`` + Global shape of the ``__DCSX_matrix`` """ - return self.__gshape + return self._gshape @property def gshape(self) -> Tuple[int, ...]: """ - Global shape of the ``DCSR_matrix`` + Global shape of the ``__DCSX_matrix`` """ return self.shape @property def lshape(self) -> Tuple[int, ...]: """ - Local shape of the ``DCSR_matrix`` + Local shape of the ``__DCSX_matrix`` """ - return tuple(self.__array.size()) + return tuple(self._array.size()) @property def dtype(self): """ - The :class:`~heat.core.types.datatype` of the ``DCSR_matrix`` + The :class:`~heat.core.types.datatype` of the ``__DCSX_matrix`` """ - return self.__dtype + return self._dtype @property def split(self) -> int: """ - Returns the axis on which the ``DCSR_matrix`` is split + Returns the axis on which the ``__DCSX_matrix`` is split """ - return self.__split + return self._split def is_distributed(self) -> bool: """ - Determines whether the data of this ``DCSR_matrix`` is distributed across multiple processes. + Determines whether the data of this ``__DCSX_matrix`` is distributed across multiple processes. """ return self.split is not None and self.comm.is_distributed() def counts_displs_nnz(self) -> Tuple[Tuple[int], Tuple[int]]: """ - Returns actual counts (number of non-zero items per process) and displacements (offsets) of the DCSR_matrix. + Returns actual counts (number of non-zero items per process) and displacements (offsets) of the __DCSX_matrix. Does not assume load balance. """ if self.split is not None: @@ -287,10 +288,10 @@ def counts_displs_nnz(self) -> Tuple[Tuple[int], Tuple[int]]: return tuple(counts.tolist()), tuple(displs) else: raise ValueError( - "Non-distributed DCSR_matrix. Cannot calculate counts and displacements." + f"Non-distributed {self.__class__.__name__}. Cannot calculate counts and displacements." ) - def astype(self, dtype, copy=True) -> DCSR_matrix: + def astype(self, dtype, copy=True) -> __DCSX_matrix: """ Returns a casted version of this matrix. Casted matrix is a new matrix of the same shape but with given type of this matrix. If copy is ``True``, the @@ -305,9 +306,9 @@ def astype(self, dtype, copy=True) -> DCSR_matrix: in-place and this matrix is returned """ dtype = canonical_heat_type(dtype) - casted_matrix = self.__array.type(dtype.torch_type()) + casted_matrix = self._array.to(dtype.torch_type(), copy=copy) if copy: - return DCSR_matrix( + return self.__class__( casted_matrix, self.gnnz, self.gshape, @@ -318,14 +319,14 @@ def astype(self, dtype, copy=True) -> DCSR_matrix: self.balanced, ) - self.__array = casted_matrix - self.__dtype = dtype + self._array = casted_matrix + self._dtype = dtype return self def __repr__(self) -> str: """ - Computes a printable representation of the passed DCSR_matrix. + Computes a printable representation of the passed __DCSX_matrix. """ print_string = ( f"(indptr: {self.indptr}, indices: {self.indices}, data: {self.data}, " @@ -338,3 +339,85 @@ def __repr__(self) -> str: if self.comm.rank != 0: return "" return print_string + + +class DCSR_matrix(__DCSX_matrix): + """ + Distributed Compressed Sparse Row Matrix. It is composed of + PyTorch sparse_csr_tensors local to each process. + + Parameters + ---------- + array : torch.Tensor (layout ==> torch.sparse_csr) + Local sparse array + gnnz: int + Total number of non-zero elements across all processes + gshape : Tuple[int,...] + The global shape of the array + dtype : datatype + The datatype of the array + split : int or None + If split is not None, it denotes the axis on which the array is divided between processes. + DCSR_matrix only supports distribution along axis 0. + device : Device + The device on which the local arrays are using (cpu or gpu) + comm : Communication + The communications object for sending and receiving data + balanced: bool or None + Describes whether the data are evenly distributed across processes. + """ + + @property + def lindptr(self) -> torch.Tensor: + """ + Local indptr of the ``DCSR_matrix`` + """ + return self._array.crow_indices() + + @property + def lindices(self) -> torch.Tensor: + """ + Local indices of the ``DCSR_matrix`` + """ + return self._array.col_indices() + + +class DCSC_matrix(__DCSX_matrix): + """ + Distributed Compressed Sparse Column Matrix. It is composed of + PyTorch sparse_csc_tensors local to each process. + + Parameters + ---------- + array : torch.Tensor (layout ==> torch.sparse_csc) + Local sparse array + gnnz: int + Total number of non-zero elements across all processes + gshape : Tuple[int,...] + The global shape of the array + dtype : datatype + The datatype of the array + split : int or None + If split is not None, it denotes the axis on which the array is divided between processes. + DCSR_matrix only supports distribution along axis 0. + device : Device + The device on which the local arrays are using (cpu or gpu) + comm : Communication + The communications object for sending and receiving data + balanced: bool or None + Describes whether the data are evenly distributed across processes. + """ + + @property + def lindptr(self) -> torch.Tensor: + """ + Local indptr of the ``DCSC_matrix`` + """ + return self._array.ccol_indices() + + @property + def lindices(self) -> torch.Tensor: + """ + Local indices of the ``DCSC_matrix`` + """ + return self._array.row_indices() diff --git a/heat/sparse/factories.py b/heat/sparse/factories.py index a9a51c5a26..0966785cdf 100644 --- a/heat/sparse/factories.py +++ b/heat/sparse/factories.py @@ -3,6 +3,7 @@ import torch import numpy as np from scipy.sparse import csr_matrix as scipy_csr_matrix +from scipy.sparse import csc_matrix as scipy_csc_matrix from typing import Optional, Type, Iterable import warnings @@ -13,10 +14,11 @@ from ..core.devices import Device from ..core.types import datatype -from .dcsr_matrix import DCSR_matrix +from .dcsx_matrix import DCSC_matrix, DCSR_matrix, __DCSX_matrix __all__ = [ "sparse_csr_matrix", + "sparse_csc_matrix", ] @@ -39,7 +41,7 @@ def sparse_csr_matrix( dtype : datatype, optional The desired data-type for the sparse matrix. If not given, then the type will be determined as the minimum type required to hold the objects in the sequence. This argument can only be used to ‘upcast’ the array. For downcasting, use - the :func:`~heat.sparse.dcsr_matrix.astype` method. + the :func:`~heat.sparse.DCSR_matrix.astype` method. split : int or None, optional The axis along which the passed array content ``obj`` is split and distributed in memory. DCSR_matrix only supports distribution along axis 0. Mutually exclusive with ``is_split``. @@ -93,6 +95,107 @@ def sparse_csr_matrix( >>> ht.sparse.sparse_csr_matrix([[0, 0, 1], [1, 0, 2], [0, 0, 3]]) (indptr: tensor([0, 1, 3, 4]), indices: tensor([2, 0, 2, 2]), data: tensor([1, 1, 2, 3]), dtype=ht.int64, device=cpu:0, split=None) """ + return __sparse_matrix(obj, dtype, split, is_split, device, comm, orientation="row") + + +def sparse_csc_matrix( + obj: Iterable, + dtype: Optional[Type[datatype]] = None, + split: Optional[int] = None, + is_split: Optional[int] = None, + device: Optional[Device] = None, + comm: Optional[Communication] = None, +) -> DCSC_matrix: + """ + Create a :class:`~heat.sparse.DCSC_matrix`. + + Parameters + ---------- + obj : array_like + A tensor or array, any object exposing the array interface, an object whose ``__array__`` method returns an + array, or any (nested) sequence. Sparse tensor that needs to be distributed. + dtype : datatype, optional + The desired data-type for the sparse matrix. If not given, then the type will be determined as the minimum type required + to hold the objects in the sequence. This argument can only be used to ‘upcast’ the array. For downcasting, use + the :func:`~heat.sparse.DCSC_matrix.astype` method. + split : int or None, optional + The axis along which the passed array content ``obj`` is split and distributed in memory. DCSC_matrix only supports + distribution along axis 1. Mutually exclusive with ``is_split``. + is_split : int or None, optional + Specifies the axis along which the local data portions, passed in obj, are split across all machines. DCSC_matrix only + supports distribution along axis 1. Useful for interfacing with other distributed-memory code. The shape of the global + array is automatically inferred. Mutually exclusive with ``split``. + device : str or Device, optional + Specifies the :class:`~heat.core.devices.Device` the array shall be allocated on (i.e. globally set default + device). + comm : Communication, optional + Handle to the nodes holding distributed array chunks. + + Raises + ------ + ValueError + If split and is_split parameters are not one of 1 or None. + + Examples + -------- + Create a :class:`~heat.sparse.DCSC_matrix` from :class:`torch.Tensor` (layout ==> torch.sparse_csc) + >>> indptr = torch.tensor([0, 2, 3, 6]) + >>> indices = torch.tensor([0, 2, 2, 0, 1, 2]) + >>> data = torch.tensor([1., 4., 5., 2., 3., 6.], dtype=torch.float) + >>> torch_sparse_csc = torch.sparse_csc_tensor(indptr, indices, data) + >>> heat_sparse_csc = ht.sparse.sparse_csc_matrix(torch_sparse_csc, split=1) + >>> heat_sparse_csc + (indptr: tensor([0, 2, 3, 6]), indices: tensor([0, 2, 2, 0, 1, 2]), data: tensor([1., 4., 5., 2., 3., 6.]), dtype=ht.float32, device=cpu:0, split=1) + + Create a :class:`~heat.sparse.DCSC_matrix` from :class:`scipy.sparse.csc_matrix` + >>> scipy_sparse_csc = scipy.sparse.csc_matrix((data, indices, indptr)) + >>> heat_sparse_csc = ht.sparse.sparse_csc_matrix(scipy_sparse_csc, split=1) + >>> heat_sparse_csc + (indptr: tensor([0, 2, 3, 6], dtype=torch.int32), indices: tensor([0, 2, 2, 0, 1, 2], dtype=torch.int32), data: tensor([1., 4., 5., 2., 3., 6.]), dtype=ht.float32, device=cpu:0, split=1) + + Create a :class:`~heat.sparse.DCSC_matrix` using data that is already distributed (with `is_split`) + >>> indptrs = [torch.tensor([0, 2, 3]), torch.tensor([0, 3])] + >>> indices = [torch.tensor([0, 2, 2]), torch.tensor([0, 1, 2])] + >>> data = [torch.tensor([1, 2, 3], dtype=torch.float), + torch.tensor([4, 5, 6], dtype=torch.float)] + >>> rank = ht.MPI_WORLD.rank + >>> local_indptr = indptrs[rank] + >>> local_indices = indices[rank] + >>> local_data = data[rank] + >>> local_torch_sparse_csr = torch.sparse_csr_tensor(local_indptr, local_indices, local_data) + >>> heat_sparse_csr = ht.sparse.sparse_csr_matrix(local_torch_sparse_csr, is_split=0) + >>> heat_sparse_csr + (indptr: tensor([0, 2, 3, 6]), indices: tensor([0, 2, 2, 0, 1, 2]), data: tensor([1., 2., 3., 4., 5., 6.]), dtype=ht.float32, device=cpu:0, split=1) + + Create a :class:`~heat.sparse.DCSC_matrix` from List + >>> ht.sparse.sparse_csc_matrix([[0, 0, 1], [1, 0, 2], [0, 0, 3]]) + (indptr: tensor([0, 1, 1, 4]), indices: tensor([1, 0, 1, 2]), data: tensor([1, 1, 2, 3]), dtype=ht.int64, device=cpu:0, split=None) + """ + return __sparse_matrix(obj, dtype, split, is_split, device, comm, orientation="col") + + +def __sparse_matrix( + obj: Iterable, + dtype: Optional[Type[datatype]] = None, + split: Optional[int] = None, + is_split: Optional[int] = None, + device: Optional[Device] = None, + comm: Optional[Communication] = None, + orientation: str = "row", +) -> __DCSX_matrix: + """ + Create a :class:`~heat.sparse.__DCSX_matrix`. + This is a common method for converting a distributed array to a sparse matrix representation. + + Raises + ------ + ValueError + If the orientation is not ``'row'`` or ``'col'``. + ValueError + If the number of dimensions of the input array is not 2. + ValueError + If the split or is_split axis is not supported for the type. + """ # version check if int(torch.__version__.split(".")[0]) <= 1 and int(torch.__version__.split(".")[1]) < 10: raise RuntimeError(f"ht.sparse requires torch >= 1.10. Found version {torch.__version__}.") @@ -105,9 +208,14 @@ def sparse_csr_matrix( if device is not None: device = devices.sanitize_device(device) + if orientation not in ["row", "col"]: + raise ValueError(f"Invalid orientation: '{orientation}'. Options: 'row' or 'col'") + + torch_class = torch.sparse_csr_tensor if orientation == "row" else torch.sparse_csc_tensor # Convert input into torch.Tensor (layout ==> torch.sparse_csr) - if isinstance(obj, scipy_csr_matrix): - obj = torch.sparse_csr_tensor( + # TODO: Check if conversion works across types + if isinstance(obj, (scipy_csr_matrix, scipy_csc_matrix)): + obj = torch_class( obj.indptr, obj.indices, obj.data, @@ -129,8 +237,12 @@ def sparse_csr_matrix( if obj.ndim != 2: raise ValueError(f"The number of dimensions must be 2, found {str(obj.ndim)}") - if obj.layout != torch.sparse_csr: - obj = obj.to_sparse_csr() + torch_layout = torch.sparse_csr if orientation == "row" else torch.sparse_csc + if obj.layout != torch_layout: + if torch_layout == torch.sparse_csr: + obj = obj.to_sparse_csr() + else: + obj = obj.to_sparse_csc() # infer dtype from obj if not explicitly given if dtype is None: @@ -156,19 +268,24 @@ def sparse_csr_matrix( lshape = gshape gnnz = obj.values().shape[0] - if split == 0: + compressed_indices = obj.crow_indices() if orientation == "row" else obj.ccol_indices() + element_indices = obj.col_indices() if orientation == "row" else obj.row_indices() + cls = DCSR_matrix if orientation == "row" else DCSC_matrix + + if (split == 0 and orientation == "row") or (split == 1 and orientation == "col"): start, end = comm.chunk(gshape, split, sparse=True) # Find the starting and ending indices for - # col_indices and values tensors for this process - indices_start = obj.crow_indices()[start] - indices_end = obj.crow_indices()[end] + # element_indices and values tensors for this process + indices_start = compressed_indices[start] + indices_end = compressed_indices[end] # Slice the data belonging to this process data = obj.values()[indices_start:indices_end] # start:(end + 1) because indptr is of size (n + 1) for array with n rows - indptr = obj.crow_indices()[start : end + 1] - indices = obj.col_indices()[indices_start:indices_end] + indptr = compressed_indices[start : end + 1] + + indices = element_indices[indices_start:indices_end] indptr = indptr - indptr[0] @@ -177,9 +294,9 @@ def sparse_csr_matrix( lshape = tuple(lshape) elif split is not None: - raise ValueError(f"Split axis {split} not supported for class DCSR_matrix") + raise ValueError(f"Split axis {split} not supported for class {cls.__name__}") - elif is_split == 0: + elif (is_split == 0 and orientation == "row") or (is_split == 1 and orientation == "col"): # Check whether the distributed data matches in # all dimensions other than axis 0 neighbour_shape = np.array(gshape) @@ -208,12 +325,12 @@ def sparse_csr_matrix( comm.Allreduce(MPI.IN_PLACE, reduction_buffer, MPI.MIN) if reduction_buffer < 0: raise ValueError( - "Unable to construct DCSR_matrix. Local data slices have inconsistent shapes or dimensions." + f"Unable to construct {cls.__name__}. Local data slices have inconsistent shapes or dimensions." ) data = obj.values() - indptr = obj.crow_indices() - indices = obj.col_indices() + indptr = compressed_indices + indices = element_indices # Calculate gshape gshape_split = torch.tensor(gshape[is_split]) @@ -231,14 +348,14 @@ def sparse_csr_matrix( split = is_split elif is_split is not None: - raise ValueError(f"Split axis {split} not supported for class DCSR_matrix") + raise ValueError(f"Split axis {is_split} not supported for class {cls.__name__}") else: # split is None and is_split is None data = obj.values() - indptr = obj.crow_indices() - indices = obj.col_indices() + indptr = compressed_indices + indices = element_indices - sparse_array = torch.sparse_csr_tensor( + sparse_array = torch_class( indptr.to(torch.int64), indices.to(torch.int64), data, @@ -247,7 +364,7 @@ def sparse_csr_matrix( device=device.torch_device, ) - return DCSR_matrix( + return cls( array=sparse_array, gnnz=gnnz, gshape=gshape, diff --git a/heat/sparse/manipulations.py b/heat/sparse/manipulations.py index b276abf243..355b04cbd1 100644 --- a/heat/sparse/manipulations.py +++ b/heat/sparse/manipulations.py @@ -2,19 +2,53 @@ from __future__ import annotations -from heat.sparse.dcsr_matrix import DCSR_matrix -from heat.sparse.factories import sparse_csr_matrix +from heat.sparse.dcsx_matrix import DCSC_matrix, DCSR_matrix, __DCSX_matrix +from heat.sparse.factories import sparse_csc_matrix, sparse_csr_matrix from ..core.memory import sanitize_memory_layout from ..core.dndarray import DNDarray from ..core.factories import empty __all__ = [ "to_dense", - "to_sparse", + "to_sparse_csr", + "to_sparse_csc", ] -def to_sparse(array: DNDarray) -> DCSR_matrix: +def __to_sparse(array: DNDarray, orientation="row") -> __DCSX_matrix: + """ + Convert the distributed array to a sparse DCSX_matrix representation. + This is a common method for converting a distributed array to a sparse matrix representation. + + Parameters + ---------- + array : DNDarray + The distributed array to be converted to a sparse matrix. + + orientation : str + The orientation of the sparse matrix. Options: ``'row'`` or ``'col'``. Default is ``'row'``. + + Returns + ------- + DCSX_matrix + + Raises + ------ + ValueError + If the orientation is not ``'row'`` or ``'col'``. + """ + if orientation not in ["row", "col"]: + raise ValueError(f"Invalid orientation: {orientation}. Options: 'row' or 'col'") + + array.balance_() + method = sparse_csr_matrix if orientation == "row" else sparse_csc_matrix + result = method( + array.larray, dtype=array.dtype, is_split=array.split, device=array.device, comm=array.comm + ) + return result + + +def to_sparse_csr(array: DNDarray) -> DCSR_matrix: """ Convert the distributed array to a sparse DCSR_matrix representation. @@ -28,31 +62,49 @@ def to_sparse(array: DNDarray) -> DCSR_matrix: DCSR_matrix A sparse DCSR_matrix representation of the input DNDarray. - Notes - ----- - This method allows for the conversion of a DNDarray into a sparse DCSR_matrix representation, - which is useful for handling large and sparse datasets efficiently. - Examples -------- >>> dense_array = ht.array([[1, 0, 0], [0, 0, 2], [0, 3, 0]]) - >>> sparse_matrix = dense_array.to_sparse() + >>> dense_array.to_sparse_csr() + (indptr: tensor([0, 1, 2, 3]), indices: tensor([0, 2, 1]), data: tensor([1, 2, 3]), dtype=ht.int64, device=cpu:0, split=None) + """ + return __to_sparse(array, orientation="row") + + +DNDarray.to_sparse_csr = to_sparse_csr +DNDarray.to_sparse_csr.__doc__ = to_sparse_csr.__doc__ + +def to_sparse_csc(array: DNDarray) -> DCSC_matrix: """ - array.balance_() - result = sparse_csr_matrix( - array.larray, dtype=array.dtype, device=array.device, comm=array.comm, is_split=array.split - ) - return result + Convert the distributed array to a sparse DCSC_matrix representation. + + Parameters + ---------- + array : DNDarray + The distributed array to be converted to a sparse DCSC_matrix. + + Returns + ------- + DCSC_matrix + A sparse DCSC_matrix representation of the input DNDarray. + + Examples + -------- + >>> dense_array = ht.array([[1, 0, 0], [0, 0, 2], [0, 3, 0]]) + >>> dense_array.to_sparse_csc() + (indptr: tensor([0, 1, 2, 3]), indices: tensor([0, 2, 1]), data: tensor([1, 3, 2]), dtype=ht.int64, device=cpu:0, split=None) + """ + return __to_sparse(array, orientation="col") -DNDarray.to_sparse = to_sparse -DNDarray.to_sparse.__doc__ = to_sparse.__doc__ +DNDarray.to_sparse_csc = to_sparse_csc +DNDarray.to_sparse_csc.__doc__ = to_sparse_csc.__doc__ -def to_dense(sparse_matrix: DCSR_matrix, order="C", out: DNDarray = None) -> DNDarray: +def to_dense(sparse_matrix: __DCSX_matrix, order="C", out: DNDarray = None) -> DNDarray: """ - Convert :class:`~heat.sparse.DCSR_matrix` to a dense :class:`~heat.core.DNDarray`. + Convert :class:`~heat.sparse.DCSX_matrix` to a dense :class:`~heat.core.DNDarray`. Output follows the same distribution among processes as the input Parameters @@ -113,5 +165,5 @@ def to_dense(sparse_matrix: DCSR_matrix, order="C", out: DNDarray = None) -> DND return out -DCSR_matrix.todense = lambda self, order="C", out=None: to_dense(self, order, out) -DCSR_matrix.to_dense = lambda self, order="C", out=None: to_dense(self, order, out) +__DCSX_matrix.todense = lambda self, order="C", out=None: to_dense(self, order, out) +__DCSX_matrix.to_dense = lambda self, order="C", out=None: to_dense(self, order, out) diff --git a/heat/sparse/tests/__init__.py b/heat/sparse/tests/__init__.py index da01ac5e0f..671573cf50 100644 --- a/heat/sparse/tests/__init__.py +++ b/heat/sparse/tests/__init__.py @@ -1,4 +1,4 @@ -from .test_arithmetics import * +from .test_arithmetics_csr import * from .test_dcsrmatrix import * from .test_factories import * from .test_manipulations import * diff --git a/heat/sparse/tests/test_arithmetics.py b/heat/sparse/tests/test_arithmetics_csr.py similarity index 99% rename from heat/sparse/tests/test_arithmetics.py rename to heat/sparse/tests/test_arithmetics_csr.py index 9c38b60b8b..aac8ced1d5 100644 --- a/heat/sparse/tests/test_arithmetics.py +++ b/heat/sparse/tests/test_arithmetics_csr.py @@ -13,10 +13,10 @@ int(torch.__version__.split(".")[0]) <= 1 and int(torch.__version__.split(".")[1]) < 12, f"ht.sparse requires torch >= 1.12. Found version {torch.__version__}.", ) -class TestArithmetics(TestCase): +class TestArithmeticsCSR(TestCase): @classmethod def setUpClass(self): - super(TestArithmetics, self).setUpClass() + super(TestArithmeticsCSR, self).setUpClass() """ A = [[0, 0, 1, 0, 2] diff --git a/heat/sparse/tests/test_dcscmatrix.py b/heat/sparse/tests/test_dcscmatrix.py new file mode 100644 index 0000000000..595ae483cc --- /dev/null +++ b/heat/sparse/tests/test_dcscmatrix.py @@ -0,0 +1,288 @@ +import unittest +import heat as ht +import torch + +from heat.core.tests.test_suites.basic_test import TestCase + +from typing import Tuple + + +@unittest.skipIf( + int(torch.__version__.split(".")[0]) <= 1 and int(torch.__version__.split(".")[1]) < 12, + f"ht.sparse requires torch >= 2.0. Found version {torch.__version__}.", +) +class TestDCSC_matrix(TestCase): + @classmethod + def setUpClass(self): + super(TestDCSC_matrix, self).setUpClass() + """ + A = [[0, 0, 1, 0, 2] + [0, 0, 0, 0, 0] + [0, 3, 0, 0, 0] + [4, 0, 0, 5, 0] + [0, 0, 0, 0, 6]] + """ + self.ref_indptr = torch.tensor( + [0, 1, 2, 3, 4, 6], dtype=torch.int, device=self.device.torch_device + ) + self.ref_indices = torch.tensor( + [3, 2, 0, 3, 0, 4], dtype=torch.int, device=self.device.torch_device + ) + self.ref_data = torch.tensor( + [4, 3, 1, 5, 2, 6], dtype=torch.float, device=self.device.torch_device + ) + self.ref_torch_sparse_csc = torch.sparse_csc_tensor( + self.ref_indptr, self.ref_indices, self.ref_data, device=self.device.torch_device + ) + + self.world_size = ht.communication.MPI_WORLD.size + self.rank = ht.communication.MPI_WORLD.rank + + def test_larray(self): + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self.ref_torch_sparse_csc) + + self.assertIsInstance(heat_sparse_csc.larray, torch.Tensor) + self.assertEqual(heat_sparse_csc.larray.layout, torch.sparse_csc) + self.assertEqual(tuple(heat_sparse_csc.larray.shape), heat_sparse_csc.lshape) + self.assertEqual(tuple(heat_sparse_csc.larray.shape), heat_sparse_csc.gshape) + + # Distributed case + if self.world_size > 1: + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self.ref_torch_sparse_csc, split=1) + + self.assertIsInstance(heat_sparse_csc.larray, torch.Tensor) + self.assertEqual(heat_sparse_csc.larray.layout, torch.sparse_csc) + self.assertEqual(tuple(heat_sparse_csc.larray.shape), heat_sparse_csc.lshape) + self.assertNotEqual(tuple(heat_sparse_csc.larray.shape), heat_sparse_csc.gshape) + + def test_nnz(self): + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self.ref_torch_sparse_csc) + + self.assertIsInstance(heat_sparse_csc.nnz, int) + self.assertIsInstance(heat_sparse_csc.gnnz, int) + self.assertIsInstance(heat_sparse_csc.lnnz, int) + + self.assertEqual(heat_sparse_csc.nnz, self.ref_torch_sparse_csc._nnz()) + self.assertEqual(heat_sparse_csc.nnz, heat_sparse_csc.gnnz) + self.assertEqual(heat_sparse_csc.nnz, heat_sparse_csc.lnnz) + + # Distributed case + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self.ref_torch_sparse_csc, split=1) + + if self.world_size == 2: + nnz_dist = [3, 3] + self.assertEqual(heat_sparse_csc.nnz, self.ref_torch_sparse_csc._nnz()) + self.assertEqual(heat_sparse_csc.lnnz, nnz_dist[self.rank]) + + if self.world_size == 3: + nnz_dist = [2, 2, 2] + self.assertEqual(heat_sparse_csc.nnz, self.ref_torch_sparse_csc._nnz()) + self.assertEqual(heat_sparse_csc.lnnz, nnz_dist[self.rank]) + + # Number of processes > Number of rows + if self.world_size == 6: + nnz_dist = [1, 1, 1, 1, 2, 0] + self.assertEqual(heat_sparse_csc.nnz, self.ref_torch_sparse_csc._nnz()) + self.assertEqual(heat_sparse_csc.lnnz, nnz_dist[self.rank]) + + def test_shape(self): + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self.ref_torch_sparse_csc) + + self.assertIsInstance(heat_sparse_csc.shape, Tuple) + self.assertIsInstance(heat_sparse_csc.gshape, Tuple) + self.assertIsInstance(heat_sparse_csc.lshape, Tuple) + + self.assertEqual(heat_sparse_csc.shape, self.ref_torch_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.shape, heat_sparse_csc.gshape) + self.assertEqual(heat_sparse_csc.shape, heat_sparse_csc.lshape) + + # Distributed case + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self.ref_torch_sparse_csc, split=1) + + if self.world_size == 2: + lshape_dist = [(5, 3), (5, 2)] + + self.assertEqual(heat_sparse_csc.shape, self.ref_torch_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.lshape, lshape_dist[self.rank]) + + if self.world_size == 3: + lshape_dist = [(5, 2), (5, 2), (5, 1)] + + self.assertEqual(heat_sparse_csc.shape, self.ref_torch_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.lshape, lshape_dist[self.rank]) + + # Number of processes > Number of rows + if self.world_size == 6: + lshape_dist = [(5, 1), (5, 1), (5, 1), (5, 1), (5, 1), (5, 0)] + + self.assertEqual(heat_sparse_csc.shape, self.ref_torch_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.lshape, lshape_dist[self.rank]) + + def test_dtype(self): + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self.ref_torch_sparse_csc) + self.assertEqual(heat_sparse_csc.dtype, ht.float32) + + def test_data(self): + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self.ref_torch_sparse_csc) + + self.assertTrue((heat_sparse_csc.data == self.ref_data).all()) + self.assertTrue((heat_sparse_csc.data == heat_sparse_csc.gdata).all()) + self.assertTrue((heat_sparse_csc.data == heat_sparse_csc.ldata).all()) + + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self.ref_torch_sparse_csc, split=1) + if self.world_size == 2: + data_dist = [[4, 3, 1], [5, 2, 6]] + + self.assertTrue((heat_sparse_csc.data == self.ref_data).all()) + self.assertTrue((heat_sparse_csc.data == heat_sparse_csc.gdata).all()) + self.assertTrue( + ( + heat_sparse_csc.ldata + == torch.tensor(data_dist[self.rank], device=self.device.torch_device) + ).all() + ) + + if self.world_size == 3: + data_dist = [[4, 3], [1, 5], [2, 6]] + + self.assertTrue((heat_sparse_csc.data == self.ref_data).all()) + self.assertTrue((heat_sparse_csc.data == heat_sparse_csc.gdata).all()) + self.assertTrue( + ( + heat_sparse_csc.ldata + == torch.tensor(data_dist[self.rank], device=self.device.torch_device) + ).all() + ) + + # Number of processes > Number of rows + if self.world_size == 6: + data_dist = [[4], [3], [1], [5], [2, 6], []] + + self.assertTrue((heat_sparse_csc.data == self.ref_data).all()) + self.assertTrue((heat_sparse_csc.data == heat_sparse_csc.gdata).all()) + self.assertTrue( + ( + heat_sparse_csc.ldata + == torch.tensor(data_dist[self.rank], device=self.device.torch_device) + ).all() + ) + + def test_indices(self): + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self.ref_torch_sparse_csc) + + self.assertTrue((heat_sparse_csc.indices == self.ref_indices).all()) + self.assertTrue((heat_sparse_csc.indices == heat_sparse_csc.gindices).all()) + self.assertTrue((heat_sparse_csc.indices == heat_sparse_csc.lindices).all()) + + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self.ref_torch_sparse_csc, split=1) + if self.world_size == 2: + indices_dist = [[3, 2, 0], [3, 0, 4]] + + self.assertTrue((heat_sparse_csc.indices == self.ref_indices).all()) + self.assertTrue((heat_sparse_csc.indices == heat_sparse_csc.gindices).all()) + self.assertTrue( + ( + heat_sparse_csc.lindices + == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) + ).all() + ) + + if self.world_size == 3: + indices_dist = [[3, 2], [0, 3], [0, 4]] + + self.assertTrue((heat_sparse_csc.indices == self.ref_indices).all()) + self.assertTrue((heat_sparse_csc.indices == heat_sparse_csc.gindices).all()) + self.assertTrue( + ( + heat_sparse_csc.lindices + == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) + ).all() + ) + + # Number of processes > Number of rows + if self.world_size == 6: + indices_dist = [[3], [2], [0], [3], [0, 4], []] + + self.assertTrue((heat_sparse_csc.indices == self.ref_indices).all()) + self.assertTrue((heat_sparse_csc.indices == heat_sparse_csc.gindices).all()) + self.assertTrue( + ( + heat_sparse_csc.lindices + == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) + ).all() + ) + + def test_indptr(self): + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self.ref_torch_sparse_csc) + + self.assertTrue((heat_sparse_csc.indptr == self.ref_indptr).all()) + self.assertTrue((heat_sparse_csc.indptr == heat_sparse_csc.gindptr).all()) + self.assertTrue((heat_sparse_csc.indptr == heat_sparse_csc.lindptr).all()) + """ + A = [[0, 0, 1, 0, 2] + [0, 0, 0, 0, 0] + [0, 3, 0, 0, 0] + [4, 0, 0, 5, 0] + [0, 0, 0, 0, 6]] + """ + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self.ref_torch_sparse_csc, split=1) + if self.world_size == 2: + indptr_dist = [[0, 1, 2, 3], [0, 1, 3]] + + self.assertTrue((heat_sparse_csc.indptr == self.ref_indptr).all()) + self.assertTrue((heat_sparse_csc.indptr == heat_sparse_csc.gindptr).all()) + self.assertTrue( + ( + heat_sparse_csc.lindptr + == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) + ).all() + ) + + if self.world_size == 3: + indptr_dist = [[0, 1, 2], [0, 1, 2], [0, 2]] + + self.assertTrue((heat_sparse_csc.indptr == self.ref_indptr).all()) + self.assertTrue((heat_sparse_csc.indptr == heat_sparse_csc.gindptr).all()) + self.assertTrue( + ( + heat_sparse_csc.lindptr + == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) + ).all() + ) + + # Number of processes > Number of rows + if self.world_size == 6: + indptr_dist = [[0, 1], [0, 1], [0, 1], [0, 1], [0, 2], [0]] + + self.assertTrue((heat_sparse_csc.indptr == self.ref_indptr).all()) + self.assertTrue((heat_sparse_csc.indptr == heat_sparse_csc.gindptr).all()) + self.assertTrue( + ( + heat_sparse_csc.lindptr + == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) + ).all() + ) + + @unittest.skipIf( + int(torch.__version__.split(".")[0]) <= 1, + f"ht.sparse requires torch >= 2.0. Found version {torch.__version__}.", + ) + def test_astype(self): + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self.ref_torch_sparse_csc) + + # check starting invariant + self.assertEqual(heat_sparse_csc.dtype, ht.float32) + + # check the copy case for uint8 + as_uint8 = heat_sparse_csc.astype(ht.uint8) + self.assertIsInstance(as_uint8, ht.sparse.DCSC_matrix) + self.assertEqual(as_uint8.dtype, ht.uint8) + self.assertEqual(as_uint8.larray.dtype, torch.uint8) + self.assertIsNot(as_uint8, heat_sparse_csc) + + # check the copy case for uint8 + as_float64 = heat_sparse_csc.astype(ht.float64, copy=False) + self.assertIsInstance(as_float64, ht.sparse.DCSC_matrix) + self.assertEqual(as_float64.dtype, ht.float64) + self.assertEqual(as_float64.larray.dtype, torch.float64) + self.assertIs(as_float64, heat_sparse_csc) diff --git a/heat/sparse/tests/test_factories.py b/heat/sparse/tests/test_factories.py index 5728534605..b9422f1d3f 100644 --- a/heat/sparse/tests/test_factories.py +++ b/heat/sparse/tests/test_factories.py @@ -22,27 +22,34 @@ def setUpClass(self): [4, 0, 0, 5, 0] [0, 0, 0, 0, 6]] """ - self.matrix_list = [ + self.arr = [ [0, 0, 1, 0, 2], [0, 0, 0, 0, 0], [0, 3, 0, 0, 0], [4, 0, 0, 5, 0], [0, 0, 0, 0, 6], ] - self.ref_indptr = torch.tensor( + + self.world_size = ht.communication.MPI_WORLD.size + self.rank = ht.communication.MPI_WORLD.rank + + def test_sparse_csr_matrix(self): + ref_indptr = torch.tensor( [0, 2, 2, 3, 5, 6], dtype=torch.int, device=self.device.torch_device ) - self.ref_indices = torch.tensor( + ref_indices = torch.tensor( [2, 4, 1, 0, 3, 4], dtype=torch.int, device=self.device.torch_device ) - self.ref_data = torch.tensor( + ref_data = torch.tensor( [1, 2, 3, 4, 5, 6], dtype=torch.float, device=self.device.torch_device ) - self.ref_torch_sparse_csr = torch.sparse_csr_tensor( - self.ref_indptr, self.ref_indices, self.ref_data, device=self.device.torch_device + ref_torch_sparse_csr = torch.sparse_csr_tensor( + ref_indptr, ref_indices, ref_data, device=self.device.torch_device ) - + """ + Input sparse: torch.Tensor + """ self.ref_scipy_sparse_csr = scipy.sparse.csr_matrix( ( torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.float, device="cpu"), @@ -50,32 +57,24 @@ def setUpClass(self): torch.tensor([0, 2, 2, 3, 5, 6], dtype=torch.int, device="cpu"), ) ) - - self.world_size = ht.communication.MPI_WORLD.size - self.rank = ht.communication.MPI_WORLD.rank - - def test_sparse_csr_matrix(self): - """ - Input sparse: torch.Tensor - """ - heat_sparse_csr = ht.sparse.sparse_csr_matrix(self.ref_torch_sparse_csr) + heat_sparse_csr = ht.sparse.sparse_csr_matrix(ref_torch_sparse_csr) self.assertIsInstance(heat_sparse_csr, ht.sparse.DCSR_matrix) self.assertEqual(heat_sparse_csr.dtype, ht.float32) self.assertEqual(heat_sparse_csr.indptr.dtype, torch.int64) self.assertEqual(heat_sparse_csr.indices.dtype, torch.int64) - self.assertEqual(heat_sparse_csr.shape, self.ref_torch_sparse_csr.shape) - self.assertEqual(heat_sparse_csr.lshape, self.ref_torch_sparse_csr.shape) + self.assertEqual(heat_sparse_csr.shape, ref_torch_sparse_csr.shape) + self.assertEqual(heat_sparse_csr.lshape, ref_torch_sparse_csr.shape) self.assertEqual(heat_sparse_csr.split, None) - self.assertTrue((heat_sparse_csr.indptr == self.ref_indptr).all()) - self.assertTrue((heat_sparse_csr.lindptr == self.ref_indptr).all()) - self.assertTrue((heat_sparse_csr.indices == self.ref_indices).all()) - self.assertTrue((heat_sparse_csr.lindices == self.ref_indices).all()) - self.assertTrue((heat_sparse_csr.data == self.ref_data).all()) - self.assertTrue((heat_sparse_csr.ldata == self.ref_data).all()) + self.assertTrue((heat_sparse_csr.indptr == ref_indptr).all()) + self.assertTrue((heat_sparse_csr.lindptr == ref_indptr).all()) + self.assertTrue((heat_sparse_csr.indices == ref_indices).all()) + self.assertTrue((heat_sparse_csr.lindices == ref_indices).all()) + self.assertTrue((heat_sparse_csr.data == ref_data).all()) + self.assertTrue((heat_sparse_csr.ldata == ref_data).all()) heat_sparse_csr = ht.sparse.sparse_csr_matrix( - self.ref_torch_sparse_csr, dtype=ht.float32, device=self.device + ref_torch_sparse_csr, dtype=ht.float32, device=self.device ) self.assertEqual(heat_sparse_csr.dtype, ht.float32) self.assertEqual(heat_sparse_csr.device, self.device) @@ -88,25 +87,25 @@ def test_sparse_csr_matrix(self): lshape_dist = [(3, 5), (2, 5)] - heat_sparse_csr = ht.sparse.sparse_csr_matrix(self.ref_torch_sparse_csr, split=0) - self.assertEqual(heat_sparse_csr.shape, self.ref_torch_sparse_csr.shape) + heat_sparse_csr = ht.sparse.sparse_csr_matrix(ref_torch_sparse_csr, split=0) + self.assertEqual(heat_sparse_csr.shape, ref_torch_sparse_csr.shape) self.assertEqual(heat_sparse_csr.lshape, lshape_dist[self.rank]) self.assertEqual(heat_sparse_csr.split, 0) - self.assertTrue((heat_sparse_csr.indptr == self.ref_indptr).all()) + self.assertTrue((heat_sparse_csr.indptr == ref_indptr).all()) self.assertTrue( ( heat_sparse_csr.lindptr == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) ).all() ) - self.assertTrue((heat_sparse_csr.indices == self.ref_indices).all()) + self.assertTrue((heat_sparse_csr.indices == ref_indices).all()) self.assertTrue( ( heat_sparse_csr.lindices == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) ).all() ) - self.assertTrue((heat_sparse_csr.data == self.ref_data).all()) + self.assertTrue((heat_sparse_csr.data == ref_data).all()) self.assertTrue( ( heat_sparse_csr.ldata @@ -121,25 +120,25 @@ def test_sparse_csr_matrix(self): lshape_dist = [(2, 5), (2, 5), (1, 5)] - heat_sparse_csr = ht.sparse.sparse_csr_matrix(self.ref_torch_sparse_csr, split=0) - self.assertEqual(heat_sparse_csr.shape, self.ref_torch_sparse_csr.shape) + heat_sparse_csr = ht.sparse.sparse_csr_matrix(ref_torch_sparse_csr, split=0) + self.assertEqual(heat_sparse_csr.shape, ref_torch_sparse_csr.shape) self.assertEqual(heat_sparse_csr.lshape, lshape_dist[self.rank]) self.assertEqual(heat_sparse_csr.split, 0) - self.assertTrue((heat_sparse_csr.indptr == self.ref_indptr).all()) + self.assertTrue((heat_sparse_csr.indptr == ref_indptr).all()) self.assertTrue( ( heat_sparse_csr.lindptr == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) ).all() ) - self.assertTrue((heat_sparse_csr.indices == self.ref_indices).all()) + self.assertTrue((heat_sparse_csr.indices == ref_indices).all()) self.assertTrue( ( heat_sparse_csr.lindices == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) ).all() ) - self.assertTrue((heat_sparse_csr.data == self.ref_data).all()) + self.assertTrue((heat_sparse_csr.data == ref_data).all()) self.assertTrue( ( heat_sparse_csr.ldata @@ -155,25 +154,25 @@ def test_sparse_csr_matrix(self): lshape_dist = [(1, 5), (1, 5), (1, 5), (1, 5), (1, 5), (0, 5)] - heat_sparse_csr = ht.sparse.sparse_csr_matrix(self.ref_torch_sparse_csr, split=0) - self.assertEqual(heat_sparse_csr.shape, self.ref_torch_sparse_csr.shape) + heat_sparse_csr = ht.sparse.sparse_csr_matrix(ref_torch_sparse_csr, split=0) + self.assertEqual(heat_sparse_csr.shape, ref_torch_sparse_csr.shape) self.assertEqual(heat_sparse_csr.lshape, lshape_dist[self.rank]) self.assertEqual(heat_sparse_csr.split, 0) - self.assertTrue((heat_sparse_csr.indptr == self.ref_indptr).all()) + self.assertTrue((heat_sparse_csr.indptr == ref_indptr).all()) self.assertTrue( ( heat_sparse_csr.lindptr == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) ).all() ) - self.assertTrue((heat_sparse_csr.indices == self.ref_indices).all()) + self.assertTrue((heat_sparse_csr.indices == ref_indices).all()) self.assertTrue( ( heat_sparse_csr.lindices == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) ).all() ) - self.assertTrue((heat_sparse_csr.data == self.ref_data).all()) + self.assertTrue((heat_sparse_csr.data == ref_data).all()) self.assertTrue( ( heat_sparse_csr.ldata @@ -197,24 +196,24 @@ def test_sparse_csr_matrix(self): ) heat_sparse_csr = ht.sparse.sparse_csr_matrix(dist_torch_sparse_csr, is_split=0) - self.assertEqual(heat_sparse_csr.shape, self.ref_torch_sparse_csr.shape) + self.assertEqual(heat_sparse_csr.shape, ref_torch_sparse_csr.shape) self.assertEqual(heat_sparse_csr.lshape, lshape_dist[self.rank]) self.assertEqual(heat_sparse_csr.split, 0) - self.assertTrue((heat_sparse_csr.indptr == self.ref_indptr).all()) + self.assertTrue((heat_sparse_csr.indptr == ref_indptr).all()) self.assertTrue( ( heat_sparse_csr.lindptr == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) ).all() ) - self.assertTrue((heat_sparse_csr.indices == self.ref_indices).all()) + self.assertTrue((heat_sparse_csr.indices == ref_indices).all()) self.assertTrue( ( heat_sparse_csr.lindices == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) ).all() ) - self.assertTrue((heat_sparse_csr.data == self.ref_data).all()) + self.assertTrue((heat_sparse_csr.data == ref_data).all()) self.assertTrue( ( heat_sparse_csr.ldata @@ -237,24 +236,24 @@ def test_sparse_csr_matrix(self): ) heat_sparse_csr = ht.sparse.sparse_csr_matrix(dist_torch_sparse_csr, is_split=0) - self.assertEqual(heat_sparse_csr.shape, self.ref_torch_sparse_csr.shape) + self.assertEqual(heat_sparse_csr.shape, ref_torch_sparse_csr.shape) self.assertEqual(heat_sparse_csr.lshape, lshape_dist[self.rank]) self.assertEqual(heat_sparse_csr.split, 0) - self.assertTrue((heat_sparse_csr.indptr == self.ref_indptr).all()) + self.assertTrue((heat_sparse_csr.indptr == ref_indptr).all()) self.assertTrue( ( heat_sparse_csr.lindptr == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) ).all() ) - self.assertTrue((heat_sparse_csr.indices == self.ref_indices).all()) + self.assertTrue((heat_sparse_csr.indices == ref_indices).all()) self.assertTrue( ( heat_sparse_csr.lindices == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) ).all() ) - self.assertTrue((heat_sparse_csr.data == self.ref_data).all()) + self.assertTrue((heat_sparse_csr.data == ref_data).all()) self.assertTrue( ( heat_sparse_csr.ldata @@ -274,12 +273,12 @@ def test_sparse_csr_matrix(self): self.assertEqual(heat_sparse_csr.shape, self.ref_scipy_sparse_csr.shape) self.assertEqual(heat_sparse_csr.lshape, self.ref_scipy_sparse_csr.shape) self.assertEqual(heat_sparse_csr.split, None) - self.assertTrue((heat_sparse_csr.indptr == self.ref_indptr).all()) - self.assertTrue((heat_sparse_csr.lindptr == self.ref_indptr).all()) - self.assertTrue((heat_sparse_csr.indices == self.ref_indices).all()) - self.assertTrue((heat_sparse_csr.lindices == self.ref_indices).all()) - self.assertTrue((heat_sparse_csr.data == self.ref_data).all()) - self.assertTrue((heat_sparse_csr.ldata == self.ref_data).all()) + self.assertTrue((heat_sparse_csr.indptr == ref_indptr).all()) + self.assertTrue((heat_sparse_csr.lindptr == ref_indptr).all()) + self.assertTrue((heat_sparse_csr.indices == ref_indices).all()) + self.assertTrue((heat_sparse_csr.lindices == ref_indices).all()) + self.assertTrue((heat_sparse_csr.data == ref_data).all()) + self.assertTrue((heat_sparse_csr.ldata == ref_data).all()) # Distributed case (split) if self.world_size == 2: @@ -293,21 +292,21 @@ def test_sparse_csr_matrix(self): self.assertEqual(heat_sparse_csr.shape, self.ref_scipy_sparse_csr.shape) self.assertEqual(heat_sparse_csr.lshape, lshape_dist[self.rank]) self.assertEqual(heat_sparse_csr.split, 0) - self.assertTrue((heat_sparse_csr.indptr == self.ref_indptr).all()) + self.assertTrue((heat_sparse_csr.indptr == ref_indptr).all()) self.assertTrue( ( heat_sparse_csr.lindptr == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) ).all() ) - self.assertTrue((heat_sparse_csr.indices == self.ref_indices).all()) + self.assertTrue((heat_sparse_csr.indices == ref_indices).all()) self.assertTrue( ( heat_sparse_csr.lindices == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) ).all() ) - self.assertTrue((heat_sparse_csr.data == self.ref_data).all()) + self.assertTrue((heat_sparse_csr.data == ref_data).all()) self.assertTrue( ( heat_sparse_csr.ldata @@ -326,21 +325,21 @@ def test_sparse_csr_matrix(self): self.assertEqual(heat_sparse_csr.shape, self.ref_scipy_sparse_csr.shape) self.assertEqual(heat_sparse_csr.lshape, lshape_dist[self.rank]) self.assertEqual(heat_sparse_csr.split, 0) - self.assertTrue((heat_sparse_csr.indptr == self.ref_indptr).all()) + self.assertTrue((heat_sparse_csr.indptr == ref_indptr).all()) self.assertTrue( ( heat_sparse_csr.lindptr == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) ).all() ) - self.assertTrue((heat_sparse_csr.indices == self.ref_indices).all()) + self.assertTrue((heat_sparse_csr.indices == ref_indices).all()) self.assertTrue( ( heat_sparse_csr.lindices == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) ).all() ) - self.assertTrue((heat_sparse_csr.data == self.ref_data).all()) + self.assertTrue((heat_sparse_csr.data == ref_data).all()) self.assertTrue( ( heat_sparse_csr.ldata @@ -360,21 +359,21 @@ def test_sparse_csr_matrix(self): self.assertEqual(heat_sparse_csr.shape, self.ref_scipy_sparse_csr.shape) self.assertEqual(heat_sparse_csr.lshape, lshape_dist[self.rank]) self.assertEqual(heat_sparse_csr.split, 0) - self.assertTrue((heat_sparse_csr.indptr == self.ref_indptr).all()) + self.assertTrue((heat_sparse_csr.indptr == ref_indptr).all()) self.assertTrue( ( heat_sparse_csr.lindptr == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) ).all() ) - self.assertTrue((heat_sparse_csr.indices == self.ref_indices).all()) + self.assertTrue((heat_sparse_csr.indices == ref_indices).all()) self.assertTrue( ( heat_sparse_csr.lindices == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) ).all() ) - self.assertTrue((heat_sparse_csr.data == self.ref_data).all()) + self.assertTrue((heat_sparse_csr.data == ref_data).all()) self.assertTrue( ( heat_sparse_csr.ldata @@ -400,24 +399,24 @@ def test_sparse_csr_matrix(self): ) heat_sparse_csr = ht.sparse.sparse_csr_matrix(dist_scipy_sparse_csr, is_split=0) - self.assertEqual(heat_sparse_csr.shape, self.ref_torch_sparse_csr.shape) + self.assertEqual(heat_sparse_csr.shape, ref_torch_sparse_csr.shape) self.assertEqual(heat_sparse_csr.lshape, lshape_dist[self.rank]) self.assertEqual(heat_sparse_csr.split, 0) - self.assertTrue((heat_sparse_csr.indptr == self.ref_indptr).all()) + self.assertTrue((heat_sparse_csr.indptr == ref_indptr).all()) self.assertTrue( ( heat_sparse_csr.lindptr == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) ).all() ) - self.assertTrue((heat_sparse_csr.indices == self.ref_indices).all()) + self.assertTrue((heat_sparse_csr.indices == ref_indices).all()) self.assertTrue( ( heat_sparse_csr.lindices == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) ).all() ) - self.assertTrue((heat_sparse_csr.data == self.ref_data).all()) + self.assertTrue((heat_sparse_csr.data == ref_data).all()) self.assertTrue( ( heat_sparse_csr.ldata @@ -441,24 +440,24 @@ def test_sparse_csr_matrix(self): shape=lshape_dist[self.rank], ) - self.assertEqual(heat_sparse_csr.shape, self.ref_torch_sparse_csr.shape) + self.assertEqual(heat_sparse_csr.shape, ref_torch_sparse_csr.shape) self.assertEqual(heat_sparse_csr.lshape, lshape_dist[self.rank]) self.assertEqual(heat_sparse_csr.split, 0) - self.assertTrue((heat_sparse_csr.indptr == self.ref_indptr).all()) + self.assertTrue((heat_sparse_csr.indptr == ref_indptr).all()) self.assertTrue( ( heat_sparse_csr.lindptr == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) ).all() ) - self.assertTrue((heat_sparse_csr.indices == self.ref_indices).all()) + self.assertTrue((heat_sparse_csr.indices == ref_indices).all()) self.assertTrue( ( heat_sparse_csr.lindices == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) ).all() ) - self.assertTrue((heat_sparse_csr.data == self.ref_data).all()) + self.assertTrue((heat_sparse_csr.data == ref_data).all()) self.assertTrue( ( heat_sparse_csr.ldata @@ -469,43 +468,41 @@ def test_sparse_csr_matrix(self): """ Input: torch.Tensor """ - torch_tensor = torch.tensor( - self.matrix_list, dtype=torch.float, device=self.device.torch_device - ) + torch_tensor = torch.tensor(self.arr, dtype=torch.float, device=self.device.torch_device) heat_sparse_csr = ht.sparse.sparse_csr_matrix(torch_tensor) self.assertIsInstance(heat_sparse_csr, ht.sparse.DCSR_matrix) self.assertEqual(heat_sparse_csr.dtype, ht.float32) self.assertEqual(heat_sparse_csr.indptr.dtype, torch.int64) self.assertEqual(heat_sparse_csr.indices.dtype, torch.int64) - self.assertEqual(heat_sparse_csr.shape, self.ref_torch_sparse_csr.shape) - self.assertEqual(heat_sparse_csr.lshape, self.ref_torch_sparse_csr.shape) + self.assertEqual(heat_sparse_csr.shape, ref_torch_sparse_csr.shape) + self.assertEqual(heat_sparse_csr.lshape, ref_torch_sparse_csr.shape) self.assertEqual(heat_sparse_csr.split, None) - self.assertTrue((heat_sparse_csr.indptr == self.ref_indptr).all()) - self.assertTrue((heat_sparse_csr.lindptr == self.ref_indptr).all()) - self.assertTrue((heat_sparse_csr.indices == self.ref_indices).all()) - self.assertTrue((heat_sparse_csr.lindices == self.ref_indices).all()) - self.assertTrue((heat_sparse_csr.data == self.ref_data).all()) - self.assertTrue((heat_sparse_csr.ldata == self.ref_data).all()) + self.assertTrue((heat_sparse_csr.indptr == ref_indptr).all()) + self.assertTrue((heat_sparse_csr.lindptr == ref_indptr).all()) + self.assertTrue((heat_sparse_csr.indices == ref_indices).all()) + self.assertTrue((heat_sparse_csr.lindices == ref_indices).all()) + self.assertTrue((heat_sparse_csr.data == ref_data).all()) + self.assertTrue((heat_sparse_csr.ldata == ref_data).all()) """ Input: List[int] """ - heat_sparse_csr = ht.sparse.sparse_csr_matrix(self.matrix_list) + heat_sparse_csr = ht.sparse.sparse_csr_matrix(self.arr) self.assertIsInstance(heat_sparse_csr, ht.sparse.DCSR_matrix) self.assertEqual(heat_sparse_csr.dtype, ht.int64) self.assertEqual(heat_sparse_csr.indptr.dtype, torch.int64) self.assertEqual(heat_sparse_csr.indices.dtype, torch.int64) - self.assertEqual(heat_sparse_csr.shape, self.ref_torch_sparse_csr.shape) - self.assertEqual(heat_sparse_csr.lshape, self.ref_torch_sparse_csr.shape) + self.assertEqual(heat_sparse_csr.shape, ref_torch_sparse_csr.shape) + self.assertEqual(heat_sparse_csr.lshape, ref_torch_sparse_csr.shape) self.assertEqual(heat_sparse_csr.split, None) - self.assertTrue((heat_sparse_csr.indptr == self.ref_indptr).all()) - self.assertTrue((heat_sparse_csr.lindptr == self.ref_indptr).all()) - self.assertTrue((heat_sparse_csr.indices == self.ref_indices).all()) - self.assertTrue((heat_sparse_csr.lindices == self.ref_indices).all()) - self.assertTrue((heat_sparse_csr.data == self.ref_data).all()) - self.assertTrue((heat_sparse_csr.ldata == self.ref_data).all()) + self.assertTrue((heat_sparse_csr.indptr == ref_indptr).all()) + self.assertTrue((heat_sparse_csr.lindptr == ref_indptr).all()) + self.assertTrue((heat_sparse_csr.indices == ref_indices).all()) + self.assertTrue((heat_sparse_csr.lindices == ref_indices).all()) + self.assertTrue((heat_sparse_csr.data == ref_data).all()) + self.assertTrue((heat_sparse_csr.ldata == ref_data).all()) with self.assertRaises(TypeError): # Passing an object which cant be converted into a torch.Tensor @@ -513,7 +510,7 @@ def test_sparse_csr_matrix(self): # Errors (torch.Tensor) with self.assertRaises(ValueError): - heat_sparse_csr = ht.sparse.sparse_csr_matrix(self.ref_torch_sparse_csr, split=1) + heat_sparse_csr = ht.sparse.sparse_csr_matrix(ref_torch_sparse_csr, split=1) with self.assertRaises(ValueError): dist_torch_sparse_csr = torch.sparse_csr_tensor( @@ -557,3 +554,525 @@ def test_sparse_csr_matrix(self): self.assertRaises(ValueError, ht.sparse.sparse_csr_matrix, torch.tensor([0, 1])) self.assertRaises(ValueError, ht.sparse.sparse_csr_matrix, torch.tensor([[[1, 0, 3]]])) + + def test_sparse_csc_matrix(self): + ref_indptr = torch.tensor( + [0, 1, 2, 3, 4, 6], dtype=torch.int, device=self.device.torch_device + ) + ref_indices = torch.tensor( + [3, 2, 0, 3, 0, 4], dtype=torch.int, device=self.device.torch_device + ) + ref_data = torch.tensor( + [4, 3, 1, 5, 2, 6], dtype=torch.float, device=self.device.torch_device + ) + + ref_torch_sparse_csc = torch.sparse_csc_tensor( + ref_indptr, ref_indices, ref_data, device=self.device.torch_device + ) + """ + Input sparse: torch.Tensor + """ + self.ref_scipy_sparse_csc = scipy.sparse.csc_matrix( + ( + torch.tensor([4, 3, 1, 5, 2, 6], dtype=torch.float, device="cpu"), + torch.tensor([3, 2, 0, 3, 0, 4], dtype=torch.int, device="cpu"), + torch.tensor([0, 1, 2, 3, 4, 6], dtype=torch.int, device="cpu"), + ) + ) + heat_sparse_csc = ht.sparse.sparse_csc_matrix(ref_torch_sparse_csc) + + self.assertIsInstance(heat_sparse_csc, ht.sparse.DCSC_matrix) + self.assertEqual(heat_sparse_csc.dtype, ht.float32) + self.assertEqual(heat_sparse_csc.indptr.dtype, torch.int64) + self.assertEqual(heat_sparse_csc.indices.dtype, torch.int64) + self.assertEqual(heat_sparse_csc.shape, ref_torch_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.lshape, ref_torch_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.split, None) + self.assertTrue((heat_sparse_csc.indptr == ref_indptr).all()) + self.assertTrue((heat_sparse_csc.lindptr == ref_indptr).all()) + self.assertTrue((heat_sparse_csc.indices == ref_indices).all()) + self.assertTrue((heat_sparse_csc.lindices == ref_indices).all()) + self.assertTrue((heat_sparse_csc.data == ref_data).all()) + self.assertTrue((heat_sparse_csc.ldata == ref_data).all()) + + heat_sparse_csc = ht.sparse.sparse_csc_matrix( + ref_torch_sparse_csc, dtype=ht.float32, device=self.device + ) + self.assertEqual(heat_sparse_csc.dtype, ht.float32) + self.assertEqual(heat_sparse_csc.device, self.device) + + # Distributed case (split) + if self.world_size == 2: + indptr_dist = [[0, 1, 2, 3], [0, 1, 3]] + indices_dist = [[3, 2, 0], [3, 0, 4]] + data_dist = [[4, 3, 1], [5, 2, 6]] + + lshape_dist = [(5, 3), (5, 2)] + + heat_sparse_csc = ht.sparse.sparse_csc_matrix(ref_torch_sparse_csc, split=1) + self.assertEqual(heat_sparse_csc.shape, ref_torch_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.lshape, lshape_dist[self.rank]) + self.assertEqual(heat_sparse_csc.split, 1) + self.assertTrue((heat_sparse_csc.indptr == ref_indptr).all()) + self.assertTrue( + ( + heat_sparse_csc.lindptr + == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) + ).all() + ) + self.assertTrue((heat_sparse_csc.indices == ref_indices).all()) + self.assertTrue( + ( + heat_sparse_csc.lindices + == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) + ).all() + ) + self.assertTrue((heat_sparse_csc.data == ref_data).all()) + self.assertTrue( + ( + heat_sparse_csc.ldata + == torch.tensor(data_dist[self.rank], device=self.device.torch_device) + ).all() + ) + + if self.world_size == 3: + indptr_dist = [[0, 1, 2], [0, 1, 2], [0, 2]] + indices_dist = [[3, 2], [0, 3], [0, 4]] + data_dist = [[4, 3], [1, 5], [2, 6]] + + lshape_dist = [(5, 2), (5, 2), (5, 1)] + + heat_sparse_csc = ht.sparse.sparse_csc_matrix(ref_torch_sparse_csc, split=1) + self.assertEqual(heat_sparse_csc.shape, ref_torch_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.lshape, lshape_dist[self.rank]) + self.assertEqual(heat_sparse_csc.split, 1) + self.assertTrue((heat_sparse_csc.indptr == ref_indptr).all()) + self.assertTrue( + ( + heat_sparse_csc.lindptr + == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) + ).all() + ) + self.assertTrue((heat_sparse_csc.indices == ref_indices).all()) + self.assertTrue( + ( + heat_sparse_csc.lindices + == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) + ).all() + ) + self.assertTrue((heat_sparse_csc.data == ref_data).all()) + self.assertTrue( + ( + heat_sparse_csc.ldata + == torch.tensor(data_dist[self.rank], device=self.device.torch_device) + ).all() + ) + + # Number of processes > Number of rows + if self.world_size == 6: + indptr_dist = [[0, 1], [0, 1], [0, 1], [0, 1], [0, 2], [0]] + indices_dist = [[3], [2], [0], [3], [0, 4], []] + data_dist = [[4], [3], [1], [5], [2, 6], []] + + lshape_dist = [(5, 1), (5, 1), (5, 1), (5, 1), (5, 1), (5, 0)] + + heat_sparse_csc = ht.sparse.sparse_csc_matrix(ref_torch_sparse_csc, split=1) + self.assertEqual(heat_sparse_csc.shape, ref_torch_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.lshape, lshape_dist[self.rank]) + self.assertEqual(heat_sparse_csc.split, 1) + self.assertTrue((heat_sparse_csc.indptr == ref_indptr).all()) + self.assertTrue( + ( + heat_sparse_csc.lindptr + == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) + ).all() + ) + self.assertTrue((heat_sparse_csc.indices == ref_indices).all()) + self.assertTrue( + ( + heat_sparse_csc.lindices + == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) + ).all() + ) + self.assertTrue((heat_sparse_csc.data == ref_data).all()) + self.assertTrue( + ( + heat_sparse_csc.ldata + == torch.tensor(data_dist[self.rank], device=self.device.torch_device) + ).all() + ) + + # Distributed case (is_split) + if self.world_size == 2: + indptr_dist = [[0, 1, 2, 3], [0, 1, 3]] + indices_dist = [[3, 2, 0], [3, 0, 4]] + data_dist = [[4, 3, 1], [5, 2, 6]] + + lshape_dist = [(5, 3), (5, 2)] + + dist_torch_sparse_csc = torch.sparse_csc_tensor( + torch.tensor(indptr_dist[self.rank], device=self.device.torch_device), + torch.tensor(indices_dist[self.rank], device=self.device.torch_device), + torch.tensor(data_dist[self.rank], device=self.device.torch_device), + size=lshape_dist[self.rank], + ) + + heat_sparse_csc = ht.sparse.sparse_csc_matrix(dist_torch_sparse_csc, is_split=1) + self.assertEqual(heat_sparse_csc.shape, ref_torch_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.lshape, lshape_dist[self.rank]) + self.assertEqual(heat_sparse_csc.split, 1) + self.assertTrue((heat_sparse_csc.indptr == ref_indptr).all()) + self.assertTrue( + ( + heat_sparse_csc.lindptr + == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) + ).all() + ) + self.assertTrue((heat_sparse_csc.indices == ref_indices).all()) + self.assertTrue( + ( + heat_sparse_csc.lindices + == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) + ).all() + ) + self.assertTrue((heat_sparse_csc.data == ref_data).all()) + self.assertTrue( + ( + heat_sparse_csc.ldata + == torch.tensor(data_dist[self.rank], device=self.device.torch_device) + ).all() + ) + + if self.world_size == 3: + indptr_dist = [[0, 1, 2], [0, 1, 2], [0, 2]] + indices_dist = [[3, 2], [0, 3], [0, 4]] + data_dist = [[4, 3], [1, 5], [2, 6]] + + lshape_dist = [(5, 2), (5, 2), (5, 1)] + + dist_torch_sparse_csc = torch.sparse_csc_tensor( + torch.tensor(indptr_dist[self.rank], device=self.device.torch_device), + torch.tensor(indices_dist[self.rank], device=self.device.torch_device), + torch.tensor(data_dist[self.rank], device=self.device.torch_device), + size=lshape_dist[self.rank], + ) + + heat_sparse_csc = ht.sparse.sparse_csc_matrix(dist_torch_sparse_csc, is_split=1) + self.assertEqual(heat_sparse_csc.shape, ref_torch_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.lshape, lshape_dist[self.rank]) + self.assertEqual(heat_sparse_csc.split, 1) + self.assertTrue((heat_sparse_csc.indptr == ref_indptr).all()) + self.assertTrue( + ( + heat_sparse_csc.lindptr + == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) + ).all() + ) + self.assertTrue((heat_sparse_csc.indices == ref_indices).all()) + self.assertTrue( + ( + heat_sparse_csc.lindices + == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) + ).all() + ) + self.assertTrue((heat_sparse_csc.data == ref_data).all()) + self.assertTrue( + ( + heat_sparse_csc.ldata + == torch.tensor(data_dist[self.rank], device=self.device.torch_device) + ).all() + ) + + """ + Input sparse: scipy.sparse.csc_matrix + """ + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self.ref_scipy_sparse_csc) + + self.assertIsInstance(heat_sparse_csc, ht.sparse.DCSC_matrix) + self.assertEqual(heat_sparse_csc.dtype, ht.float32) + self.assertEqual(heat_sparse_csc.indptr.dtype, torch.int64) + self.assertEqual(heat_sparse_csc.indices.dtype, torch.int64) + self.assertEqual(heat_sparse_csc.shape, self.ref_scipy_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.lshape, self.ref_scipy_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.split, None) + self.assertTrue((heat_sparse_csc.indptr == ref_indptr).all()) + self.assertTrue((heat_sparse_csc.lindptr == ref_indptr).all()) + self.assertTrue((heat_sparse_csc.indices == ref_indices).all()) + self.assertTrue((heat_sparse_csc.lindices == ref_indices).all()) + self.assertTrue((heat_sparse_csc.data == ref_data).all()) + self.assertTrue((heat_sparse_csc.ldata == ref_data).all()) + + # Distributed case (split) + if self.world_size == 2: + indptr_dist = [[0, 1, 2, 3], [0, 1, 3]] + indices_dist = [[3, 2, 0], [3, 0, 4]] + data_dist = [[4, 3, 1], [5, 2, 6]] + + lshape_dist = [(5, 3), (5, 2)] + + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self.ref_scipy_sparse_csc, split=1) + self.assertEqual(heat_sparse_csc.shape, self.ref_scipy_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.lshape, lshape_dist[self.rank]) + self.assertEqual(heat_sparse_csc.split, 1) + self.assertTrue((heat_sparse_csc.indptr == ref_indptr).all()) + self.assertTrue( + ( + heat_sparse_csc.lindptr + == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) + ).all() + ) + self.assertTrue((heat_sparse_csc.indices == ref_indices).all()) + self.assertTrue( + ( + heat_sparse_csc.lindices + == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) + ).all() + ) + self.assertTrue((heat_sparse_csc.data == ref_data).all()) + self.assertTrue( + ( + heat_sparse_csc.ldata + == torch.tensor(data_dist[self.rank], device=self.device.torch_device) + ).all() + ) + + if self.world_size == 3: + indptr_dist = [[0, 1, 2], [0, 1, 2], [0, 2]] + indices_dist = [[3, 2], [0, 3], [0, 4]] + data_dist = [[4, 3], [1, 5], [2, 6]] + + lshape_dist = [(5, 2), (5, 2), (5, 1)] + + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self.ref_scipy_sparse_csc, split=1) + self.assertEqual(heat_sparse_csc.shape, self.ref_scipy_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.lshape, lshape_dist[self.rank]) + self.assertEqual(heat_sparse_csc.split, 1) + self.assertTrue((heat_sparse_csc.indptr == ref_indptr).all()) + self.assertTrue( + ( + heat_sparse_csc.lindptr + == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) + ).all() + ) + self.assertTrue((heat_sparse_csc.indices == ref_indices).all()) + self.assertTrue( + ( + heat_sparse_csc.lindices + == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) + ).all() + ) + self.assertTrue((heat_sparse_csc.data == ref_data).all()) + self.assertTrue( + ( + heat_sparse_csc.ldata + == torch.tensor(data_dist[self.rank], device=self.device.torch_device) + ).all() + ) + + # Number of processes > Number of rows + if self.world_size == 6: + indptr_dist = [[0, 1], [0, 1], [0, 1], [0, 1], [0, 2], [0]] + indices_dist = [[3], [2], [0], [3], [0, 4], []] + data_dist = [[4], [3], [1], [5], [2, 6], []] + + lshape_dist = [(5, 1), (5, 1), (5, 1), (5, 1), (5, 1), (5, 0)] + + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self.ref_scipy_sparse_csc, split=1) + self.assertEqual(heat_sparse_csc.shape, self.ref_scipy_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.lshape, lshape_dist[self.rank]) + self.assertEqual(heat_sparse_csc.split, 1) + self.assertTrue((heat_sparse_csc.indptr == ref_indptr).all()) + self.assertTrue( + ( + heat_sparse_csc.lindptr + == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) + ).all() + ) + self.assertTrue((heat_sparse_csc.indices == ref_indices).all()) + self.assertTrue( + ( + heat_sparse_csc.lindices + == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) + ).all() + ) + self.assertTrue((heat_sparse_csc.data == ref_data).all()) + self.assertTrue( + ( + heat_sparse_csc.ldata + == torch.tensor(data_dist[self.rank], device=self.device.torch_device) + ).all() + ) + + # Distributed case (is_split) + if self.world_size == 2: + indptr_dist = [[0, 1, 2, 3], [0, 1, 3]] + indices_dist = [[3, 2, 0], [3, 0, 4]] + data_dist = [[4, 3, 1], [5, 2, 6]] + + lshape_dist = [(5, 3), (5, 2)] + + dist_scipy_sparse_csc = scipy.sparse.csc_matrix( + ( + torch.tensor(data_dist[self.rank], device="cpu"), + torch.tensor(indices_dist[self.rank], device="cpu"), + torch.tensor(indptr_dist[self.rank], device="cpu"), + ), + shape=lshape_dist[self.rank], + ) + + heat_sparse_csc = ht.sparse.sparse_csc_matrix(dist_scipy_sparse_csc, is_split=1) + self.assertEqual(heat_sparse_csc.shape, ref_torch_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.lshape, lshape_dist[self.rank]) + self.assertEqual(heat_sparse_csc.split, 1) + self.assertTrue((heat_sparse_csc.indptr == ref_indptr).all()) + self.assertTrue( + ( + heat_sparse_csc.lindptr + == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) + ).all() + ) + self.assertTrue((heat_sparse_csc.indices == ref_indices).all()) + self.assertTrue( + ( + heat_sparse_csc.lindices + == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) + ).all() + ) + self.assertTrue((heat_sparse_csc.data == ref_data).all()) + self.assertTrue( + ( + heat_sparse_csc.ldata + == torch.tensor(data_dist[self.rank], device=self.device.torch_device) + ).all() + ) + + if self.world_size == 3: + indptr_dist = [[0, 1, 2], [0, 1, 2], [0, 2]] + indices_dist = [[3, 2], [0, 3], [0, 4]] + data_dist = [[4, 3], [1, 5], [2, 6]] + + lshape_dist = [(5, 2), (5, 2), (5, 1)] + + dist_scipy_sparse_csc = scipy.sparse.csc_matrix( + ( + torch.tensor(data_dist[self.rank], device="cpu"), + torch.tensor(indices_dist[self.rank], device="cpu"), + torch.tensor(indptr_dist[self.rank], device="cpu"), + ), + shape=lshape_dist[self.rank], + ) + + self.assertEqual(heat_sparse_csc.shape, ref_torch_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.lshape, lshape_dist[self.rank]) + self.assertEqual(heat_sparse_csc.split, 1) + self.assertTrue((heat_sparse_csc.indptr == ref_indptr).all()) + self.assertTrue( + ( + heat_sparse_csc.lindptr + == torch.tensor(indptr_dist[self.rank], device=self.device.torch_device) + ).all() + ) + self.assertTrue((heat_sparse_csc.indices == ref_indices).all()) + self.assertTrue( + ( + heat_sparse_csc.lindices + == torch.tensor(indices_dist[self.rank], device=self.device.torch_device) + ).all() + ) + self.assertTrue((heat_sparse_csc.data == ref_data).all()) + self.assertTrue( + ( + heat_sparse_csc.ldata + == torch.tensor(data_dist[self.rank], device=self.device.torch_device) + ).all() + ) + + """ + Input: torch.Tensor + """ + torch_tensor = torch.tensor(self.arr, dtype=torch.float, device=self.device.torch_device) + heat_sparse_csc = ht.sparse.sparse_csc_matrix(torch_tensor) + + self.assertIsInstance(heat_sparse_csc, ht.sparse.DCSC_matrix) + self.assertEqual(heat_sparse_csc.dtype, ht.float32) + self.assertEqual(heat_sparse_csc.indptr.dtype, torch.int64) + self.assertEqual(heat_sparse_csc.indices.dtype, torch.int64) + self.assertEqual(heat_sparse_csc.shape, ref_torch_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.lshape, ref_torch_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.split, None) + self.assertTrue((heat_sparse_csc.indptr == ref_indptr).all()) + self.assertTrue((heat_sparse_csc.lindptr == ref_indptr).all()) + self.assertTrue((heat_sparse_csc.indices == ref_indices).all()) + self.assertTrue((heat_sparse_csc.lindices == ref_indices).all()) + self.assertTrue((heat_sparse_csc.data == ref_data).all()) + self.assertTrue((heat_sparse_csc.ldata == ref_data).all()) + + """ + Input: List[int] + """ + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self.arr) + + self.assertIsInstance(heat_sparse_csc, ht.sparse.DCSC_matrix) + self.assertEqual(heat_sparse_csc.dtype, ht.int64) + self.assertEqual(heat_sparse_csc.indptr.dtype, torch.int64) + self.assertEqual(heat_sparse_csc.indices.dtype, torch.int64) + self.assertEqual(heat_sparse_csc.shape, ref_torch_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.lshape, ref_torch_sparse_csc.shape) + self.assertEqual(heat_sparse_csc.split, None) + self.assertTrue((heat_sparse_csc.indptr == ref_indptr).all()) + self.assertTrue((heat_sparse_csc.lindptr == ref_indptr).all()) + self.assertTrue((heat_sparse_csc.indices == ref_indices).all()) + self.assertTrue((heat_sparse_csc.lindices == ref_indices).all()) + self.assertTrue((heat_sparse_csc.data == ref_data).all()) + self.assertTrue((heat_sparse_csc.ldata == ref_data).all()) + + with self.assertRaises(TypeError): + # Passing an object which cant be converted into a torch.Tensor + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self) + + # Errors (torch.Tensor) + with self.assertRaises(ValueError): + heat_sparse_csc = ht.sparse.sparse_csc_matrix(ref_torch_sparse_csc, split=0) + + with self.assertRaises(ValueError): + dist_torch_sparse_csc = torch.sparse_csc_tensor( + torch.tensor([0, 0, 0], device=self.device.torch_device), # indptr + torch.tensor([], dtype=torch.int64, device=self.device.torch_device), # indices + torch.tensor([], dtype=torch.int64, device=self.device.torch_device), # data + size=(2, 2), + ) + + heat_sparse_csc = ht.sparse.sparse_csc_matrix(dist_torch_sparse_csc, is_split=0) + + # Errors (scipy.sparse.csc_matrix) + with self.assertRaises(ValueError): + heat_sparse_csc = ht.sparse.sparse_csc_matrix(self.ref_scipy_sparse_csc, split=0) + + with self.assertRaises(ValueError): + dist_scipy_sparse_csc = scipy.sparse.csc_matrix( + ( + torch.tensor([], dtype=torch.int64, device="cpu"), # data + torch.tensor([], dtype=torch.int64, device="cpu"), # indices + torch.tensor([0, 0, 0], device="cpu"), # indptr + ), + shape=(2, 2), + ) + + heat_sparse_csc = ht.sparse.sparse_csc_matrix(dist_torch_sparse_csc, is_split=0) + + # Invalid distribution for is_split + if self.world_size > 1: + with self.assertRaises(ValueError): + dist_torch_sparse_csc = torch.sparse_csc_tensor( + torch.tensor( + [0] * ((self.rank + 1) + 1), device=self.device.torch_device + ), # indptr + torch.tensor([], dtype=torch.int64, device=self.device.torch_device), # indices + torch.tensor([], dtype=torch.int64, device=self.device.torch_device), # data + size=(self.rank + 1, self.rank + 1), + ) + + heat_sparse_csc = ht.sparse.sparse_csr_matrix(dist_torch_sparse_csc, is_split=1) + + self.assertRaises(ValueError, ht.sparse.sparse_csc_matrix, torch.tensor([0, 1])) + self.assertRaises(ValueError, ht.sparse.sparse_csc_matrix, torch.tensor([[[1, 0, 3]]])) diff --git a/heat/sparse/tests/test_manipulations.py b/heat/sparse/tests/test_manipulations.py index 77705e636e..1de090a871 100644 --- a/heat/sparse/tests/test_manipulations.py +++ b/heat/sparse/tests/test_manipulations.py @@ -20,24 +20,17 @@ def setUpClass(self): [4, 0, 0, 5, 0] [0, 0, 0, 0, 6]] """ - self.ref_indptr = torch.tensor( - [0, 2, 2, 3, 5, 6], dtype=torch.int, device=self.device.torch_device - ) - self.ref_indices = torch.tensor( - [2, 4, 1, 0, 3, 4], dtype=torch.int, device=self.device.torch_device - ) - self.ref_data = torch.tensor( - [1, 2, 3, 4, 5, 6], dtype=torch.float, device=self.device.torch_device - ) - self.ref_torch_sparse_csr = torch.sparse_csr_tensor( - self.ref_indptr, self.ref_indices, self.ref_data, device=self.device.torch_device - ) - - def test_to_sparse(self): - arr = [[0, 0, 1, 0, 2], [0, 0, 0, 0, 0], [0, 3, 0, 0, 0], [4, 0, 0, 5, 0], [0, 0, 0, 0, 6]] - - A = ht.array(arr, split=0) - B = A.to_sparse() + self.arr = [ + [0, 0, 1, 0, 2], + [0, 0, 0, 0, 0], + [0, 3, 0, 0, 0], + [4, 0, 0, 5, 0], + [0, 0, 0, 0, 6], + ] + + def test_to_sparse_csr(self): + A = ht.array(self.arr, split=0) + B = A.to_sparse_csr() indptr_B = [0, 2, 2, 3, 5, 6] indices_B = [2, 4, 1, 0, 3, 4] @@ -54,18 +47,41 @@ def test_to_sparse(self): self.assertEqual(B.shape, A.shape) self.assertEqual(B.dtype, A.dtype) - def test_to_dense(self): - heat_sparse_csr = ht.sparse.sparse_csr_matrix(self.ref_torch_sparse_csr) - - ref_dense_array = ht.array( - [ - [0, 0, 1, 0, 2], - [0, 0, 0, 0, 0], - [0, 3, 0, 0, 0], - [4, 0, 0, 5, 0], - [0, 0, 0, 0, 6], - ] + def test_to_sparse_csc(self): + A = ht.array(self.arr, split=1) + B = A.to_sparse_csc() + + indptr_B = [0, 1, 2, 3, 4, 6] + indices_B = [3, 2, 0, 3, 0, 4] + data_B = [4, 3, 1, 5, 2, 6] + + self.assertIsInstance(B, ht.sparse.DCSC_matrix) + self.assertTrue((B.indptr == torch.tensor(indptr_B, device=self.device.torch_device)).all()) + self.assertTrue( + (B.indices == torch.tensor(indices_B, device=self.device.torch_device)).all() ) + self.assertTrue((B.data == torch.tensor(data_B, device=self.device.torch_device)).all()) + self.assertEqual(B.nnz, len(data_B)) + self.assertEqual(B.split, 1) + self.assertEqual(B.shape, A.shape) + self.assertEqual(B.dtype, A.dtype) + + def test_to_dense_csr(self): + ref_indptr = torch.tensor( + [0, 2, 2, 3, 5, 6], dtype=torch.int, device=self.device.torch_device + ) + ref_indices = torch.tensor( + [2, 4, 1, 0, 3, 4], dtype=torch.int, device=self.device.torch_device + ) + ref_data = torch.tensor( + [1, 2, 3, 4, 5, 6], dtype=torch.float, device=self.device.torch_device + ) + ref_torch_sparse_csr = torch.sparse_csr_tensor( + ref_indptr, ref_indices, ref_data, device=self.device.torch_device + ) + heat_sparse_csr = ht.sparse.sparse_csr_matrix(ref_torch_sparse_csr) + + ref_dense_array = ht.array(self.arr) dense_array = heat_sparse_csr.todense() @@ -84,7 +100,7 @@ def test_to_dense(self): self.assertEqual(out_buffer.shape, heat_sparse_csr.shape) # Distributed case - heat_sparse_csr = ht.sparse.sparse_csr_matrix(self.ref_torch_sparse_csr, split=0) + heat_sparse_csr = ht.sparse.sparse_csr_matrix(ref_torch_sparse_csr, split=0) dense_array = heat_sparse_csr.todense() ref_dense_array = ht.array(ref_dense_array, split=0) @@ -110,3 +126,64 @@ def test_to_dense(self): with self.assertRaises(ValueError): out_buffer = ht.empty(shape=[5, 5], split=None) heat_sparse_csr.todense(out=out_buffer) + + def test_to_dense_csc(self): + ref_indptr = torch.tensor( + [0, 1, 2, 3, 4, 6], dtype=torch.int, device=self.device.torch_device + ) + ref_indices = torch.tensor( + [3, 2, 0, 3, 0, 4], dtype=torch.int, device=self.device.torch_device + ) + ref_data = torch.tensor( + [4, 3, 1, 5, 2, 6], dtype=torch.float, device=self.device.torch_device + ) + ref_torch_sparse_csc = torch.sparse_csc_tensor( + ref_indptr, ref_indices, ref_data, device=self.device.torch_device + ) + heat_sparse_csc = ht.sparse.sparse_csc_matrix(ref_torch_sparse_csc) + + ref_dense_array = ht.array(self.arr) + + dense_array = heat_sparse_csc.todense() + + self.assertTrue(ht.equal(ref_dense_array, dense_array)) + self.assertEqual(dense_array.split, None) + self.assertEqual(dense_array.dtype, heat_sparse_csc.dtype) + self.assertEqual(dense_array.shape, heat_sparse_csc.shape) + + # with output buffer + out_buffer = ht.empty(shape=[5, 5]) + heat_sparse_csc.todense(out=out_buffer) + + self.assertTrue(ht.equal(ref_dense_array, out_buffer)) + self.assertEqual(out_buffer.split, None) + self.assertEqual(out_buffer.dtype, heat_sparse_csc.dtype) + self.assertEqual(out_buffer.shape, heat_sparse_csc.shape) + + # Distributed case + heat_sparse_csc = ht.sparse.sparse_csc_matrix(ref_torch_sparse_csc, split=1) + + dense_array = heat_sparse_csc.todense() + ref_dense_array = ht.array(ref_dense_array, split=1) + + self.assertTrue(ht.equal(ref_dense_array, dense_array)) + self.assertEqual(dense_array.split, 1) + self.assertEqual(dense_array.dtype, heat_sparse_csc.dtype) + self.assertEqual(dense_array.shape, heat_sparse_csc.shape) + + # with output buffer + out_buffer = ht.empty(shape=[5, 5], split=1) + heat_sparse_csc.todense(out=out_buffer) + + self.assertTrue(ht.equal(ref_dense_array, out_buffer)) + self.assertEqual(out_buffer.split, 1) + self.assertEqual(out_buffer.dtype, heat_sparse_csc.dtype) + self.assertEqual(out_buffer.shape, heat_sparse_csc.shape) + + with self.assertRaises(ValueError): + out_buffer = ht.empty(shape=[3, 3], split=1) + heat_sparse_csc.todense(out=out_buffer) + + with self.assertRaises(ValueError): + out_buffer = ht.empty(shape=[5, 5], split=None) + heat_sparse_csc.todense(out=out_buffer)