-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Driver handling in svdvals function in torch_frontend #23718
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR Compliance Checks
Thank you for your Pull Request! We have run several checks on this pull request in order to make sure it's suitable for merging into this project. The results are listed in the following section.
Conventional Commit PR Title
In order to be considered for merging, the pull request title must match the specification in conventional commits. You can edit the title in order for this check to pass.
Most often, our PR titles are something like one of these:
- docs: correct typo in README
- feat: implement dark mode"
- fix: correct remove button behavior
Linting Errors
- Found type "null", must be one of "feat","fix","docs","style","refactor","perf","test","build","ci","chore","revert"
- No subject found
@@ -310,6 +311,12 @@ def svd(A, /, *, full_matrices=True, driver=None, out=None): | |||
) | |||
def svdvals(A, *, driver=None, out=None): | |||
# TODO: add handling for driver | |||
if driver == "gesvd": | |||
return torch.linalg.svdvals(A, out=out, driver=driver) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! I think this could be simplified into ... driver in [...]
and they can share a return. Also, please add the test for this argument in the svdvals
test in the torch frontend test folder, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i have added the test and simplified the function implmentaiton
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the test should be modified in test_linalg.py
in the torch frontend tests folder within ivy_test
, not sure I've missed anything but the file is not shown as changed in the diff of the PR, I don't think the tests are added yet, also, please don't update the test_array_api
submodule yet, thanks!
|
@@ -1,5 +1,6 @@ | |||
# local | |||
import math | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! The native torch shouldn't be imported in the frontends, if ivy.svdvals
doesn't support this argument yet, you should first implement the new argument in the backend, add the related argument in the tests, and then use the ivy function in this frontend, thanks!
Thank you for this PR, here is the CI results: This pull request does not result in any additional test failures. Congratulations! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Thanks for the great work! The svdvals
function in the other backends (jax, tf, paddle, np) should also have the argument included. If they don't support it natively, you should add some compositional logic for the support, thanks!
can you merge the chnanges for now as i need to complete for the hiring process , and i will create a new issue for handling the driver in the other backends and work on them just like i did in torch issue |
ivy/functional/ivy/linear_algebra.py
Outdated
x: Union[ivy.Array, ivy.NativeArray], | ||
/, | ||
*, | ||
driver: Optional[str], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey! sure, but could you please at least add this argument to every backend + a todo comment so that its usage won't break when a user use this on a backend other than torch? thanks!
PS: the driver argument should have a = None
return torch.linalg.svdvals(x, out=out) | ||
@with_unsupported_dtypes({"2.0.1 and below": ("float16", "bfloat16")}, backend_version) | ||
def svdvals( | ||
x: torch.Tensor, /, *, driver: Optional[str], out: Optional[torch.Tensor] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
driver: ... = None
required
x: JaxArray, /, *, driver: Optional[str] = None, out: Optional[JaxArray] = None | ||
) -> JaxArray: | ||
# TODO: handling the driver argument | ||
return jnp.linalg.svd(x, driver=driver, compute_uv=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but not here as this is not a valid signature based on official docs
i think all is good now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing!
todo: added driver parameter to torch linalg frontend
handling for the driver in the svdvals function in torch frontend
Related Issue
Close #23145
Checklist
Socials: