Content:
Bioflax provides an unofficial JAX implementation of biologically plausible deep learning algorithms. In particular, Feedback Alignment, Kolen-Pollack, and Direct Feedback alignment are implemented and a framework for running experiments with them is given. The code implements custom Flax modules, which flawlessly integrate with the Flax framework.
The respective algorithms' network structures are depicted in the following scheme. For a more detailed overview please refer to the docs.
Backpropagation | Feedback Alignment & Kolen-Pollack |
Direct Feedback Alignment |
---|---|---|
Network architectures for different algorithms. Taken from [5].
To run the code on your own machine, run pip install -r requirements.txt
.
For the GPU installation of JAX, which is a little more involved please refer to the JAX installation instructions.
For dataloading PyTorch is needed. It has to be installed separately and more importantly, the CPU version must be installed because of inference issues with JAX. Please refer to PyTorch installation instructions.
In this first release, the code has built-in support to run experiments on the MNIST dataset, a teacher-student dataset, and a sin-regression dataset. None of these require explicit data download. MNIST will be downloaded automatically on execution and the teacher-student dataset as well as the sin-regression dataset are created on the fly.
Directories and files that ship with the GitHub repo
.github/workflows/
run_tests.yml Workflow to run unit tests that ensure the functionality of the layer implementations.
bioflax/
dataloading.py Dataloading functions.
metric_computation.py Functions for metric computation
model.py Layer implementations for FA, KP, and DFA.
test_layers.py Unit tests for layer implementations.
train.py Training loop code.
train_helpers.py Functions for optimization, training, and evaluation steps.
requirements.txt Requirements for running the code.
run_train.py Training loop entry point.
Directories that may be created on-the-fly:
data/MNIST/raw Raw MNIST data as downloaded.
wandb/ Local WandB log file
Separately from the rest of the code, the Flax custom modules - the biological layer implementations respectively - can be used to define custom modules (Dense networks) that run with the respective deep learning algorithm. For example, a two-layer Dense network with sigmoid activation in the hidden layer that perfectly integrates with the Flax framework can be created for each of the algorithms as follows:
import jax
import flax.linen as nn
from model import (
RandomDenseLinearFA,
RandomDenseLinearKP,
RandomDenseLinearDFAOutput,
RandomDenseLinearDFAHidden
)
class NetworkFA(nn.Module):
@nn.compact
def __call__(self, x):
x = RandomDenseLinearFA(15)(x)
x = nn.sigmoid(x)
x = RandomDenseLinearFA(10)(x)
return x
class NetworkKP(nn.Module):
@nn.compact
def __call__(self, x):
x = RandomDenseLinearKP(15)(x)
x = nn.sigmoid(x)
x = RandomDenseLinearKP(10)(x)
return x
# Note the differences for DFA. In particular, activations must be handed to the hidden layers and mustn't be on the
# computational path elsewhere. Secondly, the hidden layers need the final output dimension as an additional input
class NetworkDFA(nn.Module):
@nn.compact
def __call__(self, x):
x = RandomDenseLinearDFAHidden(15, 10, nn.sigmoid)(x)
x = RandomDenseLinearDFAOutput(10)(x)
return x
If you need help on how to use modules for actual learning in Flax please refer to the Flax doxumentation.
To run an experiment execute
python run_train.py
which will result in a run with the default configuration. For information about the arguments and their default settings execute one of the following commands
python run_train.py --help
python run_train.py --h
[1] David E. Rumelhart, Geoffrey E. Hinton, and Ronald J. Williams. Learning representations by back-propagating errors. Nature, 323(6088), 1986.
[2] Stephen Grossberg. Competitive learning: From interactive activation to adaptive resonance. Cognitive science, 11(1):23–63, 1987.
[3] Francis Crick. The recent excitement about neural networks. Nature, 337:129–132, 1989.
[4] Timothy P. Lillicrap, Daniel Cownden, Douglas B. Tweed, and Colin J. Akerman. Random synaptic feedback weights support error backpropagation for deep learning. Nature Communications, 7(1), 2016.
[5] Arild Nøkland. Direct feedback alignment provides learning in deep neural networks. In Advances in Neural Information Processing Systems, 2016.
[6] J.F. Kolen and J.B. Pollack. Backpropagation without weight transport. In Proceedings of 1994 IEEE International Conference on Neural Networks, 1994.