Skip to content

Commit

Permalink
add connectivity loss
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaozewang committed Oct 9, 2024
1 parent 495ada9 commit ef88f9b
Show file tree
Hide file tree
Showing 10 changed files with 222 additions and 3 deletions.
3 changes: 2 additions & 1 deletion nn4n/criterion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .rnn_loss import RNNLoss
from .mlp_loss import MLPLoss
from .firing_rate_loss import *
from .composite_loss import CompositeLoss
from .composite_loss import CompositeLoss
from .connectivity_loss import *
3 changes: 3 additions & 0 deletions nn4n/criterion/composite_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn as nn
from .firing_rate_loss import *
from .connectivity_loss import *

class CompositeLoss(nn.Module):
def __init__(self, loss_cfg):
Expand All @@ -22,7 +23,9 @@ def __init__(self, loss_cfg):
loss_types = {
'fr': FiringRateLoss,
'fr_dist': FiringRateDistLoss,
'rnn_conn': RNNConnectivityLoss,
'state_pred': StatePredictionLoss,
'entropy': EntropyLoss,
'mse': nn.MSELoss,
}
torch_losses = ['mse']
Expand Down
35 changes: 35 additions & 0 deletions nn4n/criterion/connectivity_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class RNNConnectivityLoss(nn.Module):
def __init__(self, layer, metric='fro', **kwargs):
super().__init__()
assert metric in ['l1', 'fro'], "metric must be either l1 or l2"
self.metric = metric
self.layer = layer

def forward(self, model, **kwargs):
if self.layer == 'all':
weights = [
model.recurrent_layer.input_layer.weight,
model.recurrent_layer.hidden_layer.weight,
model.readout_layer.weight
]

loss = torch.sum(torch.stack([self._compute_norm(weight) for weight in weights]))
return loss
elif self.layer == 'input':
return self._compute_norm(model.recurrent_layer.input_layer.weight)
elif self.layer == 'hidden':
return self._compute_norm(model.recurrent_layer.hidden_layer.weight)
elif self.layer == 'readout':
return self._compute_norm(model.readout_layer.weight)
else:
raise ValueError(f"Invalid layer '{self.layer}'. Available layers are: 'all', 'input', 'hidden', 'readout'")

def _compute_norm(self, weight):
if self.metric == 'l1':
return torch.norm(weight, p=1)
else:
return torch.norm(weight, p='fro')
72 changes: 72 additions & 0 deletions nn4n/criterion/firing_rate_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,75 @@ def forward(self, states, **kwargs):

# Use MSE loss instead of manual difference calculation
return F.mse_loss(states[:-self.tau], states[self.tau:], reduction='mean')


class EntropyLoss(nn.Module):
def __init__(self, reg=1e1, **kwargs):
super().__init__(**kwargs)
self.reg = reg

def forward(self, states):
# states shape: (batch_size, time_steps, num_neurons)
batch_size, time_steps, num_neurons = states.shape

# Normalize the states to create a probability distribution
# Add a small epsilon to avoid log(0)
eps = 1e-8
prob_states = states / (states.sum(dim=-1, keepdim=True) + eps)

# Compute the entropy of the neuron activations
entropy_loss = -torch.sum(prob_states * torch.log(prob_states + eps), dim=-1)

# Take the mean entropy over batches and time steps
mean_entropy = torch.mean(entropy_loss)

# Add regularization term (optional, same as before)
reg_loss = torch.mean(torch.norm(states, dim=-1) ** 2)
total_loss = mean_entropy + self.reg * reg_loss

return total_loss


class PopulationKL(nn.Module):
def __init__(self, symmetric=True, reg=1e-3, reduction='mean'):
super().__init__()
self.symmetric = symmetric
self.reg = reg
self.reduction = reduction

def forward(self, states_0, states_1):
# Compute the mean and variance across batches and time steps
mean_0 = torch.mean(states_0, dim=(0, 1), keepdim=True) # Shape: (1, 1, n_neurons)
mean_1 = torch.mean(states_1, dim=(0, 1), keepdim=True) # Shape: (1, 1, n_neurons)
var_0 = torch.var(states_0, dim=(0, 1), unbiased=False, keepdim=True) # Shape: (1, 1, n_neurons)
var_1 = torch.var(states_1, dim=(0, 1), unbiased=False, keepdim=True) # Shape: (1, 1, n_neurons)

# Compute the KL divergence between the two populations (per neuron)
kl_div = 0.5 * (torch.log(var_1 / var_0) + (var_0 + (mean_0 - mean_1) ** 2) / var_1 - 1) # Shape: (1, 1, n_neurons)

# Symmetric KL divergence: average the KL(P || Q) and KL(Q || P)
if self.symmetric:
reverse_kl_div = 0.5 * (torch.log(var_0 / var_1) + (var_1 + (mean_1 - mean_0) ** 2) / var_0 - 1) # Shape: (1, 1, n_neurons)
kl_div = 0.5 * (kl_div + reverse_kl_div) # Shape: (1, 1, n_neurons)

# Apply reduction based on the reduction method
if self.reduction == 'mean':
kl_loss = torch.mean(kl_div) # Scalar value
elif self.reduction == 'sum':
kl_loss = torch.sum(kl_div) # Scalar value
elif self.reduction == 'none':
kl_loss = kl_div # Shape: (1, 1, n_neurons)
else:
raise ValueError(f"Invalid reduction mode: {self.reduction}")

# Regularization: L2 norm of the states across the neurons
reg_loss = torch.mean(torch.norm(states_0, dim=-1) ** 2) + torch.mean(torch.norm(states_1, dim=-1) ** 2)

# Combine the KL divergence with the regularization term
if self.reduction == 'none':
# If no reduction, add regularization element-wise
total_loss = kl_loss + self.reg * (torch.norm(states_0, dim=-1) ** 2 + torch.norm(states_1, dim=-1) ** 2)
else:
total_loss = kl_loss + self.reg * reg_loss

return total_loss
5 changes: 5 additions & 0 deletions nn4n/layer/linear_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ def _balance_excitatory_inhibitory(self):
# apply scaling
self.weight *= scale_mat

def freeze_layer(self):
""" Freeze the layer """
self.weight.requires_grad = False
self.bias.requires_grad = False

def _generate_weight(self, weight_init):
""" Generate random weight """
if weight_init == 'uniform':
Expand Down
8 changes: 8 additions & 0 deletions nn4n/mask/multi_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ def get_readout_idx(self):
""" Return the indices of neurons that send output """
return np.where(self.readout_mask.sum(axis=0) != 0)[0]

def get_area_indices(self):
""" Get all area indices """
# Get all areas from node assignment
area_indices = []
for i in self.get_areas():
area_indices.append(self.get_area_idx(i))
return area_indices

def get_area_idx(self, area):
""" Return the indices of neurons in area """
if isinstance(area, str):
Expand Down
2 changes: 2 additions & 0 deletions nn4n/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .help_functions import *
from .area_manager import *
91 changes: 91 additions & 0 deletions nn4n/utils/area_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import numpy as np

def check_init(func):
def wrapper(self, *args, **kwargs):
if self._area_indices is None:
raise ValueError("Area indices are not initialized")
return func(self, *args, **kwargs)
return wrapper

class AreaManager:
def __init__(self, area_indices=None, batch_first=True):
"""
Initialize the AreaManager
Inputs:
- area_indices: a list of indices (array) denoting a neuron's area assignment
- batch_first: whether the states are batch-first (default: True)
"""
if area_indices is not None:
self.set_area_indices(area_indices)
else:
self._area_indices = None # Ensure it's None initially

self._batch_first = batch_first

def set_area_indices(self, area_indices):
"""
Set the area indices
Inputs:
- area_indices: a list of indices (array) denoting a neuron's area assignment
"""
self._n_areas = len(area_indices)
self._area_indices = area_indices

@property
def n_areas(self):
return self._n_areas

@property
def ai(self):
return self._area_indices

@check_init
def split_states(self, states):
"""
Parse the states of a complete RNN into a list of states of different areas
Inputs:
- states: network states of shape (batch_size, seq_len, hidden_size)
Returns:
- list of states of different areas
"""
if not self._batch_first:
states = states.permute(1, 0, 2)

area_states = [states[:, :, idx] for idx in self._area_indices]
return area_states

@check_init
def get_area_states(self, states, area_idx):
"""
Get the states of a specific area
Inputs:
- states: network states of shape (batch_size, seq_len, hidden_size)
- area_idx: index of the area
Returns:
- states of the specific area
"""
if self._batch_first:
states = states.permute(1, 0, 2)

return states[:, :, self._area_indices[area_idx]]

@check_init
def random_indices(self, area_idx, n, replace=False):
"""
Randomly pick n indices from an area
Inputs:
- n: number of indices to pick
- area_idx: index of the area
- replace: whether to sample with replacement (default: False)
Returns:
- indices of the neurons
"""
n_neurons = len(self._area_indices[area_idx])
if not replace and n > n_neurons:
return self._area_indices[area_idx]
return np.random.choice(self._area_indices[area_idx], n, replace=replace)
File renamed without changes.
6 changes: 4 additions & 2 deletions todo.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
- [ ] The examples need to be updated. Especially on the main branch.
- [ ] Resolve the transpose issue in the model module and the mask module.
- [ ] Make the model use `batch_first` by default.
- [ ] Refactor the RNNLoss part, let it take a dictionary instead of many separate `lambda_*` parameters.
- [x] Refactor the RNNLoss part, let it take a dictionary instead of many separate `lambda_*` parameters. --> added the `CompositeLoss` instead.
- [x] Added batch_first parameter. Adjusted to batch_first by default to follow PyTorch standard.
- [ ] Varying alpha
- [ ] Need to adjust implementation for `apply_plasticity` as it won't support SSL framework.
- [ ] Change output to readout.
- [ ] Change output to readout.
- [ ] Some quick methods to access firing rates of different values

0 comments on commit ef88f9b

Please sign in to comment.