-
Notifications
You must be signed in to change notification settings - Fork 633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Adding a conv MLP, following VAN #321
Conversation
|
||
|
||
@register_feedforward("ConvMLP", ConvMlpConfig) | ||
class ConvMLP(Feedforward): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @fmassa, it's an interesting take I think
Codecov Report
@@ Coverage Diff @@
## main #321 +/- ##
==========================================
- Coverage 93.75% 93.70% -0.05%
==========================================
Files 68 69 +1
Lines 3840 3889 +49
==========================================
+ Hits 3600 3644 +44
- Misses 240 245 +5
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
) | ||
|
||
# This feedforward requires a context length which is squared, often due to 2D pooling | ||
self.requires_squared_context = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this does 2D convolutions, meaning that the layer needs to be able to go from [Batch x Context x Embedding] to [Batch x H x W x Embedding]. A solution which is not too intrusive is to force the use of sequences being squared numbers, meaning essentially that we only work with square pictures. It's pretty common in vision codebases, I think that another solution would be to keep track of the original H and W prior to flattening this dimension.
should be fixed with the last update |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Nice!
What does this PR do?
One step towards #319, adding the MLP/Conv2d hybrid proposed by the VAN paper. Interestingly, testing this with a "Metaformer" (in true xformers fashion you can mix and match) on a tiny example does bring a measurable benefit.
Small (6M) Metaformer on Cifar10
Orange is the default (scaled dot product attention, not poolformer) + MLP White is the same but with the ConvMLP that this PR introducesBefore submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.