Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug/825 setitem slice dndarrays #826

Merged
merged 23 commits into from
Jul 20, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
639cd57
setitem can now set values with DNDarrays which are not the same size…
coquelin77 Jun 29, 2021
d1a0e55
added new test cases (simple)
coquelin77 Jun 29, 2021
28c268f
added oop redistribute to manipulations
coquelin77 Jun 29, 2021
61266c1
changelog update
coquelin77 Jun 29, 2021
174c792
added more test cases to increase coveraged and removed some dead code
coquelin77 Jun 29, 2021
1938f4a
abstracted section of setitem: key slice generation
coquelin77 Jul 8, 2021
66851a5
Merge branch 'master' into bug/825-setitem-slice-dndarrays
coquelin77 Jul 8, 2021
94af9f4
used key logic in getitem, added typehints/simple docstring to xitem_…
coquelin77 Jul 8, 2021
e4f5364
corrected false logic in key start stop adjustments
coquelin77 Jul 8, 2021
4960951
Merge branch 'master' into bug/825-setitem-slice-dndarrays
coquelin77 Jul 13, 2021
2a82e05
added a raise in setitem for when the value and self have different s…
coquelin77 Jul 13, 2021
dc77f17
Merge branch 'master' into bug/825-setitem-slice-dndarrays
coquelin77 Jul 13, 2021
30e2df4
added handling for single value DNDarrays in key for setitem
coquelin77 Jul 13, 2021
755c786
corrected try/expect in setitem to work with torch tensors as well
coquelin77 Jul 13, 2021
fe6ad27
Merge branch 'master' into bug/825-setitem-slice-dndarrays
coquelin77 Jul 19, 2021
8d08330
removing dead code
coquelin77 Jul 20, 2021
f193b0b
Verb correction in lshape map creation
coquelin77 Jul 20, 2021
34ba9c5
new changelog to add pending additions again
coquelin77 Jul 20, 2021
2affb59
Merge branch 'bug/825-setitem-slice-dndarrays' of https://github.com/…
coquelin77 Jul 20, 2021
3ec7b44
added tests for lshape map property and forced creation
coquelin77 Jul 20, 2021
cf8aa1a
corrected incorrect changelog, wrong line was moved the the pending a…
coquelin77 Jul 20, 2021
bf39b25
added raise test for splits != case in setitem
coquelin77 Jul 20, 2021
89fb977
new raise test now only runs on multiple processes
coquelin77 Jul 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Example on 2 processes:
- [#820](https://github.com/helmholtz-analytics/heat/pull/820) `randn` values are pushed away from 0 by the minimum value the given dtype before being transformed into the Gaussian shape
- [#821](https://github.com/helmholtz-analytics/heat/pull/821) Fixed `__getitem__` handling of distributed `DNDarray` key element
- [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension
- [#831](https://github.com/helmholtz-analytics/heat/pull/831) `__getitem__` handling of `array-like` 1-element key
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pending Additions


## Feature additions
### Exponential
Expand Down
28 changes: 24 additions & 4 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,18 +576,18 @@ def cpu(self) -> DNDarray:
self.__device = devices.cpu
return self

def create_lshape_map(self, recreate: bool = True) -> torch.Tensor:
def create_lshape_map(self, force_check: bool = True) -> torch.Tensor:
"""
Generate a 'map' of the lshapes of the data on all processes.
Units are ``(process rank, lshape)``

Parameters
----------
recreate : bool, optional
force_check : bool, optional
if False (default) and the lshape map has already been created, use the previous
result. Otherwise, create the lshape_map
"""
if not recreate and self.__lshape_map is not None:
if not force_check and self.__lshape_map is not None:
return self.__lshape_map

lshape_map = torch.zeros(
Expand Down Expand Up @@ -1367,6 +1367,18 @@ def __setitem__(
[0., 1., 0., 0., 0.]])
"""
key = getattr(key, "copy()", key)
try:
if value.split != self.split:
val_split = int(value.split)
sp = self.split
warnings.warn(
f"\nvalue.split {val_split} not equal to this DNDarray's split:"
f" {sp}. this may cause errors or unwanted behavior",
category=RuntimeWarning,
Comment on lines +1375 to +1377
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😘

Copy link
Collaborator

@mtar mtar Jul 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a test for the warning? No test according to codecov

)
except (AttributeError, TypeError):
pass

if isinstance(key, DNDarray) and key.ndim == self.ndim:
# this splits the key into torch.Tensors in each dimension for advanced indexing
lkey = [slice(None, None, None)] * self.ndim
Expand All @@ -1392,6 +1404,13 @@ def __setitem__(
kend = key[ell_ind + 1 :]
slices = [slice(None)] * (self.ndim - (len(kst) + len(kend)))
key = kst + slices + kend

for c, k in enumerate(key):
try:
key[c] = k.item()
except (AttributeError, ValueError):
pass

key = tuple(key)

if not self.is_distributed():
Expand All @@ -1411,6 +1430,8 @@ def __setitem__(
chunk_start = chunk_slice[self.split].start
chunk_end = chunk_slice[self.split].stop

self_proxy = torch.ones((1,)).as_strided(self.gshape, [0] * self.ndim)

if not isinstance(key, tuple):
return self.__setter(key, value) # returns None

Expand Down Expand Up @@ -1448,7 +1469,6 @@ def __setitem__(
target_reshape_map = torch.zeros(
(self.comm.size, self.ndim), dtype=torch.int, device=self.device.torch_device
)
self_proxy = torch.ones((1,)).as_strided(self.gshape, [0] * self.ndim)
for r in range(self.comm.size):
if r not in actives:
loc_key = key.copy()
Expand Down
2 changes: 1 addition & 1 deletion heat/core/tests/test_dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,7 +1089,7 @@ def test_setitem_getitem(self):

# slice in 1st dim across 1 node (2nd) w/ singular second dim
c = ht.zeros((13, 5), split=0)
c[8:12, 1] = 1
c[8:12, ht.array(1)] = 1
b = c[8:12, np.int64(1)]
self.assertTrue((b == 1).all())
self.assertEqual(b.gshape, (4,))
Expand Down