diff --git a/brainpy/_src/initialize/decay_inits.py b/brainpy/_src/initialize/decay_inits.py index 5ed1fc3ac..22c1b4441 100644 --- a/brainpy/_src/initialize/decay_inits.py +++ b/brainpy/_src/initialize/decay_inits.py @@ -2,6 +2,9 @@ import numpy as np +from jax import vmap, jit, numpy as jnp +from functools import partial + from brainpy import math as bm from brainpy.tools import to_size, size2num from .base import _IntraLayerInitializer @@ -13,6 +16,22 @@ ] +@jit +@partial(vmap, in_axes=(0, None, None)) +def gaussian_decay_dist_cal1(i_value, post_values, sigma): + dists = jnp.abs(i_value - post_values) + exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) + return bm.asarray(exp_dists) + + +@jit +@partial(vmap, in_axes=(0, None, None, None)) +def gaussian_decay_dist_cal2(i_value, post_values, value_sizes, sigma): + dists = jnp.abs(i_value - post_values) + dists = jnp.where(dists > (value_sizes / 2), value_sizes - dists, dists) + exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) + return bm.asarray(exp_dists) + class GaussianDecay(_IntraLayerInitializer): r"""Builds a Gaussian connectivity pattern within a population of neurons, where the weights decay with gaussian function. @@ -106,8 +125,9 @@ def __call__(self, shape, dtype=None): value_sizes = np.expand_dims(value_sizes, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) # connectivity matrix - conn_mat = [] + i_value_list = np.zeros(shape=(net_size, len(shape), 1)) for i in range(net_size): + list_index = i # values for node i i_coordinate = tuple() for s in shape[:-1]: @@ -117,23 +137,23 @@ def __call__(self, shape, dtype=None): i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) if i_value.ndim < post_values.ndim: i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) - # distances - dists = np.abs(i_value - post_values) - if self.periodic_boundary: - dists = np.where(dists > value_sizes / 2, value_sizes - dists, dists) - exp_dists = np.exp(-(np.linalg.norm(dists, axis=0) / self.sigma) ** 2 / 2) - conn_mat.append(exp_dists) - conn_mat = np.stack(conn_mat) + i_value_list[list_index] = i_value + + if self.periodic_boundary: + conn_mat = gaussian_decay_dist_cal2(i_value_list, post_values, value_sizes, self.sigma) + else: + conn_mat = gaussian_decay_dist_cal1(i_value_list, post_values, self.sigma) + if self.normalize: conn_mat /= conn_mat.max() if not self.include_self: - np.fill_diagonal(conn_mat, 0.) + bm.fill_diagonal(conn_mat, 0.) # connectivity weights - conn_weights = conn_mat * self.max_w - conn_weights = np.where(conn_weights < self.min_w, 0., conn_weights) - return bm.asarray(conn_weights, dtype=dtype) - + conn_mat *= self.max_w + conn_mat = bm.where(conn_mat < self.min_w, 0., conn_mat) + return bm.asarray(conn_mat, dtype=dtype) + def __repr__(self): name = self.__class__.__name__ bank = ' ' * len(name)