diff --git a/README.md b/README.md index 0604e60..9ea5686 100644 --- a/README.md +++ b/README.md @@ -8,26 +8,23 @@ ![status](https://img.shields.io/badge/status-active-green) -JPC is a [**J**AX](https://github.com/google/jax) library to train neural networks -with **P**redictive **C**oding (PC). It is built on top of three main libraries: +JPC is a [**J**AX](https://github.com/google/jax) library for training neural +networks with **P**redictive **C**oding (PC). It is built on top of three main +libraries: * [Equinox](https://github.com/patrick-kidger/equinox), to define neural networks with PyTorch-like syntax, * [Diffrax](https://github.com/patrick-kidger/diffrax), to solve the PC inference (activity) dynamics, and * [Optax](https://github.com/google-deepmind/optax), for parameter optimisation. -Unlike existing PC libraries, JPC leverages ordinary differential equation solvers -to integrate the inference (activity) dynamics of PC networks, which we find -can provide significant speed-ups compared to standard optimisers, especially -for deeper models. - -JPC provides a **simple**, **relatively fast** and **flexible** API. -1. It is simple in that, like JAX, JPC follows a fully functional paradigm, -and the core library is <1000 lines of code. -2. It is relatively fast in that higher-order solvers can provide speed-ups -compared to standard optimisers, especially on deeper models. -3. And it is flexible in that it allows training a variety of PC networks -including discriminative, generative and hybrid models. +JPC provides a **simple**, **relatively fast** and **flexible** API for +training of a variety of PCNs including discriminative, generative and hybrid +models. Like JAX, JPC is completely functional, and the core library is <1000 +lines of code. Unlike existing implementations, JPC leverages ordinary +differential equation (ODE) solvers to integrate the inference dynamics of PC +networks (PCNs), which we find can provide significant speed-ups compared to +standard optimisers, especially for deeper models. JPC also provides some +analytical tools that can be used to study and diagnose issues with PCNs. ## Overview * [Installation](#installation) @@ -95,17 +92,18 @@ Under the hood, `jpc.make_pc_step` ## 🚀 Advanced usage More advanced users can access any of the functionality used by `jpc.make_pc_step`. +A custom PC training step would look like the following ```py import jpc # 1. initialise activities with a feedforward pass -activities0 = jpc.init_activities_with_ffwd(model=model, input=x) +activities = jpc.init_activities_with_ffwd(model=model, input=x) # 2. run the inference dynamics to equilibrium equilibrated_activities = jpc.solve_pc_inference( params=(model, None), - activities=activities0, + activities=activities, output=y, input=x ) @@ -120,6 +118,7 @@ step_result = jpc.update_params( input=x ) ``` +which can be embedded in a jitted function with any other additional computations. ## 📄 Citation If you found this library useful in your work, please cite (arXiv link):