diff --git a/CHANGELOG.md b/CHANGELOG.md index d9baacd5d3..e7f45d8100 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ ### Linear Algebra - [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()` - [#846](https://github.com/helmholtz-analytics/heat/pull/846) New features `norm`, `vector_norm`, `matrix_norm` +### Logical +- [#862](https://github.com/helmholtz-analytics/heat/pull/862) New feature `signbit` ### Manipulations - [#829](https://github.com/helmholtz-analytics/heat/pull/829) New feature: `roll` - [#853](https://github.com/helmholtz-analytics/heat/pull/853) New Feature: `swapaxes` diff --git a/heat/core/logical.py b/heat/core/logical.py index 5e299bd6e5..4d9ff4a463 100644 --- a/heat/core/logical.py +++ b/heat/core/logical.py @@ -31,6 +31,7 @@ "logical_not", "logical_or", "logical_xor", + "signbit", ] @@ -508,3 +509,23 @@ def sanitize_input_type( else: return x, y + + +def signbit(x: DNDarray, out: Optional[DNDarray] = None) -> DNDarray: + """ + Checks if signbit is set element-wise (less than zero). + + Parameters + ---------- + x : DNDarray + The input array. + out : DNDarray, optional + The output array. + + Examples + -------- + >>> a = ht.array([2, -1.3, 0]) + >>> ht.signbit(a) + DNDarray([False, True, False], dtype=ht.bool, device=cpu:0, split=None) + """ + return _operations.__local_op(torch.signbit, x, out, no_cast=True) diff --git a/heat/core/tests/test_logical.py b/heat/core/tests/test_logical.py index aa1782138e..a995d53db3 100644 --- a/heat/core/tests/test_logical.py +++ b/heat/core/tests/test_logical.py @@ -481,3 +481,14 @@ def test_logical_xor(self): ht.array([[False, False], [False, False]]), ) ) + + def test_signbit(self): + a = ht.array([2, -1.3, 0, -5], split=0) + + sb = ht.signbit(a) + cmp = ht.array([False, True, False, True]) + + self.assertEqual(sb.dtype, ht.bool) + self.assertEqual(sb.split, 0) + self.assertEqual(sb.device, a.device) + self.assertTrue(ht.equal(sb, cmp))