Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added default params rng to .apply #3698

Merged
merged 1 commit into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,9 +1074,19 @@ def apply(
def wrapper(
variables: VariableDict,
*args,
rngs: Optional[RNGSequences] = None,
rngs: Optional[Union[PRNGKey, RNGSequences]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rngs: Optional[Union[PRNGKey, RNGSequences]] = None,
rngs: PRNGKey | RNGSequences | None = None,

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't this fail Github CI for python 3.9?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

our minimum Python version is still 3.9.

**kwargs,
) -> Union[Any, Tuple[Any, Union[VariableDict, Dict[str, Any]]]]:
if rngs is not None:
if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs):
if not _is_valid_rng(rngs):

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are two separate functions: _is_valid_rng checks if the rng key rngs is valid, and _is_valid_rngs checks if the dictionary mapping rngs is valid (recursively)

raise ValueError(
'The ``rngs`` argument passed to an apply function should be a '
'``jax.PRNGKey`` or a dictionary mapping strings to '
'``jax.PRNGKey``.'
)
if not isinstance(rngs, (dict, FrozenDict)):
rngs = {'params': rngs}

# Try to detect if user accidentally passed {'params': {'params': ...}.
if (
'params' in variables
Expand Down Expand Up @@ -1118,10 +1128,10 @@ def wrapper(rngs, *args, **kwargs) -> Tuple[Any, VariableDict]:
if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs):
raise ValueError(
'First argument passed to an init function should be a '
'`jax.PRNGKey` or a dictionary mapping strings to '
'`jax.PRNGKey`.'
'``jax.PRNGKey`` or a dictionary mapping strings to '
'``jax.PRNGKey``.'
)
if not isinstance(rngs, dict):
if not isinstance(rngs, (dict, FrozenDict)):
rngs = {'params': rngs}
init_flags = {**(flags if flags is not None else {}), 'initializing': True}
return apply(fn, mutable=mutable, flags=init_flags)(
Expand Down Expand Up @@ -1217,7 +1227,7 @@ def _is_valid_rng(rng: Array):
return True


def _is_valid_rngs(rngs: RNGSequences):
def _is_valid_rngs(rngs: Union[PRNGKey, RNGSequences]):
if not isinstance(rngs, (FrozenDict, dict)):
return False
for key, val in rngs.items():
Expand Down
10 changes: 9 additions & 1 deletion flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2041,7 +2041,7 @@ def apply(
self,
variables: VariableDict,
*args,
rngs: Optional[RNGSequences] = None,
rngs: Optional[Union[PRNGKey, RNGSequences]] = None,
method: Union[Callable[..., Any], str, None] = None,
mutable: CollectionFilter = False,
capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False,
Expand Down Expand Up @@ -2115,6 +2115,14 @@ def apply(
"""
Module._module_checks(self)

if rngs is not None and not isinstance(rngs, dict):
if not core.scope._is_valid_rng(rngs):
raise errors.InvalidRngError(
'RNGs should be of shape (2,) or PRNGKey in Module '
f'{self.__class__.__name__}, but rngs are: {rngs}'
)
rngs = {'params': rngs}

if isinstance(method, str):
attribute_name = method
method = getattr(self, attribute_name)
Expand Down
37 changes: 37 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,43 @@ def __call__(self, x):
trace = mlp.apply(variables, x)
self.assertEqual(trace, expected_trace)

def test_default_params_rng_equivalence(self):
class Model(nn.Module):
@nn.compact
def __call__(self, x, add_dropout=False, add_noise=False):
x = nn.Dense(16)(x)
x = nn.Dropout(0.5)(x, deterministic=not add_dropout)
if add_noise:
x += jax.random.normal(self.make_rng('params'))
return x

model = Model()
key0, key1, key2 = jax.random.split(jax.random.key(0), 3)
x = jax.random.normal(key0, (10, 8))

with self.assertRaisesRegex(ValueError, 'First argument passed to an init function should be a ``jax.PRNGKey``'):
model.init({'params': 'test'}, x)
with self.assertRaisesRegex(errors.InvalidRngError, 'RNGs should be of shape \\(2,\\) or PRNGKey in Module Model, but rngs are: test'):
model.init('test', x)
with self.assertRaisesRegex(errors.InvalidRngError, 'Dropout_0 needs PRNG for "dropout"'):
model.init(key1, x, add_dropout=True)

v = model.init({'params': key1}, x)
v2 = model.init(key1, x)
jax.tree_map(np.testing.assert_allclose, v, v2)

out = model.apply(v, x, add_noise=True, rngs={'params': key2})
out2 = model.apply(v, x, add_noise=True, rngs=key2)
np.testing.assert_allclose(out, out2)

with self.assertRaisesRegex(ValueError, 'The ``rngs`` argument passed to an apply function should be a ``jax.PRNGKey``'):
model.apply(v, x, rngs={'params': 'test'})
with self.assertRaisesRegex(errors.InvalidRngError, 'RNGs should be of shape \\(2,\\) or PRNGKey in Module Model, but rngs are: test'):
model.apply(v, x, rngs='test')
with self.assertRaisesRegex(errors.InvalidRngError, 'Dropout_0 needs PRNG for "dropout"'):
model.apply(v, x, add_dropout=True, rngs=key2)


def test_module_apply_method(self):
class Foo(nn.Module):
not_callable: int = 1
Expand Down
Loading