Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] add reseed #4099

Merged
merged 1 commit into from
Jul 22, 2024
Merged

[nnx] add reseed #4099

merged 1 commit into from
Jul 22, 2024

Conversation

cgarciae
Copy link
Collaborator

What does this PR do?

Adds nnx.reseed to update the nested keys of a graph node.

class Model(nnx.Module):
  def __init__(self, rngs):
    self.linear = nnx.Linear(2, 3, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)

  def __call__(self, x):
    return self.dropout(self.linear(x))

model = Model(nnx.Rngs(params=0, dropout=42))
x = jnp.ones((1, 2))

y1 = model(x)

# reset the ``dropout`` stream key to 42
nnx.reseed(model, dropout=42)
y2 = model(x)

assert jnp.allclose(y1, y2)

@copybara-service copybara-service bot merged commit a7bdadb into main Jul 22, 2024
18 checks passed
@copybara-service copybara-service bot deleted the nnx-reseed branch July 22, 2024 23:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants