-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Features/lshape map+redistribute #428
Conversation
Codecov Report
@@ Coverage Diff @@
## master #428 +/- ##
==========================================
+ Coverage 96.58% 96.59% +0.01%
==========================================
Files 57 57
Lines 11981 12053 +72
==========================================
+ Hits 11572 11643 +71
- Misses 409 410 +1
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going to be really useful, thanks a lot. I'm running into a problem with the unittests though, see comment under test_redistribute.
Returns | ||
------- | ||
None, the local shapes of the DNDarray are modified | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a couple examples would be good, esp. about "the only important parts of the target map are the values along the split axis".
if lshape_map.shape != (self.comm.size, len(self.gshape)): | ||
raise ValueError( | ||
"lshape_map must have the dimensions ({}, {}), currently {}".format( | ||
self.comm.size, len(self.gshape), lshape_map.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mix up between "dimensions" and "shape"? (lshape_map must have shape blah blah, currently blah)
if not isinstance(target_map, torch.Tensor): | ||
raise TypeError( | ||
"target_map must be a torch.Tensor, currently {}".format(type(lshape_map)) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
format(type(lshape_map))
should be format(type(target_map))
if target_map.shape != (self.comm.size, len(self.gshape)): | ||
raise ValueError( | ||
"target_map must have the dimensions {}, currently {}".format( | ||
(self.comm.size, len(self.gshape)), target_map.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dimensions vs. shape, see above
Description
The balance function now uses a redistribute function. The logic is the same as it was before, but it can now be given an arbitrary shape to match.
Fixes: None, updates balance function in view of purposely unbalancing tensors for other operations
Changes proposed:
create_lshape_map
function to generate the lshape of a functionType of change
Select relevant options.
Are all split configurations tested and accounted for?
[x] yes [ ] no
Does this change require a documentation update outside of the changes proposed?
[ ] yes [x] no
Does this change modify the behaviour of other functions?
[x] yes [ ] no
Are there code practices which require justification?
[ ] yes [x] no