Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Nov 23, 2024
1 parent 2cb7463 commit 01108f7
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 167 deletions.
165 changes: 0 additions & 165 deletions braintaichi/_jitconnop/csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,171 +28,6 @@
from braintaichi._primitive._xla_custom_op import XLACustomOp
from braintaichi.rand._taichi_rand import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal)

__all__ = [
'get_homo_weight_matrix',
'get_uniform_weight_matrix',
'get_normal_weight_matrix'
]


@set_module_as('braintaichi')
def get_homo_weight_matrix(
weight: float,
conn_prob: float,
seed: Optional[int] = None,
*,
shape: Tuple[int, int],
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
r"""Get the connection matrix :math:`M` with a connection probability `conn_prob`.
Parameters
----------
conn_prob: float
The connection probability.
shape: tuple of int
The matrix shape.
seed: int
The random number generation seed.
transpose: bool
Transpose the random matrix or not.
outdim_parallel: bool
Perform the parallel random generations along the out dimension or not.
It can be used to set the just-in-time generated :math:M^T: is the same
as the just-in-time generated :math:`M` when ``transpose=True``.
Returns
-------
out: Array, ndarray
The connection matrix :math:`M`.
"""
if isinstance(weight, numbers.Number):
weight = jnp.atleast_1d(
jnp.asarray(weight, dtype=jnp.float64 if jax.config.read('jax_enable_x64') else jnp.float32)
)
else:
raise ValueError(f'weight must be a number type, but get {type(weight)}')

conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
if seed is None:
with jax.ensure_compile_time_eval():
seed = np.random.randint(0, int(1e8), 1)
seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
r = raw_get_homo_weight_matrix(conn_len, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_)
r *= weight
if transpose:
return r.transpose()
else:
return r


@set_module_as('braintaichi')
def get_uniform_weight_matrix(
w_low: float,
w_high: float,
conn_prob: float,
seed: Optional[int] = None,
*,
shape: Tuple[int, int],
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
r"""Get the weight matrix :math:`M` with a uniform distribution for its value.
Parameters
----------
w_low: float
Lower boundary of the output interval.
w_high: float
Upper boundary of the output interval.
conn_prob: float
The connection probability.
shape: tuple of int
The matrix shape.
seed: int
The random number generation seed.
transpose: bool
Transpose the random matrix or not.
outdim_parallel: bool
Perform the parallel random generations along the out dimension or not.
It can be used to set the just-in-time generated :math:M^T: is the same
as the just-in-time generated :math:`M` when ``transpose=True``.
Returns
-------
out: Array, ndarray
The weight matrix :math:`M`.
"""
w_low = jnp.atleast_1d(jnp.asarray(w_low))
w_high = jnp.atleast_1d(jnp.asarray(w_high))
conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
if seed is None:
with jax.ensure_compile_time_eval():
seed = np.random.randint(0, int(1e8), 1)
seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
r = raw_get_uniform_weight_matrix(w_low, w_high, conn_len, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)[0]
if transpose:
return r.transpose()
else:
return r


@set_module_as('braintaichi')
def get_normal_weight_matrix(
w_mu: float,
w_sigma: float,
conn_prob: float,
seed: Optional[int] = None,
*,
shape: Tuple[int, int],
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
r"""Get the weight matrix :math:`M` with a normal distribution for its value.
Parameters
----------
w_mu: float
Mean (centre) of the distribution.
w_sigma: float
Standard deviation (spread or “width”) of the distribution. Must be non-negative.
shape: tuple of int
The matrix shape.
seed: int
The random number generation seed.
transpose: bool
Transpose the random matrix or not.
outdim_parallel: bool
Perform the parallel random generations along the out dimension or not.
It can be used to set the just-in-time generated :math:M^T: is the same
as the just-in-time generated :math:`M` when ``transpose=True``.
Returns
-------
out: Array, ndarray
The weight matrix :math:`M`.
"""
w_mu = jnp.atleast_1d(jnp.asarray(w_mu))
w_sigma = jnp.atleast_1d(jnp.asarray(w_sigma))
conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
if seed is None:
with jax.ensure_compile_time_eval():
seed = np.random.randint(0, int(1e8), 1)
seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
r = raw_get_normal_weight_matrix(w_mu, w_sigma, conn_len, seed,
shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)[0]
if transpose:
return r.transpose()
else:
return r


def raw_mv_prob_homo(
vector: jax.Array,
Expand Down
167 changes: 165 additions & 2 deletions braintaichi/_jitconnop/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import numbers
from typing import Tuple, Optional

import jax
import numpy as np
from jax import numpy as jnp

from braintaichi._misc import set_module_as
from .csrmv import raw_mv_prob_homo, raw_mv_prob_uniform, raw_mv_prob_normal
from .csrmv import raw_mv_prob_homo, raw_mv_prob_uniform, raw_mv_prob_normal, raw_get_homo_weight_matrix, \
raw_get_uniform_weight_matrix, raw_get_normal_weight_matrix
from .event_csrmv import raw_event_mv_prob_homo, raw_event_mv_prob_uniform, raw_event_mv_prob_normal

__all__ = [
Expand All @@ -30,6 +31,9 @@
'jitc_event_mv_prob_homo',
'jitc_event_mv_prob_uniform',
'jitc_event_mv_prob_normal',
'get_homo_weight_matrix',
'get_uniform_weight_matrix',
'get_normal_weight_matrix'
]


Expand Down Expand Up @@ -349,3 +353,162 @@ def jitc_event_mv_prob_normal(


jitc_event_mv_prob_normal.__doc__ = jitc_mv_prob_normal.__doc__


@set_module_as('braintaichi')
def get_homo_weight_matrix(
weight: float,
conn_prob: float,
seed: Optional[int] = None,
*,
shape: Tuple[int, int],
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
r"""Get the connection matrix :math:`M` with a connection probability `conn_prob`.
Parameters
----------
conn_prob: float
The connection probability.
shape: tuple of int
The matrix shape.
seed: int
The random number generation seed.
transpose: bool
Transpose the random matrix or not.
outdim_parallel: bool
Perform the parallel random generations along the out dimension or not.
It can be used to set the just-in-time generated :math:M^T: is the same
as the just-in-time generated :math:`M` when ``transpose=True``.
Returns
-------
out: Array, ndarray
The connection matrix :math:`M`.
"""
if isinstance(weight, numbers.Number):
weight = jnp.atleast_1d(
jnp.asarray(weight, dtype=jnp.float64 if jax.config.read('jax_enable_x64') else jnp.float32)
)
else:
raise ValueError(f'weight must be a number type, but get {type(weight)}')

conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
if seed is None:
with jax.ensure_compile_time_eval():
seed = np.random.randint(0, int(1e8), 1)
seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
r = raw_get_homo_weight_matrix(conn_len, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_)
r *= weight
if transpose:
return r.transpose()
else:
return r


@set_module_as('braintaichi')
def get_uniform_weight_matrix(
w_low: float,
w_high: float,
conn_prob: float,
seed: Optional[int] = None,
*,
shape: Tuple[int, int],
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
r"""Get the weight matrix :math:`M` with a uniform distribution for its value.
Parameters
----------
w_low: float
Lower boundary of the output interval.
w_high: float
Upper boundary of the output interval.
conn_prob: float
The connection probability.
shape: tuple of int
The matrix shape.
seed: int
The random number generation seed.
transpose: bool
Transpose the random matrix or not.
outdim_parallel: bool
Perform the parallel random generations along the out dimension or not.
It can be used to set the just-in-time generated :math:M^T: is the same
as the just-in-time generated :math:`M` when ``transpose=True``.
Returns
-------
out: Array, ndarray
The weight matrix :math:`M`.
"""
w_low = jnp.atleast_1d(jnp.asarray(w_low))
w_high = jnp.atleast_1d(jnp.asarray(w_high))
conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
if seed is None:
with jax.ensure_compile_time_eval():
seed = np.random.randint(0, int(1e8), 1)
seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
r = raw_get_uniform_weight_matrix(w_low, w_high, conn_len, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)[0]
if transpose:
return r.transpose()
else:
return r


@set_module_as('braintaichi')
def get_normal_weight_matrix(
w_mu: float,
w_sigma: float,
conn_prob: float,
seed: Optional[int] = None,
*,
shape: Tuple[int, int],
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
r"""Get the weight matrix :math:`M` with a normal distribution for its value.
Parameters
----------
w_mu: float
Mean (centre) of the distribution.
w_sigma: float
Standard deviation (spread or “width”) of the distribution. Must be non-negative.
shape: tuple of int
The matrix shape.
seed: int
The random number generation seed.
transpose: bool
Transpose the random matrix or not.
outdim_parallel: bool
Perform the parallel random generations along the out dimension or not.
It can be used to set the just-in-time generated :math:M^T: is the same
as the just-in-time generated :math:`M` when ``transpose=True``.
Returns
-------
out: Array, ndarray
The weight matrix :math:`M`.
"""
w_mu = jnp.atleast_1d(jnp.asarray(w_mu))
w_sigma = jnp.atleast_1d(jnp.asarray(w_sigma))
conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
if seed is None:
with jax.ensure_compile_time_eval():
seed = np.random.randint(0, int(1e8), 1)
seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
r = raw_get_normal_weight_matrix(w_mu, w_sigma, conn_len, seed,
shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)[0]
if transpose:
return r.transpose()
else:
return r

0 comments on commit 01108f7

Please sign in to comment.