Skip to content

Commit

Permalink
refactored layers
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaozewang committed Nov 23, 2024
1 parent b258dc0 commit 1d844a7
Show file tree
Hide file tree
Showing 28 changed files with 974 additions and 1,721 deletions.
696 changes: 0 additions & 696 deletions examples/CTRNN.ipynb

This file was deleted.

393 changes: 0 additions & 393 deletions examples/MultiArea.ipynb

This file was deleted.

29 changes: 15 additions & 14 deletions nn4n/criterion/composite_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@ def __init__(self, loss_cfg):

# Mapping of loss types to their respective classes or instances
loss_types = {
'fr': FiringRateLoss,
'fr_dist': FiringRateDistLoss,
'rnn_conn': RNNConnectivityLoss,
'state_pred': StatePredictionLoss,
'entropy': EntropyLoss,
'mse': nn.MSELoss,
'hebbian': HebbianLoss,
"fr": FiringRateLoss,
"fr_dist": FiringRateDistLoss,
"rnn_conn": RNNConnectivityLoss,
"state_pred": StatePredictionLoss,
"entropy": EntropyLoss,
"mse": nn.MSELoss,
"hebbian": HebbianLoss,
}
torch_losses = ['mse']
torch_losses = ["mse"]

# Iterate over the loss_cfg to instantiate and store losses
for loss_name, loss_spec in loss_cfg.items():
loss_type = loss_spec['type']
loss_params = loss_spec.get('params', {})
loss_lambda = loss_spec.get('lambda', 1.0)
loss_type = loss_spec["type"]
loss_params = loss_spec.get("params", {})
loss_lambda = loss_spec.get("lambda", 1.0)

# Instantiate the loss function
if loss_type in loss_types:
Expand All @@ -51,7 +51,9 @@ def __init__(self, loss_cfg):
# Store the loss instance and its weight in a dictionary
self.loss_components[loss_name] = (loss_instance, loss_lambda)
else:
raise ValueError(f"Invalid loss type '{loss_type}'. Available types are: {list(loss_types.keys())}")
raise ValueError(
f"Invalid loss type '{loss_type}'. Available types are: {list(loss_types.keys())}"
)

def forward(self, loss_input_dict):
"""
Expand All @@ -70,8 +72,7 @@ def forward(self, loss_input_dict):
loss_inputs = loss_input_dict[loss_name]
if isinstance(loss_fn, nn.MSELoss):
# For MSELoss, assume the inputs are 'input' and 'target'
loss_value = loss_fn(
loss_inputs['input'], loss_inputs['target'])
loss_value = loss_fn(loss_inputs["input"], loss_inputs["target"])
else:
loss_value = loss_fn(**loss_inputs)
loss_dict[loss_name] = loss_weight * loss_value
Expand Down
27 changes: 15 additions & 12 deletions nn4n/criterion/connectivity_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,37 @@


class RNNConnectivityLoss(nn.Module):
def __init__(self, layer, metric='fro', **kwargs):
def __init__(self, layer, metric="fro", **kwargs):
super().__init__()
assert metric in ['l1', 'fro'], "metric must be either l1 or l2"
assert metric in ["l1", "fro"], "metric must be either l1 or l2"
self.metric = metric
self.layer = layer

def forward(self, model, **kwargs):
if self.layer == 'all':
if self.layer == "all":
weights = [
model.recurrent_layer.input_layer.weight,
model.recurrent_layer.hidden_layer.weight,
model.readout_layer.weight
model.readout_layer.weight,
]

loss = torch.sum(torch.stack(
[self._compute_norm(weight) for weight in weights]))
loss = torch.sum(
torch.stack([self._compute_norm(weight) for weight in weights])
)
return loss
elif self.layer == 'input':
elif self.layer == "input":
return self._compute_norm(model.recurrent_layer.input_layer.weight)
elif self.layer == 'hidden':
elif self.layer == "hidden":
return self._compute_norm(model.recurrent_layer.hidden_layer.weight)
elif self.layer == 'readout':
elif self.layer == "readout":
return self._compute_norm(model.readout_layer.weight)
else:
raise ValueError(f"Invalid layer '{self.layer}'. Available layers are: 'all', 'input', 'hidden', 'readout'")
raise ValueError(
f"Invalid layer '{self.layer}'. Available layers are: 'all', 'input', 'hidden', 'readout'"
)

def _compute_norm(self, weight):
if self.metric == 'l1':
if self.metric == "l1":
return torch.norm(weight, p=1)
else:
return torch.norm(weight, p='fro')
return torch.norm(weight, p="fro")
77 changes: 41 additions & 36 deletions nn4n/criterion/firing_rate_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,26 @@ def forward(self, **kwargs):


class FiringRateLoss(CustomLoss):
def __init__(self, metric='l2', **kwargs):
def __init__(self, metric="l2", **kwargs):
super().__init__(**kwargs)
assert metric in ['l1', 'l2'], "metric must be either l1 or l2"
assert metric in ["l1", "l2"], "metric must be either l1 or l2"
self.metric = metric

def forward(self, states, **kwargs):
# Calculate the mean firing rate across specified dimensions
mean_fr = torch.mean(states, dim=(0, 1))

# Replace custom norm calculation with PyTorch's built-in norm
if self.metric == 'l1':
return F.l1_loss(mean_fr, torch.zeros_like(mean_fr), reduction='mean')
if self.metric == "l1":
return F.l1_loss(mean_fr, torch.zeros_like(mean_fr), reduction="mean")
else:
return F.mse_loss(mean_fr, torch.zeros_like(mean_fr), reduction='mean')
return F.mse_loss(mean_fr, torch.zeros_like(mean_fr), reduction="mean")


class FiringRateDistLoss(CustomLoss):
def __init__(self, metric='sd', **kwargs):
def __init__(self, metric="sd", **kwargs):
super().__init__(**kwargs)
valid_metrics = ['sd', 'cv', 'mean_ad', 'max_ad']
valid_metrics = ["sd", "cv", "mean_ad", "max_ad"]
assert metric in valid_metrics, (
"metric must be chosen from 'sd' (standard deviation), "
"'cv' (coefficient of variation), 'mean_ad' (mean abs deviation), "
Expand All @@ -44,21 +44,21 @@ def forward(self, states, **kwargs):
mean_fr = torch.mean(states, dim=(0, 1))

# Standard deviation
if self.metric == 'sd':
if self.metric == "sd":
return torch.std(mean_fr)

# Coefficient of variation
elif self.metric == 'cv':
elif self.metric == "cv":
return torch.std(mean_fr) / torch.mean(mean_fr)

# Mean absolute deviation
elif self.metric == 'mean_ad':
elif self.metric == "mean_ad":
avg_mean_fr = torch.mean(mean_fr)
# Use F.l1_loss for mean absolute deviation
return F.l1_loss(mean_fr, avg_mean_fr.expand_as(mean_fr), reduction='mean')
return F.l1_loss(mean_fr, avg_mean_fr.expand_as(mean_fr), reduction="mean")

# Maximum absolute deviation
elif self.metric == 'max_ad':
elif self.metric == "max_ad":
avg_mean_fr = torch.mean(mean_fr)
return torch.max(torch.abs(mean_fr - avg_mean_fr))

Expand All @@ -73,10 +73,12 @@ def forward(self, states, **kwargs):
states = states.transpose(0, 1)

# Ensure the sequence is long enough for the prediction window
assert states.shape[1] > self.tau, "The sequence length is shorter than the prediction window."
assert (
states.shape[1] > self.tau
), "The sequence length is shorter than the prediction window."

# Use MSE loss instead of manual difference calculation
return F.mse_loss(states[:-self.tau], states[self.tau:], reduction='mean')
return F.mse_loss(states[: -self.tau], states[self.tau :], reduction="mean")


class HebbianLoss(nn.Module):
Expand All @@ -88,7 +90,7 @@ def forward(self, states, weights):
# weights shape: (num_neurons, num_neurons)

# Compute correlations by averaging over time steps
correlations = torch.einsum('bti,btj->btij', states, states)
correlations = torch.einsum("bti,btj->btij", states, states)

# Apply weights to correlations and sum to get Hebbian loss for each batch
hebbian_loss = torch.sum(weights * correlations, dim=(-1, -2))
Expand All @@ -114,8 +116,7 @@ def forward(self, states):
prob_states = states / (states.sum(dim=-1, keepdim=True) + eps)

# Compute the entropy of the neuron activations
entropy_loss = -torch.sum(prob_states *
torch.log(prob_states + eps), dim=-1)
entropy_loss = -torch.sum(prob_states * torch.log(prob_states + eps), dim=-1)

# Take the mean entropy over batches and time steps
mean_entropy = torch.mean(entropy_loss)
Expand All @@ -128,7 +129,7 @@ def forward(self, states):


class PopulationKL(nn.Module):
def __init__(self, symmetric=True, reg=1e-3, reduction='mean'):
def __init__(self, symmetric=True, reg=1e-3, reduction="mean"):
super().__init__()
self.symmetric = symmetric
self.reg = reg
Expand All @@ -140,45 +141,49 @@ def forward(self, states_0, states_1):
mean_0 = torch.mean(states_0, dim=(0, 1), keepdim=True)
# Shape: (1, 1, n_neurons)
mean_1 = torch.mean(states_1, dim=(0, 1), keepdim=True)
var_0 = torch.var(states_0, dim=(0, 1), unbiased=False,
keepdim=True) # Shape: (1, 1, n_neurons)
var_1 = torch.var(states_1, dim=(0, 1), unbiased=False,
keepdim=True) # Shape: (1, 1, n_neurons)
var_0 = torch.var(
states_0, dim=(0, 1), unbiased=False, keepdim=True
) # Shape: (1, 1, n_neurons)
var_1 = torch.var(
states_1, dim=(0, 1), unbiased=False, keepdim=True
) # Shape: (1, 1, n_neurons)

# Compute the KL divergence between the two populations (per neuron)
# Shape: (1, 1, n_neurons)
kl_div = 0.5 * (torch.log(var_1 / var_0) +
(var_0 + (mean_0 - mean_1) ** 2) / var_1 - 1)
kl_div = 0.5 * (
torch.log(var_1 / var_0) + (var_0 + (mean_0 - mean_1) ** 2) / var_1 - 1
)

# Symmetric KL divergence: average the KL(P || Q) and KL(Q || P)
if self.symmetric:
# Shape: (1, 1, n_neurons)
reverse_kl_div = 0.5 * \
(torch.log(var_0 / var_1) +
(var_1 + (mean_1 - mean_0) ** 2) / var_0 - 1)
reverse_kl_div = 0.5 * (
torch.log(var_0 / var_1) + (var_1 + (mean_1 - mean_0) ** 2) / var_0 - 1
)
# Shape: (1, 1, n_neurons)
kl_div = 0.5 * (kl_div + reverse_kl_div)

# Apply reduction based on the reduction method
if self.reduction == 'mean':
if self.reduction == "mean":
kl_loss = torch.mean(kl_div) # Scalar value
elif self.reduction == 'sum':
elif self.reduction == "sum":
kl_loss = torch.sum(kl_div) # Scalar value
elif self.reduction == 'none':
elif self.reduction == "none":
kl_loss = kl_div # Shape: (1, 1, n_neurons)
else:
raise ValueError(f"Invalid reduction mode: {self.reduction}")

# Regularization: L2 norm of the states across the neurons
reg_loss = torch.mean(torch.norm(states_0, dim=-1) ** 2) + \
torch.mean(torch.norm(states_1, dim=-1) ** 2)
reg_loss = torch.mean(torch.norm(states_0, dim=-1) ** 2) + torch.mean(
torch.norm(states_1, dim=-1) ** 2
)

# Combine the KL divergence with the regularization term
if self.reduction == 'none':
if self.reduction == "none":
# If no reduction, add regularization element-wise
total_loss = kl_loss + self.reg * \
(torch.norm(states_0, dim=-1) ** 2 +
torch.norm(states_1, dim=-1) ** 2)
total_loss = kl_loss + self.reg * (
torch.norm(states_0, dim=-1) ** 2 + torch.norm(states_1, dim=-1) ** 2
)
else:
total_loss = kl_loss + self.reg * reg_loss

Expand Down
10 changes: 5 additions & 5 deletions nn4n/criterion/mlp_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _init_losses(self, **kwargs):
self.loss_list = loss_list

def _loss_fr(self, states, **kwargs):
""" Compute the loss for firing rate """
"""Compute the loss for firing rate"""
# return torch.sqrt(torch.square(states)).mean()
loss = []
for s in states:
Expand All @@ -39,7 +39,7 @@ def _loss_fr(self, states, **kwargs):
return torch.stack(loss).mean()

def _loss_fr_sd(self, states, **kwargs):
""" Compute the loss for firing rate for each neuron in terms of SD """
"""Compute the loss for firing rate for each neuron in terms of SD"""
# return torch.sqrt(torch.square(states)).mean(dim=(0)).std()
return torch.pow(torch.mean(states, dim=(0, 1)), 2).std()

Expand All @@ -50,17 +50,17 @@ def forward(self, pred, label, **kwargs):
@param label: size=(-1, batch_size, 2), labels
@param dur: duration of the trial
"""
loss = [self.lambda_mse * torch.square(pred-label).mean()]
loss = [self.lambda_mse * torch.square(pred - label).mean()]
for i in range(len(self.loss_list)):
if self.lambda_list[i] == 0:
continue
else:
loss.append(self.lambda_list[i]*self.loss_list[i](**kwargs))
loss.append(self.lambda_list[i] * self.loss_list[i](**kwargs))
loss = torch.stack(loss)
return loss.sum(), loss

def to(self, device):
""" Move to device """
"""Move to device"""
super().to(device)
self.lambda_list = self.lambda_list.to(device)
return self
39 changes: 24 additions & 15 deletions nn4n/criterion/rnn_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class RNNLoss(nn.Module):
- lambda_hid: coefficient for the hidden layer loss, default: 0
- lambda_out: coefficient for the readout layer loss, default: 0
- lambda_fr: coefficient for the overall firing rate loss, default: 0
- lambda_fr_sd: coefficient for the standard deviation of firing rate
- lambda_fr_sd: coefficient for the standard deviation of firing rate
loss (to evenly distribute firing rate across neurons), default: 0
- lambda_fr_cv: coefficient for the coefficient of variation of firing
rate loss (to evenly distribute firing rate across neurons), default: 0
Expand Down Expand Up @@ -72,21 +72,30 @@ def _init_losses(self, **kwargs):
n_in = self.model.recurrent_layer.input_layer.weight.shape[1]
n_size = self.model.recurrent_layer.hidden_layer.weight.shape[0]
n_out = self.model.readout_layer.weight.shape[0]
self.n_in_dividend = n_in*n_size
self.n_hid_dividend = n_size*n_size
self.n_out_dividend = n_out*n_size
self.n_in_dividend = n_in * n_size
self.n_hid_dividend = n_size * n_size
self.n_out_dividend = n_out * n_size

def _loss_in(self, **kwargs):
""" Compute the loss for InputLayer """
return torch.norm(self.model.recurrent_layer.input_layer.weight, p='fro')**2/self.n_in_dividend
"""Compute the loss for InputLayer"""
return (
torch.norm(self.model.recurrent_layer.input_layer.weight, p="fro") ** 2
/ self.n_in_dividend
)

def _loss_hid(self, **kwargs):
""" Compute the loss for RecurrentLayer """
return torch.norm(self.model.recurrent_layer.hidden_layer.weight, p='fro')**2/self.n_hid_dividend
"""Compute the loss for RecurrentLayer"""
return (
torch.norm(self.model.recurrent_layer.hidden_layer.weight, p="fro") ** 2
/ self.n_hid_dividend
)

def _loss_out(self, **kwargs):
""" Compute the loss for ReadoutLayer """
return torch.norm(self.model.readout_layer.weight, p='fro')**2/self.n_out_dividend
"""Compute the loss for ReadoutLayer"""
return (
torch.norm(self.model.readout_layer.weight, p="fro") ** 2
/ self.n_out_dividend
)

def _loss_fr(self, states, **kwargs):
"""
Expand All @@ -99,10 +108,10 @@ def _loss_fr(self, states, **kwargs):
mean_fr = torch.mean(states, dim=(0, 1))
# return torch.pow(torch.mean(states, dim=(0, 1)), 2).mean() # this might not be correct
# return torch.norm(states, p='fro')**2/states.numel() # this might not be correct
return torch.norm(mean_fr, p=2)**2/mean_fr.numel()
return torch.norm(mean_fr, p=2) ** 2 / mean_fr.numel()

def _loss_fr_sd(self, states, **kwargs):
"""
"""
Compute the loss for firing rate for each neuron in terms of SD
This will take the average firing rate of each neuron across all timesteps and batch_size
and compute the standard deviation of the firing rate across all neurons
Expand All @@ -127,7 +136,7 @@ def _loss_fr_cv(self, states, **kwargs):
if not self.batch_first:
states = states.transpose(0, 1)
avg_fr = torch.mean(torch.sqrt(torch.square(states)), dim=(0, 1))
return avg_fr.std()/avg_fr.mean()
return avg_fr.std() / avg_fr.mean()

def forward(self, pred, label, **kwargs):
"""
Expand All @@ -139,11 +148,11 @@ def forward(self, pred, label, **kwargs):
where -1 is the sequence length
"""
loss = [self.lambda_mse * torch.square(pred-label).mean()]
loss = [self.lambda_mse * torch.square(pred - label).mean()]
for i in range(len(self.loss_list)):
if self.lambda_list[i] == 0:
continue
else:
loss.append(self.lambda_list[i]*self.loss_list[i](**kwargs))
loss.append(self.lambda_list[i] * self.loss_list[i](**kwargs))
loss = torch.stack(loss)
return loss.sum(), loss
Loading

0 comments on commit 1d844a7

Please sign in to comment.