Skip to content

Commit

Permalink
updated rng docstring for init and apply
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed Mar 18, 2024
1 parent 718aa8c commit 76f6c46
Showing 1 changed file with 126 additions and 15 deletions.
141 changes: 126 additions & 15 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1897,7 +1897,26 @@ def make_rng(self, name: str = 'params') -> PRNGKey:
the user in ``.init`` or ``.apply`` for this name), then ``name``
will default to ``'params'``.
TODO: Link to Flax RNG design note.
Example::
>>> import jax
>>> import flax.linen as nn
>>> class ParamsModule(nn.Module):
... def __call__(self):
... return self.make_rng('params')
>>> class OtherModule(nn.Module):
... def __call__(self):
... return self.make_rng('other')
>>> key = jax.random.key(0)
>>> params_out, _ = ParamsModule().init_with_output({'params': key})
>>> # self.make_rng('other') will default to using the 'params' RNG stream
>>> other_out, _ = OtherModule().init_with_output({'params': key})
>>> assert params_out == other_out
Learn more about RNG's by reading the Flax RNG guide:
https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html
Args:
name: The RNG sequence name.
Expand Down Expand Up @@ -2061,6 +2080,10 @@ def apply(
Transformer modules has a method called ``encode``, then the following calls
``apply`` on that method::
>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> import numpy as np
>>> class Transformer(nn.Module):
... def encode(self, x):
... ...
Expand Down Expand Up @@ -2092,6 +2115,50 @@ def apply(
>>> model.apply(variables, x, method=other_fn)
If you pass a single ``PRNGKey``, Flax will use it to feed the ``'params'``
RNG stream. If you want to use a different RNG stream or need to use
multiple streams, you can pass a dictionary mapping each RNG stream name
to its corresponding ``PRNGKey`` to ``apply``. If ``self.make_rng(name)``
is called on an RNG stream name that isn't passed by the user, it will
default to using the ``'params'`` RNG stream.
Example::
>>> class Foo(nn.Module):
... @nn.compact
... def __call__(self, x, add_noise=False):
... x = nn.Dense(16)(x)
... x = nn.relu(x)
...
... if add_noise:
... # Add gaussian noise
... noise_key = self.make_rng('noise')
... x = x + jax.random.normal(noise_key, x.shape)
...
... return nn.Dense(1)(x)
>>> x = jnp.empty((1, 7))
>>> module = Foo()
>>> rngs = {'params': jax.random.key(0), 'noise': jax.random.key(1)}
>>> variables = module.init(rngs, x)
>>> out0 = module.apply(variables, x, add_noise=True, rngs=rngs)
>>> rngs['noise'] = jax.random.key(0)
>>> out1 = module.apply(variables, x, add_noise=True, rngs=rngs)
>>> # different output (key(1) vs key(0))
>>> np.testing.assert_raises(AssertionError, np.testing.assert_allclose, out0, out1)
>>> del rngs['noise']
>>> # self.make_rng('noise') will default to using the 'params' RNG stream
>>> out2 = module.apply(variables, x, add_noise=True, rngs=rngs)
>>> # same output (key(0))
>>> np.testing.assert_allclose(out1, out2)
>>> # passing in a single key is equivalent to passing in {'params': key}
>>> out3 = module.apply(variables, x, add_noise=True, rngs=jax.random.key(0))
>>> # same output (key(0))
>>> np.testing.assert_allclose(out2, out3)
Args:
variables: A dictionary containing variables keyed by variable
collections. See :mod:`flax.core.variables` for more details about
Expand Down Expand Up @@ -2240,8 +2307,8 @@ def init(
Example::
>>> import flax.linen as nn
>>> import jax.numpy as jnp
>>> import jax
>>> import jax, jax.numpy as jnp
>>> import numpy as np
>>> class Foo(nn.Module):
... @nn.compact
Expand All @@ -2251,42 +2318,86 @@ def init(
... x = nn.relu(x)
... return nn.Dense(1)(x)
>>> x = jnp.empty((1, 7))
>>> module = Foo()
>>> key = jax.random.key(0)
>>> variables = module.init(key, jnp.empty((1, 7)), train=False)
>>> variables = module.init(key, x, train=False)
If you pass a single ``PRNGKey``, Flax will use it to feed the ``'params'``
RNG stream. If you want to use a different RNG stream or need to use
multiple streams, you must pass a dictionary mapping each RNG stream name
to its corresponding ``PRNGKey`` to ``init``.
multiple streams, you can pass a dictionary mapping each RNG stream name
to its corresponding ``PRNGKey`` to ``init``. If ``self.make_rng(name)``
is called on an RNG stream name that isn't passed by the user, it will
default to using the ``'params'`` RNG stream.
Example::
>>> class Foo(nn.Module):
... @nn.compact
... def __call__(self, x, train):
... def __call__(self, x):
... x = nn.Dense(16)(x)
... x = nn.BatchNorm(use_running_average=not train)(x)
... x = nn.relu(x)
...
... # Add gaussian noise
... noise_key = self.make_rng('noise')
... x = x + jax.random.normal(noise_key, x.shape)
... other_variable = self.variable(
... 'other_collection',
... 'other_variable',
... lambda x: jax.random.normal(self.make_rng('other_rng'), x.shape),
... x,
... )
... x = x + other_variable.value
...
... return nn.Dense(1)(x)
>>> module = Foo()
>>> rngs = {'params': jax.random.key(0),
... 'noise': jax.random.key(1)}
>>> variables = module.init(rngs, jnp.empty((1, 7)), train=False)
>>> rngs = {'params': jax.random.key(0), 'other_rng': jax.random.key(1)}
>>> variables0 = module.init(rngs, x)
>>> rngs['other_rng'] = jax.random.key(0)
>>> variables1 = module.init(rngs, x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
... np.testing.assert_allclose, variables0['params'], variables1['params']
... )
>>> # different other_variable (key(1) vs key(0))
>>> np.testing.assert_raises(
... AssertionError,
... np.testing.assert_allclose,
... variables0['other_collection']['other_variable'],
... variables1['other_collection']['other_variable'],
... )
>>> del rngs['other_rng']
>>> # self.make_rng('other_rng') will default to using the 'params' RNG stream
>>> variables2 = module.init(rngs, x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
... np.testing.assert_allclose, variables1['params'], variables2['params']
... )
>>> # equivalent other_variable (key(0))
>>> np.testing.assert_allclose(
... variables1['other_collection']['other_variable'],
... variables2['other_collection']['other_variable'],
... )
>>> # passing in a single key is equivalent to passing in {'params': key}
>>> variables3 = module.init(jax.random.key(0), x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
... np.testing.assert_allclose, variables2['params'], variables3['params']
... )
>>> # equivalent other_variable (key(0))
>>> np.testing.assert_allclose(
... variables2['other_collection']['other_variable'],
... variables3['other_collection']['other_variable'],
... )
Jitting ``init`` initializes a model lazily using only the shapes of the
provided arguments, and avoids computing the forward pass with actual
values. Example::
>>> module = nn.Dense(1)
>>> init_jit = jax.jit(module.init)
>>> variables = init_jit(jax.random.key(0), jnp.empty((1, 7)))
>>> variables = init_jit(jax.random.key(0), x)
``init`` is a light wrapper over ``apply``, so other ``apply`` arguments
like ``method``, ``mutable``, and ``capture_intermediates`` are also
Expand Down

0 comments on commit 76f6c46

Please sign in to comment.