Skip to content

Commit

Permalink
Updates training apis and docs (#145)
Browse files Browse the repository at this point in the history
Updates training apis and docs
  • Loading branch information
chaoming0625 authored Apr 2, 2022
2 parents d9c662d + 0f37a24 commit e52e104
Show file tree
Hide file tree
Showing 20 changed files with 1,959 additions and 643 deletions.
8 changes: 8 additions & 0 deletions brainpy/datasets/chaotic_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,14 @@ def lorenz_series(duration, dt=0.001, sigma=10, beta=8 / 3, rho=28, method='rk4'
The Lorenz system is a system of ordinary differential equations first
studied by mathematician and meteorologist Edward Lorenz.
Returns
-------
data: dict
A dict data with the keys of ``ts``, ``x``, ``y``, and ``z``,
where ``ts`` is the history time value, ``x, y, z`` are history
values of the variable in the Lorenz system.
References
----------
.. [6] https://brainpy-examples.readthedocs.io/en/latest/classical_dynamical_systems/lorenz_system.html
Expand Down
3 changes: 2 additions & 1 deletion brainpy/math/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def permutation(self, x):

def shuffle(self, x, axis=0):
x = x.value if isinstance(x, JaxArray) else x
return JaxArray(jr.permutation(self.split_key(), x, axis=axis, independent=True))
# return JaxArray(jr.permutation(self.split_key(), x, axis=axis, independent=True))
return JaxArray(jr.permutation(self.split_key(), x, axis=axis))

def beta(self, a, b, size=None):
a = a.value if isinstance(a, JaxArray) else a
Expand Down
113 changes: 65 additions & 48 deletions brainpy/nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
detect_cycle,
detect_path)
from brainpy.tools.checking import (check_dict_data,
check_batch_shape,
check_shape_except_batch,
check_integer)
from brainpy.types import Tensor

Expand Down Expand Up @@ -231,10 +231,10 @@ def set_state(self, state):
"""
if self.state is None:
if self.output_shape is not None:
check_batch_shape(self.output_shape, state.shape)
check_shape_except_batch(self.output_shape, state.shape)
self._state = bm.Variable(state) if not isinstance(state, bm.Variable) else state
else:
check_batch_shape(self.state.shape, state.shape)
check_shape_except_batch(self.state.shape, state.shape)
if self.state.dtype != state.dtype:
raise MathError('Cannot set the state, because the dtype is not consistent: '
f'{self.state.dtype} != {state.dtype}')
Expand All @@ -261,10 +261,10 @@ def set_fb_output(self, state: Tensor):
"""
if self.fb_output is None:
if self.output_shape is not None:
check_batch_shape(self.output_shape, state.shape)
check_shape_except_batch(self.output_shape, state.shape)
self._fb_output = bm.Variable(state) if not isinstance(state, bm.Variable) else state
else:
check_batch_shape(self.fb_output.shape, state.shape)
check_shape_except_batch(self.fb_output.shape, state.shape)
if self.fb_output.dtype != state.dtype:
raise MathError('Cannot set the feedback state, because the dtype is '
f'not consistent: {self.fb_output.dtype} != {state.dtype}')
Expand Down Expand Up @@ -313,15 +313,13 @@ def set_feedforward_shapes(self, feedforward_shapes: Dict):
self._feedforward_shapes = feedforward_shapes
else:
if self.feedforward_shapes is not None:
for key, size in self._feedforward_shapes.items():
if key not in feedforward_shapes:
raise ValueError(f"Impossible to reset the input shape of {self.name}. "
f"Because this Node has the input dimension {size} from {key}. "
f"While we do not find it in the given feedforward_shapes")
if not check_batch_shape(size, feedforward_shapes[key], mode='bool'):
raise ValueError(f"Impossible to reset the input shape of {self.name}. "
f"Because this Node has the input dimension {size} from {key}. "
f"While the give shape is {feedforward_shapes[key]}")
sizes1 = sorted(list(self._feedforward_shapes.values()))
sizes2 = sorted(list(feedforward_shapes.values()))
if sizes1 != sizes2:
raise ValueError(f"Impossible to reset the input shapes of {self.name}. "
f"Because this Node has the input shapes {sizes1}. "
f"While we got input shapes of {sizes2}")
self._feedforward_shapes = feedforward_shapes

@property
def feedback_shapes(self):
Expand All @@ -341,15 +339,13 @@ def set_feedback_shapes(self, fb_shapes: Dict):
self._feedback_shapes = fb_shapes
else:
if self.feedback_shapes is not None:
for key, size in self._feedforward_shapes.items():
if key not in fb_shapes:
raise ValueError(f"Impossible to reset the input data of {self.name}. "
f"Because this Node has the input dimension {size} from {key}. "
f"While we do not find it in {fb_shapes}")
if not check_batch_shape(size, fb_shapes[key], mode='bool'):
raise ValueError(f"Impossible to reset the input data of {self.name}. "
f"Because this Node has the input dimension {size} from {key}. "
f"While the give shape is {fb_shapes[key]}")
sizes1 = sorted(list(self._feedback_shapes.values()))
sizes2 = sorted(list(fb_shapes.values()))
if sizes1 != sizes2:
raise ValueError(f"Impossible to reset the feedback shapes of {self.name}. "
f"Because this Node has the feedback shapes {sizes1}. "
f"While we got feedback shapes of {sizes2}")
self._feedback_shapes = fb_shapes

@property
def output_shape(self) -> Optional[Tuple[int]]:
Expand Down Expand Up @@ -380,7 +376,7 @@ def set_output_shape(self, shape: Sequence[int]):
raise ValueError(f'Must be a sequence of int, but got {shape}')
self._output_shape = tuple(shape)
else:
check_batch_shape(shape, self.output_shape)
check_shape_except_batch(shape, self.output_shape)

def nodes(self, method='absolute', level=1, include_self=True):
return super(Node, self).nodes(method=method, level=level, include_self=include_self)
Expand Down Expand Up @@ -468,7 +464,7 @@ def init_fb_output(self, num_batch=1) -> Optional[Tensor]:
"""
return bm.zeros((num_batch,) + self.output_shape[1:], dtype=bm.float_)

def initialize(self, num_batch: int):
def initialize(self, num_batch: int = 1):
"""
Initialize the node. This function must be called before applying JIT.
Expand Down Expand Up @@ -506,7 +502,15 @@ def _check_inputs(self, ff, fb=None):
for k, size in self._feedforward_shapes.items():
if k not in ff:
raise ValueError(f"The required key {k} is not provided in feedforward inputs.")
check_batch_shape(size, ff[k].shape)
check_shape_except_batch(size, ff[k].shape)
if self.state is not None:
for inp in ff.values():
if self.state.shape[0] != inp.shape[0]:
raise ValueError(f'The batch size of the input data {inp.shape[0]} is not '
f'equal to the batch size of the node state {self.state.shape[0]}. '
f'Maybe you need to reinitialize the data with the desired '
f'batch size by ".initialize(num_batch)", or change the data '
f'consistent with the data batch size {self.state.shape[0]}.')

# check feedback inputs
if fb is not None:
Expand All @@ -516,8 +520,15 @@ def _check_inputs(self, ff, fb=None):
for k, size in self._feedback_shapes.items():
if k not in fb:
raise ValueError(f"The required key {k} is not provided in feedback inputs.")
check_batch_shape(size, fb[k].shape)

check_shape_except_batch(size, fb[k].shape)
if self.state is not None:
for inp in fb.values():
if self.state.shape[0] != inp.shape[0]:
raise ValueError(f'The batch size of the feedback data {inp.shape[0]} is not '
f'equal to the batch size of the node state {self.state.shape[0]}. '
f'Maybe you need to reinitialize the data with the desired '
f'batch size by ".initialize(num_batch)", or change the data '
f'consistent with the data batch size {self.state.shape[0]}.')
# data
ff = self.data_pass_func(ff)
fb = self.data_pass_func(fb)
Expand All @@ -540,8 +551,9 @@ def _call(self,
forced_states = {self.name: forced_states}
check_dict_data(forced_states, key_type=str, val_type=(bm.ndarray, jnp.ndarray))
if forced_feedbacks is not None:
raise ValueError('Single instance of brainpy.nn.Node do '
'not support "forced_feedbacks"')
if len(forced_feedbacks) != 0:
raise ValueError('Single instance of brainpy.nn.Node do '
'not support "forced_feedbacks"')
# monitors
need_return_monitor = True
if monitors is None:
Expand All @@ -550,7 +562,6 @@ def _call(self,
attr_monitors: Dict[str, Tensor] = {}
state_monitors: Dict[str, Tensor] = {}
for key in monitors:

splits = key.split('.')
if len(splits) != 2:
raise ValueError(f'Every term in "monitors" must be (node.item), '
Expand Down Expand Up @@ -714,13 +725,13 @@ def set_state(self, state):
"""
if self.state is None:
if self.output_shape is not None:
check_batch_shape(self.output_shape, state.shape)
check_shape_except_batch(self.output_shape, state.shape)
self._state = bm.Variable(state) if not isinstance(state, bm.Variable) else state
if self.state_trainable:
self._train_state = bm.TrainVar(self._state[0]) # get the first elements as the initial state
self._state[:] = self._train_state # set all batch states the same
else:
check_batch_shape(self.state.shape, state.shape)
check_shape_except_batch(self.state.shape, state.shape)
if self.state.dtype != state.dtype:
raise MathError('Cannot set the state, because the dtype is not consistent: '
f'{self.state.dtype} != {state.dtype}')
Expand Down Expand Up @@ -937,7 +948,7 @@ def set_output_shape(self, shape: Dict[str, Sequence[int]]):
self._output_shape = shape
else:
for val in shape.values():
check_batch_shape(val, self.output_shape)
check_shape_except_batch(val, self.output_shape)

def init_ff_conn(self):
"""Initialize the feedforward connections of the network.
Expand Down Expand Up @@ -1025,7 +1036,7 @@ def _init_fb_output(self, num_batch=1):
node._init_fb_output(num_batch)
self._is_fb_state_initialized = True

def initialize(self, num_batch: int):
def initialize(self, num_batch: int = 1):
"""
Initialize the whole network. This function must be called before applying JIT.
Expand Down Expand Up @@ -1097,35 +1108,41 @@ def _check_inputs(self, ff, fb=None):
# feedforward inputs
if isinstance(ff, (bm.ndarray, jnp.ndarray)):
ff = {self.entry_nodes[0].name: ff}
assert isinstance(ff, dict), f'ff must be a dict or a tensor, got {type(ff)}: {ff}'
assert len(self.entry_nodes) == len(ff), (f'This network has {len(self.entry_nodes)} '
f'entry nodes. While only {len(ff)} input '
f'data are given.')
if not isinstance(ff, dict):
raise ValueError(f'ff must be a dict or a tensor, got {type(ff)}: {ff}')
if len(self.entry_nodes) != len(ff):
raise ValueError(f'This network has {len(self.entry_nodes)} '
f'entry nodes. While only {len(ff)} input '
f'data are given.')
for n in self.entry_nodes:
if n.name not in ff:
raise ValueError(f'Cannot find the input of the node: \n{n}')
for k, size in self._feedforward_shapes.items():
assert k in ff, f"The required key {k} is not provided in feedforward inputs."
if not check_batch_shape(size, ff[k].shape, mode='bool'):
if k not in ff:
raise ValueError(f"The required key {k} is not provided in feedforward inputs.")
if not check_shape_except_batch(size, ff[k].shape, mode='bool'):
raise ValueError(f'Input size {ff[k].shape} is not consistent with '
f'the input size {size}')

# feedback inputs
if fb is not None:
if isinstance(fb, (bm.ndarray, jnp.ndarray)):
fb = {self.feedback_nodes[0]: fb}
assert isinstance(fb, dict), (f'fb must be a dict or a tensor, '
f'got {type(fb)}: {fb}')
assert len(self.feedback_nodes) == len(fb), (f'This network has {len(self.feedback_nodes)} '
f'feedback nodes. While only {len(ff)} '
f'feedback data are given.')
if not isinstance(fb, dict):
raise ValueError(f'fb must be a dict or a tensor, '
f'got {type(fb)}: {fb}')
if len(self.feedback_nodes) != len(fb):
raise ValueError(f'This network has {len(self.feedback_nodes)} '
f'feedback nodes. While only {len(ff)} '
f'feedback data are given.')
for n in self.feedback_nodes:
if n.name not in fb:
raise ValueError(f'Cannot find the feedback data from the node {n}')
# check feedback consistency
for k, size in self._feedback_shapes.items():
assert k in fb, f"The required key {k} is not provided in feedback inputs."
check_batch_shape(size, fb[k].shape)
if k not in fb:
raise ValueError(f"The required key {k} is not provided in feedback inputs.")
check_shape_except_batch(size, fb[k].shape)

# data transformation
ff = self.data_pass_func(ff)
Expand Down
2 changes: 2 additions & 0 deletions brainpy/nn/runners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
for various neural networks.
The supported training algorithms include
- offline training methods, like ridge regression, linear regression, etc.
- online training methods, like recursive least squares (RLS, or Force Learning),
least mean squares (LMS), etc.
- back-propagation learning method
- and others
The supported neural networks include
- reservoir computing networks,
- artificial recurrent neural networks,
- and others.
Expand Down
Loading

0 comments on commit e52e104

Please sign in to comment.