Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[partial refactor] Dedicated attention mask wrapper #113

Merged
merged 5 commits into from
Nov 21, 2021
Merged

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Nov 18, 2021

What does this PR do?

  • add a dedicated class to handle additive attention masks, and centralize some of the typical operations (change types, add to another mask, handle smaller sequences..)
  • the sparse machinery is still the same and bools (cc @fmassa), I did not touch that.
  • I think that the above needs to be changed eventually, so that we converge towards additive masks all around, but it means that on the sparse side we need to completely handle a float additive mask, in that a value a would both mark the fact that this needs to be computed, and be added a posteriori to the computed attention value (not done right now)
  • Fix scaled_dot_product_attention shouldn't slice attn_mask #90 (do not slice the attention mask)
  • Fix scaled_dot_product_attention shouldn't have a causal flag #91 (no causal flag in the core attention)

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 18, 2021
@blefaudeux blefaudeux force-pushed the mask_refactor branch 7 times, most recently from 8f9d6d4 to a146fa1 Compare November 18, 2021 22:10
@@ -35,7 +35,7 @@
author = "Facebook AI Research"

# The full version, including alpha/beta/rc tags
release = "0.0.5"
release = "0.0.6"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just bumped to mark a difference with the released package (0.0.5)

@codecov-commenter
Copy link

codecov-commenter commented Nov 18, 2021

Codecov Report

Merging #113 (68519cc) into main (3deba0b) will increase coverage by 0.31%.
The diff coverage is 95.83%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #113      +/-   ##
==========================================
+ Coverage   86.90%   87.21%   +0.31%     
==========================================
  Files          49       50       +1     
  Lines        2497     2566      +69     
==========================================
+ Hits         2170     2238      +68     
- Misses        327      328       +1     
Flag Coverage Δ
Python 87.21% <95.83%> (+0.31%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
xformers/components/__init__.py 100.00% <ø> (ø)
xformers/components/attention/core.py 88.40% <81.25%> (+0.72%) ⬆️
xformers/components/attention/random.py 95.12% <88.88%> (-1.94%) ⬇️
xformers/components/attention/attention_mask.py 98.50% <98.50%> (ø)
xformers/__init__.py 43.24% <100.00%> (ø)
xformers/components/attention/__init__.py 82.60% <100.00%> (+0.38%) ⬆️
xformers/components/attention/base.py 96.96% <100.00%> (ø)
xformers/components/attention/global_tokens.py 100.00% <100.00%> (ø)
...formers/components/attention/scaled_dot_product.py 100.00% <100.00%> (+4.54%) ⬆️
xformers/components/attention/_sputnik_sparse.py 94.66% <0.00%> (-0.89%) ⬇️
... and 1 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 3deba0b...68519cc. Read the comment docs.

@@ -114,7 +114,6 @@ def test_pytorch_encoder_parity(device=torch.device("cuda")):
dim_feedforward=4 * EMB,
dropout=DROP,
activation=ACTIVATION,
layer_norm_eps=1e-05,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not present in some pytorch versions, and this is the default anyway

# Sparsify if that makes sense
if torch.count_nonzero(matrix).item() / matrix.numel() > _DENSITY_THRESHOLD:
return matrix
# If not sparse, then AttentionMask is the reference type
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

heads up on that, basically hijacking any non-sparse case to put it under the AttentionMask umbrella

Self = TypeVar("Self", bound="AttentionMask")


class AttentionMask:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tentatively own all attention masks, but this "breaks" (not per say, but not handled here) right now when things are sparsified

@@ -69,12 +70,12 @@ def _broadcast_batch(mask, batch_size):


def _matmul_with_mask(
a: torch.Tensor, b: torch.Tensor, mask: Optional[torch.Tensor]
a: torch.Tensor, b: torch.Tensor, mask: Optional[Union[torch.Tensor, "SparseCS"]]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's where the seams break a little in between AttentionMask and the sparse takes. On the bright side, it's all internal. On the dark side, ideally we could have something a little more streamlined overall.


# Could also be that the mask was additive
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dianaml0 _matmul_with_mask already has an additive mask codepath actually. It was written twice, but the computation was correct (since we only passed the mask to _matmul_with_mask if it was a bool..). Maybe that we should keep the addition here actually and remove it from matmul ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see that now, I guess it's nice to have all the mask logic live in a single place so keeping it in _matmul_with_mask?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I think what happened was that different people were only focusing on different parts.

This line was added by me when we originally added support for additive masks. But I only really cared about sparse tensors back then and I had a bug in my implementation for dense.

Then it was fixed for dense, but maybe it could have been fixed in a different place, as this ended up being duplicate.


# Softmax to get the attention probabilities
att = _softmax(att, causal=causal)
is_causal = att_mask.is_causal if isinstance(att_mask, AttentionMask) else False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this fixes #91

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

One downside of this approach is that it leaks internal implementation details of AttentionMask inside a core function.

My thoughts were to modify instead _softmax so that the causal path is just a different function call from a CausalAttentionMask Tensor type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in that case you pass the mask to softmax, it moves the branching there, why not. Typically softmax do not take masks though, so I think that the interface becomes a little confusing (we'll pass a mask in only to say that it's causal or not)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa is that ok if you propose a change (object oriented if I got it right, the masks handle the dispatch problem) in another PR ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll send a PR with my proposed changes later this week

xformers/components/attention/core.py Outdated Show resolved Hide resolved
@@ -91,10 +92,24 @@ def forward(
self.attention_mask = self.attention_mask.to(q.device)

# Mask-aware attention
mask = (
self.attention_mask if att_mask is None else self.attention_mask & att_mask
if att_mask is not None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, sparse/dense handling discrepancy, fixing this would be a nice follow up

self.rand_attention_mask
if att_mask is None
else self.rand_attention_mask & att_mask
if att_mask is not None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, sparse/dense discrepancy

Copy link
Contributor

@dianaml0 dianaml0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice to have this all centralized!

) -> torch.Tensor:
if mask is None:
return a @ b

if _is_sparse_available:
if _is_sparse_available and mask.dtype == torch.bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for now for sparse we should always convert to bool mask, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally I think that we should just expose one mask type, it's easier to understand for the user, and do the appropriate thing internally from there (sparse, dense, blocksparse, I agree with @fmassa on that). I think that it's beyond this PR though, so at that time yes if you pass a bool mask it goes the maybe_sparse way, if you pass a float then it's all dense. I can change that, this is a draft !


# Could also be that the mask was additive
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see that now, I guess it's nice to have all the mask logic live in a single place so keeping it in _matmul_with_mask?

xformers/components/attention/core.py Outdated Show resolved Hide resolved
xformers/components/attention/core.py Outdated Show resolved Hide resolved
return self.values != float("-inf")

@classmethod
def from_bool(cls: Type[Self], x: torch.Tensor) -> Self:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice to have these!

Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work @blefaudeux !

I had some slightly different thoughts about the implementation, but let's move with what you have for now. I think it might be worth getting in a call and discuss.

I propose we get this merged, but IMO we need to re-write this whole file to entirely take into account the things we care about and get the abstractions that best suit our new use-cases.


# Could also be that the mask was additive
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I think what happened was that different people were only focusing on different parts.

This line was added by me when we originally added support for additive masks. But I only really cared about sparse tensors back then and I had a bug in my implementation for dense.

Then it was fixed for dense, but maybe it could have been fixed in a different place, as this ended up being duplicate.


# Softmax to get the attention probabilities
att = _softmax(att, causal=causal)
is_causal = att_mask.is_causal if isinstance(att_mask, AttentionMask) else False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

One downside of this approach is that it leaks internal implementation details of AttentionMask inside a core function.

My thoughts were to modify instead _softmax so that the causal path is just a different function call from a CausalAttentionMask Tensor type.

xformers/components/attention/core.py Outdated Show resolved Hide resolved
@blefaudeux
Copy link
Contributor Author

landing that as I understood that @dianaml0 and @fmassa were ok with this as a first take, but @fmassa to propose a rewrite where attention mask(s) handle the specific computations redirect (personally ok with that)

@blefaudeux blefaudeux changed the title [DRAFT] Dedicated attention mask wrapper [partial refactor] Dedicated attention mask wrapper Nov 21, 2021
linting, web editing does not help
@blefaudeux blefaudeux merged commit e640d1b into main Nov 21, 2021
@fmassa fmassa deleted the mask_refactor branch November 22, 2021 14:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
5 participants