Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cross #850

Merged
merged 14 commits into from
Jan 11, 2022
Merged

add cross #850

merged 14 commits into from
Jan 11, 2022

Conversation

mtar
Copy link
Collaborator

@mtar mtar commented Jul 29, 2021

Description

implementation of the cross product for 2d and 3d vectors

Issue/s resolved: #844

Changes proposed:

  • add cross

Type of change

  • New feature

Due Diligence

  • All split configurations tested
  • Multiple dtypes tested in relevant functions
  • Documentation updated (if needed)
  • Updated changelog.md under the title "Pending Additions"

Does this change modify the behaviour of other functions? If so, which?

no

@codecov
Copy link

codecov bot commented Jul 29, 2021

Codecov Report

Merging #850 (88038d1) into master (bc74848) will increase coverage by 0.01%.
The diff coverage is 97.72%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #850      +/-   ##
==========================================
+ Coverage   95.49%   95.50%   +0.01%     
==========================================
  Files          64       64              
  Lines        9535     9579      +44     
==========================================
+ Hits         9105     9148      +43     
- Misses        430      431       +1     
Flag Coverage Δ
gpu 94.63% <97.72%> (+0.01%) ⬆️
unit 91.06% <95.45%> (+0.02%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
heat/core/linalg/basics.py 95.04% <97.72%> (+0.15%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update bc74848...88038d1. Read the comment docs.

@mtar mtar requested a review from coquelin77 July 29, 2021 09:48
Comment on lines 71 to 74
if x1.split != x2.split:
raise ValueError(
"'x1' and 'x2' must have the same split, {} != {}".format(x1.split, x2.split)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there an open issue for this? we should probably have this functionality

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we use halos to do this?

Comment on lines 79 to 80
if x1.comm != x2.comm: # pragma: no cover
raise ValueError("'x1' and 'x2' must have the same comm, {} != {}".format(x1.comm, x2.comm))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may result in errors. in most places we assume that all DNDarrays are mapped across all processes the same way. I am very uncertain if this will cause errors at scale.


promoted = torch.promote_types(x1.larray.dtype, x2.larray.dtype)

ret = torch.cross(x1.larray.type(promoted), x2.larray.type(promoted), dim=axis)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so if im reading the correctly, this is a limited cross product function. do you have plans to implement more cases?

Copy link
Collaborator Author

@mtar mtar Aug 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What other cases? The only restriction I made deliberately was not allowing the split axis of the DNDarray as the axis argument. What do you miss?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it actually doesn't make any sense to split along the vector axis, it can only be 2 or 3 elements. What's missing is being able to set the axisa, axisb parameters, and being able to perform cross products of 2D with 3D vectors.

@coquelin77
Copy link
Member

oh also, i think it would be useful to add correctness tests to the unittests

@ClaudiaComito ClaudiaComito added this to the 1.2.x milestone Aug 18, 2021
Copy link
Contributor

@ClaudiaComito ClaudiaComito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mtar thank you so much for addressing this, it's great to have it! I have some changes requests, mainly to do with the numpy-API compliance.

heat/core/linalg/basics.py Show resolved Hide resolved
heat/core/linalg/basics.py Outdated Show resolved Hide resolved
sanitation.sanitize_in(x1)
sanitation.sanitize_in(x2)

if x1.gshape != x2.gshape:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shapes should be the same except for the axes that define the vectors. Example:

import numpy as np
a = np.arange(2*12).reshape(2,-1,3)
b = np.arange(2*4*2).reshape(2,4,2)
np.cross(a,b)
array([[[  -2,    0,    0],
        [ -15,   10,    1],
        [ -40,   32,    2],
        [ -77,   66,    3]],

       [[-126,  112,    4],
        [-187,  170,    5],
        [-260,  240,    6],
        [-345,  322,    7]]])

# different vector axis for b
b = np.transpose(b, (0, 2, 1))
b.shape
(2, 2, 4)
a.shape
(2, 4, 3)
np.cross(a, b, axisb=1)
array([[[  -2,    0,    0],
        [ -15,   10,    1],
        [ -40,   32,    2],
        [ -77,   66,    3]],

       [[-126,  112,    4],
        [-187,  170,    5],
        [-260,  240,    6],
        [-345,  322,    7]]])


if x1.gshape != x2.gshape:
raise ValueError(
"'x1' and 'x2' must have the same shape, {} != {}".format(x1.gshape, x2.gshape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the check here should be whether a.shape and b.shape are broadcastable after purging axisa, axisb

)
if x1.split != x2.split:
raise ValueError(
"'x1' and 'x2' must have the same split, {} != {}".format(x1.split, x2.split)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the splits must match after purging the vector axis

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont know what you mean here. Is this error message not sufficient?

)
if x1.device != x2.device:
raise ValueError(
"'x1' and 'x2' must have the same device type, {} != {}".format(x1.device, x2.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? one of them will be copied to cpu, but the operation can still be performed, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can remove the check. PyTorch will throw an error about the device mismatch then.

Comment on lines 95 to 96
x1.balance_()
x2.balance_()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?


promoted = torch.promote_types(x1.larray.dtype, x2.larray.dtype)

ret = torch.cross(x1.larray.type(promoted), x2.larray.type(promoted), dim=axis)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it actually doesn't make any sense to split along the vector axis, it can only be 2 or 3 elements. What's missing is being able to set the axisa, axisb parameters, and being able to perform cross products of 2D with 3D vectors.

self.assertEqual(cross.comm, a.comm)
self.assertEqual(cross.device, a.device)
self.assertTrue(ht.equal(cross, ht.array([[0, 0, -1], [-1, 0, 0], [0, -1, 0]])))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add tests for axisa, axisb, 2D with 3D cross product

@mtar
Copy link
Collaborator Author

mtar commented Oct 5, 2021

@mtar thank you so much for addressing this, it's great to have it! I have some changes requests, mainly to do with the numpy-API compliance.

I was following the Python array API standard on that one skipping NumPy's quirks 😃
https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-cross-x1-x2-axis-1

coquelin77
coquelin77 previously approved these changes Oct 8, 2021
@coquelin77
Copy link
Member

are all of the comments addressed here?

@mtar mtar added High priority, urgent and removed High priority, urgent labels Dec 10, 2021
ClaudiaComito and others added 2 commits December 21, 2021 11:49
Implement axisa, axisb, axisc functionality for ht.cross
Copy link
Contributor

@ClaudiaComito ClaudiaComito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mtar !

@ClaudiaComito ClaudiaComito merged commit 293d873 into master Jan 11, 2022
@ClaudiaComito ClaudiaComito deleted the feature/844-cross branch January 11, 2022 14:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

implement cross
3 participants