-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added default params rng to .apply #3698
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -1074,9 +1074,19 @@ def apply( | |||||
def wrapper( | ||||||
variables: VariableDict, | ||||||
*args, | ||||||
rngs: Optional[RNGSequences] = None, | ||||||
rngs: Optional[Union[PRNGKey, RNGSequences]] = None, | ||||||
**kwargs, | ||||||
) -> Union[Any, Tuple[Any, Union[VariableDict, Dict[str, Any]]]]: | ||||||
if rngs is not None: | ||||||
if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. they are two separate functions: |
||||||
raise ValueError( | ||||||
'The ``rngs`` argument passed to an apply function should be a ' | ||||||
'``jax.PRNGKey`` or a dictionary mapping strings to ' | ||||||
'``jax.PRNGKey``.' | ||||||
) | ||||||
if not isinstance(rngs, (dict, FrozenDict)): | ||||||
rngs = {'params': rngs} | ||||||
|
||||||
# Try to detect if user accidentally passed {'params': {'params': ...}. | ||||||
if ( | ||||||
'params' in variables | ||||||
|
@@ -1118,10 +1128,10 @@ def wrapper(rngs, *args, **kwargs) -> Tuple[Any, VariableDict]: | |||||
if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs): | ||||||
raise ValueError( | ||||||
'First argument passed to an init function should be a ' | ||||||
'`jax.PRNGKey` or a dictionary mapping strings to ' | ||||||
'`jax.PRNGKey`.' | ||||||
'``jax.PRNGKey`` or a dictionary mapping strings to ' | ||||||
'``jax.PRNGKey``.' | ||||||
) | ||||||
if not isinstance(rngs, dict): | ||||||
if not isinstance(rngs, (dict, FrozenDict)): | ||||||
rngs = {'params': rngs} | ||||||
init_flags = {**(flags if flags is not None else {}), 'initializing': True} | ||||||
return apply(fn, mutable=mutable, flags=init_flags)( | ||||||
|
@@ -1217,7 +1227,7 @@ def _is_valid_rng(rng: Array): | |||||
return True | ||||||
|
||||||
|
||||||
def _is_valid_rngs(rngs: RNGSequences): | ||||||
def _is_valid_rngs(rngs: Union[PRNGKey, RNGSequences]): | ||||||
if not isinstance(rngs, (FrozenDict, dict)): | ||||||
return False | ||||||
for key, val in rngs.items(): | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
won't this fail Github CI for python 3.9?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
our minimum Python version is still 3.9.