Skip to content

Commit

Permalink
Update readme.
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Nov 22, 2024
1 parent 111f61f commit 6633836
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,23 @@

![status](https://img.shields.io/badge/status-active-green)

JPC is a [**J**AX](https://github.com/google/jax) library to train neural networks
with **P**redictive **C**oding (PC). It is built on top of three main libraries:
JPC is a [**J**AX](https://github.com/google/jax) library for training neural
networks with **P**redictive **C**oding (PC). It is built on top of three main
libraries:

* [Equinox](https://github.com/patrick-kidger/equinox), to define neural
networks with PyTorch-like syntax,
* [Diffrax](https://github.com/patrick-kidger/diffrax), to solve the PC inference (activity) dynamics, and
* [Optax](https://github.com/google-deepmind/optax), for parameter optimisation.

Unlike existing PC libraries, JPC leverages ordinary differential equation solvers
to integrate the inference (activity) dynamics of PC networks, which we find
can provide significant speed-ups compared to standard optimisers, especially
for deeper models.

JPC provides a **simple**, **relatively fast** and **flexible** API.
1. It is simple in that, like JAX, JPC follows a fully functional paradigm,
and the core library is <1000 lines of code.
2. It is relatively fast in that higher-order solvers can provide speed-ups
compared to standard optimisers, especially on deeper models.
3. And it is flexible in that it allows training a variety of PC networks
including discriminative, generative and hybrid models.
JPC provides a **simple**, **relatively fast** and **flexible** API for
training of a variety of PCNs including discriminative, generative and hybrid
models. Like JAX, JPC is completely functional, and the core library is <1000
lines of code. Unlike existing implementations, JPC leverages ordinary
differential equation (ODE) solvers to integrate the inference dynamics of PC
networks (PCNs), which we find can provide significant speed-ups compared to
standard optimisers, especially for deeper models. JPC also provides some
analytical tools that can be used to study and diagnose issues with PCNs.

## Overview
* [Installation](#installation)
Expand Down Expand Up @@ -95,17 +92,18 @@ Under the hood, `jpc.make_pc_step`
## 🚀 Advanced usage
More advanced users can access any of the functionality used by `jpc.make_pc_step`.
A custom PC training step would look like the following

```py
import jpc

# 1. initialise activities with a feedforward pass
activities0 = jpc.init_activities_with_ffwd(model=model, input=x)
activities = jpc.init_activities_with_ffwd(model=model, input=x)

# 2. run the inference dynamics to equilibrium
equilibrated_activities = jpc.solve_pc_inference(
params=(model, None),
activities=activities0,
activities=activities,
output=y,
input=x
)
Expand All @@ -120,6 +118,7 @@ step_result = jpc.update_params(
input=x
)
```
which can be embedded in a jitted function with any other additional computations.

## 📄 Citation
If you found this library useful in your work, please cite (arXiv link):
Expand Down

0 comments on commit 6633836

Please sign in to comment.