-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] add Randomness guide #4216
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
de289a6
to
d9bd85c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Very informative!
docs_nnx/guides/randomness.md
Outdated
``` | ||
|
||
## Rngs, RngStream, and RngState | ||
Flax NNX provides the `nnx.Rngs` type a the main convenice API for managing random state. Following Flax Linen's footsteps, `Rngs` has the ability to create multiple named RNG streams, each with its own state, for the purpose of allowing for tight control over randomness in the context of JAX transforms. Here's a breakdown of the main RNG-related types in Flax NNX: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
convenice -> convenience
for the purpose of allowing for tight control -> for the purpose of tight control
## Rngs, RngStream, and RngState | ||
Flax NNX provides the `nnx.Rngs` type a the main convenice API for managing random state. Following Flax Linen's footsteps, `Rngs` has the ability to create multiple named RNG streams, each with its own state, for the purpose of allowing for tight control over randomness in the context of JAX transforms. Here's a breakdown of the main RNG-related types in Flax NNX: | ||
|
||
* **Rngs**: The main user interface. It defines a set of named `RngStream` objects. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add nnx.
to all these titles? like nnx.Rngs, nnx.RngStream, nnx.RngState, nnx.RngKey, nnx.RngCount.
docs_nnx/guides/randomness.md
Outdated
Note that the `key` attribute does not change when a new keys are generated. | ||
|
||
### Standard stream names | ||
There are only two standard stream names used by Flax NNX, shown in the table below: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
used by Flax NNX -> used by Flax NNX built-in layers
Just to avoid user confusion that other names are not allowed.
docs_nnx/guides/randomness.md
Outdated
| `params` | Used for parameter initialization | | ||
| `dropout`| Used by `Dropout` to create dropout masks | | ||
|
||
`params` is used my most of the standard layers (`Linear`, `Conv`, `MultiHeadAttention`, etc.) during construction to initialize their parameters. `dropout` is used by the `Dropout` and `MultiHeadAttention` to generate dropout masks. Here's a simple example of a model using `params` and `dropout` streams: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: etc.
-> etc
docs_nnx/guides/randomness.md
Outdated
print(e) | ||
``` | ||
|
||
The other option is to use the `nnx.split_rngs` decorator which will automatically split the random state of any `RngStream`s found in the inputs of the function, and will automatically "lower" them once the function call ends so the `Rngs` can be used outside again. `split_rngs` allows passing Filters to the `only` keyword argument to select the `RngStream`s that should be split. Using `split_rngs` is useful in combination with a transform but here we will show a simple example without any transforms to illustrate the concept, we'll use `split_rngs` with a transform on the next section. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a feeling that we should introduce and suggest nnx.split_rngs
first, since it works better with transform cases and doesn't have the drawback that splitting in stream has (aka. can't work outside transform).
Actually, if there isn't any case in which splitting in stream is better, maybe it's fine to simply avoid introducing that... or introducing that as a lower-level view of what nnx.split_rngs
actually does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I'll think about it. We might as well just delete the other example as stick with split_rngs
.
The main motivation of showing how to manually lift was to make it more familiar as the operation is explicit but it might be less used.
docs_nnx/guides/randomness.md
Outdated
``` | ||
|
||
## Transforms | ||
As stated before, in Flax NNX random state is just another type of state, this means that there is nothing special about it regarding transforms. This means that you should be able to use the state handling APIs of each transform to get the results you want. In this section we will two examples of using random state in transforms, one with `pmap` and another one with `scan`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one with
pmap
and another one withscan
.
What about this to highlight the intention of both examples:
one with pmap
to split multiple RNG keys and one with scan
to broadcast a single RNG key.
d9bd85c
to
54fde93
Compare
54fde93
to
65d4193
Compare
What does this PR do?
Adds Randoness guide.
Preview