Skip to content

Commit

Permalink
change to batch_first; add criterions
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaozewang committed Aug 25, 2024
1 parent fdbc708 commit 9945383
Show file tree
Hide file tree
Showing 14 changed files with 105 additions and 30 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ Artificial Neural Networks (ANNs) trained with backpropagation, despite being bi

This project implements Recurrent Neural Networks (RNNs) and Multilayer Perceptrons (MLPs) designed for a parametrized and granular control over network modularity, synaptic plasticity, and other constraints to enable biologically feasible modeling of brain regions.

## Documentation
Documentation is available at [here](https://nn4n.org/).
## [GitHub](https://github.com/NN4Neurosim/nn4n)

## [Documentation](https://nn4n.org/)
- [Installation](https://nn4n.org/install/installation/)
- [Quickstart](https://nn4n.org/install/quickstart/)

Expand Down
6 changes: 6 additions & 0 deletions nn4n/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from . import model
from . import layer
from . import mask
from . import utils
from . import criterion
from . import constraint
5 changes: 1 addition & 4 deletions nn4n/criterion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from .rnn_loss import RNNLoss
from .mlp_loss import MLPLoss

if __name__ == '__main__':
print(RNNLoss)
print(MLPLoss)
from .firing_rate import *
20 changes: 20 additions & 0 deletions nn4n/criterion/firing_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
import torch.nn as nn


class L1FiringRateLoss(nn.Module):
def __init__(self):
super().__init__()

def forward(self, states):
mean_fr = torch.mean(states, dim=(0, 1))
return torch.norm(mean_fr, p=1)**2/mean_fr.numel()


class L2FiringRateLoss(nn.Module):
def __init__(self):
super().__init__()

def forward(self, states):
mean_fr = torch.mean(states, dim=(0, 1))
return torch.norm(mean_fr, p=2)**2/mean_fr.numel()
3 changes: 2 additions & 1 deletion nn4n/criterion/mlp_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@ def forward(self, pred, label, **kwargs):
def to(self, device):
""" Move to device """
super().to(device)
self.lambda_list = self.lambda_list.to(device)
self.lambda_list = self.lambda_list.to(device)
return self
47 changes: 40 additions & 7 deletions nn4n/criterion/rnn_loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
TODO: Refactor this implementation. Let this module to be initialized with a list of dicts,
with each dict contains the parameters for each type of loss. This will make the code more
readable and easier to maintain.
Loss function for RNN
"""

import torch
import torch.nn as nn

Expand Down Expand Up @@ -26,6 +33,7 @@ class RNNLoss(nn.Module):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
self.batch_first = model.batch_first
if type(self.model) != CTRNN:
raise TypeError("model must be CTRNN")
self._init_losses(**kwargs)
Expand Down Expand Up @@ -79,23 +87,48 @@ def _loss_out(self, **kwargs):
return torch.norm(self.model.readout_layer.weight, p='fro')**2/self.n_out_dividend

def _loss_fr(self, states, **kwargs):
""" Compute the loss for firing rate """
return torch.pow(torch.mean(states, dim=(0, 1)), 2).mean()
"""
Compute the loss for firing rate
This compute the L2 norm (for now) of the hidden states across all timesteps and batch_size
Then take the square of the mean of the norm
"""
if not self.batch_first: states = states.transpose(0, 1)
mean_fr = torch.mean(states, dim=(0, 1))
# return torch.pow(torch.mean(states, dim=(0, 1)), 2).mean() # this might not be correct
# return torch.norm(states, p='fro')**2/states.numel() # this might not be correct
return torch.norm(mean_fr, p=2)**2/mean_fr.numel()

def _loss_fr_sd(self, states, **kwargs):
""" Compute the loss for firing rate for each neuron in terms of SD """
return torch.pow(torch.mean(states, dim=(0, 1)), 2).std()
"""
Compute the loss for firing rate for each neuron in terms of SD
This will take the average firing rate of each neuron across all timesteps and batch_size
and compute the standard deviation of the firing rate across all neurons
Parameters:
- states: size=(batch_size, n_timesteps, hidden_size), hidden states of the network
"""
if not self.batch_first: states = states.transpose(0, 1)
avg_fr = torch.mean(states, dim=(0, 1))
return avg_fr.std()

def _loss_fr_cv(self, states, **kwargs):
""" Compute the loss for firing rate for each neuron in terms of coefficient of variation """
avg_fr = torch.sqrt(torch.square(states)).mean(dim=0)
"""
Compute the loss for firing rate for each neuron in terms of coefficient of variation
This will take the average firing rate of each neuron across all timesteps and batch_size
and compute the coefficient of variation of the firing rate across all neurons
Parameters:
- states: size=(batch_size, n_timesteps, hidden_size), hidden states of the network
"""
if not self.batch_first: states = states.transpose(0, 1)
avg_fr = torch.mean(torch.sqrt(torch.square(states)), dim=(0, 1))
return avg_fr.std()/avg_fr.mean()

def forward(self, pred, label, **kwargs):
"""
Compute the loss
Inputs:
Parameters:
- pred: size=(-1, batch_size, 2), predicted labels
- label: size=(-1, batch_size, 2), labels
Expand Down
1 change: 1 addition & 0 deletions nn4n/layer/hidden_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def to(self, device):
self.ei_mask = self.ei_mask.to(device)
if self.bias.requires_grad:
self.bias = self.bias.to(device)
return self

def forward(self, x):
"""
Expand Down
5 changes: 3 additions & 2 deletions nn4n/layer/linear_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def to(self, device):
self.ei_mask = self.ei_mask.to(device)
if self.bias.requires_grad:
self.bias = self.bias.to(device)
return self

def forward(self, x):
"""
Expand Down Expand Up @@ -204,9 +205,9 @@ def plot_layers(self):
""" Plot the weights matrix and distribution of each layer """
weight = self.weight.cpu() if self.weight.device != torch.device('cpu') else self.weight
if weight.size(0) < weight.size(1):
utils.plot_connectivity_matrix_dist(weight.detach().numpy(), "Weight Matrix (Transposed)", False, self.sparsity_mask is not None)
utils.plot_connectivity_matrix_dist(weight.detach().numpy(), "Weight Matrix", False, self.sparsity_mask is not None)
else:
utils.plot_connectivity_matrix_dist(weight.detach().numpy().T, "Weight Matrix", False, self.sparsity_mask is not None)
utils.plot_connectivity_matrix_dist(weight.detach().numpy().T, "Weight Matrix (Transposed)", False, self.sparsity_mask is not None)

def print_layers(self):
""" Print the specs of each layer """
Expand Down
13 changes: 7 additions & 6 deletions nn4n/layer/recurrent_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,31 +58,32 @@ def to(self, device):
self.input_layer.to(device)
self.hidden_layer.to(device)
self.hidden_state = self.hidden_state.to(device)
return self

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forwardly update network
Inputs:
- x: input, shape: (n_timesteps, batch_size, input_dim)
- x: input, shape: (batch_size, n_timesteps, input_dim)
Returns:
- states: shape: (n_timesteps, batch_size, hidden_size)
- stacked_states: hidden states of the network, shape: (batch_size, n_timesteps, hidden_size)
"""
v_t = self._reset_state().to(x.device)
fr_t = self.activation(v_t)
# update hidden state and append to stacked_states
stacked_states = []
for i in range(x.size(0)):
fr_t, v_t = self._recurrence(fr_t, v_t, x[i])
for i in range(x.size(1)):
fr_t, v_t = self._recurrence(fr_t, v_t, x[:,i])
# append to stacked_states
stacked_states.append(fr_t)

# if keeping the last state, save it to hidden_state
if self.init_state == 'keep':
self.hidden_state = fr_t.detach().clone() # TODO: haven't tested this yet

return torch.stack(stacked_states, dim=0)
return torch.stack(stacked_states, dim=1)

def _reset_state(self):
if self.init_state == 'learn' or self.init_state == 'keep':
Expand Down
2 changes: 1 addition & 1 deletion nn4n/mask/multi_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,4 @@ def _generate_readout_mask(self):
output_idx = self.output_table[i].reshape(-1, 1)
readout_mask[:,dim_counter:dim_counter+d] = np.tile(output_idx, d)
dim_counter += d
self.readout_mask = readout_mask.T # TODO: remove this and flip other masks
self.readout_mask = readout_mask.T # TODO: remove this and flip other masks
12 changes: 11 additions & 1 deletion nn4n/model/ctrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class CTRNN(BaseNN):
- activation: activation function, default: "relu", can be "relu", "sigmoid", "tanh", "retanh"
- dt: time step, default: 10
- tau: time constant, default: 100
- batch_first: whether the input is batch first or not, default: True
- biases: use bias or not for each layer, a list of 3 values or a single value
if a single value is passed, it will be broadcasted to a list of 3 values, it can be:
- None: no bias
Expand Down Expand Up @@ -73,6 +74,7 @@ def _initialize(self, **kwargs):
self.dims = kwargs.pop("dims", [1, 100, 1])
self.biases = kwargs.pop("biases", None)
self.weights = kwargs.pop("weights", 'uniform')
self.batch_first = kwargs.pop("batch_first", True)

# network dynamics parameters
self.sparsity_masks = kwargs.pop("sparsity_masks", None)
Expand Down Expand Up @@ -306,19 +308,27 @@ def to(self, device):
super().to(device)
self.recurrent_layer.to(device)
self.readout_layer.to(device)
return self

def forward(self, x):
"""
Forwardly update network
Inputs:
- x: input, shape: (n_timesteps, batch_size, input_dim)
- x: input, shape: (batch_size, n_timesteps, input_dim)
"""
if not self.batch_first:
x = x.transpose(0, 1)

# skip constraints if the model is not in training mode
if self.training:
self._enforce_constraints()
hidden_states = self.recurrent_layer(x)
output = self.readout_layer(hidden_states.float())

if not self.batch_first:
output = output.transpose(0, 1)
hidden_states = hidden_states.transpose(0, 1)
return output, [hidden_states]

def train(self):
Expand Down
6 changes: 4 additions & 2 deletions nn4n/model/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,13 @@ def _broadcast_values(value, length, is_mask=False):
raise ValueError(f"Expected a list of length {length}, got a list of length {len(value)}")
return value

def forward(self, x):
def forward(self, x, batch_first=True):
"""
Inputs:
- x: size=(batch_size, input_dim)
- x: size=(batch_size, n_timesteps, input_dim)
"""
if not batch_first:
x = x.transpose(0, 1)
hidden_states = []
for i, layer in enumerate(self.layers):
if i == len(self.layers)-1:
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

setup(
name='nn4n',
version='1.1.0',
description='Neural Networks for Neuroscience Research',
version='1.1.2',
description='Neural Networks for Neurosimulation',
long_description=long_description,
long_description_content_type='text/markdown',
author='Zhaoze Wang',
Expand All @@ -21,5 +21,5 @@
'IPython',
'scipy',
],
python_requires='>=3.10',
python_requires='>=3.9',
)
4 changes: 3 additions & 1 deletion todo.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# TODOs:
- [ ] The examples need to be updated. Especially on the main branch.
- [ ] Resolve the transpose issue in the model module and the mask module.
- [ ] Resolve the transpose issue in the model module and the mask module.
- [ ] Make the model use `batch_first` by default.
- [ ] Refactor the RNNLoss part, let it take a dictionary instead of many separate `lambda_*` parameters.

0 comments on commit 9945383

Please sign in to comment.