Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to newest versions of jax + jaxlib and add Windows support for JAX Solver #3550

Merged

Conversation

agriyakhetarpal
Copy link
Member

Description

Closes #3443. With this PR, I have updated the versions for jax and jaxlib to match the latest version(s) including support for Windows and for retaining support on Python 3.8, which we still support.

Type of change

Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.

  • New feature (non-breaking change which adds functionality)
  • Optimization (back-end change that speeds up the code)
  • Bug fix (non-breaking change which fixes an issue)

Key checklist:

  • No style issues: $ pre-commit run (or $ nox -s pre-commit) (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)
  • All tests pass: $ python run-tests.py --all (or $ nox -s tests)
  • The documentation builds: $ python run-tests.py --doctest (or $ nox -s doctests)

You can run integration tests, unit tests, and doctests together at once, using $ python run-tests.py --quick (or $ nox -s quick).

Further checks:

  • Code is commented, particularly in hard-to-understand areas
  • Tests added that prove fix is effective or that feature works

1. Add support for Python 3.11 on aarch64 containers
2. Keep Python 3.8 support on older version
3. Add Python 3.9–3.11 support on newer version (same as the one for point 1)
4. Add support for CPU-only Windows installation
5. Pin all versions so as to not break anything.
tested with `--upgrade` and `--upgrade-strategy eager` plus `--no-cache-dir`
Copy link

codecov bot commented Nov 22, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (32fad00) 99.58% compared to head (f41be98) 99.58%.

Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #3550   +/-   ##
========================================
  Coverage    99.58%   99.58%           
========================================
  Files          257      257           
  Lines        20708    20708           
========================================
  Hits         20623    20623           
  Misses          85       85           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@agriyakhetarpal
Copy link
Member Author

agriyakhetarpal commented Nov 22, 2023

It looks like I misread the jaxlib release notes since they do not appear to have Python 3.8 support on Windows; it is for Python 3.9+. The JAX developers often delete the releases from PyPI due to storage constraints and provide an alternate index here: https://storage.googleapis.com/jax-releases/jax_releases.html, browsing through which confirmed my suspicion.

The way around this is to either drop Python 3.8 support (big breaking change) or provide Windows Jax support for just Python 3.9–3.11 (more reasonable change) similar to how [odes] will be available for Python <3.12 in #3531. @Saransh-cpp and @brosaplanella: your thoughts will be appreciated on whether this is a good idea, following which I can make the necessary changes.

As I had mentioned, I did start working on rewriting the installation guide side-by-side; however, it shall shift almost all of the pages and their links and would be better as a separate PR. It may be merged in a follow-up queue shortly after the completion of this PR.


P.S. Stepping a bit aside from relevant discussion pertaining to the changes here, but @Saransh-cpp previously mentioned https://scientific-python.org/specs/spec-0000/, which suggested dropping Python 3.9 support by the end of the year since its support window has now ended. While we may not want to do that yet, dropping Python 3.8 could be a reasonable idea since 1. all our core dependencies have done so now with casadi being the last out of the pack, and 2. it is generally good practice for packages in the scientific Python ecosystem to be on track with each other and support an array of not more than four major Python versions at a time. It would be nice to discuss this further in the December monthly meeting.

setup.py Outdated Show resolved Hide resolved
setup.py Outdated Show resolved Hide resolved
agriyakhetarpal and others added 2 commits November 23, 2023 02:11
Co-authored-by: Eric G. Kratz <kratman@users.noreply.github.com>
Co-authored-by: Eric G. Kratz <kratman@users.noreply.github.com>
@kratman
Copy link
Contributor

kratman commented Nov 22, 2023

My 2 cents on dropping 3.8/3.9 is that any change like this will break python environments for some people. Probably better to do both at once than drop one, wait then drop the other. Probably less painful to get it done all at once for users with old versions.

@agriyakhetarpal
Copy link
Member Author

agriyakhetarpal commented Nov 22, 2023

My 2 cents on dropping 3.8/3.9 is that any change like this will break python environments for some people. Probably better to do both at once than drop one, wait then drop the other. Probably less painful to get it done all at once for users with old versions.

Ah I see, are you suggesting we should drop both 3.8 and 3.9? I don't think that would be a good idea since not all of our dependencies will be ready for Python 3.12. That would leave developers and users with less leeway to work between different Python versions, since supporting just 3.10 and 3.11 would be too little and cramped a space for a Python package like PyBaMM that is also useful as a library.

Python 3.8 support had closed a long time back and its EOL is less than a year away (October 2024) which is usually the recommended time for dropping support, but the Scientific Python ecosystem moves quicker and has its own, steadfast ways of proceeding with such things. While Python 3.9 also has a closed support window now, we have two years to reach EOL for it so the day to remove it from PyBaMM is quite far (not until Q3 2025 as the farthest estimate) if we decide to not follow the aforementioned SPEC 0000 in the releases starting next year. But if we do, it would come much quicker.

@Saransh-cpp
Copy link
Member

Saransh-cpp commented Nov 25, 2023

The way around this is to either drop Python 3.8 support (big breaking change) or provide Windows Jax support for just Python 3.9–3.11 (more reasonable change) similar to how [odes] will be available for Python <3.12 in #3531. @Saransh-cpp and @brosaplanella: your thoughts will be appreciated on whether this is a good idea, following which I can make the necessary changes.

Supporting Jax on Python 3.9+ for Windows sounds like a better option.

Stepping a bit aside from relevant discussion pertaining to the changes here, but @Saransh-cpp previously mentioned https://scientific-python.org/specs/spec-0000/, which suggested dropping Python 3.9 support by the end of the year since its support window has now ended. While we may not want to do that yet, dropping Python 3.8 could be a reasonable idea since 1. all our core dependencies have done so now with casadi being the last out of the pack, and 2. it is generally good practice for packages in the scientific Python ecosystem to be on track with each other and support an array of not more than four major Python versions at a time. It would be nice to discuss this further in the December monthly meeting.

I'm not really a big fan of the drop schedule of the Python version listed in SPEC 0000. Most of the libraries I am involved with still support Python 3.8 and will continue until it reaches EOL. I'll ask around and see what the maintainers of other ecosystems think about this, but until then, let's keep supporting Python 3.8.

@agriyakhetarpal
Copy link
Member Author

agriyakhetarpal commented Nov 25, 2023

I have added support for jax==0.4.20+ jaxlib==0.4.20 for Python 3.9+ and the PR should be ready for review, but I would suggest holding off on merging this, because I am working on a branch locally checked out from this branch, where I am updating the installation guide with a lot of changes. When I open a PR for that, I can keep both of the branches synced up together so that they can be merged one after the other without conflicts.

@Saransh-cpp Saransh-cpp self-requested a review November 25, 2023 19:49
@agriyakhetarpal
Copy link
Member Author

Discussion from the developer meeting today: there are some changes in the JAX API between 0.4.14 and 0.4.16 plus progress on #2282 is being tested on jaxlib>=0.4.16 onwards, therefore those changes can be deemed breaking with the version being bumped here to 0.4.20

This means we will not be able to support [jax] on Python 3.8 as mentioned above (we can for Python 3.9–3.11, I will make a note of this in the upcoming revisions for the installation guide). The previously mentioned Apple Metal / Windows unofficial GPU wheels won't get support either. GPU support on WSL and with TPUs is unaffected, of course. For GPU support on experimental platforms, all we can hope for is that the Windows community wheels get bumped (latest version: 0.4.11) and the Metal support continues to mature in the coming years (0.4.11 is the latest version there too).

@agriyakhetarpal agriyakhetarpal mentioned this pull request Dec 1, 2023
8 tasks
Copy link
Member

@Saransh-cpp Saransh-cpp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @agriyakhetarpal! Looks good!

Edit: Is this blocked by the discussion above?

@agriyakhetarpal
Copy link
Member Author

It's not blocked by anyone but me. I have to update the installation guide and restructure it to take into account both this PR and #3531. We can merge this after that – though I will still request some more reviews here as I had mentioned in the meeting

Copy link
Contributor

@jsbrittain jsbrittain left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@agriyakhetarpal These changes look good to me, thanks.

Copy link
Member

@BradyPlanden BradyPlanden left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice one, LGTM!

@agriyakhetarpal
Copy link
Member Author

agriyakhetarpal commented Dec 7, 2023

I have updated the installation guide here for the time being until I work on that in another branch and added an entry to the breaking changes in the CHANGELOG, this should be ready to merge now! I figured that it would be easier to unblock other PRs such as #3531 once this one gets merged.

@Saransh-cpp Saransh-cpp merged commit 22d1229 into pybamm-team:develop Dec 8, 2023
34 of 35 checks passed
@agriyakhetarpal agriyakhetarpal deleted the bump-jax-jaxlib-versions branch December 8, 2023 14:21
@agriyakhetarpal agriyakhetarpal mentioned this pull request Mar 27, 2024
5 tasks
@kratman kratman mentioned this pull request Apr 1, 2024
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bump jax and jaxlib versions
5 participants