diff --git a/README.md b/README.md
index 83c10c1..1a821d6 100644
--- a/README.md
+++ b/README.md
@@ -6,8 +6,10 @@
🧠 Predictive coding networks in JAX ⚡️
+![status](https://img.shields.io/badge/status-active-green)
+
JPC is a [JAX](https://github.com/google/jax) library to train neural networks
-with predictive coding. It is built on top of three main libraries:
+with predictive coding (PC). It is built on top of three main libraries:
* [Equinox](https://github.com/patrick-kidger/equinox), to define neural
networks with PyTorch-like syntax,
@@ -15,18 +17,17 @@ networks with PyTorch-like syntax,
activity (inference) dynamics, and
* [Optax](https://github.com/google-deepmind/optax), for parameter optimisation.
-JPC provides a simple but flexible API for research of PCNs compatible with
-useful JAX transforms such as `vmap` and `jit`.
-
-![status](https://img.shields.io/badge/status-active-green)
+JPC provides a simple, fast and flexible API for research on PCNs compatible
+with all of JAX and leveraging ODE solvers to integrate the PC inference
+dynamics.
## Overview
* [Installation](#installation)
* [Documentation](#documentation)
* [Quick example](#quick-example)
-* [Basic usage](#basic-usage)
* [Advanced usage](#advanced-usage)
+* [Citation](#citation)
## ️💻 Installation
@@ -43,54 +44,74 @@ Requires Python 3.9+, JAX 0.4.23+, [Equinox](https://github.com/patrick-kidger/e
Available at https://github.com/thebuckleylab.githhub.io/jpc.
## ⚡️ Quick example
-
-Given a neural network with callable layers
+Use `jpc.make_pc_step` to update the parameters of essentially any neural
+network with PC
```py
+import jpc
import jax
import jax.numpy as jnp
-from equinox import nn as nn
+import equinox as eqx
+import optax
-# some data
+# toy data
x = jnp.array([1., 1., 1.])
y = -x
-# network
+# define model and optimiser
key = jax.random.key(0)
-_, *subkeys = jax.random.split(key)
-network = [
- nn.Sequential(
- [
- nn.Linear(3, 100, key=subkeys[0]),
- nn.Lambda(jax.nn.relu)
- ],
- ),
- nn.Linear(100, 3, key=subkeys[1]),
-]
+model = jpc.make_mlp(key, layer_sizes=[3, 5, 5, 3], act_fn="relu")
+optim = optax.adam(1e-3)
+opt_state = optim.init(eqx.filter(model, eqx.is_array))
+
+# update model parameters with PC
+result = jpc.make_pc_step(
+ model,
+ optim,
+ opt_state,
+ y,
+ x
+)
```
-We can train it with predictive coding in a few lines of code
+Under the hood, `jpc.make_pc_step`
+1. integrates the activity (inference) dynamics using a [Diffrax](https://github.com/patrick-kidger/diffrax) ODE solver (Euler by default),
+2. computes the PC gradient w.r.t. the model parameters at the numerical solution of the activities, and
+3. updates the parameters with the provided [Optax](https://github.com/google-deepmind/optax) optimiser.
+
+> **NOTE**: All convenience training and test functions including `make_pc_step`
+> are already "jitted" (for increased performance) for the user's convenience.
+
+## Advanced usage
+More advanced users can access the functionality used by `jpc.make_pc_step`.
```py
import jpc
-# initialise layer activities with a feedforward pass
-activities = jpc.init_activities_with_ffwd(network, x)
+# 1. initialise activities with a feedforward pass
+activities0 = jpc.init_activities_with_ffwd(model, x)
-# run the inference dynamics to equilibrium
-equilib_activities = jpc.solve_pc_activities(network, activities, y, x)
+# 2. run the inference dynamics to equilibrium
+equilib_activities = jpc.solve_pc_activities(
+ model,
+ activities0,
+ y,
+ x
+)
-# compute the PC parameter gradients
-pc_param_grads = jpc.compute_pc_param_grads(
- network,
+# 3. compute PC parameter gradients
+param_grads = jpc.compute_pc_param_grads(
+ model,
equilib_activities,
y,
x
)
-```
-The gradients can then be fed to your favourite optimiser (e.g. gradient
-descent) to update the network parameters.
-## Basic usage
-
-## Advanced usage
+# 4. update parameters
+updates, opt_states = optim.update(
+ param_grads,
+ opt_state,
+ model
+)
+model = eqx.apply_updates(model, updates)
+```
## 📄 Citation
diff --git a/docs/index.md b/docs/index.md
index b84f4d7..678bcaf 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -1,7 +1,7 @@
# Getting started
JPC is a [JAX](https://github.com/google/jax) library to train neural networks
-with predictive coding. It is built on top of three main libraries:
+with predictive coding (PC). It is built on top of three main libraries:
* [Equinox](https://github.com/patrick-kidger/equinox), to define neural
networks with PyTorch-like syntax,
@@ -9,8 +9,9 @@ networks with PyTorch-like syntax,
activity (inference) dynamics, and
* [Optax](https://github.com/google-deepmind/optax), for parameter optimisation.
-JPC provides a simple but flexible API for research of PCNs compatible with
-useful JAX transforms such as `vmap` and `jit`.
+JPC provides a simple, fast and flexible API for research on PCNs compatible
+with all of JAX and leveraging ODE solvers to integrate the PC inference
+dynamics.
## 💻 Installation
@@ -24,47 +25,73 @@ Requires Python 3.9+, JAX 0.4.23+, [Equinox](https://github.com/patrick-kidger/e
[Jaxtyping](https://github.com/patrick-kidger/jaxtyping) 0.2.24+.
## ⚡️ Quick example
-
-Given a neural network with callable layers
+Use `jpc.make_pc_step` to update the parameters of essentially any neural
+network with PC
```py
+import jpc
import jax
import jax.numpy as jnp
-from equinox import nn as nn
+import equinox as eqx
+import optax
-# some data
+# toy data
x = jnp.array([1., 1., 1.])
y = -x
-# network
+# define model and optimiser
key = jax.random.key(0)
-_, *subkeys = jax.random.split(key)
-network = [nn.Sequential(
- [
- nn.Linear(3, 100, key=subkeys[0]),
- nn.Lambda(jax.nn.relu)],
- ),
- nn.Linear(100, 3, key=subkeys[1]),
-]
+model = jpc.make_mlp(key, layer_sizes=[3, 5, 5, 3], act_fn="relu")
+optim = optax.adam(1e-3)
+opt_state = optim.init(eqx.filter(model, eqx.is_array))
+
+# update model parameters with PC
+result = jpc.make_pc_step(
+ model,
+ optim,
+ opt_state,
+ y,
+ x
+)
```
-we can perform a PC parameter update with a single function call
+Under the hood, `jpc.make_pc_step`
+1. integrates the activity (inference) dynamics using a [Diffrax](https://github.com/patrick-kidger/diffrax) ODE solver (Euler by default),
+2. computes the PC gradient w.r.t. the model parameters at the numerical solution of the activities, and
+3. updates the parameters with the provided [Optax](https://github.com/google-deepmind/optax) optimiser.
+
+> **NOTE**: All convenience training and test functions including `make_pc_step`
+> are already "jitted" (for increased performance) for the user's convenience.
+
+## Advanced usage
+More advanced users can access the functionality used by `jpc.make_pc_step`.
```py
import jpc
-import optax
-import equinox as eqx
-# optimiser
-optim = optax.adam(1e-3)
-opt_state = optim.init(eqx.filter(network, eqx.is_array))
+# 1. initialise activities with a feedforward pass
+activities0 = jpc.init_activities_with_ffwd(model, x)
-# PC parameter update
-result = jpc.make_pc_step(
- model=network,
- optim=optim,
- opt_state=opt_state,
- y=y,
- x=x
+# 2. run the inference dynamics to equilibrium
+equilib_activities = jpc.solve_pc_activities(
+ model,
+ activities0,
+ y,
+ x
+)
+
+# 3. compute PC parameter gradients
+param_grads = jpc.compute_pc_param_grads(
+ model,
+ equilib_activities,
+ y,
+ x
)
+# 4. update parameters
+updates, opt_states = optim.update(
+ param_grads,
+ opt_state,
+ model
+)
+model = eqx.apply_updates(model, updates)
```
## 📄 Citation