-
Notifications
You must be signed in to change notification settings - Fork 596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] Nystrom + microGPT + some additive masking #75
Conversation
5eddcff
to
b52b135
Compare
b52b135
to
6f626d7
Compare
seq = q.shape[-2] | ||
|
||
if not att_mask.is_sparse: | ||
att_mask = att_mask[:seq, :seq] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we were testing for smaller sequences, but when the mask was sparse it was not adjusted -> possible memory error which showed up on CI later on
converting back to draft, I think that the key padding mask is not handled correctly with my changes |
mask = ( | ||
key_padding_mask | ||
if mask is None | ||
else mask.logical_and(key_padding_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dianaml0 I'm not sure of how that worked, since mask and key_padding_mask had different dimensions here, no ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should automatically broadcast key_padding_mask along the mismatched dimension, similar to https://github.com/pytorch/pytorch/blob/4262c8913c2bddb8d91565888b4871790301faba/torch/nn/functional.py#L5189
d2ba0e3
to
013a927
Compare
Codecov Report
@@ Coverage Diff @@
## main #75 +/- ##
==========================================
+ Coverage 87.10% 87.12% +0.01%
==========================================
Files 50 50
Lines 2428 2447 +19
==========================================
+ Hits 2115 2132 +17
- Misses 313 315 +2
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!! Really helpful changes! Makes masking easier to work with!
@@ -30,7 +30,7 @@ def test_core_attention(): | |||
def test_core_attention_mask_types(): | |||
|
|||
b, s, d = 8, 900, 32 | |||
prob = 0.5 | |||
prob = 0.8 # make sure that we trigger the sparse kernels |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh oops, thanks for catching that!
mask = ( | ||
key_padding_mask | ||
if mask is None | ||
else mask.logical_and(key_padding_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should automatically broadcast key_padding_mask along the mismatched dimension, similar to https://github.com/pytorch/pytorch/blob/4262c8913c2bddb8d91565888b4871790301faba/torch/nn/functional.py#L5189
) | ||
key_padding_mask = bool_mask_to_additive(key_padding_mask) | ||
|
||
assert key_padding_mask is not None # mypy is drunk |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if a return type is added to bool_mask_to_additive
it may fix the mypy error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh good point, I'll fix that, thank you !
…the same thing, but it's not correct
72defc3
to
11f9d38
Compare
* Add some 2d-specific attention patterns * Add notebook with examples
What does this PR do?
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.