You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello!
I noticed that FlashAttention forward supports the key_padding_mask argument, while your implementation allows only for the Causal mask. Can the key_padding_mask also (easily) be implemented in xformers?
The text was updated successfully, but these errors were encountered:
Yes it's something we plan on adding. As you mentioned, the kernels already support it, so it's more a matter of figuring out the right API for this - we're looking into pytorch's nested tensors to see if it's the right abstraction
We just merged this support - through a special attention bias.
You can learn more on the doc website.
Note that the backward is only supported by Flash, and hence won't work on V100 or older devices or even fp32.
❓ Questions and Help
Hello!
I noticed that FlashAttention forward supports the
key_padding_mask
argument, while your implementation allows only for the Causal mask. Can thekey_padding_mask
also (easily) be implemented in xformers?The text was updated successfully, but these errors were encountered: