diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 84d187abde..3cf2ab0eac 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -166,6 +166,29 @@ def split(self): """ return self.__split + @property + def stride(self): + """ + Returns + ------- + tuple of ints: steps in each dimension when traversing a tensor. + torch-like usage: self.stride() + """ + return self.__array.stride + + @property + def strides(self): + """ + Returns + ------- + tuple of ints: bytes to step in each dimension when traversing a tensor. + numpy-like usage: self.strides + """ + steps = list(self._DNDarray__array.stride()) + itemsize = self._DNDarray__array.storage().element_size() + strides = tuple(step * itemsize for step in steps) + return strides + @property def T(self, axes=None): return linalg.transpose(self, axes) diff --git a/heat/core/factories.py b/heat/core/factories.py index 6d574af067..7e8ee248cb 100644 --- a/heat/core/factories.py +++ b/heat/core/factories.py @@ -5,6 +5,7 @@ from .stride_tricks import sanitize_axis, sanitize_shape from . import devices from . import dndarray +from . import memory from . import types __all__ = [ @@ -131,7 +132,17 @@ def arange(*args, dtype=None, split=None, device=None, comm=None): return dndarray.DNDarray(data, gshape, htype, split, device, comm) -def array(obj, dtype=None, copy=True, ndmin=0, split=None, is_split=None, device=None, comm=None): +def array( + obj, + dtype=None, + copy=True, + ndmin=0, + order="C", + split=None, + is_split=None, + device=None, + comm=None, +): """ Create a tensor. Parameters @@ -149,6 +160,11 @@ def array(obj, dtype=None, copy=True, ndmin=0, split=None, is_split=None, device ndmin : int, optional Specifies the minimum number of dimensions that the resulting array should have. Ones will, if needed, be attached to the shape if ndim>0 and prefaced in case of ndim<0 to meet the requirement. + order: str, optional + Options: 'C' or 'F'. Specifies the memory layout of the newly created tensor. Default is order='C', meaning the array + will be stored in row-major order (C-like). If order=‘F’, the array will be stored in column-major order (Fortran-like). + Raises NotImplementedError for NumPy options 'K' and 'A'. + #TODO: implement 'K' option when torch.clone() fix to preserve memory layout is released. split : None or int, optional The axis along which the passed array content obj is split and distributed in memory. Mutually exclusive with is_split. @@ -198,6 +214,71 @@ def array(obj, dtype=None, copy=True, ndmin=0, split=None, is_split=None, device (1/2) >>> ht.array([3, 4], is_split=0) (0/2) tensor([1, 2, 3, 4]) (1/2) tensor([1, 2, 3, 4]) + + Memory layout, single-node: + >>> a = np.arange(2 * 3).reshape(2, 3) + >>> a + array([[ 0, 1, 2], + [ 3, 4, 5]]) + >>> a.strides + (24, 8) + >>> b = ht.array(a) + >>> b + tensor([[0, 1, 2], + [3, 4, 5]]) + >>> b.strides + (24, 8) + >>> b._DNDarray__array.storage() #TODO: implement ht.view() + 0 + 1 + 2 + 3 + 4 + 5 + [torch.LongStorage of size 6] + >>> c = ht.array(a, order='F') + >>> c + tensor([[0, 1, 2], + [3, 4, 5]]) + >>> c.strides + (8, 16) + >>> c._DNDarray__array.storage() #TODO: implement ht.view() + 0 + 3 + 1 + 4 + 2 + 5 + [torch.LongStorage of size 6] + + Memory layout, distributed: + >>> a = np.arange(4 * 3).reshape(4, 3) + >>> a.strides + (24, 8) + >>> b = ht.array(a, order='F') + >>> b + (0/2) tensor([[0, 1, 2], + [3, 4, 5]]) + (1/2) tensor([[ 6, 7, 8], + [ 9, 10, 11]]) + >>> b.strides + (0/2) (8, 16) + (1/2) (8, 16) + >>> b._DNDarray__array.storage() #TODO: implement ht.view() + (0/2) 0 + 3 + 1 + 4 + 2 + 5 + [torch.LongStorage of size 6] + (1/2) 6 + 9 + 7 + 10 + 8 + 11 + [torch.LongStorage of size 6] """ # extract the internal tensor in case of a heat tensor if isinstance(obj, dndarray.DNDarray): @@ -210,6 +291,8 @@ def array(obj, dtype=None, copy=True, ndmin=0, split=None, is_split=None, device # initialize the array if bool(copy): if isinstance(obj, torch.Tensor): + # TODO: watch out. At the moment clone() implies losing the underlying memory layout. + # pytorch fix in progress obj = obj.clone().detach() else: try: @@ -250,8 +333,10 @@ def array(obj, dtype=None, copy=True, ndmin=0, split=None, is_split=None, device if split is not None: _, _, slices = comm.chunk(obj.shape, split) obj = obj[slices].clone() + obj = memory.sanitize_memory_layout(obj, order=order) # check with the neighboring rank whether the local shape would fit into a global shape elif is_split is not None: + obj = memory.sanitize_memory_layout(obj, order=order) if comm.rank < comm.size - 1: comm.Isend(lshape, dest=comm.rank + 1) if comm.rank != 0: @@ -284,11 +369,13 @@ def array(obj, dtype=None, copy=True, ndmin=0, split=None, is_split=None, device comm.Allreduce(MPI.IN_PLACE, ttl_shape, MPI.SUM) gshape[is_split] = ttl_shape[is_split] split = is_split + elif split is None and is_split is None: + obj = memory.sanitize_memory_layout(obj, order=order) return dndarray.DNDarray(obj, tuple(int(ele) for ele in gshape), dtype, split, device, comm) -def empty(shape, dtype=types.float32, split=None, device=None, comm=None): +def empty(shape, dtype=types.float32, split=None, device=None, comm=None, order="C"): """ Returns a new uninitialized array of given shape and data type. May be allocated split up across multiple nodes along the specified axis. @@ -305,6 +392,11 @@ def empty(shape, dtype=types.float32, split=None, device=None, comm=None): Specifies the device the tensor shall be allocated on, defaults to None (i.e. globally set default device). comm: Communication, optional Handle to the nodes holding distributed parts or copies of this tensor. + order: str, optional + Options: 'C' or 'F'. Specifies the memory layout of the newly created tensor. Default is order='C', meaning the array + will be stored in row-major order (C-like). If order=‘F’, the array will be stored in column-major order (Fortran-like). + Raises NotImplementedError for NumPy options 'K' and 'A'. + #TODO: implement 'K' option when torch.clone() fix to preserve memory layout is released. Returns ------- @@ -323,10 +415,10 @@ def empty(shape, dtype=types.float32, split=None, device=None, comm=None): tensor([[ 0.0000e+00, -2.0000e+00, 3.3113e+35], [ 3.6902e+19, 1.2096e+04, 7.1846e+22]]) """ - return __factory(shape, dtype, split, torch.empty, device, comm) + return __factory(shape, dtype, split, torch.empty, device, comm, order) -def empty_like(a, dtype=None, split=None, device=None, comm=None): +def empty_like(a, dtype=None, split=None, device=None, comm=None, order="C"): """ Returns a new uninitialized array with the same type, shape and data distribution of given object. Data type and data distribution strategy can be explicitly overriden. @@ -361,10 +453,10 @@ def empty_like(a, dtype=None, split=None, device=None, comm=None): tensor([[ 0.0000e+00, -2.0000e+00, 3.3113e+35], [ 3.6902e+19, 1.2096e+04, 7.1846e+22]]) """ - return __factory_like(a, dtype, split, empty, device, comm) + return __factory_like(a, dtype, split, empty, device, comm, order=order) -def eye(shape, dtype=types.float32, split=None, device=None, comm=None): +def eye(shape, dtype=types.float32, split=None, device=None, comm=None, order="C"): """ Returns a new 2-D tensor with ones on the diagonal and zeroes elsewhere. @@ -381,6 +473,11 @@ def eye(shape, dtype=types.float32, split=None, device=None, comm=None): Specifies the device the tensor shall be allocated on, defaults to None (i.e. globally set default device). comm : Communication, optional Handle to the nodes holding distributed parts or copies of this tensor. + order: str, optional + Options: 'C' or 'F'. Specifies the memory layout of the newly created tensor. Default is order='C', meaning the array + will be stored in row-major order (C-like). If order=‘F’, the array will be stored in column-major order (Fortran-like). + Raises NotImplementedError for NumPy options 'K' and 'A'. + #TODO: implement 'K' option when torch.clone() fix to preserve memory layout is released. Returns ------- @@ -421,12 +518,13 @@ def eye(shape, dtype=types.float32, split=None, device=None, comm=None): pos_y = i if split == 1 else i + offset data[pos_x][pos_y] = 1 + data = memory.sanitize_memory_layout(data, order=order) return dndarray.DNDarray( data, gshape, types.canonical_heat_type(data.dtype), split, device, comm ) -def __factory(shape, dtype, split, local_factory, device, comm): +def __factory(shape, dtype, split, local_factory, device, comm, order): """ Abstracted factory function for HeAT tensor initialization. @@ -461,11 +559,11 @@ def __factory(shape, dtype, split, local_factory, device, comm): _, local_shape, _ = comm.chunk(shape, split) # create the torch data using the factory function data = local_factory(local_shape, dtype=dtype.torch_type(), device=device.torch_device) - + data = memory.sanitize_memory_layout(data, order=order) return dndarray.DNDarray(data, shape, dtype, split, device, comm) -def __factory_like(a, dtype, split, factory, device, comm, **kwargs): +def __factory_like(a, dtype, split, factory, device, comm, order="C", **kwargs): """ Abstracted '...-like' factory function for HeAT tensor initialization @@ -483,6 +581,12 @@ def __factory_like(a, dtype, split, factory, device, comm, **kwargs): Specifies the device the tensor shall be allocated on, defaults to None (i.e. globally set default device). comm: Communication Handle to the nodes holding distributed parts or copies of this tensor. + order: str, optional + Options: 'C' or 'F'. Specifies the memory layout of the newly created tensor. Default is order='C', meaning the array + will be stored in row-major order (C-like). If order=‘F’, the array will be stored in column-major order (Fortran-like). + Raises NotImplementedError for NumPy options 'K' and 'A'. + #TODO: implement 'K' option when torch.clone() fix to preserve memory layout is released. + Returns ------- @@ -517,10 +621,10 @@ def __factory_like(a, dtype, split, factory, device, comm, **kwargs): # use the default communicator, if not set comm = sanitize_comm(comm) - return factory(shape, dtype=dtype, split=split, device=device, comm=comm, **kwargs) + return factory(shape, dtype=dtype, split=split, device=device, comm=comm, order=order, **kwargs) -def full(shape, fill_value, dtype=types.float32, split=None, device=None, comm=None): +def full(shape, fill_value, dtype=types.float32, split=None, device=None, comm=None, order="C"): """ Return a new array of given shape and type, filled with fill_value. @@ -557,10 +661,10 @@ def full(shape, fill_value, dtype=types.float32, split=None, device=None, comm=N def local_factory(*args, **kwargs): return torch.full(*args, fill_value=fill_value, **kwargs) - return __factory(shape, dtype, split, local_factory, device, comm) + return __factory(shape, dtype, split, local_factory, device, comm, order=order) -def full_like(a, fill_value, dtype=types.float32, split=None, device=None, comm=None): +def full_like(a, fill_value, dtype=types.float32, split=None, device=None, comm=None, order="C"): """ Return a full array with the same shape and type as a given array. @@ -595,7 +699,7 @@ def full_like(a, fill_value, dtype=types.float32, split=None, device=None, comm= tensor([[1., 1., 1.], [1., 1., 1.]]) """ - return __factory_like(a, dtype, split, full, device, comm, fill_value=fill_value) + return __factory_like(a, dtype, split, full, device, comm, fill_value=fill_value, order=order) def linspace( @@ -755,7 +859,7 @@ def logspace( return pow(base, y).astype(dtype, copy=False) -def ones(shape, dtype=types.float32, split=None, device=None, comm=None): +def ones(shape, dtype=types.float32, split=None, device=None, comm=None, order="C"): """ Returns a new array of given shape and data type filled with one values. May be allocated split up across multiple nodes along the specified axis. @@ -772,6 +876,12 @@ def ones(shape, dtype=types.float32, split=None, device=None, comm=None): Specifies the device the tensor shall be allocated on, defaults to None (i.e. globally set default device). comm : Communication, optional Handle to the nodes holding distributed parts or copies of this tensor. + order: str, optional + Options: 'C' or 'F'. Specifies the memory layout of the newly created tensor. Default is order='C', meaning the array + will be stored in row-major order (C-like). If order=‘F’, the array will be stored in column-major order (Fortran-like). + Raises NotImplementedError for NumPy options 'K' and 'A'. + #TODO: implement 'K' option when torch.clone() fix to preserve memory layout is released. + Returns ------- @@ -790,10 +900,10 @@ def ones(shape, dtype=types.float32, split=None, device=None, comm=None): tensor([[1., 1., 1.], [1., 1., 1.]]) """ - return __factory(shape, dtype, split, torch.ones, device, comm) + return __factory(shape, dtype, split, torch.ones, device, comm, order) -def ones_like(a, dtype=None, split=None, device=None, comm=None): +def ones_like(a, dtype=None, split=None, device=None, comm=None, order="C"): """ Returns a new array filled with ones with the same type, shape and data distribution of given object. Data type and data distribution strategy can be explicitly overriden. @@ -827,10 +937,10 @@ def ones_like(a, dtype=None, split=None, device=None, comm=None): tensor([[1., 1., 1.], [1., 1., 1.]]) """ - return __factory_like(a, dtype, split, ones, device, comm) + return __factory_like(a, dtype, split, ones, device, comm, order=order) -def zeros(shape, dtype=types.float32, split=None, device=None, comm=None): +def zeros(shape, dtype=types.float32, split=None, device=None, comm=None, order="C"): """ Returns a new array of given shape and data type filled with zero values. May be allocated split up across multiple nodes along the specified axis. @@ -847,6 +957,12 @@ def zeros(shape, dtype=types.float32, split=None, device=None, comm=None): Specifies the device the tensor shall be allocated on, defaults to None (i.e. globally set default device). comm: Communication, optional Handle to the nodes holding distributed parts or copies of this tensor. + order: str, optional + Options: 'C' or 'F'. Specifies the memory layout of the newly created tensor. Default is order='C', meaning the array + will be stored in row-major order (C-like). If order=‘F’, the array will be stored in column-major order (Fortran-like). + Raises NotImplementedError for NumPy options 'K' and 'A'. + #TODO: implement 'K' option when torch.clone() fix to preserve memory layout is released. + Returns ------- @@ -865,10 +981,10 @@ def zeros(shape, dtype=types.float32, split=None, device=None, comm=None): tensor([[0., 0., 0.], [0., 0., 0.]]) """ - return __factory(shape, dtype, split, torch.zeros, device, comm) + return __factory(shape, dtype, split, torch.zeros, device, comm, order=order) -def zeros_like(a, dtype=None, split=None, device=None, comm=None): +def zeros_like(a, dtype=None, split=None, device=None, comm=None, order="C"): """ Returns a new array filled with zeros with the same type, shape and data distribution of given object. Data type and data distribution strategy can be explicitly overriden. @@ -885,6 +1001,11 @@ def zeros_like(a, dtype=None, split=None, device=None, comm=None): Specifies the device the tensor shall be allocated on, defaults to None (i.e. globally set default device). comm: Communication, optional Handle to the nodes holding distributed parts or copies of this tensor. + order: str, optional + Options: 'C' or 'F'. Specifies the memory layout of the newly created tensor. Default is order='C', meaning the array + will be stored in row-major order (C-like). If order=‘F’, the array will be stored in column-major order (Fortran-like). + Raises NotImplementedError for NumPy options 'K' and 'A'. + #TODO: implement 'K' option when torch.clone() fix to preserve memory layout is released. Returns ------- @@ -902,4 +1023,4 @@ def zeros_like(a, dtype=None, split=None, device=None, comm=None): tensor([[0., 0., 0.], [0., 0., 0.]]) """ - return __factory_like(a, dtype, split, zeros, device, comm) + return __factory_like(a, dtype, split, zeros, device, comm, order=order) diff --git a/heat/core/memory.py b/heat/core/memory.py index acd1827b28..dc98d14ff9 100644 --- a/heat/core/memory.py +++ b/heat/core/memory.py @@ -1,6 +1,8 @@ +import numpy as np +import torch from . import dndarray -__all__ = ["copy"] +__all__ = ["copy", "sanitize_memory_layout"] def copy(a): @@ -22,3 +24,44 @@ def copy(a): return dndarray.DNDarray( a._DNDarray__array.clone(), a.shape, a.dtype, a.split, a.device, a.comm ) + + +def sanitize_memory_layout(x, order="C"): + """ + Return the given object with memory layout as defined below. + + Parameters + ----------- + + x: torch.tensor + Input data + + order: str, optional. + Default is 'C' as in C-like (row-major) memory layout. The array is stored first dimension first (rows first if ndim=2). + Alternative is 'F', as in Fortran-like (column-major) memory layout. The array is stored last dimension first (columns first if ndim=2). + """ + if x.ndim < 2: + # do nothing + return x + dims = list(range(x.ndim)) + stride = list(x.stride()) + row_major = all(np.diff(stride) <= 0) + column_major = all(np.diff(stride) >= 0) + if (order == "C" and row_major) or (order == "F" and column_major): + # do nothing + return x + if (order == "C" and column_major) or (order == "F" and row_major): + dims = tuple(reversed(dims)) + y = torch.empty_like(x) + permutation = x.permute(dims).contiguous() + y = y.set_( + permutation.storage(), + x.storage_offset(), + x.shape, + tuple(reversed(permutation.stride())), + ) + if order == "K": + raise NotImplementedError( + "Internal usage of torch.clone() means losing original memory layout for now. \n Please specify order='C' for row-major, order='F' for column-major layout." + ) + return y diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index ddf1a9a275..07a420d54e 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -709,3 +709,65 @@ def test_size_gnumel(self): self.assertEqual(a.gnumel, 10 * 10 * 10) self.assertEqual(ht.array(0).size, 1) + + def test_stride_and_strides(self): + # Local, int16, row-major memory layout + torch_int16 = torch.arange(6 * 5 * 3 * 4 * 5 * 7, dtype=torch.int16).reshape( + 6, 5, 3, 4, 5, 7 + ) + heat_int16 = ht.array(torch_int16) + numpy_int16 = torch_int16.numpy() + self.assertEqual(heat_int16.stride(), torch_int16.stride()) + self.assertEqual(heat_int16.strides, numpy_int16.strides) + # Local, float32, row-major memory layout + torch_float32 = torch.arange(6 * 5 * 3 * 4 * 5 * 7, dtype=torch.float32).reshape( + 6, 5, 3, 4, 5, 7 + ) + heat_float32 = ht.array(torch_float32) + numpy_float32 = torch_float32.numpy() + self.assertEqual(heat_float32.stride(), torch_float32.stride()) + self.assertEqual(heat_float32.strides, numpy_float32.strides) + # Local, float64, column-major memory layout + torch_float64 = torch.arange(6 * 5 * 3 * 4 * 5 * 7, dtype=torch.float64).reshape( + 6, 5, 3, 4, 5, 7 + ) + heat_float64_F = ht.array(torch_float64, order="F") + numpy_float64_F = np.array(torch_float64.numpy(), order="F") + self.assertNotEqual(heat_float64_F.stride(), torch_float64.stride()) + self.assertEqual(heat_float64_F.strides, numpy_float64_F.strides) + # Distributed, int16, row-major memory layout + size = ht.communication.MPI_WORLD.size + split = 2 + torch_int16 = torch.arange(6 * 5 * 3 * size * 4 * 5 * 7, dtype=torch.int16).reshape( + 6, 5, 3 * size, 4, 5, 7 + ) + heat_int16_split = ht.array(torch_int16, split=split) + numpy_int16 = torch_int16.numpy() + if size > 1: + self.assertNotEqual(heat_int16_split.stride(), torch_int16.stride()) + numpy_int16_split_strides = ( + tuple(np.array(numpy_int16.strides[:split]) / size) + numpy_int16.strides[split:] + ) + self.assertEqual(heat_int16_split.strides, numpy_int16_split_strides) + # Distributed, float32, row-major memory layout + split = -1 + torch_float32 = torch.arange(6 * 5 * 3 * 4 * 5 * 7 * size, dtype=torch.float32).reshape( + 6, 5, 3, 4, 5, 7 * size + ) + heat_float32_split = ht.array(torch_float32, split=split) + numpy_float32 = torch_float32.numpy() + numpy_float32_split_strides = ( + tuple(np.array(numpy_float32.strides[:split]) / size) + numpy_float32.strides[split:] + ) + self.assertEqual(heat_float32_split.strides, numpy_float32_split_strides) + # Distributed, float64, column-major memory layout + split = -2 + torch_float64 = torch.arange(6 * 5 * 3 * 4 * 5 * size * 7, dtype=torch.float64).reshape( + 6, 5, 3, 4, 5 * size, 7 + ) + heat_float64_F_split = ht.array(torch_float64, order="F", split=split) + numpy_float64_F = np.array(torch_float64.numpy(), order="F") + numpy_float64_F_split_strides = numpy_float64_F.strides[: split + 1] + tuple( + np.array(numpy_float64_F.strides[split + 1 :]) / size + ) + self.assertEqual(heat_float64_F_split.strides, numpy_float64_F_split_strides) diff --git a/heat/core/tests/test_memory.py b/heat/core/tests/test_memory.py index b680311d6a..22fb3a4a55 100644 --- a/heat/core/tests/test_memory.py +++ b/heat/core/tests/test_memory.py @@ -1,9 +1,12 @@ import unittest - +import torch +import numpy as np import heat as ht +from heat.core.tests.test_suites.basic_test import BasicTest + -class TestMemory(unittest.TestCase): +class TestMemory(BasicTest): def test_copy(self): tensor = ht.ones(5) copied = tensor.copy() @@ -16,3 +19,58 @@ def test_copy(self): # test exceptions with self.assertRaises(TypeError): ht.copy("hello world") + + def test_sanitize_memory_layout(self): + # non distributed, 2D + a_torch = torch.arange(12).reshape(4, 3) + a_heat_C = ht.array(a_torch) + a_heat_F = ht.array(a_torch, order="F") + self.assertTrue_memory_layout(a_heat_C, "C") + self.assertTrue_memory_layout(a_heat_F, "F") + # non distributed, 5D + a_torch_5d = torch.arange(4 * 3 * 5 * 2 * 1).reshape(4, 3, 1, 2, 5) + a_heat_5d_C = ht.array(a_torch_5d) + a_heat_5d_F = ht.array(a_torch_5d, order="F") + self.assertTrue_memory_layout(a_heat_5d_C, "C") + self.assertTrue_memory_layout(a_heat_5d_F, "F") + a_heat_5d_F_sum = a_heat_5d_F.sum(-2) + a_torch_5d_sum = a_torch_5d.sum(-2) + self.assert_array_equal(a_heat_5d_F_sum, a_torch_5d_sum) + # distributed, split, 2D + size = ht.communication.MPI_WORLD.size + a_torch_2d = torch.arange(4 * size * 3 * size).reshape(4 * size, 3 * size) + a_heat_2d_C_split = ht.array(a_torch_2d, split=0) + a_heat_2d_F_split = ht.array(a_torch_2d, split=1, order="F") + self.assertTrue_memory_layout(a_heat_2d_C_split, "C") + self.assertTrue_memory_layout(a_heat_2d_F_split, "F") + a_heat_2d_F_split_sum = a_heat_2d_F_split.sum(1) + a_torch_2d_sum = a_torch_2d.sum(1) + self.assert_array_equal(a_heat_2d_F_split_sum, a_torch_2d_sum) + # distributed, split, 5D + a_torch_5d = torch.arange(4 * 3 * 5 * 2 * size * 7).reshape(4, 3, 7, 2 * size, 5) + a_heat_5d_C_split = ht.array(a_torch_5d, split=-2) + a_heat_5d_F_split = ht.array(a_torch_5d, split=-2, order="F") + self.assertTrue_memory_layout(a_heat_5d_C_split, "C") + self.assertTrue_memory_layout(a_heat_5d_F_split, "F") + a_heat_5d_F_split_sum = a_heat_5d_F_split.sum(-2) + a_torch_5d_sum = a_torch_5d.sum(-2) + self.assert_array_equal(a_heat_5d_F_split_sum, a_torch_5d_sum) + # distributed, is_split, 2D + a_heat_2d_C_issplit = ht.array(a_torch_2d, is_split=0) + a_heat_2d_F_issplit = ht.array(a_torch_2d, is_split=1, order="F") + self.assertTrue_memory_layout(a_heat_2d_C_issplit, "C") + self.assertTrue_memory_layout(a_heat_2d_F_issplit, "F") + a_heat_2d_F_issplit_sum = a_heat_2d_F_issplit.sum(1) + a_torch_2d_sum = a_torch_2d.sum(1) * size + self.assert_array_equal(a_heat_2d_F_issplit_sum, a_torch_2d_sum) + # distributed, is_split, 5D + a_heat_5d_C_issplit = ht.array(a_torch_5d, is_split=-2) + a_heat_5d_F_issplit = ht.array(a_torch_5d, is_split=-2, order="F") + self.assertTrue_memory_layout(a_heat_5d_C_issplit, "C") + self.assertTrue_memory_layout(a_heat_5d_F_issplit, "F") + a_heat_5d_F_issplit_sum = a_heat_5d_F_issplit.sum(-2) + a_torch_5d_sum = a_torch_5d.sum(-2) * size + self.assert_array_equal(a_heat_5d_F_issplit_sum, a_torch_5d_sum) + # test exceptions + with self.assertRaises(NotImplementedError): + ht.zeros_like(a_heat_5d_C_split, order="K") diff --git a/heat/core/tests/test_suites/basic_test.py b/heat/core/tests/test_suites/basic_test.py index 7f95dcee1d..27fd64a793 100644 --- a/heat/core/tests/test_suites/basic_test.py +++ b/heat/core/tests/test_suites/basic_test.py @@ -293,6 +293,24 @@ def assert_func_equal_for_tensor( else: self.assertTrue(np.array_equal(ht_res._DNDarray__array.numpy(), np_res)) + def assertTrue_memory_layout(self, tensor, order): + """ + Checks that the memory layout of a given heat tensor is as specified by argument order. + + Parameters: + ----------- + order: str, 'C' for C-like (row-major), 'F' for Fortran-like (column-major) memory layout. + """ + stride = tensor._DNDarray__array.stride() + row_major = all(np.diff(list(stride)) <= 0) + column_major = all(np.diff(list(stride)) >= 0) + if order == "C": + return self.assertTrue(row_major) + elif order == "F": + return self.assertTrue(column_major) + else: + raise ValueError("expected order to be 'C' or 'F', but was {}".format(order)) + def __create_random_np_array(self, shape, dtype=np.float64, low=-10000, high=10000): """ Creates a random array based on the input parameters. diff --git a/heat/core/tests/test_suites/test_basic_test.py b/heat/core/tests/test_suites/test_basic_test.py index eaf6f3a1b7..6fc536226f 100644 --- a/heat/core/tests/test_suites/test_basic_test.py +++ b/heat/core/tests/test_suites/test_basic_test.py @@ -85,3 +85,9 @@ def test_assert_func_equal_for_tensor(self): array = ht.ones((15, 15)) with self.assertRaises(TypeError): self.assert_func_equal_for_tensor(array, heat_func=ht_func, numpy_func=np_func) + + def test_assertTrue_memory_layout(self): + data = torch.arange(3 * 4 * 5).reshape(3, 4, 5) + data_F = ht.array(data, order="F") + with self.assertRaises(ValueError): + self.assertTrue_memory_layout(data_F, order="K")