diff --git a/README.md b/README.md index de98326..83c10c1 100644 --- a/README.md +++ b/README.md @@ -92,3 +92,17 @@ descent) to update the network parameters. ## Advanced usage +## 📄 Citation + +If you found this library useful in your work, please cite (arXiv link): + +```bibtex +@article{innocenti2024jpc, + title={JPC: Predictive Coding Networks in JAX}, + author={Innocenti, Francesco and Kinghorn, Paul and Singh, Ryan and + De Llanza Varona, Miguel and Buckley, Christopher}, + journal={arXiv preprint}, + year={2024} +} +``` +Also consider starring the repo! ⭐️ diff --git a/docs/index.md b/docs/index.md index 79d779c..b84f4d7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,9 +4,9 @@ JPC is a [JAX](https://github.com/google/jax) library to train neural networks with predictive coding. It is built on top of three main libraries: * [Equinox](https://github.com/patrick-kidger/equinox), to define neural -networks with PyTorch-like syntax, and +networks with PyTorch-like syntax, * [Diffrax](https://github.com/patrick-kidger/diffrax), to solve the PC -activity (inference) dynamics. +activity (inference) dynamics, and * [Optax](https://github.com/google-deepmind/optax), for parameter optimisation. JPC provides a simple but flexible API for research of PCNs compatible with diff --git a/jpc/_test.py b/jpc/_test.py index 97bcc50..8be91b8 100644 --- a/jpc/_test.py +++ b/jpc/_test.py @@ -50,7 +50,7 @@ def test_generative_pc( layer_sizes: PyTree[int], batch_size: int, sigma: Scalar = 0.05, - solver: AbstractSolver = Euler(), + ode_solver: AbstractSolver = Euler(), dt: float | int = 1, n_iters: int = 20, stepsize_controller: AbstractStepSizeController = ConstantStepSize() @@ -74,7 +74,7 @@ def test_generative_pc( - `sigma`: Standard deviation for Gaussian to sample activities from. Defaults to 5e-2. - - `solver`: Diffrax (ODE) solver to be used. Default is Euler. + - `ode_solver`: Diffrax ODE solver to be used. Default is Euler. - `dt`: Integration step size. Defaults to 1. - `n_iters`: Number of integration steps (20 as default). - `stepsize_controller`: diffrax controller for step size integration. @@ -96,7 +96,7 @@ def test_generative_pc( model=model, activities=activities, y=y, - solver=solver, + solver=ode_solver, n_iters=n_iters, stepsize_controller=stepsize_controller, dt=dt @@ -116,7 +116,7 @@ def test_hpc( layer_sizes: PyTree[int], batch_size: int, sigma: Scalar = 0.05, - solver: AbstractSolver = Euler(), + ode_solver: AbstractSolver = Euler(), dt: float | int = 1, n_iters: int = 20, stepsize_controller: AbstractStepSizeController = ConstantStepSize() @@ -142,7 +142,7 @@ def test_hpc( - `sigma`: Standard deviation for Gaussian to sample activities from. Defaults to 5e-2. - - `solver`: Diffrax (ODE) solver to be used. Default is Euler. + - `ode_solver`: Diffrax ODE solver to be used. Default is Euler. - `dt`: Integration step size. Defaults to 1. - `n_iters`: Number of integration steps (20 as default). - `stepsize_controller`: diffrax controller for step size integration. @@ -163,7 +163,7 @@ def test_hpc( model=generator, activities=amort_activities, y=y, - solver=solver, + solver=ode_solver, n_iters=n_iters, stepsize_controller=stepsize_controller, dt=dt @@ -179,7 +179,7 @@ def test_hpc( model=generator, activities=activities, y=y, - solver=solver, + solver=ode_solver, n_iters=n_iters, stepsize_controller=stepsize_controller, dt=dt diff --git a/jpc/_train.py b/jpc/_train.py index b29a7df..b47e86e 100644 --- a/jpc/_train.py +++ b/jpc/_train.py @@ -32,7 +32,7 @@ def make_pc_step( opt_state: OptState, y: ArrayLike, x: Optional[ArrayLike] = None, - solver: AbstractSolver = Euler(), + ode_solver: AbstractSolver = Euler(), dt: float | int = 1, n_iters: Optional[int] = 20, stepsize_controller: AbstractStepSizeController = ConstantStepSize(), @@ -61,7 +61,7 @@ def make_pc_step( **Other arguments:** - - `solver`: Diffrax (ODE) solver to be used. Default is Euler. + - `ode_solver`: Diffrax ODE solver to be used. Default is Euler. - `dt`: Integration step size. Defaults to 1. - `n_iters`: Number of integration steps (20 as default). - `stepsize_controller`: diffrax controller for step size integration. @@ -112,7 +112,7 @@ def make_pc_step( activities=activities, y=y, x=x, - solver=solver, + solver=ode_solver, n_iters=n_iters, stepsize_controller=stepsize_controller, dt=dt, @@ -161,7 +161,7 @@ def make_hpc_step( opt_states: Tuple[OptState], y: ArrayLike, x: ArrayLike, - solver: AbstractSolver = Euler(), + ode_solver: AbstractSolver = Euler(), dt: float | int = 1, n_iters: Optional[int] = 20, stepsize_controller: AbstractStepSizeController = ConstantStepSize(), @@ -197,7 +197,7 @@ def make_hpc_step( **Other arguments:** - - `solver`: Diffrax (ODE) solver to be used. Default is Euler. + - `ode_solver`: Diffrax ODE solver to be used. Default is Euler. - `dt`: Integration step size. Defaults to 1. - `n_iters`: Number of integration steps (20 as default). - `stepsize_controller`: diffrax controller for step size integration. @@ -240,7 +240,7 @@ def make_hpc_step( activities=amort_activities[1:], y=y, x=x, - solver=solver, + solver=ode_solver, n_iters=n_iters, stepsize_controller=stepsize_controller, dt=dt,