-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5c8df12
commit bcee8af
Showing
14 changed files
with
455 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
name: Build docs | ||
|
||
on: | ||
push: | ||
branches: | ||
- main | ||
|
||
jobs: | ||
build: | ||
strategy: | ||
matrix: | ||
python-version: [ 3.11 ] | ||
os: [ ubuntu-latest ] | ||
runs-on: ${{ matrix.os }} | ||
steps: | ||
- name: Checkout code | ||
uses: actions/checkout@v2 | ||
|
||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
|
||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
python -m pip install . | ||
#python -m pip install -r docs/requirements.txt | ||
- name: Build docs | ||
run: | | ||
mkdocs build | ||
mkdocs build | ||
- name: Upload docs | ||
uses: actions/upload-artifact@v2 | ||
with: | ||
name: docs | ||
path: site # where `mkdocs build` puts the built site |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
ErrorDocument 404 /diffrax/404.html |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
{% import "partials/language.html" as lang with context %} | ||
<a href="{{ config.repo_url }}" title="{{ lang.t('source.link.title') }}" class="md-source" data-md-component="source"> | ||
<div class="md-source__icon md-icon"> | ||
{% set icon = config.theme.icon.repo or "fontawesome/brands/git-alt" %} | ||
{% include ".icons/" ~ icon ~ ".svg" %} | ||
</div> | ||
<div class="md-source__repository"> | ||
{{ config.repo_name }} | ||
</div> | ||
</a> | ||
{% if config.theme.twitter_url %} | ||
<a href="{{ config.theme.twitter_url }}" title="Go to Twitter" class="md-source"> | ||
<div class="md-source__icon md-icon"> | ||
{% include ".icons/fontawesome/brands/twitter.svg" %} | ||
</div> | ||
<div class="md-source__repository"> | ||
{{ config.theme.twitter_name }} | ||
</div> | ||
</a> | ||
{% endif %} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
/* Fix /page#foo going to the top of the viewport and being hidden by the navbar */ | ||
html { | ||
scroll-padding-top: 50px; | ||
} | ||
|
||
/* Fit the Twitter handle alongside the GitHub one in the top right. */ | ||
|
||
div.md-header__source { | ||
width: revert; | ||
max-width: revert; | ||
} | ||
|
||
a.md-source { | ||
display: inline-block; | ||
} | ||
|
||
.md-source__repository { | ||
max-width: 100%; | ||
} | ||
|
||
/* Emphasise sections of nav on left hand side */ | ||
|
||
nav.md-nav { | ||
padding-left: 5px; | ||
} | ||
|
||
nav.md-nav--secondary { | ||
border-left: revert !important; | ||
} | ||
|
||
.md-nav__title { | ||
font-size: 0.9rem; | ||
} | ||
|
||
.md-nav__item--section > .md-nav__link { | ||
font-size: 0.9rem; | ||
} | ||
|
||
/* Indent autogenerated documentation */ | ||
|
||
div.doc-contents { | ||
padding-left: 25px; | ||
border-left: 4px solid rgba(230, 230, 230); | ||
} | ||
|
||
/* Increase visibility of splitters "---" */ | ||
|
||
[data-md-color-scheme="default"] .md-typeset hr { | ||
border-bottom-color: rgb(0, 0, 0); | ||
border-bottom-width: 1pt; | ||
} | ||
|
||
[data-md-color-scheme="slate"] .md-typeset hr { | ||
border-bottom-color: rgb(230, 230, 230); | ||
} | ||
|
||
/* More space at the bottom of the page */ | ||
|
||
.md-main__inner { | ||
margin-bottom: 1.5rem; | ||
} | ||
|
||
/* Remove prev/next footer buttons */ | ||
|
||
.md-footer__inner { | ||
display: none; | ||
} | ||
|
||
/* Change font sizes */ | ||
|
||
html { | ||
/* Decrease font size for overall webpage | ||
Down from 137.5% which is the Material default */ | ||
font-size: 110%; | ||
} | ||
|
||
.md-typeset .admonition { | ||
/* Increase font size in admonitions */ | ||
font-size: 100% !important; | ||
} | ||
|
||
.md-typeset details { | ||
/* Increase font size in details */ | ||
font-size: 100% !important; | ||
} | ||
|
||
.md-typeset h1 { | ||
font-size: 1.6rem; | ||
} | ||
|
||
.md-typeset h2 { | ||
font-size: 1.5rem; | ||
} | ||
|
||
.md-typeset h3 { | ||
font-size: 1.3rem; | ||
} | ||
|
||
.md-typeset h4 { | ||
font-size: 1.1rem; | ||
} | ||
|
||
.md-typeset h5 { | ||
font-size: 0.9rem; | ||
} | ||
|
||
.md-typeset h6 { | ||
font-size: 0.8rem; | ||
} | ||
|
||
/* Bugfix: remove the superfluous parts generated when doing: | ||
??? Blah | ||
::: library.something | ||
*/ | ||
|
||
.md-typeset details .mkdocstrings > h4 { | ||
display: none; | ||
} | ||
|
||
.md-typeset details .mkdocstrings > h5 { | ||
display: none; | ||
} | ||
|
||
/* Change default colours for <a> tags */ | ||
|
||
[data-md-color-scheme="default"] { | ||
--md-typeset-a-color: rgb(0, 189, 164) !important; | ||
} | ||
[data-md-color-scheme="slate"] { | ||
--md-typeset-a-color: rgb(0, 189, 164) !important; | ||
} | ||
|
||
/* Highlight functions, classes etc. type signatures. Really helps to make clear where | ||
one item ends and another begins. */ | ||
|
||
[data-md-color-scheme="default"] { | ||
--doc-heading-color: #DDD; | ||
--doc-heading-border-color: #CCC; | ||
--doc-heading-color-alt: #F0F0F0; | ||
} | ||
[data-md-color-scheme="slate"] { | ||
--doc-heading-color: rgb(25,25,33); | ||
--doc-heading-border-color: rgb(25,25,33); | ||
--doc-heading-color-alt: rgb(33,33,44); | ||
--md-code-bg-color: rgb(38,38,50); | ||
} | ||
|
||
h4.doc-heading { | ||
/* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/ | ||
background-color: var(--doc-heading-color); | ||
border: solid var(--doc-heading-border-color); | ||
border-width: 1.5pt; | ||
border-radius: 2pt; | ||
padding: 0pt 5pt 2pt 5pt; | ||
} | ||
h5.doc-heading, h6.heading { | ||
background-color: var(--doc-heading-color-alt); | ||
border-radius: 2pt; | ||
padding: 0pt 5pt 2pt 5pt; | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
window.MathJax = { | ||
tex: { | ||
inlineMath: [["\\(", "\\)"]], | ||
displayMath: [["\\[", "\\]"]], | ||
processEscapes: true, | ||
processEnvironments: true | ||
}, | ||
options: { | ||
ignoreHtmlClass: ".*|", | ||
processHtmlClass: "arithmatex" | ||
} | ||
}; | ||
|
||
document$.subscribe(() => { | ||
MathJax.startup.output.clearCache() | ||
MathJax.typesetClear() | ||
MathJax.texReset() | ||
MathJax.typesetPromise() | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Energy functions | ||
|
||
::: jpc.pc_energy_fn | ||
|
||
::: jpc.hpc_energy_fn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Gradients | ||
|
||
::: jpc.compute_pc_param_grads | ||
|
||
::: jpc.compute_gen_param_grads | ||
|
||
::: jpc.compute_amort_param_grads |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Inference | ||
|
||
::: jpc.solve_pc_activities |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Initialisation | ||
|
||
::: jpc.init_activities_with_ffwd | ||
|
||
::: jpc.init_activities_from_gaussian |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Getting started | ||
|
||
JPC is a [JAX](https://github.com/google/jax) library for predictive | ||
coding networks (PCNs). It is built on top of two main libraries: | ||
|
||
* [Equinox](https://github.com/patrick-kidger/equinox) to define neural | ||
networks with PyTorch-like syntax, and | ||
* [Diffrax](https://github.com/patrick-kidger/diffrax) to solve the PC | ||
activity (inference) dynamics. | ||
|
||
JPC provides a simple but flexible API for research of PCNs compatible with | ||
useful JAX transforms such as `vmap` and `jit`. | ||
|
||
|
||
## Installation | ||
|
||
``` | ||
pip install jpc | ||
``` | ||
|
||
Requires Python 3.8+, JAX 0.4.13+, [Equinox](https://github.com/patrick-kidger/equinox) | ||
0.10.4+, [Diffrax](https://github.com/patrick-kidger/diffrax) 0.3.1+, and | ||
[Jaxtyping](https://github.com/patrick-kidger/jaxtyping) 0.2.19+. | ||
|
||
|
||
## Quick example | ||
|
||
Given a neural network with callable layers, for example defined with | ||
[Equinox](https://github.com/patrick-kidger/equinox) | ||
```py | ||
import jax | ||
import jax.numpy as jnp | ||
from equinox import nn as nn | ||
|
||
# some data | ||
x = jnp.array([1., 1., 1.]) | ||
y = -x | ||
|
||
# network | ||
key = jax.random.key(0) | ||
_, *subkeys = jax.random.split(key) | ||
network = [ | ||
nn.Sequential( | ||
[ | ||
nn.Linear(3, 100, key=subkeys[0]), | ||
nn.Lambda(jax.nn.relu) | ||
], | ||
), | ||
nn.Linear(100, 3, key=subkeys[1]), | ||
] | ||
``` | ||
We can train it with predictive coding in a few lines of code | ||
```py | ||
import jpc | ||
|
||
# initialise layer activities with a feedforward pass | ||
activities = jpc.init_activities_with_ffwd(network, x) | ||
|
||
# run the inference dynamics to equilibrium | ||
equilib_activities = jpc.solve_pc_activities(network, activities, y, x) | ||
|
||
# compute the PC parameter gradients | ||
pc_param_grads = jpc.compute_pc_param_grads( | ||
network, | ||
equilib_activities, | ||
y, | ||
x | ||
) | ||
``` | ||
The gradients can then be fed to your favourite optimiser (e.g. gradient | ||
descent) to update the network parameters. | ||
|
||
|
||
## Citation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Latest versions at time of writing. | ||
mkdocs==1.3.0 # Main documentation generator. | ||
mkdocs-material==7.3.6 # Theme | ||
pymdown-extensions==9.4 # Markdown extensions e.g. to handle LaTeX. | ||
mkdocstrings==0.17.0 # Autogenerate documentation from docstrings. | ||
mknotebooks==0.7.1 # Turn Jupyter Lab notebooks into webpages. | ||
pytkdocs_tweaks==0.0.8 # Tweaks mkdocstrings to improve various aspects | ||
mkdocs_include_exclude_files==0.0.1 # Allow for customising which files get included | ||
jinja2==3.0.3 # Older version. After 3.1.0 seems to be incompatible with current versions of mkdocstrings. | ||
nbconvert==6.5.0 # | Older verson to avoid error | ||
nbformat==5.4.0 # | | ||
pygments==2.14.0 | ||
|
||
# Install latest version of our dependencies | ||
jax[cpu] |
Oops, something went wrong.