Skip to content

Commit

Permalink
added mutli io
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaozewang committed Jun 22, 2024
1 parent 525ce60 commit fdbc708
Show file tree
Hide file tree
Showing 10 changed files with 170 additions and 120 deletions.
182 changes: 78 additions & 104 deletions nn4n/layer/recurrent_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,77 +7,36 @@


class RecurrentLayer(nn.Module):
def __init__(
self,
hidden_size,
positivity_constraints,
sparsity_constraints,
layer_distributions,
layer_biases,
layer_masks,
preact_noise,
postact_noise,
learnable=True,
**kwargs
):
"""
Hidden layer of the RNN
Parameters:
@param hidden_size: number of hidden neurons
@param positivity_constraints: whether to enforce positivity constraint
@param sparsity_constraints: use sparsity_constraints or not
@param layer_distributions: distribution of weights for each layer, a list of 3 strings
@param layer_biases: use bias or not for each layer, a list of 3 boolean values
Keyword Arguments:
@kwarg activation: activation function, default: "relu"
@kwarg preact_noise: noise added to pre-activation, default: 0
@kwarg postact_noise: noise added to post-activation, default: 0
@kwarg dt: time step, default: 1
@kwarg tau: time constant, default: 1
@kwarg input_dim: input dimension, default: 1
@kwarg hidden_dist: distribution of hidden layer weights, default: "normal"
@kwarg self_connections: allow self connections or not, default: False
@kwarg init_state: initial state of the network, 'zero', 'keep', or 'learn'
"""
"""
Recurrent layer of the RNN. The layer is initialized by passing specs in layer_struct.
Required keywords in layer_struct:
- activation: activation function, default: "relu"
- preact_noise: noise added to pre-activation
- postact_noise: noise added to post-activation
- dt: time step, default: 10
- tau: time constant, default: 100
- init_state: initial state of the network. It defines the hidden state at t=0.
- 'zero': all zeros
- 'keep': keep the last state
- 'learn': learn the initial state
- in_struct: input layer layer_struct
- hid_struct: hidden layer layer_struct
"""
def __init__(self, layer_struct, **kwargs):
super().__init__()

self.hidden_size = hidden_size
self.preact_noise = preact_noise
self.postact_noise = postact_noise
self.alpha = kwargs.get("dt", 10) / kwargs.get("tau", 100)
self.layer_distributions = layer_distributions
self.layer_biases = layer_biases
self.layer_masks = layer_masks
self.alpha = layer_struct['dt']/layer_struct['tau']
self.hidden_size = layer_struct['hid_struct']['input_dim']
self.hidden_state = torch.zeros(self.hidden_size)
self.init_state = kwargs.get("init_state", 'zero')
self.act = kwargs.get("activation", "relu")
self.init_state = layer_struct['init_state']
self.act = layer_struct['activation']
self.activation = get_activation(self.act)
self.preact_noise = kwargs.pop("preact_noise", 0)
self.postact_noise = kwargs.pop("postact_noise", 0)
self._set_hidden_state()

self.input_layer = LinearLayer(
positivity_constraints=positivity_constraints[0],
sparsity_constraints=sparsity_constraints[0],
output_dim=self.hidden_size,
input_dim=kwargs.pop("input_dim", 1),
use_bias=self.layer_biases[0],
dist=self.layer_distributions[0],
mask=self.layer_masks[0],
learnable=learnable[0],
)
self.hidden_layer = HiddenLayer(
hidden_size=self.hidden_size,
sparsity_constraints=sparsity_constraints[1],
positivity_constraints=positivity_constraints[1],
dist=self.layer_distributions[1],
use_bias=self.layer_biases[1],
scaling=kwargs.get("scaling", 1.0),
mask=self.layer_masks[1],
self_connections=kwargs.get("self_connections", False),
learnable=learnable[1],
)
self.input_layer = LinearLayer(layer_struct=layer_struct['in_struct'])
self.hidden_layer = HiddenLayer(layer_struct=layer_struct['hid_struct'])

# INITIALIZATION
# ==================================================================================================
Expand All @@ -93,17 +52,59 @@ def _set_hidden_state(self):

# FORWARD
# ==================================================================================================
def to(self, device):
""" Move the network to the device (cpu/gpu) """
super().to(device)
self.input_layer.to(device)
self.hidden_layer.to(device)
self.hidden_state = self.hidden_state.to(device)

def forward(self, x):
"""
Forwardly update network
Inputs:
- x: input, shape: (n_timesteps, batch_size, input_dim)
Returns:
- states: shape: (n_timesteps, batch_size, hidden_size)
"""
v_t = self._reset_state().to(x.device)
fr_t = self.activation(v_t)
# update hidden state and append to stacked_states
stacked_states = []
for i in range(x.size(0)):
fr_t, v_t = self._recurrence(fr_t, v_t, x[i])
# append to stacked_states
stacked_states.append(fr_t)

# if keeping the last state, save it to hidden_state
if self.init_state == 'keep':
self.hidden_state = fr_t.detach().clone() # TODO: haven't tested this yet

return torch.stack(stacked_states, dim=0)

def _reset_state(self):
if self.init_state == 'learn' or self.init_state == 'keep':
return self.hidden_state
else:
return torch.zeros(self.hidden_size)

def apply_plasticity(self):
""" Apply plasticity masks to the weight gradients """
self.input_layer.apply_plasticity()
self.hidden_layer.apply_plasticity()

def enforce_constraints(self):
"""
Enforce sparsity and excitatory/inhibitory constraints if applicable.
This is by default automatically called after each forward pass,
but can be called manually if needed
"""
self.input_layer.enforce_constraints()
self.hidden_layer.enforce_constraints()

def recurrence(self, fr_t, v_t, u_t):
def _recurrence(self, fr_t, v_t, u_t):
""" Recurrence function """
# through input layer
v_in_u_t = self.input_layer(u_t) # u_t @ W_in
Expand All @@ -126,55 +127,28 @@ def recurrence(self, fr_t, v_t, u_t):
fr_t = fr_t + postact_epsilon

return fr_t, v_t

def forward(self, input):
"""
Propogate input through the network.
@param input: shape=(seq_len, batch, input_dim), network input
@return stacked_states: shape=(seq_len, batch, hidden_size), stack of hidden layer status
"""
v_t = self._reset_state().to(input.device)
fr_t = self.activation(v_t)
# update hidden state and append to stacked_states
stacked_states = []
for i in range(input.size(0)):
fr_t, v_t = self.recurrence(fr_t, v_t, input[i])
# append to stacked_states
stacked_states.append(fr_t)

# if keeping the last state, save it to hidden_state
if self.init_state == 'keep':
self.hidden_state = fr_t.detach().clone() # TODO: haven't tested this yet

return torch.stack(stacked_states, dim=0)
# ==================================================================================================

# HELPER FUNCTIONS
# ==================================================================================================
def to(self, device):
"""
Move the network to the device (cpu/gpu)
"""
super().to(device)
self.input_layer.to(device)
self.hidden_layer.to(device)
self.hidden_state = self.hidden_state.to(device)
def plot_layers(self, **kwargs):
""" Plot the weights matrix and distribution of each layer """
self.input_layer.plot_layers()
self.hidden_layer.plot_layers()

def print_layers(self):
""" Print the weights matrix and distribution of each layer """
param_dict = {
"hidden_min": self.hidden_state.min(),
"hidden_max": self.hidden_state.max(),
"hidden_mean": self.hidden_state.mean(),
"init_hidden_min": self.hidden_state.min(),
"init_hidden_max": self.hidden_state.max(),
"preact_noise": self.preact_noise,
"postact_noise": self.postact_noise,
"activation": self.act,
"alpha": self.alpha,
"init_state": self.init_state,
"init_state_learnable": self.hidden_state.requires_grad,
}
self.input_layer.print_layers()
print_dict("Recurrence", param_dict)
self.hidden_layer.print_layers()

def plot_layers(self, **kwargs):
self.input_layer.plot_layers()
self.hidden_layer.plot_layers()
# ==================================================================================================
# ==================================================================================================
6 changes: 1 addition & 5 deletions nn4n/mask/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
from .multi_area import MultiArea
from .multi_area_ei import MultiAreaEI
from .random_input import RandomInput

if __name__ == "__main__":
print(MultiAreaEI)
print(MultiArea)
print(RandomInput)
from .multi_io import MultiIO
2 changes: 1 addition & 1 deletion nn4n/mask/base_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _check_parameters(self):
assert isinstance(self.output_dim, int), "output_dim must be int"
assert self.output_dim > 0, "output_dim must be positive"

def _generate_mask(self):
def _generate_masks(self):
"""
Generate the mask for the multi-area network
"""
Expand Down
2 changes: 1 addition & 1 deletion nn4n/mask/multi_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, **kwargs):
# run if it is not a child class
if self.__class__.__name__ == "MultiArea":
self._check_parameters()
self._generate_mask()
self._generate_masks()

def _check_parameters(self):
""" Check if parameters are valid """
Expand Down
4 changes: 2 additions & 2 deletions nn4n/mask/multi_area_ei.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, **kwargs):
self.inh_readout = kwargs.get("inh_readout", True)
# check parameters and generate mask
self._check_parameters()
self._generate_mask()
self._generate_masks()

def _check_parameters(self):
super()._check_parameters()
Expand All @@ -35,7 +35,7 @@ def _generate_mask(self):
"""
Generate the mask for the multi-area network
"""
super()._generate_mask()
super()._generate_masks()
self._generate_ei_assigment()
self._masks_to_ei()

Expand Down
82 changes: 82 additions & 0 deletions nn4n/mask/multi_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import numpy as np
from nn4n.mask.base_mask import BaseMask

class MultiIO(BaseMask):
def __init__(self, **kwargs):
"""
The MultiIO generate masks when there are multiple groups/types (e.g., a 1-dim olfactory signal +
100-dim visual signal will be two groups) of signals that required to be projected to
different hidden layer regions. The generated masks primarily works on the input/readout layer.
@kwarg input_dims: a list denoting the dimensions of each group of input signals.
E.g., if 1-dim olfactory signal + 100-dim visual signal will be two groups, then [1, 100],
must sum-up to dims[0]
@kwarg input_dims: a list denoting the dimensions of each group of output signals.
E.g., if 1-dim olfactory signal + 100-dim visual signal will be two groups, then [1, 100],
must sum-up to dims[2]
@kwarg input_table: a table denoting whether an input signal will be projected to a given
hidden layer node. Must be of a table of shape (n_input_groups, hidden_size) and containing
only 0s or 1s, default: all ones.
@kwarg output_table: a table denoting whether a hidden layer node will be used to generate a
specific output. Must be of a table of shape (n_output_groups, hidden_size) and containing
only 0s or 1s, default: all ones.
"""
super().__init__(**kwargs)
self.input_dims = kwargs.get("input_dims", [self.dims[0]])
self.output_dims = kwargs.get("output_dims", [self.dims[2]])
self.n_input_groups = len(self.input_dims) # number of groups of input signals
self.n_output_groups = len(self.output_dims) # number of groups of output signals
self.input_table = kwargs.get("input_table", np.ones((self.n_input_groups, self.dims[1])))
self.output_table = kwargs.get("output_table", np.ones((self.n_output_groups, self.dims[1])))

# check parameters and generate masks
self._check_parameters()
self._generate_masks()

def _check_parameters(self):
""" Check if parameters are valid """
super()._check_parameters()

# The input/output dims must be a list
assert type(self.input_dims) == list and self._check_int_list(self.input_dims), "input_dims must be a list of integers"
assert type(self.output_dims) == list and self._check_int_list(self.output_dims), "output_dims must be a list of integers"

# Check if the input_dims and output_dims all sum up to self.dims[0] and self.dims[2]
assert np.sum(self.input_dims) == self.dims[0], "input_dims must sum-up to the full input dimension specified in self.dims[0]"
assert np.sum(self.output_dims) == self.dims[2], "output_dims must sum-up to the full output dimension specified in self.dims[2]"

# Check if the input/output table dimension is valid
assert self.input_table.shape == (self.n_input_groups, self.dims[1])
assert self.output_table.shape == (self.n_output_groups, self.dims[1])

# TODO: check if all input/output table are zero/one.

@staticmethod
def _check_int_list(el_list):
all_int = True
for el in el_list:
all_int = all_int and type(el) == int
return all_int

def _generate_hidden_mask(self):
""" Hidden mask is not important for this class, thus all ones by default """
hidden_mask = np.ones((self.dims[1], self.dims[1]))
self.hidden_mask = hidden_mask.T # TODO: remove this and flip other masks

def _generate_input_mask(self):
input_mask = np.zeros((self.dims[0], self.dims[1]))
dim_counter = 0
for i, d in enumerate(self.input_dims):
input_idx = self.input_table[i].reshape(-1, 1)
input_mask[dim_counter:dim_counter+d,:] = np.tile(input_idx, d).T
dim_counter += d
self.input_mask = input_mask.T # TODO: remove this and flip other masks

def _generate_readout_mask(self):
readout_mask = np.zeros((self.dims[1], self.dims[2]))
dim_counter = 0
for i, d in enumerate(self.output_dims):
output_idx = self.output_table[i].reshape(-1, 1)
readout_mask[:,dim_counter:dim_counter+d] = np.tile(output_idx, d)
dim_counter += d
self.readout_mask = readout_mask.T # TODO: remove this and flip other masks
2 changes: 1 addition & 1 deletion nn4n/mask/random_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, **kwargs):
# run if it is not a child class
if self.__class__.__name__ == "RandomInput":
self._check_parameters()
self._generate_mask()
self._generate_masks()

def _check_parameters(self):
"""
Expand Down
5 changes: 0 additions & 5 deletions nn4n/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
from .base_nn import BaseNN
from .ctrnn import CTRNN
from .mlp import MLP

if __name__ == '__main__':
print(BaseNN)
print(CTRNN)
print(MLP)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@
'IPython',
'scipy',
],
python_requires='>=3.7',
python_requires='>=3.10',
)
3 changes: 3 additions & 0 deletions todo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# TODOs:
- [ ] The examples need to be updated. Especially on the main branch.
- [ ] Resolve the transpose issue in the model module and the mask module.

0 comments on commit fdbc708

Please sign in to comment.