Skip to content

Commit

Permalink
fix dependency on brainpy
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Apr 5, 2024
1 parent 949456a commit 2b58c49
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
26 changes: 16 additions & 10 deletions braintools/optim/_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import braincore as bc
import jax
import jax.numpy as jnp
from brainpy import check
from brainpy.errors import MathError

__all__ = [
'LRScheduler',
Expand Down Expand Up @@ -163,8 +161,10 @@ def __init__(
):
super().__init__(lr=lr, last_epoch=last_epoch)

self.step_size = check.is_integer(step_size, min_bound=1, allow_none=False)
self.gamma = check.is_float(gamma, min_bound=0., max_bound=1., allow_int=False)
assert step_size >= 1, 'step_size should be greater than or equal to 1.'
assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
self.step_size = step_size
self.gamma = gamma

def __call__(self, i=None):
i = (self.last_epoch.value + 1) if i is None else i
Expand Down Expand Up @@ -202,9 +202,13 @@ def __init__(
):
super().__init__(lr=lr, last_epoch=last_epoch)

milestones = check.is_sequence(milestones, elem_type=int, allow_none=False)
assert len(milestones) > 0, 'milestones should be a non-empty sequence.'
assert all([milestones[i] < milestones[i + 1] for i in range(len(milestones) - 1)]), (
'milestones should be a sequence of increasing integers.'
)
assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
self.milestones = jnp.asarray((-1,) + tuple(milestones), dtype=bc.environ.ditype())
self.gamma = check.is_float(gamma, min_bound=0., max_bound=1., allow_int=False)
self.gamma = gamma

def __call__(self, i=None):
i = (self.last_epoch.value + 1) if i is None else i
Expand Down Expand Up @@ -268,8 +272,9 @@ def __init__(
):
super().__init__(lr=lr, last_epoch=last_epoch)

assert T_max >= 1, 'T_max should be greater than or equal to 1.'
self._init_epoch = last_epoch
self.T_max = check.is_integer(T_max, min_bound=1)
self.T_max = T_max
self.eta_min = eta_min

def __call__(self, i=None):
Expand Down Expand Up @@ -382,7 +387,8 @@ def __init__(self,
gamma: float,
last_epoch: int = -1):
super(ExponentialLR, self).__init__(lr=lr, last_epoch=last_epoch)
self.gamma = check.is_float(gamma, min_bound=0., max_bound=1.)
assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
self.gamma = gamma

def __call__(self, i: int = None):
i = (self.last_epoch.value + 1) if i is None else i
Expand Down Expand Up @@ -447,9 +453,9 @@ def __init__(self, boundaries, values, last_epoch: int = -1, last_call: int = -1
boundaries = jnp.array(boundaries)
values = jnp.array(values)
if not boundaries.ndim == values.ndim == 1:
raise MathError("boundaries and values must be sequences")
raise ValueError("boundaries and values must be sequences")
if not boundaries.shape[0] == values.shape[0] - 1:
raise MathError("boundaries length must be one shorter than values length")
raise ValueError("boundaries length must be one shorter than values length")
self.boundaries = boundaries
self.values = values

Expand Down
4 changes: 1 addition & 3 deletions braintools/optim/_sgd_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
# -*- coding: utf-8 -*-

import copy
import functools
from typing import Union, Dict, Optional, Tuple, Any, TypeVar

import braincore as bc
import jax
import jax.numpy as jnp
from brainpy import check

from ._lr_scheduler import make_schedule, LRScheduler

Expand Down Expand Up @@ -157,7 +155,7 @@ def __init__(
):
super().__init__(lr=lr, name=name)
self.lr: LRScheduler = make_schedule(lr)
weight_decay = check.is_float(weight_decay, min_bound=0., max_bound=1., allow_none=True)
assert weight_decay is None or 0. <= weight_decay <= 1., 'weight_decay must be in [0, 1].'
self.weight_decay = (fcast(weight_decay) if weight_decay is not None else None)

def extra_repr(self) -> str:
Expand Down

0 comments on commit 2b58c49

Please sign in to comment.