-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] Rngs and RngStream inherit from GraphNode #3793
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3793 +/- ##
==========================================
+ Coverage 60.16% 60.20% +0.03%
==========================================
Files 101 101
Lines 12842 12859 +17
==========================================
+ Hits 7727 7742 +15
- Misses 5115 5117 +2 ☔ View full report in Codecov by Sentry. |
3b668b5
to
e0ea88e
Compare
e0ea88e
to
fbb7d1e
Compare
def __init__(self, not_rngs): | ||
rngs = not_rngs | ||
self.linear = nnx.Linear(2, 2, rngs=rngs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __init__(self, not_rngs): | |
rngs = not_rngs | |
self.linear = nnx.Linear(2, 2, rngs=rngs) | |
def __init__(self, rngs): | |
self.linear = nnx.Linear(2, 2, rngs=rngs) |
Small nit: What is the purpose of naming the rng arg not_rngs
when it's reassigned to the rngs
variable and passed in as an rng key to an nnx.Module
anyway?
def f(m: Foo, x: jax.Array, not_rngs: nnx.Rngs): | ||
rngs = not_rngs | ||
x = m(x, rngs) | ||
x = m(x, rngs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def f(m: Foo, x: jax.Array, not_rngs: nnx.Rngs): | |
rngs = not_rngs | |
x = m(x, rngs) | |
x = m(x, rngs) | |
def f(m: Foo, x: jax.Array, rngs: nnx.Rngs): | |
x = m(x, rngs) | |
x = m(x, rngs) |
Same here.
fbb7d1e
to
e9a83a0
Compare
e9a83a0
to
1c75d1a
Compare
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
1c75d1a
to
453f38b
Compare
What does this PR do?
Rngs
andRngStream
are nowGraphNode
s, this means their state can be more easily tracked by transforms.RngState(Variable)
type.RngStream.key
andRngStream.count
are nowRngState
variables. This also means thatcount
is now a dynamic property.GraphNode.check_valid_context
method that can be used to check that the GraphNode is in the appropriate context.Rngs._trace_state
andRngs.is_valid
in favor of usingGraphNode.check_valid_context
.nnx.jit
no longer has a special rule forRngs
, they are just handled asGraphNode
s if present.other
ModuleMeta(GraphNodeMeta)
and passed all dataclass post-init logic (e.g. calling.setup
or settingrngs = None
) toModuleMeta
.ModuleMeta
.