From 4e26f6de609a8e1e56e78fcabf37dbde14d6339a Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Tue, 29 Jun 2021 16:23:49 +0200 Subject: [PATCH 01/17] add sign, sgn --- CHANGELOG.md | 2 + heat/core/rounding.py | 95 +++++++++++++++++++++++++++++++- heat/core/tests/test_rounding.py | 68 +++++++++++++++++++++++ 3 files changed, 164 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c26939204..fdc833bdba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,8 @@ Example on 2 processes: - [#768](https://github.com/helmholtz-analytics/heat/pull/768) New feature: unary positive and negative operations ### Manipulations - [#796](https://github.com/helmholtz-analytics/heat/pull/796) `DNDarray.reshape(shape)`: method now allows shape elements to be passed in as single arguments. +### Rounding +- [#827](https://github.com/helmholtz-analytics/heat/pull/827) New feature: `sign`, `sgn` ### Trigonometrics / Arithmetic - [#806](https://github.com/helmholtz-analytics/heat/pull/809) New feature: `square` - [#809](https://github.com/helmholtz-analytics/heat/pull/809) New feature: `acosh`, `asinh`, `atanh` diff --git a/heat/core/rounding.py b/heat/core/rounding.py index bef9a13eab..d19e9175a2 100644 --- a/heat/core/rounding.py +++ b/heat/core/rounding.py @@ -12,7 +12,19 @@ from . import sanitation from . import types -__all__ = ["abs", "absolute", "ceil", "clip", "fabs", "floor", "modf", "round", "trunc"] +__all__ = [ + "abs", + "absolute", + "ceil", + "clip", + "fabs", + "floor", + "modf", + "round", + "sign", + "sgn", + "trunc", +] def abs( @@ -328,6 +340,87 @@ def round( DNDarray.round.__doc__ = round.__doc__ +def sign(x: DNDarray, out: Optional[DNDarray] = None): + """ + Returns an indication of the sign of a number, element-wise. The definition for complex values is equivalent to :math:`x / \\sqrt{x \\cdot x}`. + + Parameters + ---------- + x : DNDarray + Input array + out : DNDarray, optional + A location in which to store the results. + + See Also + -------- + :func:`sgn` + Equivalent function on non-complex arrays. The definition for complex values is equivalent to :math:`x / |x|`. + + Examples + -------- + >>> a = ht.array([-1, -0.5, 0, 0.5, 1]) + >>> ht.sign(a) + DNDarray([-1., -1., 0., 1., 1.], dtype=ht.float32, device=cpu:0, split=None) + >>> ht.sign(ht.array([5-2j, 3+4j])) + DNDarray([(1+0j), (1+0j)], dtype=ht.complex64, device=cpu:0, split=None) + """ + # special case for complex values + if types.heat_type_is_complexfloating(x.dtype): + sanitation.sanitize_in(x) + if out is not None: + sanitation.sanitize_out(out, x.shape, x.split, x.device) + out.larray.copy_(x.larray) + data = out.larray + else: + data = torch.clone(x.larray) + + indices = torch.nonzero(data) + pos = torch.split(indices, 1, 1) + data[pos] = x.larray[pos] / torch.sqrt(torch.square(x.larray[pos])) + + if out is not None: + out.__dtype = types.heat_type_of(data) + return out + return DNDarray( + data, + gshape=x.shape, + dtype=types.heat_type_of(data), + split=x.split, + device=x.device, + comm=x.comm, + balanced=x.balanced, + ) + + return _operations.__local_op(torch.sign, x, out) + + +def sgn(x: DNDarray, out: Optional[DNDarray] = None): + """ + Returns an indication of the sign of a number, element-wise. The definition for complex values is equivalent to :math:`x / |x|`. + + Parameters + ---------- + x : DNDarray + Input array + out : DNDarray, optional + A location in which to store the results. + + See Also + -------- + :func:`sign` + Equivalent function on non-complex arrays. The definition for complex values is equivalent to :math:`x / \\sqrt{x \\cdot x}` + + Examples + -------- + >>> a = ht.array([-1, -0.5, 0, 0.5, 1]) + >>> ht.sign(a) + DNDarray([-1., -1., 0., 1., 1.], dtype=ht.float32, device=cpu:0, split=None) + >>> ht.sgn(ht.array([5-2j, 3+4j])) + DNDarray([(0.9284766912460327-0.3713906705379486j), (0.6000000238418579+0.800000011920929j)], dtype=ht.complex64, device=cpu:0, split=None) + """ + return _operations.__local_op(torch.sgn, x, out) + + def trunc(x: DNDarray, out: Optional[DNDarray] = None) -> DNDarray: """ Return the trunc of the input, element-wise. diff --git a/heat/core/tests/test_rounding.py b/heat/core/tests/test_rounding.py index 23af481e93..764f2c90b6 100644 --- a/heat/core/tests/test_rounding.py +++ b/heat/core/tests/test_rounding.py @@ -301,6 +301,74 @@ def test_round(self): self.assertEqual(float64_round_distrbd.dtype, ht.float64) self.assert_array_equal(float64_round_distrbd, comparison) + def test_sign(self): + # floats 1d + a = ht.array([-1, -0.5, 0, 0.5, 1]) + signed = ht.sign(a) + comparison = ht.array([-1.0, -1, 0, 1, 1]) + + self.assertEqual(signed.dtype, comparison.dtype) + self.assertEqual(signed.shape, comparison.shape) + self.assertEqual(signed.device, a.device) + self.assertEqual(signed.split, a.split) + self.assertTrue(ht.equal(signed, comparison)) + + # complex + 2d + split + a = ht.array([[1 - 2j, -0.5 + 1j], [0, 4 + 6j]], split=0) + signed = ht.sign(a) + comparison = ht.array([[1 + 0j, -1 + 0j], [0 + 0j, 1 + 0j]]) + + self.assertEqual(signed.dtype, comparison.dtype) + self.assertEqual(signed.shape, comparison.shape) + self.assertEqual(signed.device, a.device) + self.assertEqual(signed.split, a.split) + self.assertTrue(ht.equal(signed, comparison)) + + # complex + split + out + a = ht.array([[1 - 2j, -0.5 + 1j], [0, 4 + 6j]], split=1) + b = ht.empty_like(a) + signed = ht.sign(a, b) + comparison = ht.array([[1 + 0j, -1 + 0j], [0 + 0j, 1 + 0j]]) + + self.assertIs(b, signed) + self.assertEqual(signed.dtype, comparison.dtype) + self.assertEqual(signed.shape, comparison.shape) + self.assertEqual(signed.device, a.device) + self.assertEqual(signed.split, a.split) + self.assertTrue(ht.equal(signed, comparison)) + + # zeros + 3d + complex + split + a = ht.zeros((4, 4, 4), dtype=ht.complex128, split=2) + signed = ht.sign(a) + comparison = ht.zeros((4, 4, 4), dtype=ht.complex128) + + self.assertEqual(signed.dtype, comparison.dtype) + self.assertEqual(signed.shape, comparison.shape) + self.assertEqual(signed.device, a.device) + self.assertEqual(signed.split, a.split) + self.assertTrue(ht.equal(signed, comparison)) + + def test_sgn(self): + # floats + a = ht.array([-1, -0.5, 0, 0.5, 1]) + signed = ht.sgn(a) + comparison = ht.array([-1.0, -1, 0, 1, 1]) + + self.assertEqual(signed.dtype, comparison.dtype) + self.assertEqual(signed.shape, comparison.shape) + self.assertEqual(signed.device, a.device) + self.assertTrue(ht.equal(signed, comparison)) + + # complex + a = ht.array([[1 - 2j, -0.5 + 1j], [0 - 3j, 4 + 6j]], split=0) + signed = ht.sgn(a) + comparison = torch.sgn(torch.tensor([[1 - 2j, -0.5 + 1j], [0 - 3j, 4 + 6j]])) + + self.assertEqual(signed.dtype, ht.heat_type_of(comparison)) + self.assertEqual(signed.shape, a.shape) + self.assertEqual(signed.device, a.device) + self.assertTrue(ht.equal(signed, ht.array(comparison))) + def test_trunc(self): base_array = np.random.randn(20) From 4327a57a8e8a189e67da9fd658c6469164c8b6c9 Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Tue, 29 Jun 2021 16:42:52 +0200 Subject: [PATCH 02/17] compute indices on cpu --- heat/core/rounding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/heat/core/rounding.py b/heat/core/rounding.py index d19e9175a2..34ef6e5bed 100644 --- a/heat/core/rounding.py +++ b/heat/core/rounding.py @@ -372,9 +372,9 @@ def sign(x: DNDarray, out: Optional[DNDarray] = None): out.larray.copy_(x.larray) data = out.larray else: - data = torch.clone(x.larray) + data = torch.clone(x.larray.cpu()) - indices = torch.nonzero(data) + indices = torch.nonzero(data.cpu()) # nonzero_cuda not implemented for complex pos = torch.split(indices, 1, 1) data[pos] = x.larray[pos] / torch.sqrt(torch.square(x.larray[pos])) From 82a7e401548dea5ee2cf2e02248a1afcfd28c3d6 Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Tue, 29 Jun 2021 16:46:02 +0200 Subject: [PATCH 03/17] revert --- heat/core/rounding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/heat/core/rounding.py b/heat/core/rounding.py index 34ef6e5bed..d19e9175a2 100644 --- a/heat/core/rounding.py +++ b/heat/core/rounding.py @@ -372,9 +372,9 @@ def sign(x: DNDarray, out: Optional[DNDarray] = None): out.larray.copy_(x.larray) data = out.larray else: - data = torch.clone(x.larray.cpu()) + data = torch.clone(x.larray) - indices = torch.nonzero(data.cpu()) # nonzero_cuda not implemented for complex + indices = torch.nonzero(data) pos = torch.split(indices, 1, 1) data[pos] = x.larray[pos] / torch.sqrt(torch.square(x.larray[pos])) From 9d375a2f90403edfe41ab5caebc703db2ca420fe Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 13 Jul 2021 13:52:43 +0200 Subject: [PATCH 04/17] correcting device issue in sign tests (hopefully) --- heat/core/tests/test_rounding.py | 1 + 1 file changed, 1 insertion(+) diff --git a/heat/core/tests/test_rounding.py b/heat/core/tests/test_rounding.py index 764f2c90b6..de921de7f0 100644 --- a/heat/core/tests/test_rounding.py +++ b/heat/core/tests/test_rounding.py @@ -363,6 +363,7 @@ def test_sgn(self): a = ht.array([[1 - 2j, -0.5 + 1j], [0 - 3j, 4 + 6j]], split=0) signed = ht.sgn(a) comparison = torch.sgn(torch.tensor([[1 - 2j, -0.5 + 1j], [0 - 3j, 4 + 6j]])) + comparison = comparison.to(a.device.torch_device) self.assertEqual(signed.dtype, ht.heat_type_of(comparison)) self.assertEqual(signed.shape, a.shape) From cd1816777b94ac2378c83354cd3ad76eeed94048 Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Tue, 20 Jul 2021 09:48:25 +0200 Subject: [PATCH 05/17] fix for older pytorch versions --- CHANGELOG.md | 7 +++++-- heat/core/rounding.py | 9 ++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 48261e7cf1..3b51d73144 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# Pending Additions +## Feature Additions +### Rounding +- [#827](https://github.com/helmholtz-analytics/heat/pull/827) New feature: `sign`, `sgn` + # v1.1.0 ## Highlights @@ -62,8 +67,6 @@ Example on 2 processes: ### Manipulations - [#796](https://github.com/helmholtz-analytics/heat/pull/796) `DNDarray.reshape(shape)`: method now allows shape elements to be passed in as single arguments. -### Rounding -- [#827](https://github.com/helmholtz-analytics/heat/pull/827) New feature: `sign`, `sgn` ### Trigonometrics / Arithmetic - [#806](https://github.com/helmholtz-analytics/heat/pull/809) New feature: `square` diff --git a/heat/core/rounding.py b/heat/core/rounding.py index d19e9175a2..2280a50c68 100644 --- a/heat/core/rounding.py +++ b/heat/core/rounding.py @@ -373,9 +373,12 @@ def sign(x: DNDarray, out: Optional[DNDarray] = None): data = out.larray else: data = torch.clone(x.larray) - - indices = torch.nonzero(data) - pos = torch.split(indices, 1, 1) + # NOTE remove when min version >= 1.9 + if "1.7" in torch.__version__ or "1.8" in torch.__version__: + pos = data != 0 + else: + indices = torch.nonzero(data) + pos = torch.split(indices, 1, 1) data[pos] = x.larray[pos] / torch.sqrt(torch.square(x.larray[pos])) if out is not None: From fa223502ca4e58333370620758372d2e8247afda Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Fri, 30 Jul 2021 09:57:23 +0200 Subject: [PATCH 06/17] update type hints --- heat/core/rounding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/heat/core/rounding.py b/heat/core/rounding.py index 2280a50c68..1b0a3421df 100644 --- a/heat/core/rounding.py +++ b/heat/core/rounding.py @@ -340,7 +340,7 @@ def round( DNDarray.round.__doc__ = round.__doc__ -def sign(x: DNDarray, out: Optional[DNDarray] = None): +def sign(x: DNDarray, out: Optional[DNDarray] = None) -> DNDarray: """ Returns an indication of the sign of a number, element-wise. The definition for complex values is equivalent to :math:`x / \\sqrt{x \\cdot x}`. @@ -397,7 +397,7 @@ def sign(x: DNDarray, out: Optional[DNDarray] = None): return _operations.__local_op(torch.sign, x, out) -def sgn(x: DNDarray, out: Optional[DNDarray] = None): +def sgn(x: DNDarray, out: Optional[DNDarray] = None) -> DNDarray: """ Returns an indication of the sign of a number, element-wise. The definition for complex values is equivalent to :math:`x / |x|`. From b14de92b72d9c2e5057d893c9f25426c35e1c531 Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Mon, 23 Aug 2021 15:58:13 +0200 Subject: [PATCH 07/17] fix tests --- heat/core/tests/test_rounding.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/heat/core/tests/test_rounding.py b/heat/core/tests/test_rounding.py index de921de7f0..7fd2b687a0 100644 --- a/heat/core/tests/test_rounding.py +++ b/heat/core/tests/test_rounding.py @@ -316,7 +316,7 @@ def test_sign(self): # complex + 2d + split a = ht.array([[1 - 2j, -0.5 + 1j], [0, 4 + 6j]], split=0) signed = ht.sign(a) - comparison = ht.array([[1 + 0j, -1 + 0j], [0 + 0j, 1 + 0j]]) + comparison = ht.array([[1 + 0j, -1 + 0j], [0 + 0j, 1 + 0j]], split=0) self.assertEqual(signed.dtype, comparison.dtype) self.assertEqual(signed.shape, comparison.shape) @@ -328,7 +328,7 @@ def test_sign(self): a = ht.array([[1 - 2j, -0.5 + 1j], [0, 4 + 6j]], split=1) b = ht.empty_like(a) signed = ht.sign(a, b) - comparison = ht.array([[1 + 0j, -1 + 0j], [0 + 0j, 1 + 0j]]) + comparison = ht.array([[1 + 0j, -1 + 0j], [0 + 0j, 1 + 0j]], split=1) self.assertIs(b, signed) self.assertEqual(signed.dtype, comparison.dtype) @@ -340,7 +340,7 @@ def test_sign(self): # zeros + 3d + complex + split a = ht.zeros((4, 4, 4), dtype=ht.complex128, split=2) signed = ht.sign(a) - comparison = ht.zeros((4, 4, 4), dtype=ht.complex128) + comparison = ht.zeros((4, 4, 4), dtype=ht.complex128, split=2) self.assertEqual(signed.dtype, comparison.dtype) self.assertEqual(signed.shape, comparison.shape) @@ -362,13 +362,14 @@ def test_sgn(self): # complex a = ht.array([[1 - 2j, -0.5 + 1j], [0 - 3j, 4 + 6j]], split=0) signed = ht.sgn(a) - comparison = torch.sgn(torch.tensor([[1 - 2j, -0.5 + 1j], [0 - 3j, 4 + 6j]])) - comparison = comparison.to(a.device.torch_device) + comparison = ht.array( + torch.sgn(torch.tensor([[1 - 2j, -0.5 + 1j], [0 - 3j, 4 + 6j]])), split=0 + ) - self.assertEqual(signed.dtype, ht.heat_type_of(comparison)) + self.assertEqual(signed.dtype, comparison.dtype) self.assertEqual(signed.shape, a.shape) self.assertEqual(signed.device, a.device) - self.assertTrue(ht.equal(signed, ht.array(comparison))) + self.assertTrue(ht.equal(signed, comparison)) def test_trunc(self): base_array = np.random.randn(20) From 1c49a3c8a27f251d613e78c6d9bcc4d3f080410e Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Mon, 23 Aug 2021 16:39:12 +0200 Subject: [PATCH 08/17] seperate complex tests --- heat/core/tests/test_rounding.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/heat/core/tests/test_rounding.py b/heat/core/tests/test_rounding.py index 7fd2b687a0..d8c9c06f93 100644 --- a/heat/core/tests/test_rounding.py +++ b/heat/core/tests/test_rounding.py @@ -322,7 +322,8 @@ def test_sign(self): self.assertEqual(signed.shape, comparison.shape) self.assertEqual(signed.device, a.device) self.assertEqual(signed.split, a.split) - self.assertTrue(ht.equal(signed, comparison)) + self.assertTrue(ht.equal(signed.real, comparison.real)) + self.assertTrue(ht.equal(signed.imag, comparison.imag)) # complex + split + out a = ht.array([[1 - 2j, -0.5 + 1j], [0, 4 + 6j]], split=1) @@ -335,7 +336,8 @@ def test_sign(self): self.assertEqual(signed.shape, comparison.shape) self.assertEqual(signed.device, a.device) self.assertEqual(signed.split, a.split) - self.assertTrue(ht.equal(signed, comparison)) + self.assertTrue(ht.equal(signed.real, comparison.real)) + self.assertTrue(ht.equal(signed.imag, comparison.imag)) # zeros + 3d + complex + split a = ht.zeros((4, 4, 4), dtype=ht.complex128, split=2) @@ -346,7 +348,8 @@ def test_sign(self): self.assertEqual(signed.shape, comparison.shape) self.assertEqual(signed.device, a.device) self.assertEqual(signed.split, a.split) - self.assertTrue(ht.equal(signed, comparison)) + self.assertTrue(ht.equal(signed.real, comparison.real)) + self.assertTrue(ht.equal(signed.imag, comparison.imag)) def test_sgn(self): # floats @@ -369,7 +372,8 @@ def test_sgn(self): self.assertEqual(signed.dtype, comparison.dtype) self.assertEqual(signed.shape, a.shape) self.assertEqual(signed.device, a.device) - self.assertTrue(ht.equal(signed, comparison)) + self.assertTrue(ht.equal(signed.real, comparison.real)) + self.assertTrue(ht.equal(signed.imag, comparison.imag)) def test_trunc(self): base_array = np.random.randn(20) From 71a2b136bd8bbd25e0d5207e95973b8bb70f174d Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Mon, 23 Aug 2021 16:52:19 +0200 Subject: [PATCH 09/17] use allclose comparison --- heat/core/tests/test_rounding.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/heat/core/tests/test_rounding.py b/heat/core/tests/test_rounding.py index d8c9c06f93..a3b9c1e5e9 100644 --- a/heat/core/tests/test_rounding.py +++ b/heat/core/tests/test_rounding.py @@ -322,8 +322,8 @@ def test_sign(self): self.assertEqual(signed.shape, comparison.shape) self.assertEqual(signed.device, a.device) self.assertEqual(signed.split, a.split) - self.assertTrue(ht.equal(signed.real, comparison.real)) - self.assertTrue(ht.equal(signed.imag, comparison.imag)) + self.assertTrue(ht.allclose(signed.real, comparison.real)) + self.assertTrue(ht.allclose(signed.imag, comparison.imag)) # complex + split + out a = ht.array([[1 - 2j, -0.5 + 1j], [0, 4 + 6j]], split=1) @@ -336,8 +336,8 @@ def test_sign(self): self.assertEqual(signed.shape, comparison.shape) self.assertEqual(signed.device, a.device) self.assertEqual(signed.split, a.split) - self.assertTrue(ht.equal(signed.real, comparison.real)) - self.assertTrue(ht.equal(signed.imag, comparison.imag)) + self.assertTrue(ht.allclose(signed.real, comparison.real)) + self.assertTrue(ht.allclose(signed.imag, comparison.imag)) # zeros + 3d + complex + split a = ht.zeros((4, 4, 4), dtype=ht.complex128, split=2) @@ -348,8 +348,8 @@ def test_sign(self): self.assertEqual(signed.shape, comparison.shape) self.assertEqual(signed.device, a.device) self.assertEqual(signed.split, a.split) - self.assertTrue(ht.equal(signed.real, comparison.real)) - self.assertTrue(ht.equal(signed.imag, comparison.imag)) + self.assertTrue(ht.allclose(signed.real, comparison.real)) + self.assertTrue(ht.allclose(signed.imag, comparison.imag)) def test_sgn(self): # floats @@ -372,8 +372,8 @@ def test_sgn(self): self.assertEqual(signed.dtype, comparison.dtype) self.assertEqual(signed.shape, a.shape) self.assertEqual(signed.device, a.device) - self.assertTrue(ht.equal(signed.real, comparison.real)) - self.assertTrue(ht.equal(signed.imag, comparison.imag)) + self.assertTrue(ht.allclose(signed.real, comparison.real)) + self.assertTrue(ht.allclose(signed.imag, comparison.imag)) def test_trunc(self): base_array = np.random.randn(20) From cbc3f8c788f7ce1f163301b0627d9e7f90da6ce5 Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Mon, 23 Aug 2021 17:05:23 +0200 Subject: [PATCH 10/17] test sgn device --- heat/core/tests/test_rounding.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/heat/core/tests/test_rounding.py b/heat/core/tests/test_rounding.py index a3b9c1e5e9..eaa374fd7f 100644 --- a/heat/core/tests/test_rounding.py +++ b/heat/core/tests/test_rounding.py @@ -365,15 +365,13 @@ def test_sgn(self): # complex a = ht.array([[1 - 2j, -0.5 + 1j], [0 - 3j, 4 + 6j]], split=0) signed = ht.sgn(a) - comparison = ht.array( - torch.sgn(torch.tensor([[1 - 2j, -0.5 + 1j], [0 - 3j, 4 + 6j]])), split=0 - ) + comparison = torch.sgn(torch.tensor([[1 - 2j, -0.5 + 1j], [0 - 3j, 4 + 6j]])) + comparison = comparison.to(a.device.torch_device) - self.assertEqual(signed.dtype, comparison.dtype) + self.assertEqual(signed.dtype, ht.heat_type_of(comparison)) self.assertEqual(signed.shape, a.shape) self.assertEqual(signed.device, a.device) - self.assertTrue(ht.allclose(signed.real, comparison.real)) - self.assertTrue(ht.allclose(signed.imag, comparison.imag)) + self.assertTrue(ht.equal(signed, ht.array(comparison, split=0))) def test_trunc(self): base_array = np.random.randn(20) From a8e7cc636e5cd5fd5f4f7a04fff2d023e56c15ec Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Tue, 24 Aug 2021 09:23:33 +0200 Subject: [PATCH 11/17] debug --- heat/core/tests/test_rounding.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/heat/core/tests/test_rounding.py b/heat/core/tests/test_rounding.py index eaa374fd7f..14fd835ae6 100644 --- a/heat/core/tests/test_rounding.py +++ b/heat/core/tests/test_rounding.py @@ -1,3 +1,4 @@ +from heat.core.rounding import sign import numpy as np import torch @@ -322,6 +323,7 @@ def test_sign(self): self.assertEqual(signed.shape, comparison.shape) self.assertEqual(signed.device, a.device) self.assertEqual(signed.split, a.split) + print(signed, comparison) self.assertTrue(ht.allclose(signed.real, comparison.real)) self.assertTrue(ht.allclose(signed.imag, comparison.imag)) From 911be2e4b27030f5571571b4f0d4df6182aed81b Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Tue, 24 Aug 2021 09:29:15 +0200 Subject: [PATCH 12/17] higher tolerance --- heat/core/tests/test_rounding.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/heat/core/tests/test_rounding.py b/heat/core/tests/test_rounding.py index 14fd835ae6..3d52de6402 100644 --- a/heat/core/tests/test_rounding.py +++ b/heat/core/tests/test_rounding.py @@ -1,4 +1,3 @@ -from heat.core.rounding import sign import numpy as np import torch @@ -323,9 +322,8 @@ def test_sign(self): self.assertEqual(signed.shape, comparison.shape) self.assertEqual(signed.device, a.device) self.assertEqual(signed.split, a.split) - print(signed, comparison) self.assertTrue(ht.allclose(signed.real, comparison.real)) - self.assertTrue(ht.allclose(signed.imag, comparison.imag)) + self.assertTrue(ht.allclose(signed.imag, comparison.imag, rtol=2e-5)) # complex + split + out a = ht.array([[1 - 2j, -0.5 + 1j], [0, 4 + 6j]], split=1) @@ -339,7 +337,7 @@ def test_sign(self): self.assertEqual(signed.device, a.device) self.assertEqual(signed.split, a.split) self.assertTrue(ht.allclose(signed.real, comparison.real)) - self.assertTrue(ht.allclose(signed.imag, comparison.imag)) + self.assertTrue(ht.allclose(signed.imag, comparison.imag, rtol=2e-5)) # zeros + 3d + complex + split a = ht.zeros((4, 4, 4), dtype=ht.complex128, split=2) @@ -351,7 +349,7 @@ def test_sign(self): self.assertEqual(signed.device, a.device) self.assertEqual(signed.split, a.split) self.assertTrue(ht.allclose(signed.real, comparison.real)) - self.assertTrue(ht.allclose(signed.imag, comparison.imag)) + self.assertTrue(ht.allclose(signed.imag, comparison.imag, rtol=2e-5)) def test_sgn(self): # floats From 7ab3781d22ee4b9553fbb79aa387840f36af9aec Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Tue, 24 Aug 2021 09:58:02 +0200 Subject: [PATCH 13/17] higher tolerance --- heat/core/tests/test_rounding.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/heat/core/tests/test_rounding.py b/heat/core/tests/test_rounding.py index 3d52de6402..0b269cc8e0 100644 --- a/heat/core/tests/test_rounding.py +++ b/heat/core/tests/test_rounding.py @@ -323,7 +323,7 @@ def test_sign(self): self.assertEqual(signed.device, a.device) self.assertEqual(signed.split, a.split) self.assertTrue(ht.allclose(signed.real, comparison.real)) - self.assertTrue(ht.allclose(signed.imag, comparison.imag, rtol=2e-5)) + self.assertTrue(ht.allclose(signed.imag, comparison.imag, rtol=1e-4)) # complex + split + out a = ht.array([[1 - 2j, -0.5 + 1j], [0, 4 + 6j]], split=1) @@ -337,7 +337,7 @@ def test_sign(self): self.assertEqual(signed.device, a.device) self.assertEqual(signed.split, a.split) self.assertTrue(ht.allclose(signed.real, comparison.real)) - self.assertTrue(ht.allclose(signed.imag, comparison.imag, rtol=2e-5)) + self.assertTrue(ht.allclose(signed.imag, comparison.imag, rtol=1e-4)) # zeros + 3d + complex + split a = ht.zeros((4, 4, 4), dtype=ht.complex128, split=2) @@ -349,7 +349,7 @@ def test_sign(self): self.assertEqual(signed.device, a.device) self.assertEqual(signed.split, a.split) self.assertTrue(ht.allclose(signed.real, comparison.real)) - self.assertTrue(ht.allclose(signed.imag, comparison.imag, rtol=2e-5)) + self.assertTrue(ht.allclose(signed.imag, comparison.imag, rtol=1e-4)) def test_sgn(self): # floats From cde6fa0f40980823cfd8490fcabc933335ec40cf Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Tue, 24 Aug 2021 10:03:43 +0200 Subject: [PATCH 14/17] use atol --- heat/core/tests/test_rounding.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/heat/core/tests/test_rounding.py b/heat/core/tests/test_rounding.py index 0b269cc8e0..88f86482cd 100644 --- a/heat/core/tests/test_rounding.py +++ b/heat/core/tests/test_rounding.py @@ -323,7 +323,7 @@ def test_sign(self): self.assertEqual(signed.device, a.device) self.assertEqual(signed.split, a.split) self.assertTrue(ht.allclose(signed.real, comparison.real)) - self.assertTrue(ht.allclose(signed.imag, comparison.imag, rtol=1e-4)) + self.assertTrue(ht.allclose(signed.imag, comparison.imag, atol=2e-5)) # complex + split + out a = ht.array([[1 - 2j, -0.5 + 1j], [0, 4 + 6j]], split=1) @@ -337,7 +337,7 @@ def test_sign(self): self.assertEqual(signed.device, a.device) self.assertEqual(signed.split, a.split) self.assertTrue(ht.allclose(signed.real, comparison.real)) - self.assertTrue(ht.allclose(signed.imag, comparison.imag, rtol=1e-4)) + self.assertTrue(ht.allclose(signed.imag, comparison.imag, atol=2e-5)) # zeros + 3d + complex + split a = ht.zeros((4, 4, 4), dtype=ht.complex128, split=2) @@ -349,7 +349,7 @@ def test_sign(self): self.assertEqual(signed.device, a.device) self.assertEqual(signed.split, a.split) self.assertTrue(ht.allclose(signed.real, comparison.real)) - self.assertTrue(ht.allclose(signed.imag, comparison.imag, rtol=1e-4)) + self.assertTrue(ht.allclose(signed.imag, comparison.imag, atol=2e-5)) def test_sgn(self): # floats From ef780abd93ab4e58a130ad2b82cd86dfc705227f Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Thu, 26 Aug 2021 07:50:04 +0200 Subject: [PATCH 15/17] ignore code for pytorch>1.8 in coverage --- heat/core/rounding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/rounding.py b/heat/core/rounding.py index 1b0a3421df..f99f9a2883 100644 --- a/heat/core/rounding.py +++ b/heat/core/rounding.py @@ -376,7 +376,7 @@ def sign(x: DNDarray, out: Optional[DNDarray] = None) -> DNDarray: # NOTE remove when min version >= 1.9 if "1.7" in torch.__version__ or "1.8" in torch.__version__: pos = data != 0 - else: + else: # pragma: no cover indices = torch.nonzero(data) pos = torch.split(indices, 1, 1) data[pos] = x.larray[pos] / torch.sqrt(torch.square(x.larray[pos])) From 0dbdac093aa868a167a00699bb9c2efb2653b7ce Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Fri, 17 Sep 2021 10:00:58 +0200 Subject: [PATCH 16/17] alphabetical order --- heat/core/rounding.py | 56 ++++++++++++++++---------------- heat/core/tests/test_rounding.py | 44 ++++++++++++------------- 2 files changed, 50 insertions(+), 50 deletions(-) diff --git a/heat/core/rounding.py b/heat/core/rounding.py index f99f9a2883..f9fa6b005d 100644 --- a/heat/core/rounding.py +++ b/heat/core/rounding.py @@ -21,8 +21,8 @@ "floor", "modf", "round", - "sign", "sgn", + "sign", "trunc", ] @@ -340,6 +340,33 @@ def round( DNDarray.round.__doc__ = round.__doc__ +def sgn(x: DNDarray, out: Optional[DNDarray] = None) -> DNDarray: + """ + Returns an indication of the sign of a number, element-wise. The definition for complex values is equivalent to :math:`x / |x|`. + + Parameters + ---------- + x : DNDarray + Input array + out : DNDarray, optional + A location in which to store the results. + + See Also + -------- + :func:`sign` + Equivalent function on non-complex arrays. The definition for complex values is equivalent to :math:`x / \\sqrt{x \\cdot x}` + + Examples + -------- + >>> a = ht.array([-1, -0.5, 0, 0.5, 1]) + >>> ht.sign(a) + DNDarray([-1., -1., 0., 1., 1.], dtype=ht.float32, device=cpu:0, split=None) + >>> ht.sgn(ht.array([5-2j, 3+4j])) + DNDarray([(0.9284766912460327-0.3713906705379486j), (0.6000000238418579+0.800000011920929j)], dtype=ht.complex64, device=cpu:0, split=None) + """ + return _operations.__local_op(torch.sgn, x, out) + + def sign(x: DNDarray, out: Optional[DNDarray] = None) -> DNDarray: """ Returns an indication of the sign of a number, element-wise. The definition for complex values is equivalent to :math:`x / \\sqrt{x \\cdot x}`. @@ -397,33 +424,6 @@ def sign(x: DNDarray, out: Optional[DNDarray] = None) -> DNDarray: return _operations.__local_op(torch.sign, x, out) -def sgn(x: DNDarray, out: Optional[DNDarray] = None) -> DNDarray: - """ - Returns an indication of the sign of a number, element-wise. The definition for complex values is equivalent to :math:`x / |x|`. - - Parameters - ---------- - x : DNDarray - Input array - out : DNDarray, optional - A location in which to store the results. - - See Also - -------- - :func:`sign` - Equivalent function on non-complex arrays. The definition for complex values is equivalent to :math:`x / \\sqrt{x \\cdot x}` - - Examples - -------- - >>> a = ht.array([-1, -0.5, 0, 0.5, 1]) - >>> ht.sign(a) - DNDarray([-1., -1., 0., 1., 1.], dtype=ht.float32, device=cpu:0, split=None) - >>> ht.sgn(ht.array([5-2j, 3+4j])) - DNDarray([(0.9284766912460327-0.3713906705379486j), (0.6000000238418579+0.800000011920929j)], dtype=ht.complex64, device=cpu:0, split=None) - """ - return _operations.__local_op(torch.sgn, x, out) - - def trunc(x: DNDarray, out: Optional[DNDarray] = None) -> DNDarray: """ Return the trunc of the input, element-wise. diff --git a/heat/core/tests/test_rounding.py b/heat/core/tests/test_rounding.py index 88f86482cd..379c0e0ca4 100644 --- a/heat/core/tests/test_rounding.py +++ b/heat/core/tests/test_rounding.py @@ -301,6 +301,28 @@ def test_round(self): self.assertEqual(float64_round_distrbd.dtype, ht.float64) self.assert_array_equal(float64_round_distrbd, comparison) + def test_sgn(self): + # floats + a = ht.array([-1, -0.5, 0, 0.5, 1]) + signed = ht.sgn(a) + comparison = ht.array([-1.0, -1, 0, 1, 1]) + + self.assertEqual(signed.dtype, comparison.dtype) + self.assertEqual(signed.shape, comparison.shape) + self.assertEqual(signed.device, a.device) + self.assertTrue(ht.equal(signed, comparison)) + + # complex + a = ht.array([[1 - 2j, -0.5 + 1j], [0 - 3j, 4 + 6j]], split=0) + signed = ht.sgn(a) + comparison = torch.sgn(torch.tensor([[1 - 2j, -0.5 + 1j], [0 - 3j, 4 + 6j]])) + comparison = comparison.to(a.device.torch_device) + + self.assertEqual(signed.dtype, ht.heat_type_of(comparison)) + self.assertEqual(signed.shape, a.shape) + self.assertEqual(signed.device, a.device) + self.assertTrue(ht.equal(signed, ht.array(comparison, split=0))) + def test_sign(self): # floats 1d a = ht.array([-1, -0.5, 0, 0.5, 1]) @@ -351,28 +373,6 @@ def test_sign(self): self.assertTrue(ht.allclose(signed.real, comparison.real)) self.assertTrue(ht.allclose(signed.imag, comparison.imag, atol=2e-5)) - def test_sgn(self): - # floats - a = ht.array([-1, -0.5, 0, 0.5, 1]) - signed = ht.sgn(a) - comparison = ht.array([-1.0, -1, 0, 1, 1]) - - self.assertEqual(signed.dtype, comparison.dtype) - self.assertEqual(signed.shape, comparison.shape) - self.assertEqual(signed.device, a.device) - self.assertTrue(ht.equal(signed, comparison)) - - # complex - a = ht.array([[1 - 2j, -0.5 + 1j], [0 - 3j, 4 + 6j]], split=0) - signed = ht.sgn(a) - comparison = torch.sgn(torch.tensor([[1 - 2j, -0.5 + 1j], [0 - 3j, 4 + 6j]])) - comparison = comparison.to(a.device.torch_device) - - self.assertEqual(signed.dtype, ht.heat_type_of(comparison)) - self.assertEqual(signed.shape, a.shape) - self.assertEqual(signed.device, a.device) - self.assertTrue(ht.equal(signed, ht.array(comparison, split=0))) - def test_trunc(self): base_array = np.random.randn(20) From b6fa52f84f0a4a4cc2744cffe08112a9b5f35b36 Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Fri, 17 Sep 2021 10:30:32 +0200 Subject: [PATCH 17/17] fix changelog --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b62857575f..c6d15776e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,10 +18,10 @@ - [#829](https://github.com/helmholtz-analytics/heat/pull/829) New feature: `roll` - [#853](https://github.com/helmholtz-analytics/heat/pull/853) New Feature: `swapaxes` - [#854](https://github.com/helmholtz-analytics/heat/pull/854) New Feature: `moveaxis` -### Rounding -- [#827](https://github.com/helmholtz-analytics/heat/pull/827) New feature: `sign`, `sgn` ### Random - [#858](https://github.com/helmholtz-analytics/heat/pull/858) New Feature: `standard_normal`, `normal` +### Rounding +- [#827](https://github.com/helmholtz-analytics/heat/pull/827) New feature: `sign`, `sgn` # v1.1.1 - [#864](https://github.com/helmholtz-analytics/heat/pull/864) Dependencies: constrain `torchvision` version range to match supported `pytorch` version range.