diff --git a/README.md b/README.md index 83c10c1..1a821d6 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,10 @@

🧠 Predictive coding networks in JAX ⚡️

+![status](https://img.shields.io/badge/status-active-green) + JPC is a [JAX](https://github.com/google/jax) library to train neural networks -with predictive coding. It is built on top of three main libraries: +with predictive coding (PC). It is built on top of three main libraries: * [Equinox](https://github.com/patrick-kidger/equinox), to define neural networks with PyTorch-like syntax, @@ -15,18 +17,17 @@ networks with PyTorch-like syntax, activity (inference) dynamics, and * [Optax](https://github.com/google-deepmind/optax), for parameter optimisation. -JPC provides a simple but flexible API for research of PCNs compatible with -useful JAX transforms such as `vmap` and `jit`. - -![status](https://img.shields.io/badge/status-active-green) +JPC provides a simple, fast and flexible API for research on PCNs compatible +with all of JAX and leveraging ODE solvers to integrate the PC inference +dynamics. ## Overview * [Installation](#installation) * [Documentation](#documentation) * [Quick example](#quick-example) -* [Basic usage](#basic-usage) * [Advanced usage](#advanced-usage) +* [Citation](#citation) ## ️💻 Installation @@ -43,54 +44,74 @@ Requires Python 3.9+, JAX 0.4.23+, [Equinox](https://github.com/patrick-kidger/e Available at https://github.com/thebuckleylab.githhub.io/jpc. ## ⚡️ Quick example - -Given a neural network with callable layers +Use `jpc.make_pc_step` to update the parameters of essentially any neural +network with PC ```py +import jpc import jax import jax.numpy as jnp -from equinox import nn as nn +import equinox as eqx +import optax -# some data +# toy data x = jnp.array([1., 1., 1.]) y = -x -# network +# define model and optimiser key = jax.random.key(0) -_, *subkeys = jax.random.split(key) -network = [ - nn.Sequential( - [ - nn.Linear(3, 100, key=subkeys[0]), - nn.Lambda(jax.nn.relu) - ], - ), - nn.Linear(100, 3, key=subkeys[1]), -] +model = jpc.make_mlp(key, layer_sizes=[3, 5, 5, 3], act_fn="relu") +optim = optax.adam(1e-3) +opt_state = optim.init(eqx.filter(model, eqx.is_array)) + +# update model parameters with PC +result = jpc.make_pc_step( + model, + optim, + opt_state, + y, + x +) ``` -We can train it with predictive coding in a few lines of code +Under the hood, `jpc.make_pc_step` +1. integrates the activity (inference) dynamics using a [Diffrax](https://github.com/patrick-kidger/diffrax) ODE solver (Euler by default), +2. computes the PC gradient w.r.t. the model parameters at the numerical solution of the activities, and +3. updates the parameters with the provided [Optax](https://github.com/google-deepmind/optax) optimiser. + +> **NOTE**: All convenience training and test functions including `make_pc_step` +> are already "jitted" (for increased performance) for the user's convenience. + +## Advanced usage +More advanced users can access the functionality used by `jpc.make_pc_step`. ```py import jpc -# initialise layer activities with a feedforward pass -activities = jpc.init_activities_with_ffwd(network, x) +# 1. initialise activities with a feedforward pass +activities0 = jpc.init_activities_with_ffwd(model, x) -# run the inference dynamics to equilibrium -equilib_activities = jpc.solve_pc_activities(network, activities, y, x) +# 2. run the inference dynamics to equilibrium +equilib_activities = jpc.solve_pc_activities( + model, + activities0, + y, + x +) -# compute the PC parameter gradients -pc_param_grads = jpc.compute_pc_param_grads( - network, +# 3. compute PC parameter gradients +param_grads = jpc.compute_pc_param_grads( + model, equilib_activities, y, x ) -``` -The gradients can then be fed to your favourite optimiser (e.g. gradient -descent) to update the network parameters. -## Basic usage - -## Advanced usage +# 4. update parameters +updates, opt_states = optim.update( + param_grads, + opt_state, + model +) +model = eqx.apply_updates(model, updates) +``` ## 📄 Citation diff --git a/docs/index.md b/docs/index.md index b84f4d7..678bcaf 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,7 +1,7 @@ # Getting started JPC is a [JAX](https://github.com/google/jax) library to train neural networks -with predictive coding. It is built on top of three main libraries: +with predictive coding (PC). It is built on top of three main libraries: * [Equinox](https://github.com/patrick-kidger/equinox), to define neural networks with PyTorch-like syntax, @@ -9,8 +9,9 @@ networks with PyTorch-like syntax, activity (inference) dynamics, and * [Optax](https://github.com/google-deepmind/optax), for parameter optimisation. -JPC provides a simple but flexible API for research of PCNs compatible with -useful JAX transforms such as `vmap` and `jit`. +JPC provides a simple, fast and flexible API for research on PCNs compatible +with all of JAX and leveraging ODE solvers to integrate the PC inference +dynamics. ## 💻 Installation @@ -24,47 +25,73 @@ Requires Python 3.9+, JAX 0.4.23+, [Equinox](https://github.com/patrick-kidger/e [Jaxtyping](https://github.com/patrick-kidger/jaxtyping) 0.2.24+. ## ⚡️ Quick example - -Given a neural network with callable layers +Use `jpc.make_pc_step` to update the parameters of essentially any neural +network with PC ```py +import jpc import jax import jax.numpy as jnp -from equinox import nn as nn +import equinox as eqx +import optax -# some data +# toy data x = jnp.array([1., 1., 1.]) y = -x -# network +# define model and optimiser key = jax.random.key(0) -_, *subkeys = jax.random.split(key) -network = [nn.Sequential( - [ - nn.Linear(3, 100, key=subkeys[0]), - nn.Lambda(jax.nn.relu)], - ), - nn.Linear(100, 3, key=subkeys[1]), -] +model = jpc.make_mlp(key, layer_sizes=[3, 5, 5, 3], act_fn="relu") +optim = optax.adam(1e-3) +opt_state = optim.init(eqx.filter(model, eqx.is_array)) + +# update model parameters with PC +result = jpc.make_pc_step( + model, + optim, + opt_state, + y, + x +) ``` -we can perform a PC parameter update with a single function call +Under the hood, `jpc.make_pc_step` +1. integrates the activity (inference) dynamics using a [Diffrax](https://github.com/patrick-kidger/diffrax) ODE solver (Euler by default), +2. computes the PC gradient w.r.t. the model parameters at the numerical solution of the activities, and +3. updates the parameters with the provided [Optax](https://github.com/google-deepmind/optax) optimiser. + +> **NOTE**: All convenience training and test functions including `make_pc_step` +> are already "jitted" (for increased performance) for the user's convenience. + +## Advanced usage +More advanced users can access the functionality used by `jpc.make_pc_step`. ```py import jpc -import optax -import equinox as eqx -# optimiser -optim = optax.adam(1e-3) -opt_state = optim.init(eqx.filter(network, eqx.is_array)) +# 1. initialise activities with a feedforward pass +activities0 = jpc.init_activities_with_ffwd(model, x) -# PC parameter update -result = jpc.make_pc_step( - model=network, - optim=optim, - opt_state=opt_state, - y=y, - x=x +# 2. run the inference dynamics to equilibrium +equilib_activities = jpc.solve_pc_activities( + model, + activities0, + y, + x +) + +# 3. compute PC parameter gradients +param_grads = jpc.compute_pc_param_grads( + model, + equilib_activities, + y, + x ) +# 4. update parameters +updates, opt_states = optim.update( + param_grads, + opt_state, + model +) +model = eqx.apply_updates(model, updates) ``` ## 📄 Citation