Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update jax config import #199

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

alberthli
Copy link

This PR changes the import statement for jax.config, which resolves the error

File "/usr/local/lib/python3.10/dist-packages/torchquad/utils/set_precision.py", line 58, in set_precision
    from jax.config import config
ImportError: cannot import name 'config' from 'jax.config' (/usr/local/lib/python3.10/dist-packages/jax/config.py)

@gomezzz gomezzz changed the base branch from main to develop April 25, 2024 10:17
@gomezzz gomezzz changed the base branch from develop to main April 25, 2024 10:17
@gomezzz
Copy link
Collaborator

gomezzz commented Apr 25, 2024

Hi @alberthli , thanks for the PR!

To understand a bit better, it is my understand jax deprecated this import recently.

Did you check by any chance if this works for older jax version as well? Currently, we require jax>=0.2.22 , so we may need to update the version

(P.S. don't worry update the failing tests, the failures are unrelated to this PR, I think, and a regression in the CI)

@alberthli
Copy link
Author

alberthli commented Apr 25, 2024

Hi @gomezzz, I haven't checked whether it works for earlier versions. This commit from 6 months ago seems relevant if you want to control versioning, though.

@gomezzz gomezzz mentioned this pull request Jun 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants