-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Force no split in make_rng
#3115
Comments
Hey @zaccharieramzi, I've converted the discussion into and issue as it seems something that we should improve. # A simple network.
class MyModel(nn.Module):
num_neurons: int
training: bool
@nn.compact
def __call__(self, x):
x = nn.Dense(self.num_neurons)(x)
# Set the dropout layer with a rate of 50% .
# When the `deterministic` flag is `True`, dropout is turned off.
key = self.make_rng('dropout')
x = nn.Dropout(rate=0.5, deterministic=not self.training)(x, rng=key)
x = nn.Dropout(rate=0.5, deterministic=not self.training)(x, rng=key)
return This way both layers will produce the same mask. |
Would there be a way to propagate this information rather than having to pass it around to each dropout? Something like: key = self.make_rng('dropout')
x = MyModule(...)(x, rng=key) where originally |
ofc I understand it might be way more complex to do, so it's really just a question |
@cgarciae I see that this was closed so maybe you missed my earlier question. Typically in modules like |
FYI @zaccharieramzi, I added a |
Discussed in #3113
Originally posted by zaccharieramzi May 24, 2023
I have the following situation: I am using a
Dropout
layer multiple times without ann.scan
ornn.while_loop
, therefore I cannot usesplit_rngs={"dropout": False}
.However, I would still like to use the same dropout mask twice.
Is it possible to specify "no split" to make rng for certain collections?
If I just take the original dropout example I would like to do something like:
and still have
jnp.sum(y == 0.) / (3*4*3) == 0.5
approx.For more context I am actually trying to implement Deep Equilibrium Models using
jaxopt
andflax
, where the fixed point defining function uses dropout.I also tried to see if the
split_rngs
functionality could be extended tojaxopt
but I think it's going to be difficult.The text was updated successfully, but these errors were encountered: