-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add cross #850
add cross #850
Conversation
Codecov Report
@@ Coverage Diff @@
## master #850 +/- ##
==========================================
+ Coverage 95.49% 95.50% +0.01%
==========================================
Files 64 64
Lines 9535 9579 +44
==========================================
+ Hits 9105 9148 +43
- Misses 430 431 +1
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
heat/core/linalg/basics.py
Outdated
if x1.split != x2.split: | ||
raise ValueError( | ||
"'x1' and 'x2' must have the same split, {} != {}".format(x1.split, x2.split) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there an open issue for this? we should probably have this functionality
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we use halos to do this?
heat/core/linalg/basics.py
Outdated
if x1.comm != x2.comm: # pragma: no cover | ||
raise ValueError("'x1' and 'x2' must have the same comm, {} != {}".format(x1.comm, x2.comm)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this may result in errors. in most places we assume that all DNDarray
s are mapped across all processes the same way. I am very uncertain if this will cause errors at scale.
heat/core/linalg/basics.py
Outdated
|
||
promoted = torch.promote_types(x1.larray.dtype, x2.larray.dtype) | ||
|
||
ret = torch.cross(x1.larray.type(promoted), x2.larray.type(promoted), dim=axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so if im reading the correctly, this is a limited cross product function. do you have plans to implement more cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What other cases? The only restriction I made deliberately was not allowing the split axis of the DNDarray as the axis argument. What do you miss?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it actually doesn't make any sense to split along the vector axis, it can only be 2 or 3 elements. What's missing is being able to set the axisa
, axisb
parameters, and being able to perform cross products of 2D with 3D vectors.
oh also, i think it would be useful to add correctness tests to the unittests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mtar thank you so much for addressing this, it's great to have it! I have some changes requests, mainly to do with the numpy-API compliance.
heat/core/linalg/basics.py
Outdated
sanitation.sanitize_in(x1) | ||
sanitation.sanitize_in(x2) | ||
|
||
if x1.gshape != x2.gshape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shapes should be the same except for the axes that define the vectors. Example:
import numpy as np
a = np.arange(2*12).reshape(2,-1,3)
b = np.arange(2*4*2).reshape(2,4,2)
np.cross(a,b)
array([[[ -2, 0, 0],
[ -15, 10, 1],
[ -40, 32, 2],
[ -77, 66, 3]],
[[-126, 112, 4],
[-187, 170, 5],
[-260, 240, 6],
[-345, 322, 7]]])
# different vector axis for b
b = np.transpose(b, (0, 2, 1))
b.shape
(2, 2, 4)
a.shape
(2, 4, 3)
np.cross(a, b, axisb=1)
array([[[ -2, 0, 0],
[ -15, 10, 1],
[ -40, 32, 2],
[ -77, 66, 3]],
[[-126, 112, 4],
[-187, 170, 5],
[-260, 240, 6],
[-345, 322, 7]]])
heat/core/linalg/basics.py
Outdated
|
||
if x1.gshape != x2.gshape: | ||
raise ValueError( | ||
"'x1' and 'x2' must have the same shape, {} != {}".format(x1.gshape, x2.gshape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the check here should be whether a.shape
and b.shape
are broadcastable after purging axisa
, axisb
heat/core/linalg/basics.py
Outdated
) | ||
if x1.split != x2.split: | ||
raise ValueError( | ||
"'x1' and 'x2' must have the same split, {} != {}".format(x1.split, x2.split) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the splits must match after purging the vector axis
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i dont know what you mean here. Is this error message not sufficient?
heat/core/linalg/basics.py
Outdated
) | ||
if x1.device != x2.device: | ||
raise ValueError( | ||
"'x1' and 'x2' must have the same device type, {} != {}".format(x1.device, x2.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why? one of them will be copied to cpu, but the operation can still be performed, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can remove the check. PyTorch will throw an error about the device mismatch then.
heat/core/linalg/basics.py
Outdated
x1.balance_() | ||
x2.balance_() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why?
heat/core/linalg/basics.py
Outdated
|
||
promoted = torch.promote_types(x1.larray.dtype, x2.larray.dtype) | ||
|
||
ret = torch.cross(x1.larray.type(promoted), x2.larray.type(promoted), dim=axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it actually doesn't make any sense to split along the vector axis, it can only be 2 or 3 elements. What's missing is being able to set the axisa
, axisb
parameters, and being able to perform cross products of 2D with 3D vectors.
self.assertEqual(cross.comm, a.comm) | ||
self.assertEqual(cross.device, a.device) | ||
self.assertTrue(ht.equal(cross, ht.array([[0, 0, -1], [-1, 0, 0], [0, -1, 0]]))) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add tests for axisa
, axisb
, 2D with 3D cross product
I was following the Python array API standard on that one skipping NumPy's quirks 😃 |
are all of the comments addressed here? |
Implement axisa, axisb, axisc functionality for ht.cross
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @mtar !
Description
implementation of the cross product for 2d and 3d vectors
Issue/s resolved: #844
Changes proposed:
Type of change
Due Diligence
Does this change modify the behaviour of other functions? If so, which?
no