-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example for theoretical energy of linear nets
- Loading branch information
1 parent
0057257
commit 2d5cd42
Showing
2 changed files
with
399 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,392 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Analytical test\n", | ||
"\n", | ||
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/thebuckleylab/jpc/blob/main/examples/analytical_test.ipynb)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%%capture\n", | ||
"!pip install torch==2.3.1\n", | ||
"!pip install torchvision==0.18.1\n", | ||
"!pip install plotly==5.11.0\n", | ||
"!pip install -U kaleido" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import jpc\n", | ||
"\n", | ||
"import jax\n", | ||
"import equinox as eqx\n", | ||
"import equinox.nn as nn\n", | ||
"import optax\n", | ||
"\n", | ||
"import torch\n", | ||
"from torch.utils.data import DataLoader\n", | ||
"from torchvision import datasets, transforms\n", | ||
"\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import plotly.graph_objs as go\n", | ||
"import plotly.io as pio\n", | ||
"\n", | ||
"pio.renderers.default = 'iframe'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Hyperparameters\n", | ||
"\n", | ||
"We define some global parameters, including network architecture, learning rate, batch size etc." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"SEED = 0\n", | ||
"LEARNING_RATE = 1e-3\n", | ||
"BATCH_SIZE = 64\n", | ||
"TEST_EVERY = 10\n", | ||
"N_TRAIN_ITERS = 100" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"jp-MarkdownHeadingCollapsed": true | ||
}, | ||
"source": [ | ||
"## Dataset\n", | ||
"\n", | ||
"Some utils to fetch MNIST." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title data utils\n", | ||
"\n", | ||
"\n", | ||
"def get_mnist_loaders(batch_size):\n", | ||
" train_data = MNIST(train=True, normalise=True)\n", | ||
" test_data = MNIST(train=False, normalise=True)\n", | ||
" train_loader = DataLoader(\n", | ||
" dataset=train_data,\n", | ||
" batch_size=batch_size,\n", | ||
" shuffle=True,\n", | ||
" drop_last=True\n", | ||
" )\n", | ||
" test_loader = DataLoader(\n", | ||
" dataset=test_data,\n", | ||
" batch_size=batch_size,\n", | ||
" shuffle=True,\n", | ||
" drop_last=True\n", | ||
" )\n", | ||
" return train_loader, test_loader\n", | ||
"\n", | ||
"\n", | ||
"class MNIST(datasets.MNIST):\n", | ||
" def __init__(self, train, normalise=True, save_dir=\"data\"):\n", | ||
" if normalise:\n", | ||
" transform = transforms.Compose(\n", | ||
" [\n", | ||
" transforms.ToTensor(),\n", | ||
" transforms.Normalize(\n", | ||
" mean=(0.1307), std=(0.3081)\n", | ||
" )\n", | ||
" ]\n", | ||
" )\n", | ||
" else:\n", | ||
" transform = transforms.Compose([transforms.ToTensor()])\n", | ||
" super().__init__(save_dir, download=True, train=train, transform=transform)\n", | ||
"\n", | ||
" def __getitem__(self, index):\n", | ||
" img, label = super().__getitem__(index)\n", | ||
" img = torch.flatten(img)\n", | ||
" label = one_hot(label)\n", | ||
" return img, label\n", | ||
"\n", | ||
"\n", | ||
"def one_hot(labels, n_classes=10):\n", | ||
" arr = torch.eye(n_classes)\n", | ||
" return arr[labels]\n", | ||
" " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"jp-MarkdownHeadingCollapsed": true | ||
}, | ||
"source": [ | ||
"## Plotting" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def plot_energies(energies):\n", | ||
" n_train_iters = len(energies[\"theory\"])\n", | ||
" train_iters = [b+1 for b in range(n_train_iters)]\n", | ||
"\n", | ||
" fig = go.Figure()\n", | ||
" for energy_type, energy in energies.items():\n", | ||
" fig.add_traces(\n", | ||
" go.Scatter(\n", | ||
" x=train_iters,\n", | ||
" y=energy,\n", | ||
" name=energy_type,\n", | ||
" mode=\"lines\",\n", | ||
" line=dict(\n", | ||
" width=3, \n", | ||
" dash=\"dash\" if energy_type == \"theory\" else \"solid\"\n", | ||
" )\n", | ||
" )\n", | ||
" )\n", | ||
"\n", | ||
" fig.update_layout(\n", | ||
" height=300,\n", | ||
" width=400,\n", | ||
" xaxis=dict(\n", | ||
" title=\"Training iteration\",\n", | ||
" tickvals=[1, int(train_iters[-1]/2), train_iters[-1]],\n", | ||
" ticktext=[1, int(train_iters[-1]/2), train_iters[-1]],\n", | ||
" ),\n", | ||
" yaxis=dict(\n", | ||
" title=\"Energy\",\n", | ||
" nticks=3\n", | ||
" ),\n", | ||
" font=dict(size=16),\n", | ||
" )\n", | ||
" fig.write_image(\"dln_energy_example.pdf\")\n", | ||
" return fig" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"jp-MarkdownHeadingCollapsed": true | ||
}, | ||
"source": [ | ||
"## Linear network" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"key = jax.random.PRNGKey(0)\n", | ||
"subkeys = jax.random.split(key, 5)\n", | ||
"\n", | ||
"network = [\n", | ||
" eqx.nn.Linear(784, 300, key=subkeys[0], use_bias=False),\n", | ||
" eqx.nn.Linear(300, 300, key=subkeys[1], use_bias=False),\n", | ||
" eqx.nn.Linear(300, 300, key=subkeys[2], use_bias=False),\n", | ||
" eqx.nn.Linear(300, 300, key=subkeys[3], use_bias=False),\n", | ||
" eqx.nn.Linear(300, 300, key=subkeys[4], use_bias=False),\n", | ||
" eqx.nn.Linear(300, 300, key=subkeys[1], use_bias=False),\n", | ||
" eqx.nn.Linear(300, 300, key=subkeys[2], use_bias=False),\n", | ||
" eqx.nn.Linear(300, 300, key=subkeys[3], use_bias=False),\n", | ||
" eqx.nn.Linear(300, 300, key=subkeys[4], use_bias=False),\n", | ||
" eqx.nn.Linear(300, 300, key=subkeys[4], use_bias=False),\n", | ||
" eqx.nn.Linear(300, 10, key=subkeys[5], use_bias=False),\n", | ||
"]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Train and test\n", | ||
"\n", | ||
"A PC network can be trained in a single line of code with `jpc.make_pc_step()`. See the documentation for more. Similarly, we can use `jpc.test_discriminative_pc()` to compute the network accuracy. Note that these functions are already \"jitted\" for performance.\n", | ||
"\n", | ||
"Below we simply wrap each of these functions in our training and test loops, respectively." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def evaluate(model, test_loader):\n", | ||
" test_acc = 0\n", | ||
" for batch_id, (img_batch, label_batch) in enumerate(test_loader):\n", | ||
" img_batch = img_batch.numpy()\n", | ||
" label_batch = label_batch.numpy()\n", | ||
"\n", | ||
" test_acc += jpc.test_discriminative_pc(\n", | ||
" model=model,\n", | ||
" y=label_batch,\n", | ||
" x=img_batch\n", | ||
" )\n", | ||
"\n", | ||
" return test_acc / len(test_loader)\n", | ||
"\n", | ||
"\n", | ||
"def train(\n", | ||
" model, \n", | ||
" lr,\n", | ||
" batch_size,\n", | ||
" test_every,\n", | ||
" n_train_iters\n", | ||
"):\n", | ||
" optim = optax.adam(lr)\n", | ||
" opt_state = optim.init(eqx.filter(model, eqx.is_array))\n", | ||
" train_loader, test_loader = get_mnist_loaders(batch_size)\n", | ||
"\n", | ||
" num_energies, theory_energies = [], []\n", | ||
" for iter, (img_batch, label_batch) in enumerate(train_loader):\n", | ||
" img_batch = img_batch.numpy()\n", | ||
" label_batch = label_batch.numpy()\n", | ||
"\n", | ||
" theory_energies.append(\n", | ||
" jpc.linear_equilib_energy(\n", | ||
" network=model, \n", | ||
" x=img_batch, \n", | ||
" y=label_batch\n", | ||
" )\n", | ||
" )\n", | ||
" result = jpc.make_pc_step(\n", | ||
" model,\n", | ||
" optim,\n", | ||
" opt_state,\n", | ||
" y=label_batch,\n", | ||
" x=img_batch,\n", | ||
" record_energies=True\n", | ||
" )\n", | ||
" model, optim, opt_state = result[\"model\"], result[\"optim\"], result[\"opt_state\"]\n", | ||
" train_loss, t_max = result[\"loss\"], result[\"t_max\"]\n", | ||
" num_energies.append(result[\"energies\"][:, t_max-1].sum())\n", | ||
"\n", | ||
" if ((iter+1) % test_every) == 0:\n", | ||
" avg_test_acc = evaluate(model, test_loader)\n", | ||
" print(\n", | ||
" f\"Train iter {iter+1}, train loss={train_loss:4f}, \"\n", | ||
" f\"avg test accuracy={avg_test_acc:4f}\"\n", | ||
" )\n", | ||
" if (iter+1) >= n_train_iters:\n", | ||
" break\n", | ||
"\n", | ||
" return {\n", | ||
" \"experiment\": num_energies,\n", | ||
" \"theory\": theory_energies,\n", | ||
" }\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Run" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/Users/fi69/PycharmProjects/jpc/venv/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning:\n", | ||
"\n", | ||
"unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Train iter 10, train loss=0.078592, avg test accuracy=0.570713\n", | ||
"Train iter 20, train loss=0.072349, avg test accuracy=0.705729\n", | ||
"Train iter 30, train loss=0.059754, avg test accuracy=0.742087\n", | ||
"Train iter 40, train loss=0.054584, avg test accuracy=0.767328\n", | ||
"Train iter 50, train loss=0.052529, avg test accuracy=0.780248\n", | ||
"Train iter 60, train loss=0.042935, avg test accuracy=0.829427\n", | ||
"Train iter 70, train loss=0.052370, avg test accuracy=0.823417\n", | ||
"Train iter 80, train loss=0.044771, avg test accuracy=0.811498\n", | ||
"Train iter 90, train loss=0.045305, avg test accuracy=0.821114\n", | ||
"Train iter 100, train loss=0.045466, avg test accuracy=0.816306\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"<iframe\n", | ||
" scrolling=\"no\"\n", | ||
" width=\"420px\"\n", | ||
" height=\"320\"\n", | ||
" src=\"iframe_figures/figure_8.html\"\n", | ||
" frameborder=\"0\"\n", | ||
" allowfullscreen\n", | ||
"></iframe>\n" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"energies = train(\n", | ||
" model=network,\n", | ||
" lr=LEARNING_RATE,\n", | ||
" batch_size=BATCH_SIZE,\n", | ||
" test_every=TEST_EVERY,\n", | ||
" n_train_iters=N_TRAIN_ITERS\n", | ||
")\n", | ||
"plot_energies(energies)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |
Oops, something went wrong.