-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] grad accepts argnums #3798
Conversation
e6253a9
to
7321765
Compare
fc6c901
to
5f99f11
Compare
7321765
to
a670e4d
Compare
5f99f11
to
8189596
Compare
a670e4d
to
efcbbcc
Compare
8189596
to
1739c52
Compare
efcbbcc
to
4723528
Compare
4723528
to
3bebf97
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3798 +/- ##
==========================================
+ Coverage 60.34% 60.51% +0.17%
==========================================
Files 101 101
Lines 12862 12908 +46
==========================================
+ Hits 7761 7811 +50
+ Misses 5101 5097 -4 ☔ View full report in Codecov by Sentry. |
|
||
def test_multiple_graph_nodes(self): | ||
rngs = nnx.Rngs(0) | ||
m1 = nnx.Linear(2, 3, rngs=rngs) | ||
m2 = nnx.Linear(3, 3, rngs=rngs) | ||
loss_fn = lambda m1, m2, x, y: jnp.mean((m2(m1(x)) - y) ** 2) | ||
grad_fn = nnx.grad(loss_fn, argnums=(0, 1), wrt=nnx.Param) | ||
x = jax.random.uniform(rngs(), (1, 2)) | ||
y = jnp.ones((1, 3)) | ||
grads_m1, grads_m2 = grad_fn(m1, m2, x, y) | ||
|
||
assert 'kernel' in grads_m1 | ||
assert grads_m1.kernel.raw_value.shape == (2, 3) | ||
assert 'bias' in grads_m1 | ||
assert grads_m1.bias.raw_value.shape == (3,) | ||
assert 'kernel' in grads_m2 | ||
assert grads_m2.kernel.raw_value.shape == (3, 3) | ||
assert 'bias' in grads_m2 | ||
assert grads_m2.bias.raw_value.shape == (3,) | ||
|
||
def test_multiple_graph_nodes_mix_positions(self): | ||
rngs = nnx.Rngs(0) | ||
m1 = nnx.Linear(2, 3, rngs=rngs) | ||
m2 = nnx.Linear(3, 3, rngs=rngs) | ||
loss_fn = lambda x, m1, y, m2: jnp.mean((m2(m1(x)) - y) ** 2) | ||
grad_fn = nnx.grad(loss_fn, argnums=(1, 3), wrt=nnx.Param) | ||
x = jax.random.uniform(rngs(), (1, 2)) | ||
y = jnp.ones((1, 3)) | ||
grads_m1, grads_m2 = grad_fn(x, m1, y, m2) | ||
|
||
assert 'kernel' in grads_m1 | ||
assert grads_m1.kernel.raw_value.shape == (2, 3) | ||
assert 'bias' in grads_m1 | ||
assert grads_m1.bias.raw_value.shape == (3,) | ||
assert 'kernel' in grads_m2 | ||
assert grads_m2.kernel.raw_value.shape == (3, 3) | ||
assert 'bias' in grads_m2 | ||
assert grads_m2.bias.raw_value.shape == (3,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def test_multiple_graph_nodes(self): | |
rngs = nnx.Rngs(0) | |
m1 = nnx.Linear(2, 3, rngs=rngs) | |
m2 = nnx.Linear(3, 3, rngs=rngs) | |
loss_fn = lambda m1, m2, x, y: jnp.mean((m2(m1(x)) - y) ** 2) | |
grad_fn = nnx.grad(loss_fn, argnums=(0, 1), wrt=nnx.Param) | |
x = jax.random.uniform(rngs(), (1, 2)) | |
y = jnp.ones((1, 3)) | |
grads_m1, grads_m2 = grad_fn(m1, m2, x, y) | |
assert 'kernel' in grads_m1 | |
assert grads_m1.kernel.raw_value.shape == (2, 3) | |
assert 'bias' in grads_m1 | |
assert grads_m1.bias.raw_value.shape == (3,) | |
assert 'kernel' in grads_m2 | |
assert grads_m2.kernel.raw_value.shape == (3, 3) | |
assert 'bias' in grads_m2 | |
assert grads_m2.bias.raw_value.shape == (3,) | |
def test_multiple_graph_nodes_mix_positions(self): | |
rngs = nnx.Rngs(0) | |
m1 = nnx.Linear(2, 3, rngs=rngs) | |
m2 = nnx.Linear(3, 3, rngs=rngs) | |
loss_fn = lambda x, m1, y, m2: jnp.mean((m2(m1(x)) - y) ** 2) | |
grad_fn = nnx.grad(loss_fn, argnums=(1, 3), wrt=nnx.Param) | |
x = jax.random.uniform(rngs(), (1, 2)) | |
y = jnp.ones((1, 3)) | |
grads_m1, grads_m2 = grad_fn(x, m1, y, m2) | |
assert 'kernel' in grads_m1 | |
assert grads_m1.kernel.raw_value.shape == (2, 3) | |
assert 'bias' in grads_m1 | |
assert grads_m1.bias.raw_value.shape == (3,) | |
assert 'kernel' in grads_m2 | |
assert grads_m2.kernel.raw_value.shape == (3, 3) | |
assert 'bias' in grads_m2 | |
assert grads_m2.bias.raw_value.shape == (3,) | |
@parameterized.parameters( | |
{'loss_fn': lambda m1, m2, x, y: jnp.mean((m2(m1(x)) - y) ** 2), 'argnums': (0, 1)}, | |
{'loss_fn': lambda x, m1, y, m2: jnp.mean((m2(m1(x)) - y) ** 2), 'argnums': (1, 3)}, | |
) | |
def test_multiple_graph_nodes(self, loss_fn, argnums): | |
rngs = nnx.Rngs(0) | |
m1 = nnx.Linear(2, 3, rngs=rngs) | |
m2 = nnx.Linear(3, 3, rngs=rngs) | |
grad_fn = nnx.grad(loss_fn, argnums=argnums, wrt=nnx.Param) | |
x = jax.random.uniform(rngs(), (1, 2)) | |
y = jnp.ones((1, 3)) | |
grads_m1, grads_m2 = grad_fn(m1, m2, x, y) | |
assert 'kernel' in grads_m1 | |
assert grads_m1.kernel.raw_value.shape == (2, 3) | |
assert 'bias' in grads_m1 | |
assert grads_m1.bias.raw_value.shape == (3,) | |
assert 'kernel' in grads_m2 | |
assert grads_m2.kernel.raw_value.shape == (3, 3) | |
assert 'bias' in grads_m2 | |
assert grads_m2.bias.raw_value.shape == (3,) |
The two tests seem similar enough that we could combine them using parameterized.parameters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can do it but the suggested implementation is not quite right, you also need to parametrize argnums
and loss_fn
is being redefined.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yes good catch, I fixed it now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still doesn't work because you need to pass arguments to grad_fn
in different positions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
merging for now
3bebf97
to
febc1c7
Compare
What does this PR do?
nnx.grad
now acceptsargnums
and multiple graph nodes can be passed.