Skip to content

Commit

Permalink
Add example for theoretical energy of linear nets
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Jul 4, 2024
1 parent 0057257 commit 2d5cd42
Show file tree
Hide file tree
Showing 2 changed files with 399 additions and 4 deletions.
392 changes: 392 additions & 0 deletions examples/linear_net_theoretical_energy.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,392 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Analytical test\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/thebuckleylab/jpc/blob/main/examples/analytical_test.ipynb)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"!pip install torch==2.3.1\n",
"!pip install torchvision==0.18.1\n",
"!pip install plotly==5.11.0\n",
"!pip install -U kaleido"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import jpc\n",
"\n",
"import jax\n",
"import equinox as eqx\n",
"import equinox.nn as nn\n",
"import optax\n",
"\n",
"import torch\n",
"from torch.utils.data import DataLoader\n",
"from torchvision import datasets, transforms\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import plotly.graph_objs as go\n",
"import plotly.io as pio\n",
"\n",
"pio.renderers.default = 'iframe'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Hyperparameters\n",
"\n",
"We define some global parameters, including network architecture, learning rate, batch size etc."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"SEED = 0\n",
"LEARNING_RATE = 1e-3\n",
"BATCH_SIZE = 64\n",
"TEST_EVERY = 10\n",
"N_TRAIN_ITERS = 100"
]
},
{
"cell_type": "markdown",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"## Dataset\n",
"\n",
"Some utils to fetch MNIST."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"#@title data utils\n",
"\n",
"\n",
"def get_mnist_loaders(batch_size):\n",
" train_data = MNIST(train=True, normalise=True)\n",
" test_data = MNIST(train=False, normalise=True)\n",
" train_loader = DataLoader(\n",
" dataset=train_data,\n",
" batch_size=batch_size,\n",
" shuffle=True,\n",
" drop_last=True\n",
" )\n",
" test_loader = DataLoader(\n",
" dataset=test_data,\n",
" batch_size=batch_size,\n",
" shuffle=True,\n",
" drop_last=True\n",
" )\n",
" return train_loader, test_loader\n",
"\n",
"\n",
"class MNIST(datasets.MNIST):\n",
" def __init__(self, train, normalise=True, save_dir=\"data\"):\n",
" if normalise:\n",
" transform = transforms.Compose(\n",
" [\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(\n",
" mean=(0.1307), std=(0.3081)\n",
" )\n",
" ]\n",
" )\n",
" else:\n",
" transform = transforms.Compose([transforms.ToTensor()])\n",
" super().__init__(save_dir, download=True, train=train, transform=transform)\n",
"\n",
" def __getitem__(self, index):\n",
" img, label = super().__getitem__(index)\n",
" img = torch.flatten(img)\n",
" label = one_hot(label)\n",
" return img, label\n",
"\n",
"\n",
"def one_hot(labels, n_classes=10):\n",
" arr = torch.eye(n_classes)\n",
" return arr[labels]\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"## Plotting"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def plot_energies(energies):\n",
" n_train_iters = len(energies[\"theory\"])\n",
" train_iters = [b+1 for b in range(n_train_iters)]\n",
"\n",
" fig = go.Figure()\n",
" for energy_type, energy in energies.items():\n",
" fig.add_traces(\n",
" go.Scatter(\n",
" x=train_iters,\n",
" y=energy,\n",
" name=energy_type,\n",
" mode=\"lines\",\n",
" line=dict(\n",
" width=3, \n",
" dash=\"dash\" if energy_type == \"theory\" else \"solid\"\n",
" )\n",
" )\n",
" )\n",
"\n",
" fig.update_layout(\n",
" height=300,\n",
" width=400,\n",
" xaxis=dict(\n",
" title=\"Training iteration\",\n",
" tickvals=[1, int(train_iters[-1]/2), train_iters[-1]],\n",
" ticktext=[1, int(train_iters[-1]/2), train_iters[-1]],\n",
" ),\n",
" yaxis=dict(\n",
" title=\"Energy\",\n",
" nticks=3\n",
" ),\n",
" font=dict(size=16),\n",
" )\n",
" fig.write_image(\"dln_energy_example.pdf\")\n",
" return fig"
]
},
{
"cell_type": "markdown",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"## Linear network"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"key = jax.random.PRNGKey(0)\n",
"subkeys = jax.random.split(key, 5)\n",
"\n",
"network = [\n",
" eqx.nn.Linear(784, 300, key=subkeys[0], use_bias=False),\n",
" eqx.nn.Linear(300, 300, key=subkeys[1], use_bias=False),\n",
" eqx.nn.Linear(300, 300, key=subkeys[2], use_bias=False),\n",
" eqx.nn.Linear(300, 300, key=subkeys[3], use_bias=False),\n",
" eqx.nn.Linear(300, 300, key=subkeys[4], use_bias=False),\n",
" eqx.nn.Linear(300, 300, key=subkeys[1], use_bias=False),\n",
" eqx.nn.Linear(300, 300, key=subkeys[2], use_bias=False),\n",
" eqx.nn.Linear(300, 300, key=subkeys[3], use_bias=False),\n",
" eqx.nn.Linear(300, 300, key=subkeys[4], use_bias=False),\n",
" eqx.nn.Linear(300, 300, key=subkeys[4], use_bias=False),\n",
" eqx.nn.Linear(300, 10, key=subkeys[5], use_bias=False),\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train and test\n",
"\n",
"A PC network can be trained in a single line of code with `jpc.make_pc_step()`. See the documentation for more. Similarly, we can use `jpc.test_discriminative_pc()` to compute the network accuracy. Note that these functions are already \"jitted\" for performance.\n",
"\n",
"Below we simply wrap each of these functions in our training and test loops, respectively."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def evaluate(model, test_loader):\n",
" test_acc = 0\n",
" for batch_id, (img_batch, label_batch) in enumerate(test_loader):\n",
" img_batch = img_batch.numpy()\n",
" label_batch = label_batch.numpy()\n",
"\n",
" test_acc += jpc.test_discriminative_pc(\n",
" model=model,\n",
" y=label_batch,\n",
" x=img_batch\n",
" )\n",
"\n",
" return test_acc / len(test_loader)\n",
"\n",
"\n",
"def train(\n",
" model, \n",
" lr,\n",
" batch_size,\n",
" test_every,\n",
" n_train_iters\n",
"):\n",
" optim = optax.adam(lr)\n",
" opt_state = optim.init(eqx.filter(model, eqx.is_array))\n",
" train_loader, test_loader = get_mnist_loaders(batch_size)\n",
"\n",
" num_energies, theory_energies = [], []\n",
" for iter, (img_batch, label_batch) in enumerate(train_loader):\n",
" img_batch = img_batch.numpy()\n",
" label_batch = label_batch.numpy()\n",
"\n",
" theory_energies.append(\n",
" jpc.linear_equilib_energy(\n",
" network=model, \n",
" x=img_batch, \n",
" y=label_batch\n",
" )\n",
" )\n",
" result = jpc.make_pc_step(\n",
" model,\n",
" optim,\n",
" opt_state,\n",
" y=label_batch,\n",
" x=img_batch,\n",
" record_energies=True\n",
" )\n",
" model, optim, opt_state = result[\"model\"], result[\"optim\"], result[\"opt_state\"]\n",
" train_loss, t_max = result[\"loss\"], result[\"t_max\"]\n",
" num_energies.append(result[\"energies\"][:, t_max-1].sum())\n",
"\n",
" if ((iter+1) % test_every) == 0:\n",
" avg_test_acc = evaluate(model, test_loader)\n",
" print(\n",
" f\"Train iter {iter+1}, train loss={train_loss:4f}, \"\n",
" f\"avg test accuracy={avg_test_acc:4f}\"\n",
" )\n",
" if (iter+1) >= n_train_iters:\n",
" break\n",
"\n",
" return {\n",
" \"experiment\": num_energies,\n",
" \"theory\": theory_energies,\n",
" }\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/fi69/PycharmProjects/jpc/venv/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning:\n",
"\n",
"unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train iter 10, train loss=0.078592, avg test accuracy=0.570713\n",
"Train iter 20, train loss=0.072349, avg test accuracy=0.705729\n",
"Train iter 30, train loss=0.059754, avg test accuracy=0.742087\n",
"Train iter 40, train loss=0.054584, avg test accuracy=0.767328\n",
"Train iter 50, train loss=0.052529, avg test accuracy=0.780248\n",
"Train iter 60, train loss=0.042935, avg test accuracy=0.829427\n",
"Train iter 70, train loss=0.052370, avg test accuracy=0.823417\n",
"Train iter 80, train loss=0.044771, avg test accuracy=0.811498\n",
"Train iter 90, train loss=0.045305, avg test accuracy=0.821114\n",
"Train iter 100, train loss=0.045466, avg test accuracy=0.816306\n"
]
},
{
"data": {
"text/html": [
"<iframe\n",
" scrolling=\"no\"\n",
" width=\"420px\"\n",
" height=\"320\"\n",
" src=\"iframe_figures/figure_8.html\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
"></iframe>\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"energies = train(\n",
" model=network,\n",
" lr=LEARNING_RATE,\n",
" batch_size=BATCH_SIZE,\n",
" test_every=TEST_EVERY,\n",
" n_train_iters=N_TRAIN_ITERS\n",
")\n",
"plot_energies(energies)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.2"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading

0 comments on commit 2d5cd42

Please sign in to comment.