Skip to content

Commit

Permalink
Rename solver arg to ode_solver for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Jun 29, 2024
1 parent 2b34e0d commit 13fd3b9
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 15 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,17 @@ descent) to update the network parameters.

## Advanced usage

## 📄 Citation

If you found this library useful in your work, please cite (arXiv link):

```bibtex
@article{innocenti2024jpc,
title={JPC: Predictive Coding Networks in JAX},
author={Innocenti, Francesco and Kinghorn, Paul and Singh, Ryan and
De Llanza Varona, Miguel and Buckley, Christopher},
journal={arXiv preprint},
year={2024}
}
```
Also consider starring the repo! ⭐️
4 changes: 2 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ JPC is a [JAX](https://github.com/google/jax) library to train neural networks
with predictive coding. It is built on top of three main libraries:

* [Equinox](https://github.com/patrick-kidger/equinox), to define neural
networks with PyTorch-like syntax, and
networks with PyTorch-like syntax,
* [Diffrax](https://github.com/patrick-kidger/diffrax), to solve the PC
activity (inference) dynamics.
activity (inference) dynamics, and
* [Optax](https://github.com/google-deepmind/optax), for parameter optimisation.

JPC provides a simple but flexible API for research of PCNs compatible with
Expand Down
14 changes: 7 additions & 7 deletions jpc/_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_generative_pc(
layer_sizes: PyTree[int],
batch_size: int,
sigma: Scalar = 0.05,
solver: AbstractSolver = Euler(),
ode_solver: AbstractSolver = Euler(),
dt: float | int = 1,
n_iters: int = 20,
stepsize_controller: AbstractStepSizeController = ConstantStepSize()
Expand All @@ -74,7 +74,7 @@ def test_generative_pc(
- `sigma`: Standard deviation for Gaussian to sample activities from.
Defaults to 5e-2.
- `solver`: Diffrax (ODE) solver to be used. Default is Euler.
- `ode_solver`: Diffrax ODE solver to be used. Default is Euler.
- `dt`: Integration step size. Defaults to 1.
- `n_iters`: Number of integration steps (20 as default).
- `stepsize_controller`: diffrax controller for step size integration.
Expand All @@ -96,7 +96,7 @@ def test_generative_pc(
model=model,
activities=activities,
y=y,
solver=solver,
solver=ode_solver,
n_iters=n_iters,
stepsize_controller=stepsize_controller,
dt=dt
Expand All @@ -116,7 +116,7 @@ def test_hpc(
layer_sizes: PyTree[int],
batch_size: int,
sigma: Scalar = 0.05,
solver: AbstractSolver = Euler(),
ode_solver: AbstractSolver = Euler(),
dt: float | int = 1,
n_iters: int = 20,
stepsize_controller: AbstractStepSizeController = ConstantStepSize()
Expand All @@ -142,7 +142,7 @@ def test_hpc(
- `sigma`: Standard deviation for Gaussian to sample activities from.
Defaults to 5e-2.
- `solver`: Diffrax (ODE) solver to be used. Default is Euler.
- `ode_solver`: Diffrax ODE solver to be used. Default is Euler.
- `dt`: Integration step size. Defaults to 1.
- `n_iters`: Number of integration steps (20 as default).
- `stepsize_controller`: diffrax controller for step size integration.
Expand All @@ -163,7 +163,7 @@ def test_hpc(
model=generator,
activities=amort_activities,
y=y,
solver=solver,
solver=ode_solver,
n_iters=n_iters,
stepsize_controller=stepsize_controller,
dt=dt
Expand All @@ -179,7 +179,7 @@ def test_hpc(
model=generator,
activities=activities,
y=y,
solver=solver,
solver=ode_solver,
n_iters=n_iters,
stepsize_controller=stepsize_controller,
dt=dt
Expand Down
12 changes: 6 additions & 6 deletions jpc/_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def make_pc_step(
opt_state: OptState,
y: ArrayLike,
x: Optional[ArrayLike] = None,
solver: AbstractSolver = Euler(),
ode_solver: AbstractSolver = Euler(),
dt: float | int = 1,
n_iters: Optional[int] = 20,
stepsize_controller: AbstractStepSizeController = ConstantStepSize(),
Expand Down Expand Up @@ -61,7 +61,7 @@ def make_pc_step(
**Other arguments:**
- `solver`: Diffrax (ODE) solver to be used. Default is Euler.
- `ode_solver`: Diffrax ODE solver to be used. Default is Euler.
- `dt`: Integration step size. Defaults to 1.
- `n_iters`: Number of integration steps (20 as default).
- `stepsize_controller`: diffrax controller for step size integration.
Expand Down Expand Up @@ -112,7 +112,7 @@ def make_pc_step(
activities=activities,
y=y,
x=x,
solver=solver,
solver=ode_solver,
n_iters=n_iters,
stepsize_controller=stepsize_controller,
dt=dt,
Expand Down Expand Up @@ -161,7 +161,7 @@ def make_hpc_step(
opt_states: Tuple[OptState],
y: ArrayLike,
x: ArrayLike,
solver: AbstractSolver = Euler(),
ode_solver: AbstractSolver = Euler(),
dt: float | int = 1,
n_iters: Optional[int] = 20,
stepsize_controller: AbstractStepSizeController = ConstantStepSize(),
Expand Down Expand Up @@ -197,7 +197,7 @@ def make_hpc_step(
**Other arguments:**
- `solver`: Diffrax (ODE) solver to be used. Default is Euler.
- `ode_solver`: Diffrax ODE solver to be used. Default is Euler.
- `dt`: Integration step size. Defaults to 1.
- `n_iters`: Number of integration steps (20 as default).
- `stepsize_controller`: diffrax controller for step size integration.
Expand Down Expand Up @@ -240,7 +240,7 @@ def make_hpc_step(
activities=amort_activities[1:],
y=y,
x=x,
solver=solver,
solver=ode_solver,
n_iters=n_iters,
stepsize_controller=stepsize_controller,
dt=dt,
Expand Down

0 comments on commit 13fd3b9

Please sign in to comment.