Skip to content

Commit

Permalink
refactor LinearLayer and HiddenLayer
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaozewang committed Nov 23, 2024
1 parent 1d844a7 commit 2b1da32
Show file tree
Hide file tree
Showing 11 changed files with 246 additions and 645 deletions.
215 changes: 205 additions & 10 deletions nn4n/layer/base_layer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
import torch.nn as nn

import numpy as np
import nn4n.utils as utils

Expand Down Expand Up @@ -73,7 +72,7 @@ def from_dict(cls, layer_struct):
"""
# Create an instance using the dictionary values
cls._check_keys(layer_struct)
instance = cls(
return cls(
input_dim=layer_struct["input_dim"],
output_dim=layer_struct["output_dim"],
weight=layer_struct.get("weight", "uniform"),
Expand All @@ -82,20 +81,216 @@ def from_dict(cls, layer_struct):
sparsity_mask=layer_struct.get("sparsity_mask"),
plasticity_mask=layer_struct.get("plasticity_mask"),
)
# Initialize the trainable parameters then check the layer
instance._init_trainable()
instance._check_layer()

return instance
def _check_layer(self):
pass

# INIT TRAINABLE
# ======================================================================================
def _init_trainable(self):
# enfore constraints
# Enfore constraints
self._init_constraints()
# convert weight and bias to torch tensor
# Convert weight and bias to learnable parameters
self.weight = nn.Parameter(
self.weight, requires_grad=self.weight_dist is not None
)
self.bias = nn.Parameter(self.bias, requires_grad=self.bias_dist is not None)

def _check_layer(self):
pass
def _init_constraints(self):
"""
Initialize constraints
It will also balance excitatory and inhibitory neurons
"""
if self.sparsity_mask is not None:

self.weight *= self.sparsity_mask
if self.ei_mask is not None:
# Apply Dale's law
self.weight[self.ei_mask == 1] = torch.clamp(
self.weight[self.ei_mask == 1], min=0
) # For excitatory neurons, set negative weights to 0
self.weight[self.ei_mask == -1] = torch.clamp(
self.weight[self.ei_mask == -1], max=0
) # For inhibitory neurons, set positive weights to 0

# Balance excitatory and inhibitory neurons weight magnitudes
self._balance_excitatory_inhibitory()

def _generate_bias(self, bias_init):
"""Generate random bias"""
if bias_init == "uniform":
# If uniform, let b be uniform in [-sqrt(k), sqrt(k)]
sqrt_k = torch.sqrt(torch.tensor(1 / self.input_dim))
b = torch.rand(self.output_dim) * sqrt_k
b = b * 2 - sqrt_k
elif bias_init == "normal":
b = torch.randn(self.output_dim) / torch.sqrt(torch.tensor(self.input_dim))
elif bias_init == "zero" or bias_init == None:
b = torch.zeros(self.output_dim)
elif type(bias_init) == np.ndarray:
b = torch.from_numpy(bias_init)
else:
raise NotImplementedError
return b.float()

def _generate_weight(self, weight_init):
"""Generate random weight"""
if weight_init == "uniform":
# If uniform, let w be uniform in [-sqrt(k), sqrt(k)]
sqrt_k = torch.sqrt(torch.tensor(1 / self.input_dim))
w = torch.rand(self.output_dim, self.input_dim) * sqrt_k
w = w * 2 - sqrt_k
elif weight_init == "normal":
w = torch.randn(self.output_dim, self.input_dim) / torch.sqrt(
torch.tensor(self.input_dim)
)
elif weight_init == "zero":
w = torch.zeros((self.output_dim, self.input_dim))
elif type(weight_init) == np.ndarray:
w = torch.from_numpy(weight_init)
else:
raise NotImplementedError
return w.float()

def _balance_excitatory_inhibitory(self):
"""Balance excitatory and inhibitory weights"""
scale_mat = torch.ones_like(self.weight)
ext_sum = self.weight[self.sparsity_mask == 1].sum()
inh_sum = self.weight[self.sparsity_mask == -1].sum()
if ext_sum == 0 or inh_sum == 0:
# Automatically stop balancing if one of the sums is 0
# devide by 10 to avoid recurrent explosion/decay
self.weight /= 10
else:
if ext_sum > abs(inh_sum):
_scale = abs(inh_sum).item() / ext_sum.item()
scale_mat[self.sparsity_mask == 1] = _scale
elif ext_sum < abs(inh_sum):
_scale = ext_sum.item() / abs(inh_sum).item()
scale_mat[self.sparsity_mask == -1] = _scale
# Apply scaling
self.weight *= scale_mat

# TRAINING
# ======================================================================================
def to(self, device):
"""Move the network to the device (cpu/gpu)"""
super().to(device)
if self.sparsity_mask is not None:
self.sparsity_mask = self.sparsity_mask.to(device)
if self.ei_mask is not None:
self.ei_mask = self.ei_mask.to(device)
if self.bias.requires_grad:
self.bias = self.bias.to(device)
return self

def forward(self, x):
"""
Forwardly update network
Inputs:
- x: input, shape: (batch_size, input_dim)
Returns:
- state: shape: (batch_size, hidden_size)
"""
return x.float() @ self.weight.T + self.bias

def apply_plasticity(self):
"""
Apply plasticity mask to the weight gradient
"""
with torch.no_grad():
# assume the plasticity mask are all valid and being checked in ctrnn class
for scale in self.plasticity_scales:
if self.weight.grad is not None:
self.weight.grad[self.plasticity_mask == scale] *= scale
else:
raise RuntimeError(
"Weight gradient is None, possibly because the forward loop is non-differentiable"
)

def freeze(self):
"""Freeze the layer"""
self.weight.requires_grad = False
self.bias.requires_grad = False

def unfreeze(self):
"""Unfreeze the layer"""
self.weight.requires_grad = True
self.bias.requires_grad = True

# CONSTRAINTS
# ======================================================================================
def enforce_constraints(self):
"""
Enforce constraints
The constraints are:
- sparsity_mask: mask for sparse connectivity
- ei_mask: mask for Dale's law
"""
if self.sparsity_mask is not None:
self._enforce_sparsity()
if self.ei_mask is not None:
self._enforce_ei()

def _enforce_sparsity(self):
"""Enforce sparsity"""
w = self.weight.detach().clone() * self.sparsity_mask
self.weight.data.copy_(torch.nn.Parameter(w))

def _enforce_ei(self):
"""Enforce Dale's law"""
w = self.weight.detach().clone()
w[self.ei_mask == 1] = torch.clamp(w[self.ei_mask == 1], min=0)
w[self.ei_mask == -1] = torch.clamp(w[self.ei_mask == -1], max=0)
self.weight.data.copy_(torch.nn.Parameter(w))

# HELPER FUNCTIONS
# ======================================================================================
def set_weight(self, weight):
"""Set the value of weight"""
assert (
weight.shape == self.weight.shape
), f"Weight shape mismatch, expected {self.weight.shape}, got {weight.shape}"
with torch.no_grad():
self.weight.copy_(weight)

def plot_layer(self):
"""Plot the weights matrix and distribution of each layer"""
weight = (
self.weight.cpu()
if self.weight.device != torch.device("cpu")
else self.weight
)
utils.plot_connectivity_matrix_dist(
weight.detach().numpy(),
f"Weight",
False,
self.sparsity_mask is not None,
)

def get_specs(self):
"""Print the specs of each layer"""
return {
"input_dim": self.input_dim,
"output_dim": self.output_dim,
"weight_learnable": self.weight.requires_grad,
"weight_min": self.weight.min().item(),
"weight_max": self.weight.max().item(),
"bias_learnable": self.bias.requires_grad,
"bias_min": self.bias.min().item(),
"bias_max": self.bias.max().item(),
"sparsity": (
self.sparsity_mask.sum() / self.sparsity_mask.numel()
if self.sparsity_mask is not None
else 1
)
}

def print_layer(self):
"""
Print the specs of the layer
"""
utils.print_dict("Layer Specs", self.get_specs())
Loading

0 comments on commit 2b1da32

Please sign in to comment.