Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to newest versions of jax + jaxlib and add Windows support for JAX Solver #3550

Merged
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
9b73f8e
#3443 Add various versions for `jax` and `jaxlib`
agriyakhetarpal Nov 17, 2023
8a049d9
#3312 #3443 Build both arm64 + amd64 images for all solvers
agriyakhetarpal Nov 17, 2023
d5d22d2
Merge branch 'develop' into bump-jax-jaxlib-versions
agriyakhetarpal Nov 22, 2023
e2d849d
#3443 Install `jax` extras in the same command
agriyakhetarpal Nov 22, 2023
e17163f
#3443 Bump to latest version of `jax` and `jaxlib`
agriyakhetarpal Nov 22, 2023
05d1061
#3443 Add Windows support via nox
agriyakhetarpal Nov 22, 2023
30db13b
#3443 Install `[jax]` for the integration tests
agriyakhetarpal Nov 22, 2023
5fd45c6
#3443 Fix expression tree Jax evaluator test
agriyakhetarpal Nov 22, 2023
9103a10
Remove explainer comments about version constraints
agriyakhetarpal Nov 22, 2023
d99acee
Remove explainer comment about pinning
agriyakhetarpal Nov 22, 2023
96e059b
Merge branch 'develop' into bump-jax-jaxlib-versions
agriyakhetarpal Nov 25, 2023
c066c81
Bump `jax` and `jaxlib` versions again
agriyakhetarpal Nov 25, 2023
f229ab8
Add a condition to not install `jax` if < py3.9
agriyakhetarpal Nov 25, 2023
a47e78d
Add a CHANGELOG entry for `jax` and `jax` versions
agriyakhetarpal Nov 25, 2023
8301d26
Remove incorrect `jax` and `jaxlib` version pins
agriyakhetarpal Nov 25, 2023
3f422bd
Update changelog about breaking change for Jax solver
agriyakhetarpal Dec 7, 2023
ae9a637
#3443 Add minimal docs about Windows and Python support
agriyakhetarpal Dec 7, 2023
f41be98
Merge branch 'develop' into bump-jax-jaxlib-versions
agriyakhetarpal Dec 7, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 1 addition & 26 deletions .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ jobs:
echo "tag=all" >> "$GITHUB_OUTPUT"
fi

- name: Build and push Docker image to Docker Hub (no solvers)
if: matrix.build-args == 'No solvers'
- name: Build and push Docker image to Docker Hub (${{ matrix.build-args }})
uses: docker/build-push-action@v5
with:
context: .
Expand All @@ -58,29 +57,5 @@ jobs:
push: true
platforms: linux/amd64, linux/arm64

- name: Build and push Docker image to Docker Hub (with ODES and IDAKLU solvers)
if: matrix.build-args == 'ODES' || matrix.build-args == 'IDAKLU'
uses: docker/build-push-action@v5
with:
context: .
file: scripts/Dockerfile
tags: pybamm/pybamm:${{ steps.tags.outputs.tag }}
push: true
build-args: ${{ matrix.build-args }}=true
platforms: linux/amd64, linux/arm64

- name: Build and push Docker image to Docker Hub (with ALL and JAX solvers)
if: matrix.build-args == 'ALL' || matrix.build-args == 'JAX'
uses: docker/build-push-action@v5
with:
context: .
file: scripts/Dockerfile
tags: pybamm/pybamm:${{ steps.tags.outputs.tag }}
push: true
build-args: ${{ matrix.build-args }}=true
# exclude arm64 for JAX and ALL builds for now, see
# https://github.com/google/jax/issues/13608
platforms: linux/amd64

- name: List built image(s)
run: docker images
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
- Fixed bug in calculation of theoretical energy that made it very slow ([#3506](https://github.com/pybamm-team/PyBaMM/pull/3506))
- The irreversible plating model now increments `f"{Domain} dead lithium concentration [mol.m-3]"`, not `f"{Domain} lithium plating concentration [mol.m-3]"` as it did previously. ([#3485](https://github.com/pybamm-team/PyBaMM/pull/3485))

## Optimizations

- Updated `jax` and `jaxlib` to the latest available versions and added Windows (Python 3.9+) support for the Jax solver ([#3550](https://github.com/pybamm-team/PyBaMM/pull/3550))

# [v23.9](https://github.com/pybamm-team/PyBaMM/tree/v23.9) - 2023-10-31

## Features
Expand Down
47 changes: 38 additions & 9 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,12 @@ def run_coverage(session):
set_environment_variables(PYBAMM_ENV, session=session)
session.install("coverage", silent=False)
if sys.platform != "win32":
session.install("-e", ".[all,odes,jax]", silent=False)
session.install("-e", ".[all,jax,odes]", silent=False)
else:
session.install("-e", ".[all]", silent=False)
if sys.version_info < (3, 9):
session.install("-e", ".[all]", silent=False)
else:
session.install("-e", ".[all,jax]", silent=False)
session.run("coverage", "run", "--rcfile=.coveragerc", "run-tests.py", "--nosub")
session.run("coverage", "combine")
session.run("coverage", "xml")
Expand All @@ -74,9 +77,12 @@ def run_integration(session):
"""Run the integration tests."""
set_environment_variables(PYBAMM_ENV, session=session)
if sys.platform != "win32":
session.install("-e", ".[all,odes,jax]", silent=False)
session.install("-e", ".[all,jax,odes]", silent=False)
else:
session.install("-e", ".[all]", silent=False)
if sys.version_info < (3, 9):
session.install("-e", ".[all]", silent=False)
else:
session.install("-e", ".[all,jax]", silent=False)
session.run("python", "run-tests.py", "--integration")


Expand All @@ -92,9 +98,12 @@ def run_unit(session):
"""Run the unit tests."""
set_environment_variables(PYBAMM_ENV, session=session)
if sys.platform != "win32":
session.install("-e", ".[all,odes,jax]", silent=False)
session.install("-e", ".[all,jax,odes]", silent=False)
else:
session.install("-e", ".[all]", silent=False)
if sys.version_info < (3, 9):
session.install("-e", ".[all]", silent=False)
else:
session.install("-e", ".[all,jax]", silent=False)
session.run("python", "run-tests.py", "--unit")


Expand Down Expand Up @@ -143,17 +152,37 @@ def set_dev(session):
external=True,
)
else:
session.run(python, "-m", "pip", "install", "-e", ".[all,dev]", external=True)
if sys.version_info < (3, 9):
session.run(
python,
"-m",
"pip",
"install",
".[all,dev]",
external=True,
)
else:
session.run(
python,
"-m",
"pip",
"install",
".[all,dev,jax]",
external=True,
)


@nox.session(name="tests")
def run_tests(session):
"""Run the unit tests and integration tests sequentially."""
set_environment_variables(PYBAMM_ENV, session=session)
if sys.platform != "win32":
session.install("-e", ".[all,odes,jax]", silent=False)
session.install("-e", ".[all,jax,odes]", silent=False)
else:
session.install("-e", ".[all]", silent=False)
if sys.version_info < (3, 9):
session.install("-e", ".[all]", silent=False)
else:
session.install("-e", ".[all,jax]", silent=False)
session.run("python", "run-tests.py", "--all")


Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = [
"setuptools>=64",
"wheel",
# On Windows, use the CasADi vcpkg registry and CMake bundled from MSVC
"casadi>=3.6.0; platform_system!='Windows'",
"casadi>=3.6.3; platform_system!='Windows'",
"cmake; platform_system!='Windows'",
]
build-backend = "setuptools.build_meta"
Expand Down Expand Up @@ -108,13 +108,13 @@ dev = [
"nbmake",
]
# Reading CSV files
pandas = [
pandas = [
"pandas>=1.5.0",
]
# For the Jax solver. Note: these must be kept in sync with the versions defined in pybamm/util.py.
jax = [
"jax>=0.4,<=0.5",
"jaxlib>=0.4,<=0.5",
"jax==0.4.20; python_version >= '3.9'",
"jaxlib==0.4.20; python_version >= '3.9'",
]
# For the scikits.odes solver
odes = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def test_evaluator_jax(self):
expr = pybamm.exp(a * b)
evaluator = pybamm.EvaluatorJax(expr)
result = evaluator(t=None, y=np.array([[2], [3]]))
self.assertEqual(result, np.exp(6))
np.testing.assert_array_almost_equal(result, np.exp(6), decimal=15)

# test a constant expression
expr = pybamm.Scalar(2) * pybamm.Scalar(3)
Expand Down
Loading