Skip to content

Commit

Permalink
fixed major recurrent_layer bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaozewang committed Jun 22, 2024
1 parent 525ce60 commit 9c32ecc
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 105 deletions.
182 changes: 78 additions & 104 deletions nn4n/layer/recurrent_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,77 +7,36 @@


class RecurrentLayer(nn.Module):
def __init__(
self,
hidden_size,
positivity_constraints,
sparsity_constraints,
layer_distributions,
layer_biases,
layer_masks,
preact_noise,
postact_noise,
learnable=True,
**kwargs
):
"""
Hidden layer of the RNN
Parameters:
@param hidden_size: number of hidden neurons
@param positivity_constraints: whether to enforce positivity constraint
@param sparsity_constraints: use sparsity_constraints or not
@param layer_distributions: distribution of weights for each layer, a list of 3 strings
@param layer_biases: use bias or not for each layer, a list of 3 boolean values
Keyword Arguments:
@kwarg activation: activation function, default: "relu"
@kwarg preact_noise: noise added to pre-activation, default: 0
@kwarg postact_noise: noise added to post-activation, default: 0
@kwarg dt: time step, default: 1
@kwarg tau: time constant, default: 1
@kwarg input_dim: input dimension, default: 1
@kwarg hidden_dist: distribution of hidden layer weights, default: "normal"
@kwarg self_connections: allow self connections or not, default: False
@kwarg init_state: initial state of the network, 'zero', 'keep', or 'learn'
"""
"""
Recurrent layer of the RNN. The layer is initialized by passing specs in layer_struct.
Required keywords in layer_struct:
- activation: activation function, default: "relu"
- preact_noise: noise added to pre-activation
- postact_noise: noise added to post-activation
- dt: time step, default: 10
- tau: time constant, default: 100
- init_state: initial state of the network. It defines the hidden state at t=0.
- 'zero': all zeros
- 'keep': keep the last state
- 'learn': learn the initial state
- in_struct: input layer layer_struct
- hid_struct: hidden layer layer_struct
"""
def __init__(self, layer_struct, **kwargs):
super().__init__()

self.hidden_size = hidden_size
self.preact_noise = preact_noise
self.postact_noise = postact_noise
self.alpha = kwargs.get("dt", 10) / kwargs.get("tau", 100)
self.layer_distributions = layer_distributions
self.layer_biases = layer_biases
self.layer_masks = layer_masks
self.alpha = layer_struct['dt']/layer_struct['tau']
self.hidden_size = layer_struct['hid_struct']['input_dim']
self.hidden_state = torch.zeros(self.hidden_size)
self.init_state = kwargs.get("init_state", 'zero')
self.act = kwargs.get("activation", "relu")
self.init_state = layer_struct['init_state']
self.act = layer_struct['activation']
self.activation = get_activation(self.act)
self.preact_noise = kwargs.pop("preact_noise", 0)
self.postact_noise = kwargs.pop("postact_noise", 0)
self._set_hidden_state()

self.input_layer = LinearLayer(
positivity_constraints=positivity_constraints[0],
sparsity_constraints=sparsity_constraints[0],
output_dim=self.hidden_size,
input_dim=kwargs.pop("input_dim", 1),
use_bias=self.layer_biases[0],
dist=self.layer_distributions[0],
mask=self.layer_masks[0],
learnable=learnable[0],
)
self.hidden_layer = HiddenLayer(
hidden_size=self.hidden_size,
sparsity_constraints=sparsity_constraints[1],
positivity_constraints=positivity_constraints[1],
dist=self.layer_distributions[1],
use_bias=self.layer_biases[1],
scaling=kwargs.get("scaling", 1.0),
mask=self.layer_masks[1],
self_connections=kwargs.get("self_connections", False),
learnable=learnable[1],
)
self.input_layer = LinearLayer(layer_struct=layer_struct['in_struct'])
self.hidden_layer = HiddenLayer(layer_struct=layer_struct['hid_struct'])

# INITIALIZATION
# ==================================================================================================
Expand All @@ -93,17 +52,59 @@ def _set_hidden_state(self):

# FORWARD
# ==================================================================================================
def to(self, device):
""" Move the network to the device (cpu/gpu) """
super().to(device)
self.input_layer.to(device)
self.hidden_layer.to(device)
self.hidden_state = self.hidden_state.to(device)

def forward(self, x):
"""
Forwardly update network
Inputs:
- x: input, shape: (n_timesteps, batch_size, input_dim)
Returns:
- states: shape: (n_timesteps, batch_size, hidden_size)
"""
v_t = self._reset_state().to(x.device)
fr_t = self.activation(v_t)
# update hidden state and append to stacked_states
stacked_states = []
for i in range(x.size(0)):
fr_t, v_t = self._recurrence(fr_t, v_t, x[i])
# append to stacked_states
stacked_states.append(fr_t)

# if keeping the last state, save it to hidden_state
if self.init_state == 'keep':
self.hidden_state = fr_t.detach().clone() # TODO: haven't tested this yet

return torch.stack(stacked_states, dim=0)

def _reset_state(self):
if self.init_state == 'learn' or self.init_state == 'keep':
return self.hidden_state
else:
return torch.zeros(self.hidden_size)

def apply_plasticity(self):
""" Apply plasticity masks to the weight gradients """
self.input_layer.apply_plasticity()
self.hidden_layer.apply_plasticity()

def enforce_constraints(self):
"""
Enforce sparsity and excitatory/inhibitory constraints if applicable.
This is by default automatically called after each forward pass,
but can be called manually if needed
"""
self.input_layer.enforce_constraints()
self.hidden_layer.enforce_constraints()

def recurrence(self, fr_t, v_t, u_t):
def _recurrence(self, fr_t, v_t, u_t):
""" Recurrence function """
# through input layer
v_in_u_t = self.input_layer(u_t) # u_t @ W_in
Expand All @@ -126,55 +127,28 @@ def recurrence(self, fr_t, v_t, u_t):
fr_t = fr_t + postact_epsilon

return fr_t, v_t

def forward(self, input):
"""
Propogate input through the network.
@param input: shape=(seq_len, batch, input_dim), network input
@return stacked_states: shape=(seq_len, batch, hidden_size), stack of hidden layer status
"""
v_t = self._reset_state().to(input.device)
fr_t = self.activation(v_t)
# update hidden state and append to stacked_states
stacked_states = []
for i in range(input.size(0)):
fr_t, v_t = self.recurrence(fr_t, v_t, input[i])
# append to stacked_states
stacked_states.append(fr_t)

# if keeping the last state, save it to hidden_state
if self.init_state == 'keep':
self.hidden_state = fr_t.detach().clone() # TODO: haven't tested this yet

return torch.stack(stacked_states, dim=0)
# ==================================================================================================

# HELPER FUNCTIONS
# ==================================================================================================
def to(self, device):
"""
Move the network to the device (cpu/gpu)
"""
super().to(device)
self.input_layer.to(device)
self.hidden_layer.to(device)
self.hidden_state = self.hidden_state.to(device)
def plot_layers(self, **kwargs):
""" Plot the weights matrix and distribution of each layer """
self.input_layer.plot_layers()
self.hidden_layer.plot_layers()

def print_layers(self):
""" Print the weights matrix and distribution of each layer """
param_dict = {
"hidden_min": self.hidden_state.min(),
"hidden_max": self.hidden_state.max(),
"hidden_mean": self.hidden_state.mean(),
"init_hidden_min": self.hidden_state.min(),
"init_hidden_max": self.hidden_state.max(),
"preact_noise": self.preact_noise,
"postact_noise": self.postact_noise,
"activation": self.act,
"alpha": self.alpha,
"init_state": self.init_state,
"init_state_learnable": self.hidden_state.requires_grad,
}
self.input_layer.print_layers()
print_dict("Recurrence", param_dict)
self.hidden_layer.print_layers()

def plot_layers(self, **kwargs):
self.input_layer.plot_layers()
self.hidden_layer.plot_layers()
# ==================================================================================================
# ==================================================================================================
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name='nn4n',
version='1.1.0',
version='1.1.1',
description='Neural Networks for Neuroscience Research',
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 9c32ecc

Please sign in to comment.