-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Adding Visual Attention #329
Conversation
cc @mannatsingh if you're interested in these things |
c9282b7
to
bf34467
Compare
3adc3ac
to
7fb47e6
Compare
7fb47e6
to
696d178
Compare
@@ -121,8 +121,8 @@ def forward(self, x): | |||
|
|||
# Adjust batch depending on the available memory on your machine. | |||
# You can also use reversible layers to save memory | |||
REF_BATCH = 512 | |||
BATCH = 512 # lower if not enough GPU memory | |||
REF_BATCH = 768 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like a classic default for Cifar10
@@ -31,6 +31,7 @@ def __init__( | |||
num_classes=10, | |||
dim=384, | |||
attention="scaled_dot_product", | |||
feedforward="MLP", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about the defaults here, how to show that you can use these to repro "Visual Attention" for instance ? Should we show different presets ?
@@ -62,6 +62,10 @@ def __init__( | |||
# This operator does not really handle q,k,v | |||
self.requires_same_k_q_dimensions = True | |||
|
|||
# This attention requires the 2d structure out of the context, | |||
# implictly assumed to be a squared length | |||
self.requires_squared_context = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was already true before, but not formalized like this, I think it's cleaner ? "pooling" (PoolingFormer) and "visual" both recover the 2d structure of and assume a squared context length for that
H = int(math.sqrt(HW)) | ||
assert H * H == HW | ||
|
||
x = q.transpose(-2, -1).reshape(B, C, H, H) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've not benchmarked that, but maybe that it's beneficial to .contiguous() here, depending on the Conv2D kernels
What does this PR do?
Fixes #319. Note that to reproduce the paper you need the Conv2DFeedforward introduced here #321, and a metaformer-like structure
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.