diff --git a/.github/workflows/build_docs.yml b/.github/workflows/build_docs.yml new file mode 100644 index 0000000..3f5116f --- /dev/null +++ b/.github/workflows/build_docs.yml @@ -0,0 +1,39 @@ +name: Build docs + +on: + push: + branches: + - main + +jobs: + build: + strategy: + matrix: + python-version: [ 3.11 ] + os: [ ubuntu-latest ] + runs-on: ${{ matrix.os }} + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install . + #python -m pip install -r docs/requirements.txt + + - name: Build docs + run: | + mkdocs build + mkdocs build + + - name: Upload docs + uses: actions/upload-artifact@v2 + with: + name: docs + path: site # where `mkdocs build` puts the built site \ No newline at end of file diff --git a/docs/.htaccess b/docs/.htaccess new file mode 100644 index 0000000..88c2d9c --- /dev/null +++ b/docs/.htaccess @@ -0,0 +1 @@ +ErrorDocument 404 /diffrax/404.html \ No newline at end of file diff --git a/docs/FAQs.md b/docs/FAQs.md new file mode 100644 index 0000000..e69de29 diff --git a/docs/_overrides/partials/source.html b/docs/_overrides/partials/source.html new file mode 100644 index 0000000..10bba66 --- /dev/null +++ b/docs/_overrides/partials/source.html @@ -0,0 +1,20 @@ +{% import "partials/language.html" as lang with context %} + + + {% set icon = config.theme.icon.repo or "fontawesome/brands/git-alt" %} + {% include ".icons/" ~ icon ~ ".svg" %} + + + {{ config.repo_name }} + + +{% if config.theme.twitter_url %} + + + {% include ".icons/fontawesome/brands/twitter.svg" %} + + + {{ config.theme.twitter_name }} + + +{% endif %} \ No newline at end of file diff --git a/docs/_static/custom_css.css b/docs/_static/custom_css.css new file mode 100644 index 0000000..e9be9fc --- /dev/null +++ b/docs/_static/custom_css.css @@ -0,0 +1,162 @@ +/* Fix /page#foo going to the top of the viewport and being hidden by the navbar */ +html { + scroll-padding-top: 50px; +} + +/* Fit the Twitter handle alongside the GitHub one in the top right. */ + +div.md-header__source { + width: revert; + max-width: revert; +} + +a.md-source { + display: inline-block; +} + +.md-source__repository { + max-width: 100%; +} + +/* Emphasise sections of nav on left hand side */ + +nav.md-nav { + padding-left: 5px; +} + +nav.md-nav--secondary { + border-left: revert !important; +} + +.md-nav__title { + font-size: 0.9rem; +} + +.md-nav__item--section > .md-nav__link { + font-size: 0.9rem; +} + +/* Indent autogenerated documentation */ + +div.doc-contents { + padding-left: 25px; + border-left: 4px solid rgba(230, 230, 230); +} + +/* Increase visibility of splitters "---" */ + +[data-md-color-scheme="default"] .md-typeset hr { + border-bottom-color: rgb(0, 0, 0); + border-bottom-width: 1pt; +} + +[data-md-color-scheme="slate"] .md-typeset hr { + border-bottom-color: rgb(230, 230, 230); +} + +/* More space at the bottom of the page */ + +.md-main__inner { + margin-bottom: 1.5rem; +} + +/* Remove prev/next footer buttons */ + +.md-footer__inner { + display: none; +} + +/* Change font sizes */ + +html { + /* Decrease font size for overall webpage + Down from 137.5% which is the Material default */ + font-size: 110%; +} + +.md-typeset .admonition { + /* Increase font size in admonitions */ + font-size: 100% !important; +} + +.md-typeset details { + /* Increase font size in details */ + font-size: 100% !important; +} + +.md-typeset h1 { + font-size: 1.6rem; +} + +.md-typeset h2 { + font-size: 1.5rem; +} + +.md-typeset h3 { + font-size: 1.3rem; +} + +.md-typeset h4 { + font-size: 1.1rem; +} + +.md-typeset h5 { + font-size: 0.9rem; +} + +.md-typeset h6 { + font-size: 0.8rem; +} + +/* Bugfix: remove the superfluous parts generated when doing: + +??? Blah + + ::: library.something +*/ + +.md-typeset details .mkdocstrings > h4 { + display: none; +} + +.md-typeset details .mkdocstrings > h5 { + display: none; +} + +/* Change default colours for tags */ + +[data-md-color-scheme="default"] { + --md-typeset-a-color: rgb(0, 189, 164) !important; +} +[data-md-color-scheme="slate"] { + --md-typeset-a-color: rgb(0, 189, 164) !important; +} + +/* Highlight functions, classes etc. type signatures. Really helps to make clear where + one item ends and another begins. */ + +[data-md-color-scheme="default"] { + --doc-heading-color: #DDD; + --doc-heading-border-color: #CCC; + --doc-heading-color-alt: #F0F0F0; +} +[data-md-color-scheme="slate"] { + --doc-heading-color: rgb(25,25,33); + --doc-heading-border-color: rgb(25,25,33); + --doc-heading-color-alt: rgb(33,33,44); + --md-code-bg-color: rgb(38,38,50); +} + +h4.doc-heading { + /* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/ + background-color: var(--doc-heading-color); + border: solid var(--doc-heading-border-color); + border-width: 1.5pt; + border-radius: 2pt; + padding: 0pt 5pt 2pt 5pt; +} +h5.doc-heading, h6.heading { + background-color: var(--doc-heading-color-alt); + border-radius: 2pt; + padding: 0pt 5pt 2pt 5pt; +} \ No newline at end of file diff --git a/docs/_static/favicon.png b/docs/_static/favicon.png new file mode 100644 index 0000000..5e71b54 Binary files /dev/null and b/docs/_static/favicon.png differ diff --git a/docs/_static/mathjax.js b/docs/_static/mathjax.js new file mode 100644 index 0000000..0b00d2f --- /dev/null +++ b/docs/_static/mathjax.js @@ -0,0 +1,19 @@ +window.MathJax = { + tex: { + inlineMath: [["\\(", "\\)"]], + displayMath: [["\\[", "\\]"]], + processEscapes: true, + processEnvironments: true + }, + options: { + ignoreHtmlClass: ".*|", + processHtmlClass: "arithmatex" + } +}; + +document$.subscribe(() => { + MathJax.startup.output.clearCache() + MathJax.typesetClear() + MathJax.texReset() + MathJax.typesetPromise() +}) \ No newline at end of file diff --git a/docs/api/Energy functions.md b/docs/api/Energy functions.md new file mode 100644 index 0000000..24ac42f --- /dev/null +++ b/docs/api/Energy functions.md @@ -0,0 +1,5 @@ +# Energy functions + +::: jpc.pc_energy_fn + +::: jpc.hpc_energy_fn \ No newline at end of file diff --git a/docs/api/Gradients.md b/docs/api/Gradients.md new file mode 100644 index 0000000..8f4495a --- /dev/null +++ b/docs/api/Gradients.md @@ -0,0 +1,7 @@ +# Gradients + +::: jpc.compute_pc_param_grads + +::: jpc.compute_gen_param_grads + +::: jpc.compute_amort_param_grads \ No newline at end of file diff --git a/docs/api/Inference.md b/docs/api/Inference.md new file mode 100644 index 0000000..2f4ed30 --- /dev/null +++ b/docs/api/Inference.md @@ -0,0 +1,3 @@ +# Inference + +::: jpc.solve_pc_activities \ No newline at end of file diff --git a/docs/api/Initialisation.md b/docs/api/Initialisation.md new file mode 100644 index 0000000..3b86173 --- /dev/null +++ b/docs/api/Initialisation.md @@ -0,0 +1,5 @@ +# Initialisation + +::: jpc.init_activities_with_ffwd + +::: jpc.init_activities_from_gaussian \ No newline at end of file diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..92b3065 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,74 @@ +# Getting started + +JPC is a [JAX](https://github.com/google/jax) library for predictive +coding networks (PCNs). It is built on top of two main libraries: + +* [Equinox](https://github.com/patrick-kidger/equinox) to define neural +networks with PyTorch-like syntax, and +* [Diffrax](https://github.com/patrick-kidger/diffrax) to solve the PC +activity (inference) dynamics. + +JPC provides a simple but flexible API for research of PCNs compatible with +useful JAX transforms such as `vmap` and `jit`. + + +## Installation + +``` +pip install jpc +``` + +Requires Python 3.8+, JAX 0.4.13+, [Equinox](https://github.com/patrick-kidger/equinox) +0.10.4+, [Diffrax](https://github.com/patrick-kidger/diffrax) 0.3.1+, and +[Jaxtyping](https://github.com/patrick-kidger/jaxtyping) 0.2.19+. + + +## Quick example + +Given a neural network with callable layers, for example defined with +[Equinox](https://github.com/patrick-kidger/equinox) +```py +import jax +import jax.numpy as jnp +from equinox import nn as nn + +# some data +x = jnp.array([1., 1., 1.]) +y = -x + +# network +key = jax.random.key(0) +_, *subkeys = jax.random.split(key) +network = [ + nn.Sequential( + [ + nn.Linear(3, 100, key=subkeys[0]), + nn.Lambda(jax.nn.relu) + ], + ), + nn.Linear(100, 3, key=subkeys[1]), +] +``` +We can train it with predictive coding in a few lines of code +```py +import jpc + +# initialise layer activities with a feedforward pass +activities = jpc.init_activities_with_ffwd(network, x) + +# run the inference dynamics to equilibrium +equilib_activities = jpc.solve_pc_activities(network, activities, y, x) + +# compute the PC parameter gradients +pc_param_grads = jpc.compute_pc_param_grads( + network, + equilib_activities, + y, + x +) +``` +The gradients can then be fed to your favourite optimiser (e.g. gradient +descent) to update the network parameters. + + +## Citation diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..c6bedb3 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,15 @@ +# Latest versions at time of writing. +mkdocs==1.3.0 # Main documentation generator. +mkdocs-material==7.3.6 # Theme +pymdown-extensions==9.4 # Markdown extensions e.g. to handle LaTeX. +mkdocstrings==0.17.0 # Autogenerate documentation from docstrings. +mknotebooks==0.7.1 # Turn Jupyter Lab notebooks into webpages. +pytkdocs_tweaks==0.0.8 # Tweaks mkdocstrings to improve various aspects +mkdocs_include_exclude_files==0.0.1 # Allow for customising which files get included +jinja2==3.0.3 # Older version. After 3.1.0 seems to be incompatible with current versions of mkdocstrings. +nbconvert==6.5.0 # | Older verson to avoid error +nbformat==5.4.0 # | +pygments==2.14.0 + +# Install latest version of our dependencies +jax[cpu] \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..c29665c --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,105 @@ +theme: + name: material + features: + - navigation.sections # Sections are included in the navigation on the left. + - toc.integrate # Table of contents is integrated on the left; does not appear separately on the right. + - header.autohide # header disappears as you scroll + palette: + # Light mode / dark mode + # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as + # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle. + - scheme: default + primary: white + accent: amber + toggle: + icon: material/weather-night + name: Switch to dark mode + - scheme: slate + primary: black + accent: amber + toggle: + icon: material/weather-sunny + name: Switch to light mode + icon: + repo: fontawesome/brands/github # GitHub logo in top right + logo: "material/brain" # brain logo in top left + favicon: "_static/favicon.png" + custom_dir: "docs/_overrides" # Overriding part of the HTML + + # These additions are my own custom ones, having overridden a partial. + #twitter_name: "@InnocFrancesco" + #twitter_url: "https://x.com/InnocFrancesco" + +site_name: jpc +site_description: The documentation for the jpc software library. +site_author: Francesco Innocenti + +repo_url: https://github.com/thebuckleylab/jpc +repo_name: thebuckleylab/jpc +edit_uri: "" # No edit button, as some of our pages are in /docs and some in /examples via symlink, so it's impossible for them all to be accurate + +strict: true # Don't allow warnings during the build process + +extra_javascript: + # The below three make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/ + - _static/mathjax.js + - https://polyfill.io/v3/polyfill.min.js?features=es6 + - https://unpkg.com/mathjax@3/es5/tex-mml-chtml.js + +extra_css: + - _static/custom_css.css + +markdown_extensions: + - pymdownx.arithmatex: # Render LaTeX via MathJax + generic: true + - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. + - pymdownx.details # Allowing hidden expandable regions denoted by ??? + - pymdownx.snippets: # Include one Markdown file into another + base_path: docs + - admonition + - toc: + permalink: "ยค" # Adds a clickable permalink to each section heading + toc_depth: 4 + - attr_list + - pymdownx.emoji: + emoji_index: !!python/name:material.extensions.emoji.twemoji + emoji_generator: !!python/name:material.extensions.emoji.to_svg + +plugins: + - search # default search plugin; needs manually re-enabling when using any other plugins + - autorefs # Cross-links to headings + - include_exclude_files: + include: + - ".htaccess" + exclude: + - "_overrides" + - "_static/README.md" + - mknotebooks # Jupyter examples + - mkdocstrings: + handlers: + python: + setup_commands: + - import pytkdocs_tweaks + - pytkdocs_tweaks.main() + selection: + inherited_members: true # Allow looking up inherited methods + options: + show_root_heading: true # actually display anything at all... + show_root_full_path: true # display "jpc.asdf" not just "asdf" + show_if_no_docstring: true + show_signature_annotations: true + show_source: false # don't include source code + members_order: source # order methods according to their order of definition in the source code, not alphabetical order + heading_level: 4 + +nav: + - 'index.md' + - API: + - 'api/Initialisation.md' + - 'api/Inference.md' + - 'api/Energy functions.md' + - 'api/Gradients.md' + - 'FAQs.md' + +copyright: | + © 2024 Francesco Innocenti \ No newline at end of file