Skip to content

Commit

Permalink
stopped force_check in lshape map in roll, added 3D test for roll
Browse files Browse the repository at this point in the history
  • Loading branch information
coquelin77 committed Aug 2, 2021
1 parent c614253 commit b215c59
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
3 changes: 2 additions & 1 deletion heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1967,7 +1967,8 @@ def roll(
size = x.comm.Get_size()
rank = x.comm.Get_rank()

lshape_map = x.create_lshape_map()[:, x.split] # local elements along axis
# local elements along axis:
lshape_map = x.create_lshape_map(force_check=False)[:, x.split]
cumsum_map = torch.cumsum(lshape_map, dim=0) # cumulate along axis
indices = torch.arange(size, device=x.device.torch_device)
# NOTE Can be removed when min version>=1.9
Expand Down
35 changes: 35 additions & 0 deletions heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2452,6 +2452,41 @@ def test_roll(self):
self.assertEqual(rolled.split, a.split)
self.assertTrue(np.array_equal(rolled.numpy(), compare))

# added 3D test, only a quick test for functionality
a = ht.arange(4 * 5 * 6, dtype=ht.complex64).reshape((4, 5, 6), new_split=2)

rolled = ht.roll(a, -1)
compare = np.roll(a.numpy(), -1)
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(np.array_equal(rolled.numpy(), compare))

rolled = ht.roll(a, 1, 0)
compare = np.roll(a.numpy(), 1, 0)
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(np.array_equal(rolled.numpy(), compare))

rolled = ht.roll(a, -2, [0, 1])
compare = np.roll(a.numpy(), -2, [0, 1])
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(np.array_equal(rolled.numpy(), compare))

rolled = ht.roll(a, [1, 2, 1], [0, 1, -2])
compare = np.roll(a.numpy(), [1, 2, 1], [0, 1, -2])
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(np.array_equal(rolled.numpy(), compare))

with self.assertRaises(TypeError):
ht.roll(a, 1.0, 0)
with self.assertRaises(TypeError):
Expand Down

0 comments on commit b215c59

Please sign in to comment.