Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added LMU and FFT classes #21

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions docs/basic/psMNIST-FFT.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"import time\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import tensorflow as tf\n",
"\n",
"from lmu import LMU\n",
"\n",
"from tensorflow.keras.callbacks import ModelCheckpoint\n",
"from tensorflow.keras.layers import Dense\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.initializers import Constant\n",
"from tensorflow.keras.utils import to_categorical"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set seed to ensure this example is reproducible\n",
"seed = 0\n",
"tf.random.set_seed(seed)\n",
"np.random.seed(seed)\n",
"rng = np.random.RandomState(seed)\n",
"\n",
"# load mnist dataset\n",
"(\n",
" (train_images, train_labels),\n",
" (test_images, test_labels),\n",
") = tf.keras.datasets.mnist.load_data()\n",
"\n",
"# Change inputs to 0--1 range\n",
"train_images = train_images / 255\n",
"test_images = test_images / 255\n",
"\n",
"# Flatten images into sequences\n",
"train_images = train_images.reshape((train_images.shape[0], -1, 1))\n",
"test_images = test_images.reshape((test_images.shape[0], -1, 1))\n",
"\n",
"# Apply permutation\n",
"perm = rng.permutation(train_images.shape[1])\n",
"train_images = train_images[:, perm]\n",
"test_images = test_images[:, perm]\n",
"\n",
"X_train = train_images[0:50000]\n",
"X_valid = train_images[50000:]\n",
"X_test = test_images\n",
"\n",
"Y_train = train_labels[0:50000]\n",
"Y_valid = train_labels[50000:]\n",
"Y_test = test_labels\n",
"\n",
"print(X_train.shape, Y_train.shape)\n",
"print(X_valid.shape, Y_valid.shape)\n",
"print(X_test.shape, Y_test.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.subplot(111)\n",
"plt.title(\"Digit = %d\" % Y_train[1])\n",
"plt.imshow(X_train[1].reshape(28, 28))\n",
"plt.colorbar()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"padded_length = 28 ** 2 + 1\n",
"n_pixels = padded_length - 1\n",
"\n",
"\n",
"def lmu_layer(**kwargs):\n",
" return LMU(\n",
" units=212,\n",
" order=256,\n",
" theta=n_pixels,\n",
" memory_to_memory=False,\n",
" hidden_to_memory=False,\n",
" hidden_to_hidden=False,\n",
" input_encoders_initializer=Constant(1),\n",
" hidden_encoders_initializer=Constant(0),\n",
" memory_encoders_initializer=Constant(0),\n",
" input_kernel_initializer=Constant(0),\n",
" hidden_kernel_initializer=Constant(0),\n",
" memory_kernel_initializer=\"glorot_normal\",\n",
" return_sequences=False,\n",
vsheska marked this conversation as resolved.
Show resolved Hide resolved
" **kwargs\n",
" )\n",
"\n",
"\n",
"model = Sequential()\n",
"model.add(lmu_layer(input_shape=X_train.shape[1:],)) # (nr. of pixels, 1)\n",
"model.add(Dense(10, activation=\"softmax\"))\n",
"\n",
"model.compile(loss=\"categorical_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"batch_size = 100\n",
"epochs = 2\n",
"t = time.time()\n",
"\n",
"fname = \"./psMNIST-fft.hdf5\"\n",
"callbacks = [\n",
" ModelCheckpoint(filepath=fname, monitor=\"val_loss\", verbose=1, save_best_only=True),\n",
"]\n",
"\n",
"result = model.fit(\n",
" X_train,\n",
" to_categorical(Y_train),\n",
" epochs=epochs,\n",
" batch_size=batch_size,\n",
" validation_data=(X_valid, to_categorical(Y_valid)),\n",
" callbacks=callbacks,\n",
")\n",
"\n",
"print(\"Took {:.2f} min\".format((time.time() - t) / 60))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure()\n",
"plt.plot(result.history[\"val_accuracy\"], label=\"Validation\")\n",
"plt.plot(result.history[\"accuracy\"], label=\"Training\")\n",
"plt.legend()\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Accuracy\")\n",
"plt.title(\"psMNIST - LMU\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"saved_epoch = np.argmin(result.history[\"val_loss\"])\n",
"print(result.history[\"val_accuracy\"][saved_epoch])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.load_weights(fname) # load best weights"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.evaluate(X_test, to_categorical(Y_test))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
28 changes: 12 additions & 16 deletions docs/basic/psMNIST.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
"import matplotlib.pyplot as plt\n",
"import tensorflow as tf\n",
"\n",
"from lmu import LMUCell\n",
"from lmu import LMU\n",
"\n",
"from tensorflow.keras.callbacks import ModelCheckpoint\n",
"from tensorflow.keras.layers import Dense\n",
"from tensorflow.keras.layers import RNN\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.initializers import Constant\n",
"from tensorflow.keras.utils import to_categorical"
Expand Down Expand Up @@ -92,19 +91,16 @@
"\n",
"\n",
"def lmu_layer(**kwargs):\n",
" return RNN(\n",
" LMUCell(\n",
" units=212,\n",
" order=256,\n",
" theta=n_pixels,\n",
" input_encoders_initializer=Constant(1),\n",
" hidden_encoders_initializer=Constant(0),\n",
" memory_encoders_initializer=Constant(0),\n",
" input_kernel_initializer=Constant(0),\n",
" hidden_kernel_initializer=Constant(0),\n",
" memory_kernel_initializer=\"glorot_normal\",\n",
" ),\n",
" return_sequences=False,\n",
" return LMU(\n",
" units=212,\n",
" order=256,\n",
" theta=n_pixels,\n",
" input_encoders_initializer=Constant(1),\n",
" hidden_encoders_initializer=Constant(0),\n",
" memory_encoders_initializer=Constant(0),\n",
" input_kernel_initializer=Constant(0),\n",
" hidden_kernel_initializer=Constant(0),\n",
" memory_kernel_initializer=\"glorot_normal\",\n",
" **kwargs\n",
" )\n",
"\n",
Expand All @@ -121,7 +117,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
"scrolled": true
},
"outputs": [],
"source": [
Expand Down
10 changes: 9 additions & 1 deletion lmu/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
"""LMU provides a package for deep learning with Legendre Memory Units."""

from .lmu import Legendre, InputScaled, LMUCell, LMUCellODE, LMUCellGating
from .lmu import (
Legendre,
InputScaled,
LMUCell,
LMUCellODE,
LMUCellGating,
LMU,
FFTLayer,
)

from .version import version as __version__

Expand Down
Loading