Skip to content

Commit

Permalink
Merge pull request #853 from helmholtz-analytics/feature/849-swapaxes
Browse files Browse the repository at this point in the history
add swapaxes
  • Loading branch information
coquelin77 authored Aug 20, 2021
2 parents e7c13c6 + 9b12c64 commit 634a2bd
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()`
### Manipulations
- [#829](https://github.com/helmholtz-analytics/heat/pull/829) New feature: `roll`
- [#853](https://github.com/helmholtz-analytics/heat/pull/853) New Feature: `swapaxes`
- [#854](https://github.com/helmholtz-analytics/heat/pull/854) New Feature: `moveaxis`


Expand Down
50 changes: 50 additions & 0 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"split",
"squeeze",
"stack",
"swapaxes",
"topk",
"unique",
"vsplit",
Expand Down Expand Up @@ -3023,6 +3024,55 @@ def stack(
return stacked


def swapaxes(x: DNDarray, axis1: int, axis2: int) -> DNDarray:
"""
Interchanges two axes of an array.
Parameters
----------
x : DNDarray
Input array.
axis1 : int
First axis.
axis2 : int
Second axis.
See Also
--------
:func:`~heat.core.linalg.basics.transpose`
Permute the dimensions of an array.
Examples
--------
>>> x = ht.array([[[0,1],[2,3]],[[4,5],[6,7]]])
>>> ht.swapaxes(x, 0, 1)
DNDarray([[[0, 1],
[4, 5]],
[[2, 3],
[6, 7]]], dtype=ht.int64, device=cpu:0, split=None)
>>> ht.swapaxes(x, 0, 2)
DNDarray([[[0, 4],
[2, 6]],
[[1, 5],
[3, 7]]], dtype=ht.int64, device=cpu:0, split=None)
"""
axes = list(range(x.ndim))
try:
axes[axis1], axes[axis2] = axes[axis2], axes[axis1]
except TypeError:
raise TypeError(
"'axis1' and 'axis2' must be of type int, found {} and {}".format(
type(axis1), type(axis2)
)
)

return linalg.transpose(x, axes)


DNDarray.swapaxes = lambda self, axis1, axis2: swapaxes(self, axis1, axis2)
DNDarray.swapaxes.__doc__ = swapaxes.__doc__


def unique(
a: DNDarray, sorted: bool = False, return_inverse: bool = False, axis: int = None
) -> Tuple[DNDarray, torch.tensor]:
Expand Down
11 changes: 11 additions & 0 deletions heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3256,6 +3256,17 @@ def test_stack(self):
with self.assertRaises(ValueError):
ht.stack((ht_a_split, ht_b_split, ht_c_split), out=out_wrong_split)

def test_swapaxes(self):
x = ht.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
swapped = ht.swapaxes(x, 0, 1)

self.assertTrue(
ht.equal(swapped, ht.array([[[0, 1], [4, 5]], [[2, 3], [6, 7]]], dtype=ht.int64))
)

with self.assertRaises(TypeError):
ht.swapaxes(x, 4.9, "abc")

def test_topk(self):
size = ht.MPI_WORLD.size
if size == 1:
Expand Down

0 comments on commit 634a2bd

Please sign in to comment.