-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[partial refactor] Dedicated attention mask wrapper #113
Conversation
8f9d6d4
to
a146fa1
Compare
@@ -35,7 +35,7 @@ | |||
author = "Facebook AI Research" | |||
|
|||
# The full version, including alpha/beta/rc tags | |||
release = "0.0.5" | |||
release = "0.0.6" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just bumped to mark a difference with the released package (0.0.5)
a146fa1
to
8cbae02
Compare
Codecov Report
@@ Coverage Diff @@
## main #113 +/- ##
==========================================
+ Coverage 86.90% 87.21% +0.31%
==========================================
Files 49 50 +1
Lines 2497 2566 +69
==========================================
+ Hits 2170 2238 +68
- Misses 327 328 +1
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
@@ -114,7 +114,6 @@ def test_pytorch_encoder_parity(device=torch.device("cuda")): | |||
dim_feedforward=4 * EMB, | |||
dropout=DROP, | |||
activation=ACTIVATION, | |||
layer_norm_eps=1e-05, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not present in some pytorch versions, and this is the default anyway
# Sparsify if that makes sense | ||
if torch.count_nonzero(matrix).item() / matrix.numel() > _DENSITY_THRESHOLD: | ||
return matrix | ||
# If not sparse, then AttentionMask is the reference type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
heads up on that, basically hijacking any non-sparse case to put it under the AttentionMask umbrella
Self = TypeVar("Self", bound="AttentionMask") | ||
|
||
|
||
class AttentionMask: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tentatively own all attention masks, but this "breaks" (not per say, but not handled here) right now when things are sparsified
@@ -69,12 +70,12 @@ def _broadcast_batch(mask, batch_size): | |||
|
|||
|
|||
def _matmul_with_mask( | |||
a: torch.Tensor, b: torch.Tensor, mask: Optional[torch.Tensor] | |||
a: torch.Tensor, b: torch.Tensor, mask: Optional[Union[torch.Tensor, "SparseCS"]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's where the seams break a little in between AttentionMask and the sparse takes. On the bright side, it's all internal. On the dark side, ideally we could have something a little more streamlined overall.
|
||
# Could also be that the mask was additive |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dianaml0 _matmul_with_mask already has an additive mask codepath actually. It was written twice, but the computation was correct (since we only passed the mask to _matmul_with_mask if it was a bool..). Maybe that we should keep the addition here actually and remove it from matmul ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see that now, I guess it's nice to have all the mask logic live in a single place so keeping it in _matmul_with_mask
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, I think what happened was that different people were only focusing on different parts.
This line was added by me when we originally added support for additive masks. But I only really cared about sparse tensors back then and I had a bug in my implementation for dense.
Then it was fixed for dense, but maybe it could have been fixed in a different place, as this ended up being duplicate.
|
||
# Softmax to get the attention probabilities | ||
att = _softmax(att, causal=causal) | ||
is_causal = att_mask.is_causal if isinstance(att_mask, AttentionMask) else False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this fixes #91
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
One downside of this approach is that it leaks internal implementation details of AttentionMask
inside a core function.
My thoughts were to modify instead _softmax
so that the causal path is just a different function call from a CausalAttentionMask
Tensor type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in that case you pass the mask to softmax, it moves the branching there, why not. Typically softmax do not take masks though, so I think that the interface becomes a little confusing (we'll pass a mask in only to say that it's causal or not)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fmassa is that ok if you propose a change (object oriented if I got it right, the masks handle the dispatch problem) in another PR ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I'll send a PR with my proposed changes later this week
@@ -91,10 +92,24 @@ def forward( | |||
self.attention_mask = self.attention_mask.to(q.device) | |||
|
|||
# Mask-aware attention | |||
mask = ( | |||
self.attention_mask if att_mask is None else self.attention_mask & att_mask | |||
if att_mask is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same, sparse/dense handling discrepancy, fixing this would be a nice follow up
self.rand_attention_mask | ||
if att_mask is None | ||
else self.rand_attention_mask & att_mask | ||
if att_mask is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same, sparse/dense discrepancy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice to have this all centralized!
) -> torch.Tensor: | ||
if mask is None: | ||
return a @ b | ||
|
||
if _is_sparse_available: | ||
if _is_sparse_available and mask.dtype == torch.bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So for now for sparse we should always convert to bool mask, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ideally I think that we should just expose one mask type, it's easier to understand for the user, and do the appropriate thing internally from there (sparse, dense, blocksparse, I agree with @fmassa on that). I think that it's beyond this PR though, so at that time yes if you pass a bool mask it goes the maybe_sparse way, if you pass a float then it's all dense. I can change that, this is a draft !
|
||
# Could also be that the mask was additive |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see that now, I guess it's nice to have all the mask logic live in a single place so keeping it in _matmul_with_mask
?
return self.values != float("-inf") | ||
|
||
@classmethod | ||
def from_bool(cls: Type[Self], x: torch.Tensor) -> Self: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice to have these!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work @blefaudeux !
I had some slightly different thoughts about the implementation, but let's move with what you have for now. I think it might be worth getting in a call and discuss.
I propose we get this merged, but IMO we need to re-write this whole file to entirely take into account the things we care about and get the abstractions that best suit our new use-cases.
|
||
# Could also be that the mask was additive |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, I think what happened was that different people were only focusing on different parts.
This line was added by me when we originally added support for additive masks. But I only really cared about sparse tensors back then and I had a bug in my implementation for dense.
Then it was fixed for dense, but maybe it could have been fixed in a different place, as this ended up being duplicate.
|
||
# Softmax to get the attention probabilities | ||
att = _softmax(att, causal=causal) | ||
is_causal = att_mask.is_causal if isinstance(att_mask, AttentionMask) else False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
One downside of this approach is that it leaks internal implementation details of AttentionMask
inside a core function.
My thoughts were to modify instead _softmax
so that the causal path is just a different function call from a CausalAttentionMask
Tensor type.
11f3494
to
cf888ec
Compare
cf888ec
to
feb050b
Compare
linting, web editing does not help
What does this PR do?
a
would both mark the fact that this needs to be computed, and be added a posteriori to the computed attention value (not done right now)scaled_dot_product_attention
shouldn't sliceattn_mask
#90 (do not slice the attention mask)scaled_dot_product_attention
shouldn't have acausal
flag #91 (no causal flag in the core attention)Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.