Skip to content

How does make_rng() work in Flax/JAX? #1998

Answered by marcvanzee
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

(Original answer by @levskaya, rephrased by me):

When you call Module.init() or Module.apply(), you pass them a set of random keys, which are treated as "root" keys (note that init allows a single key as shorthand for {'params': rng}). At each submodule boundary we fold-in a hash of the submodule name for each of these root keys, and within a submodule where self.make_rng('foo') is called we keep track of a counter that's also folded in to guarantee uniqueness of the key with each call.

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Comment options

You must be logged in to vote
2 replies
@marcvanzee
Comment options

@wbrenton
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants