-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
495ada9
commit ef88f9b
Showing
10 changed files
with
222 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .rnn_loss import RNNLoss | ||
from .mlp_loss import MLPLoss | ||
from .firing_rate_loss import * | ||
from .composite_loss import CompositeLoss | ||
from .composite_loss import CompositeLoss | ||
from .connectivity_loss import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
class RNNConnectivityLoss(nn.Module): | ||
def __init__(self, layer, metric='fro', **kwargs): | ||
super().__init__() | ||
assert metric in ['l1', 'fro'], "metric must be either l1 or l2" | ||
self.metric = metric | ||
self.layer = layer | ||
|
||
def forward(self, model, **kwargs): | ||
if self.layer == 'all': | ||
weights = [ | ||
model.recurrent_layer.input_layer.weight, | ||
model.recurrent_layer.hidden_layer.weight, | ||
model.readout_layer.weight | ||
] | ||
|
||
loss = torch.sum(torch.stack([self._compute_norm(weight) for weight in weights])) | ||
return loss | ||
elif self.layer == 'input': | ||
return self._compute_norm(model.recurrent_layer.input_layer.weight) | ||
elif self.layer == 'hidden': | ||
return self._compute_norm(model.recurrent_layer.hidden_layer.weight) | ||
elif self.layer == 'readout': | ||
return self._compute_norm(model.readout_layer.weight) | ||
else: | ||
raise ValueError(f"Invalid layer '{self.layer}'. Available layers are: 'all', 'input', 'hidden', 'readout'") | ||
|
||
def _compute_norm(self, weight): | ||
if self.metric == 'l1': | ||
return torch.norm(weight, p=1) | ||
else: | ||
return torch.norm(weight, p='fro') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .help_functions import * | ||
from .area_manager import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import numpy as np | ||
|
||
def check_init(func): | ||
def wrapper(self, *args, **kwargs): | ||
if self._area_indices is None: | ||
raise ValueError("Area indices are not initialized") | ||
return func(self, *args, **kwargs) | ||
return wrapper | ||
|
||
class AreaManager: | ||
def __init__(self, area_indices=None, batch_first=True): | ||
""" | ||
Initialize the AreaManager | ||
Inputs: | ||
- area_indices: a list of indices (array) denoting a neuron's area assignment | ||
- batch_first: whether the states are batch-first (default: True) | ||
""" | ||
if area_indices is not None: | ||
self.set_area_indices(area_indices) | ||
else: | ||
self._area_indices = None # Ensure it's None initially | ||
|
||
self._batch_first = batch_first | ||
|
||
def set_area_indices(self, area_indices): | ||
""" | ||
Set the area indices | ||
Inputs: | ||
- area_indices: a list of indices (array) denoting a neuron's area assignment | ||
""" | ||
self._n_areas = len(area_indices) | ||
self._area_indices = area_indices | ||
|
||
@property | ||
def n_areas(self): | ||
return self._n_areas | ||
|
||
@property | ||
def ai(self): | ||
return self._area_indices | ||
|
||
@check_init | ||
def split_states(self, states): | ||
""" | ||
Parse the states of a complete RNN into a list of states of different areas | ||
Inputs: | ||
- states: network states of shape (batch_size, seq_len, hidden_size) | ||
Returns: | ||
- list of states of different areas | ||
""" | ||
if not self._batch_first: | ||
states = states.permute(1, 0, 2) | ||
|
||
area_states = [states[:, :, idx] for idx in self._area_indices] | ||
return area_states | ||
|
||
@check_init | ||
def get_area_states(self, states, area_idx): | ||
""" | ||
Get the states of a specific area | ||
Inputs: | ||
- states: network states of shape (batch_size, seq_len, hidden_size) | ||
- area_idx: index of the area | ||
Returns: | ||
- states of the specific area | ||
""" | ||
if self._batch_first: | ||
states = states.permute(1, 0, 2) | ||
|
||
return states[:, :, self._area_indices[area_idx]] | ||
|
||
@check_init | ||
def random_indices(self, area_idx, n, replace=False): | ||
""" | ||
Randomly pick n indices from an area | ||
Inputs: | ||
- n: number of indices to pick | ||
- area_idx: index of the area | ||
- replace: whether to sample with replacement (default: False) | ||
Returns: | ||
- indices of the neurons | ||
""" | ||
n_neurons = len(self._area_indices[area_idx]) | ||
if not replace and n > n_neurons: | ||
return self._area_indices[area_idx] | ||
return np.random.choice(self._area_indices[area_idx], n, replace=replace) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters