-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Features/807 roll #829
Features/807 roll #829
Conversation
Codecov Report
@@ Coverage Diff @@
## master #829 +/- ##
==========================================
+ Coverage 95.32% 95.35% +0.02%
==========================================
Files 64 64
Lines 9056 9125 +69
==========================================
+ Hits 8633 8701 +68
- Misses 423 424 +1
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
test fails on pytorch 1.8.1: RuntimeError: repeats has to be Long tensor |
rerun tests |
also fails on torch 1.7 with the same error on the 8 GPU tests |
heat/core/manipulations.py
Outdated
|
||
lshape_map = x.create_lshape_map()[:, x.split] # local elements along axis | ||
cumsum_map = torch.cumsum(lshape_map, dim=0) # cumulate along axis | ||
indices = torch.arange(size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device
heat/core/manipulations.py
Outdated
# use pytorch if it's not the split axis | ||
rolled = torch.roll(x.larray, shift, axis) | ||
return DNDarray( | ||
rolled, | ||
gshape=x.shape, | ||
dtype=x.dtype, | ||
split=x.split, | ||
device=x.device, | ||
comm=x.comm, | ||
balanced=x.balanced, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would wonderful if these loops could be switched slightly so they are not so nested. That being said, i understand why you did it this way. I am not against keeping this way, its just odd to read in the moment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great, can you add 1 more test? 3d
sorry, i just put it in myself. i thought that upon refresh this would be deleted |
Description
This PR adds roll function to Heat. It uses PyTorch if the roll is not along the split axis. It improves on it that it is now possible to use a single shift on multiple axes (int/tuple) similar to NumPy.
NOTE
Requires PyTorch 1.9Simple check for 1.7. and 1.8. added promoting the type. Can be removed when the support is dropped.Issue/s resolved: #807
Changes proposed:
roll
Type of change
Due Diligence
Does this change modify the behaviour of other functions? If so, which?
no