Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force no split in make_rng #3115

Open
cgarciae opened this issue May 24, 2023 Discussed in #3113 · 5 comments · Fixed by #3114
Open

Force no split in make_rng #3115

cgarciae opened this issue May 24, 2023 Discussed in #3113 · 5 comments · Fixed by #3114

Comments

@cgarciae
Copy link
Collaborator

Discussed in #3113

Originally posted by zaccharieramzi May 24, 2023
I have the following situation: I am using a Dropout layer multiple times without a nn.scan or nn.while_loop, therefore I cannot use split_rngs={"dropout": False}.
However, I would still like to use the same dropout mask twice.

Is it possible to specify "no split" to make rng for certain collections?

If I just take the original dropout example I would like to do something like:

# Setup.
import jax
import jax.numpy as jnp
import flax.linen as nn

# Randomness.
seed = 0
root_key = jax.random.PRNGKey(seed=seed)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)

# A simple network.
class MyModel(nn.Module):
  num_neurons: int
  training: bool
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(self.num_neurons)(x)
    # Set the dropout layer with a rate of 50% .
    # When the `deterministic` flag is `True`, dropout is turned off.
    x = nn.Dropout(rate=0.5, deterministic=not self.training)(x)
    x = nn.Dropout(rate=0.5, deterministic=not self.training)(x)
    return x

# Instantiate `MyModel` (you don't need to set `training=True` to
# avoid performing the forward pass computation).
my_model = MyModel(num_neurons=3, training=False)

x = jax.random.uniform(key=main_key, shape=(3, 4, 4))

# Initialize with `flax.linen.init()`.
# The `params_key` is equivalent to a dictionary of PRNGs.
# (Here, you are providing only one PRNG key.) 
variables = my_model.init(params_key, x)

# Perform the forward pass with `flax.linen.apply()`.
my_model.training = True
y = my_model.apply(variables, x, rngs={'dropout': dropout_key})

and still have jnp.sum(y == 0.) / (3*4*3) == 0.5 approx.

For more context I am actually trying to implement Deep Equilibrium Models using jaxopt and flax, where the fixed point defining function uses dropout.
I also tried to see if the split_rngs functionality could be extended to jaxopt but I think it's going to be difficult.

@cgarciae
Copy link
Collaborator Author

Hey @zaccharieramzi, I've converted the discussion into and issue as it seems something that we should improve.
I've created #3114, which would allow you to optionally specify the rng key for each Dropout layer, e.g:

# A simple network.
class MyModel(nn.Module):
  num_neurons: int
  training: bool
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(self.num_neurons)(x)
    # Set the dropout layer with a rate of 50% .
    # When the `deterministic` flag is `True`, dropout is turned off.
    key = self.make_rng('dropout')
    x = nn.Dropout(rate=0.5, deterministic=not self.training)(x, rng=key)
    x = nn.Dropout(rate=0.5, deterministic=not self.training)(x, rng=key)
    return 

This way both layers will produce the same mask.

@zaccharieramzi
Copy link

Would there be a way to propagate this information rather than having to pass it around to each dropout?
Indeed, in my case I would need to do key = self.make_rng("dropout") and pass it down to the actual dropout layers which are nested deep in different nn.Modules.

Something like:

key = self.make_rng('dropout')
x = MyModule(...)(x, rng=key)

where originally MyModule does not have the rng parameter in its API.

@zaccharieramzi
Copy link

ofc I understand it might be way more complex to do, so it's really just a question

@zaccharieramzi
Copy link

@cgarciae I see that this was closed so maybe you missed my earlier question. Typically in modules like dot_product_attention the dropout is hardcoded without the possibility to set the rng.
Do you think it's best then to reimplement all these modules with the possibility to pass the rng?

@chiamp
Copy link
Collaborator

chiamp commented Nov 1, 2023

FYI @zaccharieramzi, I added a dropout_arg to nn.MultiHeadDotProductAttention in #3384 so you can get the same dropout mask

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants