From 2d5cd4228aec40efe41556ce5058ee8592e56c35 Mon Sep 17 00:00:00 2001 From: Francesco Innocenti Date: Thu, 4 Jul 2024 14:04:36 +0100 Subject: [PATCH] Add example for theoretical energy of linear nets --- examples/linear_net_theoretical_energy.ipynb | 392 +++++++++++++++++++ mkdocs.yml | 11 +- 2 files changed, 399 insertions(+), 4 deletions(-) create mode 100644 examples/linear_net_theoretical_energy.ipynb diff --git a/examples/linear_net_theoretical_energy.ipynb b/examples/linear_net_theoretical_energy.ipynb new file mode 100644 index 0000000..f4ad5d8 --- /dev/null +++ b/examples/linear_net_theoretical_energy.ipynb @@ -0,0 +1,392 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Analytical test\n", + "\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/thebuckleylab/jpc/blob/main/examples/analytical_test.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install torch==2.3.1\n", + "!pip install torchvision==0.18.1\n", + "!pip install plotly==5.11.0\n", + "!pip install -U kaleido" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import jpc\n", + "\n", + "import jax\n", + "import equinox as eqx\n", + "import equinox.nn as nn\n", + "import optax\n", + "\n", + "import torch\n", + "from torch.utils.data import DataLoader\n", + "from torchvision import datasets, transforms\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import plotly.graph_objs as go\n", + "import plotly.io as pio\n", + "\n", + "pio.renderers.default = 'iframe'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Hyperparameters\n", + "\n", + "We define some global parameters, including network architecture, learning rate, batch size etc." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "SEED = 0\n", + "LEARNING_RATE = 1e-3\n", + "BATCH_SIZE = 64\n", + "TEST_EVERY = 10\n", + "N_TRAIN_ITERS = 100" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "## Dataset\n", + "\n", + "Some utils to fetch MNIST." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "#@title data utils\n", + "\n", + "\n", + "def get_mnist_loaders(batch_size):\n", + " train_data = MNIST(train=True, normalise=True)\n", + " test_data = MNIST(train=False, normalise=True)\n", + " train_loader = DataLoader(\n", + " dataset=train_data,\n", + " batch_size=batch_size,\n", + " shuffle=True,\n", + " drop_last=True\n", + " )\n", + " test_loader = DataLoader(\n", + " dataset=test_data,\n", + " batch_size=batch_size,\n", + " shuffle=True,\n", + " drop_last=True\n", + " )\n", + " return train_loader, test_loader\n", + "\n", + "\n", + "class MNIST(datasets.MNIST):\n", + " def __init__(self, train, normalise=True, save_dir=\"data\"):\n", + " if normalise:\n", + " transform = transforms.Compose(\n", + " [\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(\n", + " mean=(0.1307), std=(0.3081)\n", + " )\n", + " ]\n", + " )\n", + " else:\n", + " transform = transforms.Compose([transforms.ToTensor()])\n", + " super().__init__(save_dir, download=True, train=train, transform=transform)\n", + "\n", + " def __getitem__(self, index):\n", + " img, label = super().__getitem__(index)\n", + " img = torch.flatten(img)\n", + " label = one_hot(label)\n", + " return img, label\n", + "\n", + "\n", + "def one_hot(labels, n_classes=10):\n", + " arr = torch.eye(n_classes)\n", + " return arr[labels]\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "## Plotting" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_energies(energies):\n", + " n_train_iters = len(energies[\"theory\"])\n", + " train_iters = [b+1 for b in range(n_train_iters)]\n", + "\n", + " fig = go.Figure()\n", + " for energy_type, energy in energies.items():\n", + " fig.add_traces(\n", + " go.Scatter(\n", + " x=train_iters,\n", + " y=energy,\n", + " name=energy_type,\n", + " mode=\"lines\",\n", + " line=dict(\n", + " width=3, \n", + " dash=\"dash\" if energy_type == \"theory\" else \"solid\"\n", + " )\n", + " )\n", + " )\n", + "\n", + " fig.update_layout(\n", + " height=300,\n", + " width=400,\n", + " xaxis=dict(\n", + " title=\"Training iteration\",\n", + " tickvals=[1, int(train_iters[-1]/2), train_iters[-1]],\n", + " ticktext=[1, int(train_iters[-1]/2), train_iters[-1]],\n", + " ),\n", + " yaxis=dict(\n", + " title=\"Energy\",\n", + " nticks=3\n", + " ),\n", + " font=dict(size=16),\n", + " )\n", + " fig.write_image(\"dln_energy_example.pdf\")\n", + " return fig" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "## Linear network" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "key = jax.random.PRNGKey(0)\n", + "subkeys = jax.random.split(key, 5)\n", + "\n", + "network = [\n", + " eqx.nn.Linear(784, 300, key=subkeys[0], use_bias=False),\n", + " eqx.nn.Linear(300, 300, key=subkeys[1], use_bias=False),\n", + " eqx.nn.Linear(300, 300, key=subkeys[2], use_bias=False),\n", + " eqx.nn.Linear(300, 300, key=subkeys[3], use_bias=False),\n", + " eqx.nn.Linear(300, 300, key=subkeys[4], use_bias=False),\n", + " eqx.nn.Linear(300, 300, key=subkeys[1], use_bias=False),\n", + " eqx.nn.Linear(300, 300, key=subkeys[2], use_bias=False),\n", + " eqx.nn.Linear(300, 300, key=subkeys[3], use_bias=False),\n", + " eqx.nn.Linear(300, 300, key=subkeys[4], use_bias=False),\n", + " eqx.nn.Linear(300, 300, key=subkeys[4], use_bias=False),\n", + " eqx.nn.Linear(300, 10, key=subkeys[5], use_bias=False),\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train and test\n", + "\n", + "A PC network can be trained in a single line of code with `jpc.make_pc_step()`. See the documentation for more. Similarly, we can use `jpc.test_discriminative_pc()` to compute the network accuracy. Note that these functions are already \"jitted\" for performance.\n", + "\n", + "Below we simply wrap each of these functions in our training and test loops, respectively." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate(model, test_loader):\n", + " test_acc = 0\n", + " for batch_id, (img_batch, label_batch) in enumerate(test_loader):\n", + " img_batch = img_batch.numpy()\n", + " label_batch = label_batch.numpy()\n", + "\n", + " test_acc += jpc.test_discriminative_pc(\n", + " model=model,\n", + " y=label_batch,\n", + " x=img_batch\n", + " )\n", + "\n", + " return test_acc / len(test_loader)\n", + "\n", + "\n", + "def train(\n", + " model, \n", + " lr,\n", + " batch_size,\n", + " test_every,\n", + " n_train_iters\n", + "):\n", + " optim = optax.adam(lr)\n", + " opt_state = optim.init(eqx.filter(model, eqx.is_array))\n", + " train_loader, test_loader = get_mnist_loaders(batch_size)\n", + "\n", + " num_energies, theory_energies = [], []\n", + " for iter, (img_batch, label_batch) in enumerate(train_loader):\n", + " img_batch = img_batch.numpy()\n", + " label_batch = label_batch.numpy()\n", + "\n", + " theory_energies.append(\n", + " jpc.linear_equilib_energy(\n", + " network=model, \n", + " x=img_batch, \n", + " y=label_batch\n", + " )\n", + " )\n", + " result = jpc.make_pc_step(\n", + " model,\n", + " optim,\n", + " opt_state,\n", + " y=label_batch,\n", + " x=img_batch,\n", + " record_energies=True\n", + " )\n", + " model, optim, opt_state = result[\"model\"], result[\"optim\"], result[\"opt_state\"]\n", + " train_loss, t_max = result[\"loss\"], result[\"t_max\"]\n", + " num_energies.append(result[\"energies\"][:, t_max-1].sum())\n", + "\n", + " if ((iter+1) % test_every) == 0:\n", + " avg_test_acc = evaluate(model, test_loader)\n", + " print(\n", + " f\"Train iter {iter+1}, train loss={train_loss:4f}, \"\n", + " f\"avg test accuracy={avg_test_acc:4f}\"\n", + " )\n", + " if (iter+1) >= n_train_iters:\n", + " break\n", + "\n", + " return {\n", + " \"experiment\": num_energies,\n", + " \"theory\": theory_energies,\n", + " }\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/fi69/PycharmProjects/jpc/venv/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning:\n", + "\n", + "unhashable type: . Attempting to hash a tracer will lead to an error in a future JAX release.\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train iter 10, train loss=0.078592, avg test accuracy=0.570713\n", + "Train iter 20, train loss=0.072349, avg test accuracy=0.705729\n", + "Train iter 30, train loss=0.059754, avg test accuracy=0.742087\n", + "Train iter 40, train loss=0.054584, avg test accuracy=0.767328\n", + "Train iter 50, train loss=0.052529, avg test accuracy=0.780248\n", + "Train iter 60, train loss=0.042935, avg test accuracy=0.829427\n", + "Train iter 70, train loss=0.052370, avg test accuracy=0.823417\n", + "Train iter 80, train loss=0.044771, avg test accuracy=0.811498\n", + "Train iter 90, train loss=0.045305, avg test accuracy=0.821114\n", + "Train iter 100, train loss=0.045466, avg test accuracy=0.816306\n" + ] + }, + { + "data": { + "text/html": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "energies = train(\n", + " model=network,\n", + " lr=LEARNING_RATE,\n", + " batch_size=BATCH_SIZE,\n", + " test_every=TEST_EVERY,\n", + " n_train_iters=N_TRAIN_ITERS\n", + ")\n", + "plot_energies(energies)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.2" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/mkdocs.yml b/mkdocs.yml index 7a207e8..62bcf28 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -101,10 +101,13 @@ nav: - Advanced usage: 'advanced_usage.md' - Extending JPC: 'extending_jpc.md' - 📚 Examples: - - Discriminative PC: 'examples/discriminative_pc.ipynb' - - Supervised generative PC: 'examples/supervised_generative_pc.ipynb' - - Unsupervised generative PC: 'examples/unsupervised_generative_pc.ipynb' - - Hybrid PC: 'examples/hybrid_pc.ipynb' + - Introductory: + - Discriminative PC: 'examples/discriminative_pc.ipynb' + - Supervised generative PC: 'examples/supervised_generative_pc.ipynb' + - Unsupervised generative PC: 'examples/unsupervised_generative_pc.ipynb' + - Advanced: + - Hybrid PC: 'examples/hybrid_pc.ipynb' + - Linear theoretical energy: 'examples/linear_net_theoretical_energy.ipynb' - 🌱 Basic API: - 'api/Training.md' - 'api/Testing.md'