Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GroupNorm to NNX normalization layers #4095

Merged
merged 1 commit into from
Aug 20, 2024
Merged

Add GroupNorm to NNX normalization layers #4095

merged 1 commit into from
Aug 20, 2024

Conversation

treigerm
Copy link
Contributor

What does this PR do?

Addresses #4086 and adds a GroupNorm layer to NNX. I tried to follow the Linen implementation and how the other layers have been ported. I also tested equivalence checks between Linen's GroupNorm implementation and the NNX implementation.

Some notes:

  • I used num_features instead of num_channels as an input argument to stay consistent with the LayerNorm implementation.
  • I deleted the reduced_feature_shape list in the _normalize function as it is unused. If there is a reason why the list was created without being used I can reverse the deletion.

Checklist

Copy link

google-cla bot commented Jul 19, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@cgarciae
Copy link
Collaborator

Thanks @treigerm! It looks great. I've left a small suggestion.
Can you also please add an entry for GroupNorm here so it renders on the documentation?

@treigerm
Copy link
Contributor Author

Thanks @cgarciae ! I have added an entry for GroupNorm and also fixed the mypy error.

@cgarciae
Copy link
Collaborator

cgarciae commented Aug 6, 2024

@treigerm can you squash your commits?

@treigerm
Copy link
Contributor Author

treigerm commented Aug 6, 2024

@cgarciae done!

@treigerm
Copy link
Contributor Author

treigerm commented Aug 6, 2024

Had to apply the fix from #4098 to make mypy tests pass locally now.

@codecov-commenter
Copy link

codecov-commenter commented Aug 16, 2024

Codecov Report

Attention: Patch coverage is 0% with 47 lines in your changes missing coverage. Please review.

Project coverage is 0.00%. Comparing base (31adb00) to head (b023606).
Report is 140 commits behind head on main.

Files Patch % Lines
flax/nnx/nnx/nn/normalization.py 0.00% 46 Missing ⚠️
flax/nnx/__init__.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@          Coverage Diff           @@
##            main   #4095    +/-   ##
======================================
  Coverage   0.00%   0.00%            
======================================
  Files        106     109     +3     
  Lines      13582   14261   +679     
======================================
- Misses     13582   14261   +679     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@copybara-service copybara-service bot merged commit a3ddb0f into google:main Aug 20, 2024
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants