Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] Rngs and RngStream inherit from GraphNode #3793

Merged
merged 1 commit into from
Apr 1, 2024

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Mar 28, 2024

What does this PR do?

  • Rngs and RngStream are now GraphNodes, this means their state can be more easily tracked by transforms.
  • Added a RngState(Variable) type.
  • RngStream.key and RngStream.count are now RngState variables. This also means that count is now a dynamic property.
  • Added a general GraphNode.check_valid_context method that can be used to check that the GraphNode is in the appropriate context.
  • Removed Rngs._trace_state and Rngs.is_valid in favor of using GraphNode.check_valid_context.
  • nnx.jit no longer has a special rule for Rngs, they are just handled as GraphNodes if present.

other

  • Created a ModuleMeta(GraphNodeMeta) and passed all dataclass post-init logic (e.g. calling .setup or setting rngs = None) to ModuleMeta.
  • Transform combinator metaclasses now inherit from ModuleMeta.

@codecov-commenter
Copy link

codecov-commenter commented Mar 28, 2024

Codecov Report

Attention: Patch coverage is 96.80851% with 3 lines in your changes are missing coverage. Please review.

Project coverage is 60.20%. Comparing base (514c111) to head (453f38b).

Files Patch % Lines
flax/experimental/nnx/nnx/module.py 91.66% 2 Missing ⚠️
flax/experimental/nnx/nnx/variables.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3793      +/-   ##
==========================================
+ Coverage   60.16%   60.20%   +0.03%     
==========================================
  Files         101      101              
  Lines       12842    12859      +17     
==========================================
+ Hits         7727     7742      +15     
- Misses       5115     5117       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@cgarciae cgarciae force-pushed the nnx-rngs-are-nodes branch 5 times, most recently from 3b668b5 to e0ea88e Compare March 28, 2024 15:36
@cgarciae cgarciae marked this pull request as ready for review March 28, 2024 15:55
Comment on lines +187 to +189
def __init__(self, not_rngs):
rngs = not_rngs
self.linear = nnx.Linear(2, 2, rngs=rngs)
Copy link
Collaborator

@chiamp chiamp Mar 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, not_rngs):
rngs = not_rngs
self.linear = nnx.Linear(2, 2, rngs=rngs)
def __init__(self, rngs):
self.linear = nnx.Linear(2, 2, rngs=rngs)

Small nit: What is the purpose of naming the rng arg not_rngs when it's reassigned to the rngs variable and passed in as an rng key to an nnx.Module anyway?

Comment on lines +204 to +207
def f(m: Foo, x: jax.Array, not_rngs: nnx.Rngs):
rngs = not_rngs
x = m(x, rngs)
x = m(x, rngs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def f(m: Foo, x: jax.Array, not_rngs: nnx.Rngs):
rngs = not_rngs
x = m(x, rngs)
x = m(x, rngs)
def f(m: Foo, x: jax.Array, rngs: nnx.Rngs):
x = m(x, rngs)
x = m(x, rngs)

Same here.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@copybara-service copybara-service bot merged commit f4337c3 into main Apr 1, 2024
21 checks passed
@copybara-service copybara-service bot deleted the nnx-rngs-are-nodes branch April 1, 2024 13:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants