Skip to content

Commit

Permalink
Jax export (#1861)
Browse files Browse the repository at this point in the history
* basic prototype

* add dimerization example, add second order code, refactor jit

* remove equinox dependency, list dependencies

* make jax optional

* support conservation laws

* fixup

* fix jit nesting

* use vmap for vectorization

* fixups

* add multithreaded simulation runner

* fix my

* fixes

* fixup merge

* fix install

* actually generate code

* fix

* fix

* add better default coefficients, fix jax

* ignore fujita in jax

* ignore smith

* optimize & fix bachmann

* fix import/wokflow

* Update __init__.template.py

* fix jax imports

* Update setup.cfg

* add preequilibration support

* fix jax tests

* add filterwarning

* fix parameter transformation

* reenable ruff format

* post merge cleanup

* "fix" splines

* Update .pre-commit-config.yaml

* force optimistix 0.0.9

* add support for heavyside functions

* cleanup & actually run tests

* simply tests + add support for non-dynamic simulation in jax

* fix for NONCONST_CLS

* fix petab path

* fixup merge

* support postequilibration

* fixup

* fix

* fix gradients

* fix hessian

* Update test_petab_benchmark.py

* skip smith in jax

* exclude more models

* refactor: remove use of edatas

* update template

* Update .pre-commit-config.yaml

* fix python jax tests

* simplify petab interface

* add parameter values to model class

* refactor parameter mapping

* refactor & simplify

* refsctor

* update template

* Update .pre-commit-config.yaml

* refactor fix test

* Update petab.py

* fixups

* fixup

* add documentation and typing

* add runtime typechecks to jax tests

* add coverage from benchmark tests

* add api versioning and reenable jit compilation

* review comments

* use temporary directories

* fix doc

* Update test_jax.py

* don't generate code if jax/diffrax not available

* add example

* fix doc

* fix notebook symlink

* update notebook

* Update ExampleJaxPEtab.ipynb

* Update ExampleJaxPEtab.ipynb

* fix compilation issue

* fix
  • Loading branch information
FFroehlich authored Nov 19, 2024
1 parent 022de60 commit db25bc8
Show file tree
Hide file tree
Showing 24 changed files with 2,829 additions and 55 deletions.
21 changes: 17 additions & 4 deletions .github/workflows/test_benchmark_collection_models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
run: |
pip3 install --user petab[vis] && \
AMICI_PARALLEL_COMPILE="" pip3 install -v --user \
$(ls -t python/sdist/dist/amici-*.tar.gz | head -1)[petab,test,vis]
$(ls -t python/sdist/dist/amici-*.tar.gz | head -1)[petab,test,vis,jax]
- name: Install test dependencies
run: |
Expand All @@ -60,14 +60,27 @@ jobs:
- name: Download benchmark collection
run: |
git clone --depth 1 https://github.com/benchmarking-initiative/Benchmark-Models-PEtab.git \
&& python3 -m pip install -e Benchmark-Models-PEtab/src/python
pip install git+https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab.git@master#subdirectory=src/python
- name: Run tests
env:
AMICI_PARALLEL_COMPILE: ""
run: |
cd tests/benchmark-models && pytest --durations=10
cd tests/benchmark-models && pytest \
--durations=10
--cov=amici \
--cov-report=xml:"coverage_py.xml" \
--cov-append \
- name: Codecov Python
if: github.event_name == 'pull_request' || github.repository_owner == 'AMICI-dev'
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: coverage_py.xml
flags: python
fail_ci_if_error: true
verbose: true

# collect & upload results
- name: Aggregate results
Expand Down
5 changes: 0 additions & 5 deletions .github/workflows/test_python_cplusplus.yml
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,6 @@ jobs:
- name: Install python package
run: scripts/installAmiciSource.sh

- name: Install notebook dependencies
run: |
source venv/bin/activate \
&& pip install jax[cpu]
- name: example notebooks
run: scripts/runNotebook.sh python/examples/example_*/

Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ tests/test/*
*/tests/explicit_amici/*
*/tests/fixed_initial_amici/*
*/tests/localfunc_amici/*
*/tests/conversion/*
*/tests/dimerization/*
tests/cpp/writeResults.h5
tests/cpp/writeResults.h5.bak
tests/sbml-test-suite/*
Expand Down
1 change: 1 addition & 0 deletions documentation/ExampleJaxPEtab.ipynb
1 change: 1 addition & 0 deletions documentation/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def install_doxygen():
"numpy": ("https://numpy.org/devdocs/", None),
"sympy": ("https://docs.sympy.org/latest/", None),
"python": ("https://docs.python.org/3", None),
"jax": ["https://jax.readthedocs.io/en/latest/", None],
}

# Add notebooks prolog with binder links
Expand Down
1 change: 1 addition & 0 deletions documentation/python_examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ Various example notebooks.
example_errors.ipynb
example_large_models/example_performance_optimization.ipynb
ExampleJax.ipynb
ExampleJaxPEtab.ipynb
ExampleSplines.ipynb
ExampleSplinesSwameye2003.ipynb
1 change: 1 addition & 0 deletions documentation/python_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ AMICI Python API
amici.petab_objective
amici.petab_simulate
amici.import_utils
amici.jax
amici.de_export
amici.de_model
amici.de_model_components
Expand Down
2 changes: 2 additions & 0 deletions documentation/rtd_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ sphinx<8
mock>=5.0.2
setuptools>=67.7.2
pysb>=1.11.0
jax>=0.4.26
diffrax>=0.5.0
matplotlib==3.7.1
nbsphinx==0.9.1
nbformat==5.8.0
Expand Down
1,162 changes: 1,162 additions & 0 deletions python/examples/example_jax_petab/ExampleJaxPEtab.ipynb

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions python/sdist/amici/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ def _imported_from_setup() -> bool:
assignmentRules2observables,
)

try:
from .jax import JAXModel
except (ImportError, ModuleNotFoundError):
JAXModel = object

@runtime_checkable
class ModelModule(Protocol): # noqa: F811
"""Type of AMICI-generated model modules.
Expand All @@ -135,6 +140,11 @@ def get_model(self) -> amici.Model:
"""Create a model instance."""
...

def get_jax_model(self) -> JAXModel:
...

AmiciModel = Union[amici.Model, amici.ModelPtr]


class add_path:
"""Context manager for temporarily changing PYTHONPATH"""
Expand Down
12 changes: 11 additions & 1 deletion python/sdist/amici/__init__.template.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""AMICI-generated module for model TPL_MODELNAME"""

from pathlib import Path

from typing import TYPE_CHECKING
import amici

if TYPE_CHECKING:
from amici.jax import JAXModel

# Ensure we are binary-compatible, see #556
if "TPL_AMICI_VERSION" != amici.__version__:
raise amici.AmiciVersionError(
Expand All @@ -18,4 +21,11 @@
from .TPL_MODELNAME import * # noqa: F403, F401
from .TPL_MODELNAME import getModel as get_model # noqa: F401


def get_jax_model() -> "JAXModel":
from .jax import JAXModel_TPL_MODELNAME

return JAXModel_TPL_MODELNAME()


__version__ = "TPL_PACKAGE_VERSION"
Loading

0 comments on commit db25bc8

Please sign in to comment.