diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index a78e8f3336..63d6a316ba 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -1967,7 +1967,8 @@ def roll( size = x.comm.Get_size() rank = x.comm.Get_rank() - lshape_map = x.create_lshape_map()[:, x.split] # local elements along axis + # local elements along axis: + lshape_map = x.create_lshape_map(force_check=False)[:, x.split] cumsum_map = torch.cumsum(lshape_map, dim=0) # cumulate along axis indices = torch.arange(size, device=x.device.torch_device) # NOTE Can be removed when min version>=1.9 diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 1cf5470cad..5a6d139559 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2452,6 +2452,41 @@ def test_roll(self): self.assertEqual(rolled.split, a.split) self.assertTrue(np.array_equal(rolled.numpy(), compare)) + # added 3D test, only a quick test for functionality + a = ht.arange(4 * 5 * 6, dtype=ht.complex64).reshape((4, 5, 6), new_split=2) + + rolled = ht.roll(a, -1) + compare = np.roll(a.numpy(), -1) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(np.array_equal(rolled.numpy(), compare)) + + rolled = ht.roll(a, 1, 0) + compare = np.roll(a.numpy(), 1, 0) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(np.array_equal(rolled.numpy(), compare)) + + rolled = ht.roll(a, -2, [0, 1]) + compare = np.roll(a.numpy(), -2, [0, 1]) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(np.array_equal(rolled.numpy(), compare)) + + rolled = ht.roll(a, [1, 2, 1], [0, 1, -2]) + compare = np.roll(a.numpy(), [1, 2, 1], [0, 1, -2]) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(np.array_equal(rolled.numpy(), compare)) + with self.assertRaises(TypeError): ht.roll(a, 1.0, 0) with self.assertRaises(TypeError):