From 919b58546ad5465a34d79e8e88a68e2ff24d6ab1 Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Tue, 10 Mar 2020 15:05:54 +0100 Subject: [PATCH 1/7] implement flip() --- CHANGELOG.md | 1 + heat/core/manipulations.py | 57 +++++++++++++++++++++++++++ heat/core/tests/test_manipulations.py | 12 ++++++ 3 files changed, 70 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a84eda8ff0..aecaa9f8f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - [#429](https://github.com/helmholtz-analytics/heat/pull/429) Added PyTorch Jitter to inner function of matmul for increased speed - [#483](https://github.com/helmholtz-analytics/heat/pull/483) Bugfix: Underlying torch tensor moves to the right device on array initialisation - [#483](https://github.com/helmholtz-analytics/heat/pull/483) Bugfix:DNDarray.cpu() changes heat device to cpu +- [#498](https://github.com/helmholtz-analytics/heat/pull/498) Feature: flip() # v0.3.0 diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index ebd5f4882c..00d588d305 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -14,6 +14,7 @@ "diag", "diagonal", "expand_dims", + "flip", "hstack", "resplit", "sort", @@ -556,6 +557,62 @@ def expand_dims(a, axis): ) +def flip(a, axis=None): + """ + Reverse the order of elements in an array along the given axis. + + The shape of the array is preserved, but the elements are reordered. + + Parameters + ---------- + a: ht.DNDarray + Input array to be flipped + axis: tuple + a list of axes to be flipped + + Returns + ------- + res: ht.DNDarray + The flipped array. + + Examples + -------- + >>> a = ht.array([[0,1],[2,3]]) + >>> ht.flip(a, [0]) + tensor([[2, 3], + [0, 1]]) + + >>> ht.flip(a, [0,1]) + tensor([[3, 2], + [1, 0]]) + """ + # Nothing to do + if a.numdims <= 1: + return a + + flipped = torch.flip(a._DNDarray__array, axis) + + if a.split not in axis: + return factories.array( + flipped, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm + ) + + # Need to redistribute tensors on split axis + lshape_map = a.create_lshape_map() + dest_proc = a.comm.size - 1 - a.comm.rank + + req = a.comm.Isend(flipped, dest=dest_proc) + received = torch.empty( + tuple(lshape_map[dest_proc]), dtype=a._DNDarray__array.dtype, device=a.device.torch_device + ) + a.comm.Recv(received, source=dest_proc) + + res = factories.array(received, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm) + res.balance_() # after swapping, first processes may be empty + req.Wait() + return res + + def hstack(tup): """ Stack arrays in sequence horizontally (column wise). diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 5e77ee6a03..eb25dbd02e 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -763,6 +763,18 @@ def test_expand_dims(self): with self.assertRaises(ValueError): ht.empty((3, 4, 5), device=ht_device).expand_dims(-5) + def test_flip(self): + a = ht.array([1, 2]) + self.assertTrue(ht.equal(ht.flip(a, [0]), a)) + + a = ht.array([[2, 3], [4, 5], [6, 7], [8, 9]], split=1, dtype=ht.float32) + r_a = ht.array([[9, 8], [7, 6], [5, 4], [3, 2]], split=1, dtype=ht.float32) + self.assertTrue(ht.equal(ht.flip(a, [0, 1]), r_a)) + + a = ht.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], split=0, dtype=ht.uint8) + r_a = ht.array([[[3, 2], [1, 0]], [[7, 6], [5, 4]]], split=0, dtype=ht.uint8) + self.assertTrue(ht.equal(ht.flip(a, [1, 2]), r_a)) + def test_hstack(self): # cases to test: # MM=================================== From 6e47e85d382aa1fcf1c858e6de341906e0bc6a83 Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Wed, 11 Mar 2020 09:32:49 +0100 Subject: [PATCH 2/7] more numpy api --- heat/core/manipulations.py | 12 ++++++++---- heat/core/tests/test_manipulations.py | 7 ++++++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 00d588d305..d7ad303ccb 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -567,7 +567,7 @@ def flip(a, axis=None): ---------- a: ht.DNDarray Input array to be flipped - axis: tuple + axis: int, tuple a list of axes to be flipped Returns @@ -586,9 +586,13 @@ def flip(a, axis=None): tensor([[3, 2], [1, 0]]) """ - # Nothing to do - if a.numdims <= 1: - return a + # flip all dimensions + if axis is None: + axis = tuple(range(a.numdims)) + + # torch.flip only accepts tuples + if isinstance(axis, int): + axis = [axis] flipped = torch.flip(a._DNDarray__array, axis) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index eb25dbd02e..24d35edd03 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -765,7 +765,12 @@ def test_expand_dims(self): def test_flip(self): a = ht.array([1, 2]) - self.assertTrue(ht.equal(ht.flip(a, [0]), a)) + r_a = ht.array([2, 1]) + self.assertTrue(ht.equal(ht.flip(a, 0), r_a)) + + a = ht.array([[1, 2], [3, 4]]) + r_a = ht.array([[4, 3], [2, 1]]) + self.assertTrue(ht.equal(ht.flip(a), r_a)) a = ht.array([[2, 3], [4, 5], [6, 7], [8, 9]], split=1, dtype=ht.float32) r_a = ht.array([[9, 8], [7, 6], [5, 4], [3, 2]], split=1, dtype=ht.float32) From 06965ef230506d25d133cabe0a5b898f9d1e6f35 Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Thu, 12 Mar 2020 09:22:31 +0100 Subject: [PATCH 3/7] add device parameter in test --- heat/core/tests/test_manipulations.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 24d35edd03..cc42d128dd 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -764,20 +764,26 @@ def test_expand_dims(self): ht.empty((3, 4, 5), device=ht_device).expand_dims(-5) def test_flip(self): - a = ht.array([1, 2]) - r_a = ht.array([2, 1]) + a = ht.array([1, 2], device=ht_device) + r_a = ht.array([2, 1], device=ht_device) self.assertTrue(ht.equal(ht.flip(a, 0), r_a)) - a = ht.array([[1, 2], [3, 4]]) - r_a = ht.array([[4, 3], [2, 1]]) + a = ht.array([[1, 2], [3, 4]], device=ht_device) + r_a = ht.array([[4, 3], [2, 1]], device=ht_device) self.assertTrue(ht.equal(ht.flip(a), r_a)) - a = ht.array([[2, 3], [4, 5], [6, 7], [8, 9]], split=1, dtype=ht.float32) - r_a = ht.array([[9, 8], [7, 6], [5, 4], [3, 2]], split=1, dtype=ht.float32) + a = ht.array([[2, 3], [4, 5], [6, 7], [8, 9]], split=1, dtype=ht.float32, device=ht_device) + r_a = ht.array( + [[9, 8], [7, 6], [5, 4], [3, 2]], split=1, dtype=ht.float32, device=ht_device + ) self.assertTrue(ht.equal(ht.flip(a, [0, 1]), r_a)) - a = ht.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], split=0, dtype=ht.uint8) - r_a = ht.array([[[3, 2], [1, 0]], [[7, 6], [5, 4]]], split=0, dtype=ht.uint8) + a = ht.array( + [[[0, 1], [2, 3]], [[4, 5], [6, 7]]], split=0, dtype=ht.uint8, device=ht_device + ) + r_a = ht.array( + [[[3, 2], [1, 0]], [[7, 6], [5, 4]]], split=0, dtype=ht.uint8, device=ht_device + ) self.assertTrue(ht.equal(ht.flip(a, [1, 2]), r_a)) def test_hstack(self): From ca4491630ce91ad4793ee35ea7c9a6b2296a5c3a Mon Sep 17 00:00:00 2001 From: mtar Date: Mon, 23 Mar 2020 08:01:52 +0100 Subject: [PATCH 4/7] add split example --- heat/core/manipulations.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index d7ad303ccb..f5f3a92bb6 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -568,7 +568,7 @@ def flip(a, axis=None): a: ht.DNDarray Input array to be flipped axis: int, tuple - a list of axes to be flipped + A list of axes to be flipped Returns ------- @@ -581,10 +581,11 @@ def flip(a, axis=None): >>> ht.flip(a, [0]) tensor([[2, 3], [0, 1]]) - + + >>> b = ht.array([[0,1,2],[3,4,5]], split=1) >>> ht.flip(a, [0,1]) - tensor([[3, 2], - [1, 0]]) + (1/2) tensor([5,4,3]) + (2/2) tensor([2,1,0]) """ # flip all dimensions if axis is None: From 4b4880ca7df40b0132270ee0eb7e47a7ab871426 Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Mon, 23 Mar 2020 08:11:09 +0100 Subject: [PATCH 5/7] black formatting --- heat/core/manipulations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index f5f3a92bb6..bc64370ba8 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -581,7 +581,7 @@ def flip(a, axis=None): >>> ht.flip(a, [0]) tensor([[2, 3], [0, 1]]) - + >>> b = ht.array([[0,1,2],[3,4,5]], split=1) >>> ht.flip(a, [0,1]) (1/2) tensor([5,4,3]) From a38434c90939b456ddf57e7cece7084e699e7f28 Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Wed, 1 Apr 2020 08:51:27 +0200 Subject: [PATCH 6/7] replace lshape_map --- heat/core/manipulations.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index bc64370ba8..33222f0bc8 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -603,12 +603,19 @@ def flip(a, axis=None): ) # Need to redistribute tensors on split axis - lshape_map = a.create_lshape_map() + old_lshape = torch.tensor(a.lshape, device=a.device.torch_device) + new_lshape = torch.empty((len(a.gshape),), dtype=int, device=a.device.torch_device) dest_proc = a.comm.size - 1 - a.comm.rank + # Exchange lshapes + request = a.comm.Irecv(new_lshape, source=dest_proc) + a.comm.Send(old_lshape, dest_proc) + request.Wait() + + # Exchange local tensors req = a.comm.Isend(flipped, dest=dest_proc) received = torch.empty( - tuple(lshape_map[dest_proc]), dtype=a._DNDarray__array.dtype, device=a.device.torch_device + tuple(new_lshape), dtype=a._DNDarray__array.dtype, device=a.device.torch_device ) a.comm.Recv(received, source=dest_proc) From 2efb1db08928bf4d9f4dc8e59cf0455204c50d99 Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Wed, 1 Apr 2020 10:30:06 +0200 Subject: [PATCH 7/7] use sendrecv for tuples --- heat/core/manipulations.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 33222f0bc8..2fe2b89629 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -603,20 +603,14 @@ def flip(a, axis=None): ) # Need to redistribute tensors on split axis - old_lshape = torch.tensor(a.lshape, device=a.device.torch_device) - new_lshape = torch.empty((len(a.gshape),), dtype=int, device=a.device.torch_device) + # Get local shapes + old_lshape = a.lshape dest_proc = a.comm.size - 1 - a.comm.rank - - # Exchange lshapes - request = a.comm.Irecv(new_lshape, source=dest_proc) - a.comm.Send(old_lshape, dest_proc) - request.Wait() + new_lshape = a.comm.sendrecv(old_lshape, dest=dest_proc, source=dest_proc) # Exchange local tensors req = a.comm.Isend(flipped, dest=dest_proc) - received = torch.empty( - tuple(new_lshape), dtype=a._DNDarray__array.dtype, device=a.device.torch_device - ) + received = torch.empty(new_lshape, dtype=a._DNDarray__array.dtype, device=a.device.torch_device) a.comm.Recv(received, source=dest_proc) res = factories.array(received, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm)