Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Gaussian Decay initializer #381

Merged
merged 1 commit into from
May 31, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions brainpy/_src/initialize/decay_inits.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import numpy as np

from jax import vmap, jit, numpy as jnp
from functools import partial

from brainpy import math as bm
from brainpy.tools import to_size, size2num
from .base import _IntraLayerInitializer
Expand All @@ -13,6 +16,22 @@
]


@jit
@partial(vmap, in_axes=(0, None, None))
def gaussian_decay_dist_cal1(i_value, post_values, sigma):
dists = jnp.abs(i_value - post_values)
exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2)
return bm.asarray(exp_dists)


@jit
@partial(vmap, in_axes=(0, None, None, None))
def gaussian_decay_dist_cal2(i_value, post_values, value_sizes, sigma):
dists = jnp.abs(i_value - post_values)
dists = jnp.where(dists > (value_sizes / 2), value_sizes - dists, dists)
exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2)
return bm.asarray(exp_dists)

class GaussianDecay(_IntraLayerInitializer):
r"""Builds a Gaussian connectivity pattern within a population of neurons,
where the weights decay with gaussian function.
Expand Down Expand Up @@ -106,8 +125,9 @@ def __call__(self, shape, dtype=None):
value_sizes = np.expand_dims(value_sizes, axis=tuple([i + 1 for i in range(post_values.ndim - 1)]))

# connectivity matrix
conn_mat = []
i_value_list = np.zeros(shape=(net_size, len(shape), 1))
for i in range(net_size):
list_index = i
# values for node i
i_coordinate = tuple()
for s in shape[:-1]:
Expand All @@ -117,23 +137,23 @@ def __call__(self, shape, dtype=None):
i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)])
if i_value.ndim < post_values.ndim:
i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)]))
# distances
dists = np.abs(i_value - post_values)
if self.periodic_boundary:
dists = np.where(dists > value_sizes / 2, value_sizes - dists, dists)
exp_dists = np.exp(-(np.linalg.norm(dists, axis=0) / self.sigma) ** 2 / 2)
conn_mat.append(exp_dists)
conn_mat = np.stack(conn_mat)
i_value_list[list_index] = i_value

if self.periodic_boundary:
conn_mat = gaussian_decay_dist_cal2(i_value_list, post_values, value_sizes, self.sigma)
else:
conn_mat = gaussian_decay_dist_cal1(i_value_list, post_values, self.sigma)

if self.normalize:
conn_mat /= conn_mat.max()
if not self.include_self:
np.fill_diagonal(conn_mat, 0.)
bm.fill_diagonal(conn_mat, 0.)

# connectivity weights
conn_weights = conn_mat * self.max_w
conn_weights = np.where(conn_weights < self.min_w, 0., conn_weights)
return bm.asarray(conn_weights, dtype=dtype)

conn_mat *= self.max_w
conn_mat = bm.where(conn_mat < self.min_w, 0., conn_mat)
return bm.asarray(conn_mat, dtype=dtype)
def __repr__(self):
name = self.__class__.__name__
bank = ' ' * len(name)
Expand Down