Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] grad accepts argnums #3798

Merged
merged 1 commit into from
Apr 2, 2024
Merged

[nnx] grad accepts argnums #3798

merged 1 commit into from
Apr 2, 2024

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Mar 29, 2024

What does this PR do?

nnx.grad now accepts argnums and multiple graph nodes can be passed.

@cgarciae cgarciae force-pushed the nnx-improve-grad branch 3 times, most recently from e6253a9 to 7321765 Compare March 30, 2024 10:14
@cgarciae cgarciae marked this pull request as ready for review March 30, 2024 10:15
Base automatically changed from nnx-static-goes-first to main April 1, 2024 19:26
@codecov-commenter
Copy link

codecov-commenter commented Apr 1, 2024

Codecov Report

Attention: Patch coverage is 93.42105% with 5 lines in your changes are missing coverage. Please review.

Project coverage is 60.51%. Comparing base (cc740d4) to head (3bebf97).

Files Patch % Lines
flax/experimental/nnx/nnx/transforms.py 83.33% 5 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3798      +/-   ##
==========================================
+ Coverage   60.34%   60.51%   +0.17%     
==========================================
  Files         101      101              
  Lines       12862    12908      +46     
==========================================
+ Hits         7761     7811      +50     
+ Misses       5101     5097       -4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines +427 to +464

def test_multiple_graph_nodes(self):
rngs = nnx.Rngs(0)
m1 = nnx.Linear(2, 3, rngs=rngs)
m2 = nnx.Linear(3, 3, rngs=rngs)
loss_fn = lambda m1, m2, x, y: jnp.mean((m2(m1(x)) - y) ** 2)
grad_fn = nnx.grad(loss_fn, argnums=(0, 1), wrt=nnx.Param)
x = jax.random.uniform(rngs(), (1, 2))
y = jnp.ones((1, 3))
grads_m1, grads_m2 = grad_fn(m1, m2, x, y)

assert 'kernel' in grads_m1
assert grads_m1.kernel.raw_value.shape == (2, 3)
assert 'bias' in grads_m1
assert grads_m1.bias.raw_value.shape == (3,)
assert 'kernel' in grads_m2
assert grads_m2.kernel.raw_value.shape == (3, 3)
assert 'bias' in grads_m2
assert grads_m2.bias.raw_value.shape == (3,)

def test_multiple_graph_nodes_mix_positions(self):
rngs = nnx.Rngs(0)
m1 = nnx.Linear(2, 3, rngs=rngs)
m2 = nnx.Linear(3, 3, rngs=rngs)
loss_fn = lambda x, m1, y, m2: jnp.mean((m2(m1(x)) - y) ** 2)
grad_fn = nnx.grad(loss_fn, argnums=(1, 3), wrt=nnx.Param)
x = jax.random.uniform(rngs(), (1, 2))
y = jnp.ones((1, 3))
grads_m1, grads_m2 = grad_fn(x, m1, y, m2)

assert 'kernel' in grads_m1
assert grads_m1.kernel.raw_value.shape == (2, 3)
assert 'bias' in grads_m1
assert grads_m1.bias.raw_value.shape == (3,)
assert 'kernel' in grads_m2
assert grads_m2.kernel.raw_value.shape == (3, 3)
assert 'bias' in grads_m2
assert grads_m2.bias.raw_value.shape == (3,)
Copy link
Collaborator

@chiamp chiamp Apr 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_multiple_graph_nodes(self):
rngs = nnx.Rngs(0)
m1 = nnx.Linear(2, 3, rngs=rngs)
m2 = nnx.Linear(3, 3, rngs=rngs)
loss_fn = lambda m1, m2, x, y: jnp.mean((m2(m1(x)) - y) ** 2)
grad_fn = nnx.grad(loss_fn, argnums=(0, 1), wrt=nnx.Param)
x = jax.random.uniform(rngs(), (1, 2))
y = jnp.ones((1, 3))
grads_m1, grads_m2 = grad_fn(m1, m2, x, y)
assert 'kernel' in grads_m1
assert grads_m1.kernel.raw_value.shape == (2, 3)
assert 'bias' in grads_m1
assert grads_m1.bias.raw_value.shape == (3,)
assert 'kernel' in grads_m2
assert grads_m2.kernel.raw_value.shape == (3, 3)
assert 'bias' in grads_m2
assert grads_m2.bias.raw_value.shape == (3,)
def test_multiple_graph_nodes_mix_positions(self):
rngs = nnx.Rngs(0)
m1 = nnx.Linear(2, 3, rngs=rngs)
m2 = nnx.Linear(3, 3, rngs=rngs)
loss_fn = lambda x, m1, y, m2: jnp.mean((m2(m1(x)) - y) ** 2)
grad_fn = nnx.grad(loss_fn, argnums=(1, 3), wrt=nnx.Param)
x = jax.random.uniform(rngs(), (1, 2))
y = jnp.ones((1, 3))
grads_m1, grads_m2 = grad_fn(x, m1, y, m2)
assert 'kernel' in grads_m1
assert grads_m1.kernel.raw_value.shape == (2, 3)
assert 'bias' in grads_m1
assert grads_m1.bias.raw_value.shape == (3,)
assert 'kernel' in grads_m2
assert grads_m2.kernel.raw_value.shape == (3, 3)
assert 'bias' in grads_m2
assert grads_m2.bias.raw_value.shape == (3,)
@parameterized.parameters(
{'loss_fn': lambda m1, m2, x, y: jnp.mean((m2(m1(x)) - y) ** 2), 'argnums': (0, 1)},
{'loss_fn': lambda x, m1, y, m2: jnp.mean((m2(m1(x)) - y) ** 2), 'argnums': (1, 3)},
)
def test_multiple_graph_nodes(self, loss_fn, argnums):
rngs = nnx.Rngs(0)
m1 = nnx.Linear(2, 3, rngs=rngs)
m2 = nnx.Linear(3, 3, rngs=rngs)
grad_fn = nnx.grad(loss_fn, argnums=argnums, wrt=nnx.Param)
x = jax.random.uniform(rngs(), (1, 2))
y = jnp.ones((1, 3))
grads_m1, grads_m2 = grad_fn(m1, m2, x, y)
assert 'kernel' in grads_m1
assert grads_m1.kernel.raw_value.shape == (2, 3)
assert 'bias' in grads_m1
assert grads_m1.bias.raw_value.shape == (3,)
assert 'kernel' in grads_m2
assert grads_m2.kernel.raw_value.shape == (3, 3)
assert 'bias' in grads_m2
assert grads_m2.bias.raw_value.shape == (3,)

The two tests seem similar enough that we could combine them using parameterized.parameters

Copy link
Collaborator Author

@cgarciae cgarciae Apr 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can do it but the suggested implementation is not quite right, you also need to parametrize argnums and loss_fn is being redefined.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes good catch, I fixed it now

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still doesn't work because you need to pass arguments to grad_fn in different positions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merging for now

@copybara-service copybara-service bot merged commit 2b257b4 into main Apr 2, 2024
21 checks passed
@copybara-service copybara-service bot deleted the nnx-improve-grad branch April 2, 2024 11:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants