From 9b73f8eacc76b170842eeb2e15021b4a0c886ebf Mon Sep 17 00:00:00 2001 From: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> Date: Sat, 18 Nov 2023 03:33:06 +0530 Subject: [PATCH 01/15] #3443 Add various versions for `jax` and `jaxlib` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Add support for Python 3.11 on aarch64 containers 2. Keep Python 3.8 support on older version 3. Add Python 3.9–3.11 support on newer version (same as the one for point 1) 4. Add support for CPU-only Windows installation 5. Pin all versions so as to not break anything. --- setup.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 886378f44f..c6ddb9d83f 100644 --- a/setup.py +++ b/setup.py @@ -265,15 +265,26 @@ def compile_KLU(): "pandas": [ "pandas>=1.5.0", ], + # Note: jax and jaxlib must be pinned to a specific version + # to avoid upstream breaking changes. "jax": [ - "jax==0.4.8", - "jaxlib==0.4.7", + # 0.4.18 provides support for Jax on aarch64 containers + # via the PyBaMM images on Docker Hub which come with + # Python 3.11 installed. + # It also provides support for CPU-only Jax on Windows. + "jax==0.4.18; python_version >= '3.9'", + "jaxlib==0.4.18; python_version >= '3.9'", + # Jax 0.4.13 was the last version to support Python 3.8. + # Support for CPU-only Windows was added in 0.4.13, so + # this version supports Windows too. + "jax==0.4.13; python_version < '3.9'", + "jaxlib==0.4.13; python_version < '3.9'", ], "odes": ["scikits.odes"], "all": [ "autograd>=1.6.2", "scikit-fem>=8.1.0", - "pybamm[examples,plot,cite,latexify,bpx,tqdm,pandas]" + "pybamm[examples,plot,cite,latexify,bpx,tqdm,pandas]", ], }, entry_points={ From 8a049d9864e7ef1eaaa0fe93b4a0217ff087e10a Mon Sep 17 00:00:00 2001 From: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> Date: Sat, 18 Nov 2023 03:34:33 +0530 Subject: [PATCH 02/15] #3312 #3443 Build both arm64 + amd64 images for all solvers --- .github/workflows/docker.yml | 27 +-------------------------- 1 file changed, 1 insertion(+), 26 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index b6994795d6..f92ee76c9e 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -48,8 +48,7 @@ jobs: echo "tag=all" >> "$GITHUB_OUTPUT" fi - - name: Build and push Docker image to Docker Hub (no solvers) - if: matrix.build-args == 'No solvers' + - name: Build and push Docker image to Docker Hub (${{ matrix.build-args }}) uses: docker/build-push-action@v5 with: context: . @@ -58,29 +57,5 @@ jobs: push: true platforms: linux/amd64, linux/arm64 - - name: Build and push Docker image to Docker Hub (with ODES and IDAKLU solvers) - if: matrix.build-args == 'ODES' || matrix.build-args == 'IDAKLU' - uses: docker/build-push-action@v5 - with: - context: . - file: scripts/Dockerfile - tags: pybamm/pybamm:${{ steps.tags.outputs.tag }} - push: true - build-args: ${{ matrix.build-args }}=true - platforms: linux/amd64, linux/arm64 - - - name: Build and push Docker image to Docker Hub (with ALL and JAX solvers) - if: matrix.build-args == 'ALL' || matrix.build-args == 'JAX' - uses: docker/build-push-action@v5 - with: - context: . - file: scripts/Dockerfile - tags: pybamm/pybamm:${{ steps.tags.outputs.tag }} - push: true - build-args: ${{ matrix.build-args }}=true - # exclude arm64 for JAX and ALL builds for now, see - # https://github.com/google/jax/issues/13608 - platforms: linux/amd64 - - name: List built image(s) run: docker images From e2d849d80e6dc903428c322ab4d3037b9805bb5a Mon Sep 17 00:00:00 2001 From: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> Date: Thu, 23 Nov 2023 00:16:51 +0530 Subject: [PATCH 03/15] #3443 Install `jax` extras in the same command --- noxfile.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/noxfile.py b/noxfile.py index 430ad59659..d0ca7fc6f6 100644 --- a/noxfile.py +++ b/noxfile.py @@ -62,8 +62,7 @@ def run_coverage(session): session.install("coverage", silent=False) session.install("-e", ".[all]", silent=False) if sys.platform != "win32": - session.install("-e", ".[odes]", silent=False) - session.install("-e", ".[jax]", silent=False) + session.install("-e", ".[odes,jax]", silent=False) session.run("coverage", "run", "--rcfile=.coveragerc", "run-tests.py", "--nosub") session.run("coverage", "combine") session.run("coverage", "xml") @@ -92,8 +91,7 @@ def run_unit(session): set_environment_variables(PYBAMM_ENV, session=session) session.install("-e", ".[all]", silent=False) if sys.platform == "linux": - session.install("-e", ".[odes]", silent=False) - session.install("-e", ".[jax]", silent=False) + session.install("-e", ".[odes,jax]", silent=False) session.run("python", "run-tests.py", "--unit") @@ -139,8 +137,7 @@ def run_tests(session): set_environment_variables(PYBAMM_ENV, session=session) session.install("-e", ".[all]", silent=False) if sys.platform == "linux" or sys.platform == "darwin": - session.install("-e", ".[odes]", silent=False) - session.install("-e", ".[jax]", silent=False) + session.install("-e", ".[odes, jax]", silent=False) session.run("python", "run-tests.py", "--all") From e17163ffbce029bc15e2a8c3548b4c2ec287c3f4 Mon Sep 17 00:00:00 2001 From: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> Date: Thu, 23 Nov 2023 00:17:44 +0530 Subject: [PATCH 04/15] #3443 Bump to latest version of `jax` and `jaxlib` tested with `--upgrade` and `--upgrade-strategy eager` plus `--no-cache-dir` --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index c6ddb9d83f..c6e7414829 100644 --- a/setup.py +++ b/setup.py @@ -272,8 +272,8 @@ def compile_KLU(): # via the PyBaMM images on Docker Hub which come with # Python 3.11 installed. # It also provides support for CPU-only Jax on Windows. - "jax==0.4.18; python_version >= '3.9'", - "jaxlib==0.4.18; python_version >= '3.9'", + "jax==0.4.20; python_version >= '3.9'", + "jaxlib==0.4.20; python_version >= '3.9'", # Jax 0.4.13 was the last version to support Python 3.8. # Support for CPU-only Windows was added in 0.4.13, so # this version supports Windows too. From 05d106158e987f038f7d0c21ac783e4166239004 Mon Sep 17 00:00:00 2001 From: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> Date: Thu, 23 Nov 2023 00:20:10 +0530 Subject: [PATCH 05/15] #3443 Add Windows support via nox --- noxfile.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/noxfile.py b/noxfile.py index d0ca7fc6f6..4444778162 100644 --- a/noxfile.py +++ b/noxfile.py @@ -60,9 +60,9 @@ def run_coverage(session): """Run the coverage tests and generate an XML report.""" set_environment_variables(PYBAMM_ENV, session=session) session.install("coverage", silent=False) - session.install("-e", ".[all]", silent=False) + session.install("-e", ".[all,jax]", silent=False) if sys.platform != "win32": - session.install("-e", ".[odes,jax]", silent=False) + session.install("-e", ".[odes]", silent=False) session.run("coverage", "run", "--rcfile=.coveragerc", "run-tests.py", "--nosub") session.run("coverage", "combine") session.run("coverage", "xml") @@ -89,9 +89,9 @@ def run_doctests(session): def run_unit(session): """Run the unit tests.""" set_environment_variables(PYBAMM_ENV, session=session) - session.install("-e", ".[all]", silent=False) + session.install("-e", ".[all,jax]", silent=False) if sys.platform == "linux": - session.install("-e", ".[odes,jax]", silent=False) + session.install("-e", ".[odes]", silent=False) session.run("python", "run-tests.py", "--unit") @@ -128,16 +128,16 @@ def set_dev(session): external=True, ) else: - session.run(python, "-m", "pip", "install", "-e", ".[all,dev]", external=True) + session.run(python, "-m", "pip", "install", "-e", ".[all,dev,jax]", external=True) @nox.session(name="tests") def run_tests(session): """Run the unit tests and integration tests sequentially.""" set_environment_variables(PYBAMM_ENV, session=session) - session.install("-e", ".[all]", silent=False) + session.install("-e", ".[all,jax]", silent=False) if sys.platform == "linux" or sys.platform == "darwin": - session.install("-e", ".[odes, jax]", silent=False) + session.install("-e", ".[odes]", silent=False) session.run("python", "run-tests.py", "--all") From 30db13bc547dcd93e752950344d1961162033005 Mon Sep 17 00:00:00 2001 From: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> Date: Thu, 23 Nov 2023 00:41:32 +0530 Subject: [PATCH 06/15] #3443 Install `[jax]` for the integration tests --- noxfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/noxfile.py b/noxfile.py index 4444778162..156378e2cd 100644 --- a/noxfile.py +++ b/noxfile.py @@ -72,7 +72,7 @@ def run_coverage(session): def run_integration(session): """Run the integration tests.""" set_environment_variables(PYBAMM_ENV, session=session) - session.install("-e", ".[all]", silent=False) + session.install("-e", ".[all,jax]", silent=False) if sys.platform == "linux": session.install("-e", ".[odes]", silent=False) session.run("python", "run-tests.py", "--integration") From 5fd45c670704d8b04029f83b284fc0fa27255e51 Mon Sep 17 00:00:00 2001 From: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> Date: Thu, 23 Nov 2023 00:52:08 +0530 Subject: [PATCH 07/15] #3443 Fix expression tree Jax evaluator test --- .../test_operations/test_evaluate_python.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py index 50c9dbb744..0484af9451 100644 --- a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py +++ b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py @@ -503,7 +503,7 @@ def test_evaluator_jax(self): expr = pybamm.exp(a * b) evaluator = pybamm.EvaluatorJax(expr) result = evaluator(t=None, y=np.array([[2], [3]])) - self.assertEqual(result, np.exp(6)) + np.testing.assert_array_almost_equal(result, np.exp(6), decimal=15) # test a constant expression expr = pybamm.Scalar(2) * pybamm.Scalar(3) From 9103a106668ed3e7c85ee3783ee8c454823d2041 Mon Sep 17 00:00:00 2001 From: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> Date: Thu, 23 Nov 2023 02:11:39 +0530 Subject: [PATCH 08/15] Remove explainer comments about version constraints Co-authored-by: Eric G. Kratz --- setup.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/setup.py b/setup.py index c6e7414829..b02293a212 100644 --- a/setup.py +++ b/setup.py @@ -268,15 +268,9 @@ def compile_KLU(): # Note: jax and jaxlib must be pinned to a specific version # to avoid upstream breaking changes. "jax": [ - # 0.4.18 provides support for Jax on aarch64 containers - # via the PyBaMM images on Docker Hub which come with - # Python 3.11 installed. - # It also provides support for CPU-only Jax on Windows. "jax==0.4.20; python_version >= '3.9'", "jaxlib==0.4.20; python_version >= '3.9'", - # Jax 0.4.13 was the last version to support Python 3.8. - # Support for CPU-only Windows was added in 0.4.13, so - # this version supports Windows too. + # The versions below can be removed once PyBaMM no longer supports python 3.8 "jax==0.4.13; python_version < '3.9'", "jaxlib==0.4.13; python_version < '3.9'", ], From d99acee69d42c09f738225bf6d8aab77930c02aa Mon Sep 17 00:00:00 2001 From: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> Date: Thu, 23 Nov 2023 02:12:54 +0530 Subject: [PATCH 09/15] Remove explainer comment about pinning Co-authored-by: Eric G. Kratz --- setup.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/setup.py b/setup.py index b02293a212..0fca17f820 100644 --- a/setup.py +++ b/setup.py @@ -265,8 +265,6 @@ def compile_KLU(): "pandas": [ "pandas>=1.5.0", ], - # Note: jax and jaxlib must be pinned to a specific version - # to avoid upstream breaking changes. "jax": [ "jax==0.4.20; python_version >= '3.9'", "jaxlib==0.4.20; python_version >= '3.9'", From c066c81abafa932c2aa7e2548d1f7ce8082985dd Mon Sep 17 00:00:00 2001 From: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> Date: Sat, 25 Nov 2023 19:59:41 +0530 Subject: [PATCH 10/15] Bump `jax` and `jaxlib` versions again and also bump `casadi` build-time dependency minimum version --- pyproject.toml | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4569c7c6c3..15f8582537 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = [ "setuptools>=64", "wheel", # On Windows, use the CasADi vcpkg registry and CMake bundled from MSVC - "casadi>=3.6.0; platform_system!='Windows'", + "casadi>=3.6.3; platform_system!='Windows'", "cmake; platform_system!='Windows'", ] build-backend = "setuptools.build_meta" @@ -108,13 +108,16 @@ dev = [ "nbmake", ] # Reading CSV files -pandas = [ +pandas = [ "pandas>=1.5.0", ] # For the Jax solver. Note: these must be kept in sync with the versions defined in pybamm/util.py. jax = [ - "jax>=0.4,<=0.5", - "jaxlib>=0.4,<=0.5", + "jax==0.4.20; python_version >= '3.9'", + "jaxlib==0.4.20; python_version >= '3.9'", + # The versions below can be removed once PyBaMM no longer supports Python 3.8 + "jax==0.4.13; python_version < '3.9'", + "jaxlib==0.4.13; python_version < '3.9'", ] # For the scikits.odes solver odes = [ From f229ab8fa9955747736f577bec58b405f6c25b53 Mon Sep 17 00:00:00 2001 From: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> Date: Sat, 25 Nov 2023 20:15:49 +0530 Subject: [PATCH 11/15] Add a condition to not install `jax` if < py3.9 --- noxfile.py | 39 ++++++++++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/noxfile.py b/noxfile.py index 55ba5811cb..daa95e9833 100644 --- a/noxfile.py +++ b/noxfile.py @@ -63,7 +63,10 @@ def run_coverage(session): if sys.platform != "win32": session.install("-e", ".[all,jax,odes]", silent=False) else: - session.install("-e", ".[all,jax]", silent=False) + if sys.version_info < (3, 9): + session.install("-e", ".[all]", silent=False) + else: + session.install("-e", ".[all,jax]", silent=False) session.run("coverage", "run", "--rcfile=.coveragerc", "run-tests.py", "--nosub") session.run("coverage", "combine") session.run("coverage", "xml") @@ -76,7 +79,10 @@ def run_integration(session): if sys.platform != "win32": session.install("-e", ".[all,jax,odes]", silent=False) else: - session.install("-e", ".[all,jax]", silent=False) + if sys.version_info < (3, 9): + session.install("-e", ".[all]", silent=False) + else: + session.install("-e", ".[all,jax]", silent=False) session.run("python", "run-tests.py", "--integration") @@ -94,7 +100,10 @@ def run_unit(session): if sys.platform != "win32": session.install("-e", ".[all,jax,odes]", silent=False) else: - session.install("-e", ".[all,jax]", silent=False) + if sys.version_info < (3, 9): + session.install("-e", ".[all]", silent=False) + else: + session.install("-e", ".[all,jax]", silent=False) session.run("python", "run-tests.py", "--unit") @@ -143,7 +152,24 @@ def set_dev(session): external=True, ) else: - session.run(python, "-m", "pip", "install", "-e", ".[all,dev,jax]", external=True) + if sys.version_info < (3, 9): + session.run( + python, + "-m", + "pip", + "install", + ".[all,dev]", + external=True, + ) + else: + session.run( + python, + "-m", + "pip", + "install", + ".[all,dev,jax]", + external=True, + ) @nox.session(name="tests") @@ -153,7 +179,10 @@ def run_tests(session): if sys.platform != "win32": session.install("-e", ".[all,jax,odes]", silent=False) else: - session.install("-e", ".[all,jax]", silent=False) + if sys.version_info < (3, 9): + session.install("-e", ".[all]", silent=False) + else: + session.install("-e", ".[all,jax]", silent=False) session.run("python", "run-tests.py", "--all") From a47e78d1f17893049d6366da4003f118ee73b680 Mon Sep 17 00:00:00 2001 From: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> Date: Sat, 25 Nov 2023 20:19:35 +0530 Subject: [PATCH 12/15] Add a CHANGELOG entry for `jax` and `jax` versions --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2eea9ed7c0..6b95d66bd3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ - Fixed bug in calculation of theoretical energy that made it very slow ([#3506](https://github.com/pybamm-team/PyBaMM/pull/3506)) - The irreversible plating model now increments `f"{Domain} dead lithium concentration [mol.m-3]"`, not `f"{Domain} lithium plating concentration [mol.m-3]"` as it did previously. ([#3485](https://github.com/pybamm-team/PyBaMM/pull/3485)) +## Optimizations + +- Updated `jax` and `jaxlib` to the latest available versions and added Windows (Python 3.9+) support for the Jax solver ([#3550](https://github.com/pybamm-team/PyBaMM/pull/3550)) + # [v23.9](https://github.com/pybamm-team/PyBaMM/tree/v23.9) - 2023-10-31 ## Features From 8301d26732c6fb73374a9af8c007021ea0a88b78 Mon Sep 17 00:00:00 2001 From: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> Date: Sat, 25 Nov 2023 20:20:32 +0530 Subject: [PATCH 13/15] Remove incorrect `jax` and `jaxlib` version pins --- pyproject.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 15f8582537..966a40bac6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,9 +115,6 @@ pandas = [ jax = [ "jax==0.4.20; python_version >= '3.9'", "jaxlib==0.4.20; python_version >= '3.9'", - # The versions below can be removed once PyBaMM no longer supports Python 3.8 - "jax==0.4.13; python_version < '3.9'", - "jaxlib==0.4.13; python_version < '3.9'", ] # For the scikits.odes solver odes = [ From 3f422bdd5b1091008ae4cf1ac0f0e864f44b3607 Mon Sep 17 00:00:00 2001 From: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> Date: Thu, 7 Dec 2023 19:08:25 +0530 Subject: [PATCH 14/15] Update changelog about breaking change for Jax solver --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b95d66bd3..c91272494b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ - Updated `jax` and `jaxlib` to the latest available versions and added Windows (Python 3.9+) support for the Jax solver ([#3550](https://github.com/pybamm-team/PyBaMM/pull/3550)) +## Breaking changes + +- Dropped support for the `[jax]` extra, i.e., the Jax solver when running on Python 3.8. The Jax solver is now available on Python 3.9 and above ([#3550](https://github.com/pybamm-team/PyBaMM/pull/3550)) + # [v23.9](https://github.com/pybamm-team/PyBaMM/tree/v23.9) - 2023-10-31 ## Features From ae9a637522e277af3f5db3c8d00aa910c9acc38d Mon Sep 17 00:00:00 2001 From: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> Date: Thu, 7 Dec 2023 19:44:52 +0530 Subject: [PATCH 15/15] #3443 Add minimal docs about Windows and Python support --- docs/source/user_guide/installation/GNU-linux.rst | 5 ++++- docs/source/user_guide/installation/index.rst | 6 +++--- docs/source/user_guide/installation/windows.rst | 15 +++++++++++++++ 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/docs/source/user_guide/installation/GNU-linux.rst b/docs/source/user_guide/installation/GNU-linux.rst index ca95bbe1b5..cf027db587 100644 --- a/docs/source/user_guide/installation/GNU-linux.rst +++ b/docs/source/user_guide/installation/GNU-linux.rst @@ -133,7 +133,10 @@ Optional - JaxSolver ~~~~~~~~~~~~~~~~~~~~ Users can install ``jax`` and ``jaxlib`` to use the Jax solver. -Currently, only GNU/Linux and macOS are supported. + +.. note:: + + The Jax solver is not supported on Python 3.8. It is supported on Python 3.9, 3.10, and 3.11. .. code:: bash diff --git a/docs/source/user_guide/installation/index.rst b/docs/source/user_guide/installation/index.rst index 983f66842e..272061a7a6 100644 --- a/docs/source/user_guide/installation/index.rst +++ b/docs/source/user_guide/installation/index.rst @@ -216,13 +216,13 @@ Dependency Minimum Version p Jax dependencies ^^^^^^^^^^^^^^^^^ -Installable with ``pip install "pybamm[jax]"`` +Installable with ``pip install "pybamm[jax]"``, currently supported on Python 3.9-3.11. ========================================================================= ================== ================== ======================= Dependency Minimum Version pip extra Notes ========================================================================= ================== ================== ======================= -`JAX `__ 0.4.8 jax For JAX solvers -`jaxlib `__ 0.4.7 jax Support library for JAX +`JAX `__ 0.4.20 jax For the JAX solver +`jaxlib `__ 0.4.20 jax Support library for JAX ========================================================================= ================== ================== ======================= .. _install.odes_dependencies: diff --git a/docs/source/user_guide/installation/windows.rst b/docs/source/user_guide/installation/windows.rst index 5b104e91bd..5ad77b6f7f 100644 --- a/docs/source/user_guide/installation/windows.rst +++ b/docs/source/user_guide/installation/windows.rst @@ -66,6 +66,21 @@ installed automatically when you install PyBaMM using ``pip``. For an introduction to virtual environments, see (https://realpython.com/python-virtual-environments-a-primer/). +Optional - JaxSolver +~~~~~~~~~~~~~~~~~~~~ + +Users can install ``jax`` and ``jaxlib`` to use the Jax solver. + +.. note:: + + The Jax solver is not supported on Python 3.8. It is supported on Python 3.9, 3.10, and 3.11. + +.. code:: bash + + pip install "pybamm[jax]" + +The ``pip install "pybamm[jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``jaxlib`` on your system. (``pybamm_install_jax`` is deprecated.) + Uninstall PyBaMM ----------------