How does make_rng() work in Flax/JAX? #1998
-
(Original question from @lucasb-eyer, rephrased by me) How does Flax Linen's Module.make_rng() work? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
(Original answer by @levskaya, rephrased by me): When you call Module.init() or Module.apply(), you pass them a set of random keys, which are treated as "root" keys (note that init allows a single key as shorthand for |
Beta Was this translation helpful? Give feedback.
-
is there a good example of make_rng being used? |
Beta Was this translation helpful? Give feedback.
(Original answer by @levskaya, rephrased by me):
When you call Module.init() or Module.apply(), you pass them a set of random keys, which are treated as "root" keys (note that init allows a single key as shorthand for
{'params': rng}
). At each submodule boundary we fold-in a hash of the submodule name for each of these root keys, and within a submodule whereself.make_rng('foo')
is called we keep track of a counter that's also folded in to guarantee uniqueness of the key with each call.