-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added default params rng to .apply #3698
Conversation
@@ -1074,9 +1074,19 @@ def apply( | |||
def wrapper( | |||
variables: VariableDict, | |||
*args, | |||
rngs: Optional[RNGSequences] = None, | |||
rngs: Optional[Union[PRNGKey, RNGSequences]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rngs: Optional[Union[PRNGKey, RNGSequences]] = None, | |
rngs: PRNGKey | RNGSequences | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
won't this fail Github CI for python 3.9?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
our minimum Python version is still 3.9.
**kwargs, | ||
) -> Union[Any, Tuple[Any, Union[VariableDict, Dict[str, Any]]]]: | ||
if rngs is not None: | ||
if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs): | |
if not _is_valid_rng(rngs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they are two separate functions: _is_valid_rng
checks if the rng key rngs
is valid, and _is_valid_rngs
checks if the dictionary mapping rngs
is valid (recursively)
Added default params rng to
.apply
.Similarly to how you can get the same behavior by doing the following with
.init
:This PR allows you to do the same with
.apply
: