Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simple argument-only lifted nn.grad function. #3463

Merged
merged 1 commit into from
Nov 6, 2023

Conversation

levskaya
Copy link
Collaborator

@levskaya levskaya commented Nov 6, 2023

This function only peforms a lifted value-and-grad operation with respect to the arguments of a function, and does not try to calculate gradients with respect to any variables. This mirrors the behavior of the haiku grad function, and also easily works in the multi-scope setting (external modules passed in) while avoiding the complexities associated with that case for a more general vjp.

flax/linen/transforms.py Outdated Show resolved Hide resolved
flax/linen/transforms.py Outdated Show resolved Hide resolved
@levskaya levskaya force-pushed the vjp_fix branch 2 times, most recently from ea21918 to 30e54dd Compare November 6, 2023 22:07
This function only peforms a lifted value-and-grad operation with
respect to the arguments of a function, and does not try to calculate
gradients with respect to any variables.  This mirrors the behavior of
the haiku grad function, and also easily works in the multi-scope
setting (external modules passed in) while avoiding the complexities
associated with that case for a more general vjp.
@codecov-commenter
Copy link

codecov-commenter commented Nov 6, 2023

Codecov Report

Merging #3463 (5ff36ba) into main (055e28f) will increase coverage by 0.05%.
Report is 2 commits behind head on main.
The diff coverage is 90.90%.

@@            Coverage Diff             @@
##             main    #3463      +/-   ##
==========================================
+ Coverage   83.56%   83.62%   +0.05%     
==========================================
  Files          56       56              
  Lines        6768     6790      +22     
==========================================
+ Hits         5656     5678      +22     
  Misses       1112     1112              
Files Coverage Δ
flax/linen/__init__.py 100.00% <ø> (ø)
flax/linen/module.py 92.05% <100.00%> (+0.02%) ⬆️
flax/linen/transforms.py 93.65% <90.00%> (-0.30%) ⬇️

... and 1 file with indirect coverage changes

@copybara-service copybara-service bot merged commit e2c3dfd into google:main Nov 6, 2023
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants