-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Features cumsum & cumprod #524
Conversation
heat/core/tests/test_arithmetics.py
Outdated
cprod = ht.cumprod(a, 0) | ||
self.assertTrue(ht.equal(cprod, result)) | ||
|
||
a = ht.full((2, 4, 2), 2, dtype=ht.float32, split=1, device=ht_device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please test up to 3 dims
heat/core/arithmetics.py
Outdated
|
||
recv = torch.ones(cprod.shape[:axis] + cprod.shape[axis + 1 :], dtype=send.dtype) | ||
|
||
a.comm.Exscan(send, recv, MPI.PROD) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@krajsek does this work with the upcoming AD stuff?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet.
heat/core/arithmetics.py
Outdated
cumprod : DNDarray | ||
A new array holding the result is returned unless `out` is | ||
specified, in which case a reference to out is returned. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add Raises section for exceptions that are thrown from within this function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall the raises section only be in __cum_op() or also here?
heat/core/arithmetics.py
Outdated
|
||
recv = torch.ones(cprod.shape[:axis] + cprod.shape[axis + 1 :], dtype=send.dtype) | ||
|
||
a.comm.Exscan(send, recv, MPI.PROD) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet.
heat/core/arithmetics.py
Outdated
`axis` is not None or `a` is a 1-d array. | ||
|
||
""" | ||
if axis is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sanitize_axis
heat/core/arithmetics.py
Outdated
|
||
recv = torch.zeros(csum.shape[:axis] + csum.shape[axis + 1 :], dtype=send.dtype) | ||
|
||
a.comm.Exscan(send, recv, MPI.SUM) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More or less identical code as above. Could you please introduce a __cum_op() call analogous to __reduce_op() in operations.py and make use of it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function added
heat/core/arithmetics.py
Outdated
if csum.numel() > 0: | ||
indices = -1 | ||
Ni, Nk = csum.shape[:axis], csum.shape[axis + 1 :] | ||
for ii in np.ndindex(Ni): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we somehow manage to do this without the Python loops via views? Performance on them should be slow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I replaced it with torch.index_select()
…-analytics/heat into enhancement/243-cum_operations
Codecov Report
@@ Coverage Diff @@
## master #524 +/- ##
==========================================
- Coverage 96.38% 96.35% -0.03%
==========================================
Files 75 75
Lines 14481 14585 +104
==========================================
+ Hits 13958 14054 +96
- Misses 523 531 +8
Continue to review full report at Codecov.
|
Description
Include a summary of the change/s.
Please also include relevant motivation and context. List any dependencies that are required for this change.
implementation of cumsum and cumprod functions. numpy's axis=None is not supported.
Issue/s resolved: #243
Changes proposed:
Type of change
Remove irrelevant options:
Due Diligence
Does this change modify the behaviour of other functions? If so, which?
yes / no