Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Driver handling in svdvals function in torch_frontend #23718

Merged
merged 26 commits into from
Oct 24, 2023
Merged

Driver handling in svdvals function in torch_frontend #23718

merged 26 commits into from
Oct 24, 2023

Conversation

AhmedHossam23
Copy link
Contributor

@AhmedHossam23 AhmedHossam23 commented Sep 16, 2023

todo: added driver parameter to torch linalg frontend

handling for the driver in the svdvals function in torch frontend

Related Issue

Close #23145

Checklist

  • Did you add a function?
  • Did you add the tests?
  • Did you run your tests and are your tests passing?
  • Did pre-commit not fail on any check?
  • Did you follow the steps we provided?

Socials:

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR Compliance Checks

Thank you for your Pull Request! We have run several checks on this pull request in order to make sure it's suitable for merging into this project. The results are listed in the following section.

Conventional Commit PR Title

In order to be considered for merging, the pull request title must match the specification in conventional commits. You can edit the title in order for this check to pass.
Most often, our PR titles are something like one of these:

  • docs: correct typo in README
  • feat: implement dark mode"
  • fix: correct remove button behavior

Linting Errors

  • Found type "null", must be one of "feat","fix","docs","style","refactor","perf","test","build","ci","chore","revert"
  • No subject found

@ivy-leaves ivy-leaves added the PyTorch Frontend Developing the PyTorch Frontend, checklist triggered by commenting add_frontend_checklist label Sep 16, 2023
@@ -310,6 +311,12 @@ def svd(A, /, *, full_matrices=True, driver=None, out=None):
)
def svdvals(A, *, driver=None, out=None):
# TODO: add handling for driver
if driver == "gesvd":
return torch.linalg.svdvals(A, out=out, driver=driver)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! I think this could be simplified into ... driver in [...] and they can share a return. Also, please add the test for this argument in the svdvals test in the torch frontend test folder, thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have added the test and simplified the function implmentaiton

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the test should be modified in test_linalg.py in the torch frontend tests folder within ivy_test, not sure I've missed anything but the file is not shown as changed in the diff of the PR, I don't think the tests are added yet, also, please don't update the test_array_api submodule yet, thanks!

@AhmedHossam23
Copy link
Contributor Author

  • I added the test in the last commit , and i have not updated test_array_api module
  • Should i add both the test and the function or it is already added and able to merge ?

@@ -1,5 +1,6 @@
# local
import math
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! The native torch shouldn't be imported in the frontends, if ivy.svdvals doesn't support this argument yet, you should first implement the new argument in the backend, add the related argument in the tests, and then use the ivy function in this frontend, thanks!

@github-actions
Copy link
Contributor

github-actions bot commented Oct 4, 2023

Thank you for this PR, here is the CI results:


This pull request does not result in any additional test failures. Congratulations!

Copy link
Contributor

@juliagsy juliagsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Thanks for the great work! The svdvals function in the other backends (jax, tf, paddle, np) should also have the argument included. If they don't support it natively, you should add some compositional logic for the support, thanks!

@AhmedHossam23
Copy link
Contributor Author

Hey! Thanks for the great work! The svdvals function in the other backends (jax, tf, paddle, np) should also have the argument included. If they don't support it natively, you should add some compositional logic for the support, thanks!

can you merge the chnanges for now as i need to complete for the hiring process , and i will create a new issue for handling the driver in the other backends and work on them just like i did in torch issue

x: Union[ivy.Array, ivy.NativeArray],
/,
*,
driver: Optional[str],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey! sure, but could you please at least add this argument to every backend + a todo comment so that its usage won't break when a user use this on a backend other than torch? thanks!

PS: the driver argument should have a = None

return torch.linalg.svdvals(x, out=out)
@with_unsupported_dtypes({"2.0.1 and below": ("float16", "bfloat16")}, backend_version)
def svdvals(
x: torch.Tensor, /, *, driver: Optional[str], out: Optional[torch.Tensor] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

driver: ... = None required

ivy/functional/backends/jax/linear_algebra.py Show resolved Hide resolved
x: JaxArray, /, *, driver: Optional[str] = None, out: Optional[JaxArray] = None
) -> JaxArray:
# TODO: handling the driver argument
return jnp.linalg.svd(x, driver=driver, compute_uv=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but not here as this is not a valid signature based on official docs

@AhmedHossam23
Copy link
Contributor Author

i think all is good now

Copy link
Contributor

@juliagsy juliagsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing!

@juliagsy juliagsy merged commit bb0b201 into ivy-llc:main Oct 24, 2023
259 of 269 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Ivy Functional API PyTorch Frontend Developing the PyTorch Frontend, checklist triggered by commenting add_frontend_checklist
Projects
None yet
Development

Successfully merging this pull request may close these issues.

svdvals
4 participants