Skip to content

Commit

Permalink
Update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Jul 2, 2024
1 parent a1a1af5 commit 1005264
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 64 deletions.
91 changes: 56 additions & 35 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,28 @@

<h2 align='center'>🧠 Predictive coding networks in JAX ⚡️</h2>

![status](https://img.shields.io/badge/status-active-green)

JPC is a [JAX](https://github.com/google/jax) library to train neural networks
with predictive coding. It is built on top of three main libraries:
with predictive coding (PC). It is built on top of three main libraries:

* [Equinox](https://github.com/patrick-kidger/equinox), to define neural
networks with PyTorch-like syntax,
* [Diffrax](https://github.com/patrick-kidger/diffrax), to solve the PC
activity (inference) dynamics, and
* [Optax](https://github.com/google-deepmind/optax), for parameter optimisation.

JPC provides a simple but flexible API for research of PCNs compatible with
useful JAX transforms such as `vmap` and `jit`.

![status](https://img.shields.io/badge/status-active-green)
JPC provides a simple, fast and flexible API for research on PCNs compatible
with all of JAX and leveraging ODE solvers to integrate the PC inference
dynamics.

## Overview

* [Installation](#installation)
* [Documentation](#documentation)
* [Quick example](#quick-example)
* [Basic usage](#basic-usage)
* [Advanced usage](#advanced-usage)
* [Citation](#citation)

## ️💻 Installation

Expand All @@ -43,54 +44,74 @@ Requires Python 3.9+, JAX 0.4.23+, [Equinox](https://github.com/patrick-kidger/e
Available at https://github.com/thebuckleylab.githhub.io/jpc.

## ⚡️ Quick example

Given a neural network with callable layers
Use `jpc.make_pc_step` to update the parameters of essentially any neural
network with PC
```py
import jpc
import jax
import jax.numpy as jnp
from equinox import nn as nn
import equinox as eqx
import optax

# some data
# toy data
x = jnp.array([1., 1., 1.])
y = -x

# network
# define model and optimiser
key = jax.random.key(0)
_, *subkeys = jax.random.split(key)
network = [
nn.Sequential(
[
nn.Linear(3, 100, key=subkeys[0]),
nn.Lambda(jax.nn.relu)
],
),
nn.Linear(100, 3, key=subkeys[1]),
]
model = jpc.make_mlp(key, layer_sizes=[3, 5, 5, 3], act_fn="relu")
optim = optax.adam(1e-3)
opt_state = optim.init(eqx.filter(model, eqx.is_array))

# update model parameters with PC
result = jpc.make_pc_step(
model,
optim,
opt_state,
y,
x
)
```
We can train it with predictive coding in a few lines of code
Under the hood, `jpc.make_pc_step`
1. integrates the activity (inference) dynamics using a [Diffrax](https://github.com/patrick-kidger/diffrax) ODE solver (Euler by default),
2. computes the PC gradient w.r.t. the model parameters at the numerical solution of the activities, and
3. updates the parameters with the provided [Optax](https://github.com/google-deepmind/optax) optimiser.

> **NOTE**: All convenience training and test functions including `make_pc_step`
> are already "jitted" (for increased performance) for the user's convenience.
## Advanced usage
More advanced users can access the functionality used by `jpc.make_pc_step`.
```py
import jpc

# initialise layer activities with a feedforward pass
activities = jpc.init_activities_with_ffwd(network, x)
# 1. initialise activities with a feedforward pass
activities0 = jpc.init_activities_with_ffwd(model, x)

# run the inference dynamics to equilibrium
equilib_activities = jpc.solve_pc_activities(network, activities, y, x)
# 2. run the inference dynamics to equilibrium
equilib_activities = jpc.solve_pc_activities(
model,
activities0,
y,
x
)

# compute the PC parameter gradients
pc_param_grads = jpc.compute_pc_param_grads(
network,
# 3. compute PC parameter gradients
param_grads = jpc.compute_pc_param_grads(
model,
equilib_activities,
y,
x
)
```
The gradients can then be fed to your favourite optimiser (e.g. gradient
descent) to update the network parameters.

## Basic usage

## Advanced usage
# 4. update parameters
updates, opt_states = optim.update(
param_grads,
opt_state,
model
)
model = eqx.apply_updates(model, updates)
```

## 📄 Citation

Expand Down
85 changes: 56 additions & 29 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# Getting started

JPC is a [JAX](https://github.com/google/jax) library to train neural networks
with predictive coding. It is built on top of three main libraries:
with predictive coding (PC). It is built on top of three main libraries:

* [Equinox](https://github.com/patrick-kidger/equinox), to define neural
networks with PyTorch-like syntax,
* [Diffrax](https://github.com/patrick-kidger/diffrax), to solve the PC
activity (inference) dynamics, and
* [Optax](https://github.com/google-deepmind/optax), for parameter optimisation.

JPC provides a simple but flexible API for research of PCNs compatible with
useful JAX transforms such as `vmap` and `jit`.
JPC provides a simple, fast and flexible API for research on PCNs compatible
with all of JAX and leveraging ODE solvers to integrate the PC inference
dynamics.

## 💻 Installation

Expand All @@ -24,47 +25,73 @@ Requires Python 3.9+, JAX 0.4.23+, [Equinox](https://github.com/patrick-kidger/e
[Jaxtyping](https://github.com/patrick-kidger/jaxtyping) 0.2.24+.

## ⚡️ Quick example

Given a neural network with callable layers
Use `jpc.make_pc_step` to update the parameters of essentially any neural
network with PC
```py
import jpc
import jax
import jax.numpy as jnp
from equinox import nn as nn
import equinox as eqx
import optax

# some data
# toy data
x = jnp.array([1., 1., 1.])
y = -x

# network
# define model and optimiser
key = jax.random.key(0)
_, *subkeys = jax.random.split(key)
network = [nn.Sequential(
[
nn.Linear(3, 100, key=subkeys[0]),
nn.Lambda(jax.nn.relu)],
),
nn.Linear(100, 3, key=subkeys[1]),
]
model = jpc.make_mlp(key, layer_sizes=[3, 5, 5, 3], act_fn="relu")
optim = optax.adam(1e-3)
opt_state = optim.init(eqx.filter(model, eqx.is_array))

# update model parameters with PC
result = jpc.make_pc_step(
model,
optim,
opt_state,
y,
x
)
```
we can perform a PC parameter update with a single function call
Under the hood, `jpc.make_pc_step`
1. integrates the activity (inference) dynamics using a [Diffrax](https://github.com/patrick-kidger/diffrax) ODE solver (Euler by default),
2. computes the PC gradient w.r.t. the model parameters at the numerical solution of the activities, and
3. updates the parameters with the provided [Optax](https://github.com/google-deepmind/optax) optimiser.

> **NOTE**: All convenience training and test functions including `make_pc_step`
> are already "jitted" (for increased performance) for the user's convenience.
## Advanced usage
More advanced users can access the functionality used by `jpc.make_pc_step`.
```py
import jpc
import optax
import equinox as eqx

# optimiser
optim = optax.adam(1e-3)
opt_state = optim.init(eqx.filter(network, eqx.is_array))
# 1. initialise activities with a feedforward pass
activities0 = jpc.init_activities_with_ffwd(model, x)

# PC parameter update
result = jpc.make_pc_step(
model=network,
optim=optim,
opt_state=opt_state,
y=y,
x=x
# 2. run the inference dynamics to equilibrium
equilib_activities = jpc.solve_pc_activities(
model,
activities0,
y,
x
)

# 3. compute PC parameter gradients
param_grads = jpc.compute_pc_param_grads(
model,
equilib_activities,
y,
x
)

# 4. update parameters
updates, opt_states = optim.update(
param_grads,
opt_state,
model
)
model = eqx.apply_updates(model, updates)
```

## 📄 Citation
Expand Down

0 comments on commit 1005264

Please sign in to comment.