-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use cutlass for memory-efficient attention #362
Conversation
* Add attention bias in memory-efficient attention * Add gradient for attn_mask support * Add CPU implementation * clang-format * Add benchmark scripts * Add extra loop in benchmarks * Move zeros array out of helper function * clang-format
* Merge compute_scaling_coeffs and update_scaling_coeffs into a single function It wasn't needed to break it in two functions to begin with * Add CUDA implementation for dropout * clang-format * Make p be drop probability * Only CUDA supports dropout * Add benchmarks * Remove unused variables * Fix test * Cleanups and comments
…ted build which is a pain
@@ -53,6 +53,9 @@ There are two ways you can install xFormers locally: | |||
|
|||
```bash | |||
git clone git@github.com:facebookresearch/xformers.git | |||
git submodule update --init --recursive |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, I was chacking that when seeing that xformers now has two submodules, perfect. Thanks
DEFAULT_ARCHS_LIST = "" | ||
if cuda_version > 1100: | ||
DEFAULT_ARCHS_LIST = "7.5;8.0;8.6" | ||
elif cuda_version >= 1100: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, but cuda_version == 1100 in that case, right ?
num = 10 * int(arch[0]) + int(arch[2]) | ||
# Need at least 7.5 | ||
if num < 75: | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we printout some warnings here (or in the main setup), to recap what's being built and possibly why ? I feel like there could be a lot of issues raised around that with the build process silently skipping flashattention because of an old cuda version and users not seeing it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, I'll add some log messages
But in general, we need to improve on the packaging of xformers, specially now that a lot of hardware-specific kernels are being used. @bottler might look into improving this
_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] | ||
|
||
|
||
def ref_attention(q, k, v): | ||
def assert_allclose( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, but could be moved in some utils ? it feels like this could be used in a few places already, beyond this PR
@pytest.mark.parametrize( | ||
"attn_bias_type", [None, xformers.ops.LowerTriangularMask, torch.Tensor] | ||
) | ||
@pytest.mark.parametrize("k_len", [5, 6, 32, 128]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice test cascade ! that's some serious coverage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, there are a lot of combinations being tested now. Tests are not instantaneous now (~1min ?), but it's not too bad I think
dtype, | ||
op: xformers.ops.MemoryEfficientAttentionOp, | ||
): | ||
scale = 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for my understanding, how is this scale chosen ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just to stress test a bit more the numerics of MHA. I could have left it as 1, but with larger scales it pushes it harder in the query @ key.T
part so that we could hit overflows if the softmax is not done properly, for larger K
dimension.
|
||
grad_out = torch.ones_like(query) | ||
if grad_out_contiguous is False: | ||
grad_out = torch.tensor([1.0], device=device)[None, None, :].expand_as(query) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(updated) that works indeed, not super intuitive to me but the .expand_as() call is the one which breaks the contiguity, interesting. I would have done something like .transpose(0,1).contiguous().transpose(0,1)
, curious about your take @fmassa, how did you think of that formulation ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The expand_as
appears very often when quick-testing the backward, as it's in the gradient of .sum()
. So doing op(inputs).sum().backward()
yields gradients which have expanded tensors, which is a particular case of non-contiguous tensor. Given that the kernel for now just calls .contiguous()
in the gradients, any non-contiguous tensor is fine to exercise this codepath
mask = torch.ops.xformers._temp_dropout(mask, p) | ||
masks.append(mask.clone().cpu()) | ||
masks = torch.stack(masks, dim=0) | ||
p_value = binom_test(masks.sum(), masks.numel(), p=keep_prob) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, seems much better than these tests (my bad), thanks for this very thorough take
|
||
import xformers.ops | ||
|
||
torch.backends.cuda.matmul.allow_tf32 = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not directly related, but did you get to test out the perf effect of tf32 accumulation on A100 ? asking just in case to learn a bit more
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CUTLASS-based kernels do use a trick that benefit from tf32
while having the fp32 numerics (by performing 3 matmuls in tf32 to leverage tensorcores). The trick (which was present in cutlass examples) was to decompose a fp32 tensor in fp32 = (fp32_low_bits + fp32_high_bits)
, where fp32_low_bits
and fp32_high_bits
are tf32, so that the multiplication can be approximated by 3 matmuls on tf32 (and dropping the low_bits * low_bits part
).
The implementation that uses only a single tf32 instruction is not implemented yet, but we were thinking it could be exposed by reading the info from torch.backends.cuda.matmul.allow_tf32
and dispatching to different kernels.
sub_label = f"{dtype_str} {op.NAME} B={B}, M={M}, K={K}" | ||
|
||
if True: | ||
r = xformers.ops.memory_efficient_attention(q, q, q, attn_bias, op=op).float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I would be curious to get numbers on self vs. non-self attention, if anything I think that benchmarking your work in the self attention case (while very relevant for vision of course) sells you short, since there's more opportunity for the GPU to have a hot cache (and the vanilla computation is IO bottlenecked, so will benefit a lot from that)
q.grad = None | ||
del r, rr, grad | ||
|
||
out = xformers.ops.memory_efficient_attention(q, q, q, attn_bias, p, op=op) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above (self vs. non self), maybe that the numbers end up being similar but I would be curious to see that for real. Guess is that you would slightly increase the gap vs. pytorch for non self-attention
@@ -142,6 +146,36 @@ __device__ void compute_dot( | |||
} | |||
} | |||
|
|||
/* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: useful ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not used, I was planning on doing some refactorings to make things simpler but didn't finish it and kept it there to eventually go back to it. I could just remove it though
int64_t N, | ||
scalar_t p, | ||
int64_t col_offset) { | ||
// strategy: initialize the rng so that each element in the attention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, dropout/triton is doing the same
|
||
// we will always sample 4 random floats at a time | ||
// as it's more efficient | ||
constexpr int kSampled = 4; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same in triton, I'm guessing that this is HW dependent (big random word cut into pieces ?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it turns out that curand always generates 4 floats at a time internally, even if you call curand_random
instead of curand_random4
, and thus getting only 1 float was much more expensive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It shouldn't be much more expensive, because curand saves unused randoms and yields them on future calls, without going through generation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel indeed, but in the current setup we have to reset the state of the philox generation very often (almost for every element), and this would make up for the slowdown, as we would reset the state (expensive), generate 4 floats and only use 1.
The strategy here is to reset the state only once every 4 elements in the output
// guarantees than by doing it properly, but is much faster | ||
curand_init( | ||
std::get<0>(seeds), offset, std::get<1>(seeds) + delta, &state); | ||
// curand_init(std::get<0>(seeds) + (offset << 8) + std::get<1>(seeds), 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
old API I suppose, still needed around ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that should be removed actually, was a workaround to get faster generation back when it was slower (but had fewer guarantees regarding randomness).
if (index >= M) | ||
break; | ||
|
||
auto out_i = reinterpret_cast<vec_t*>(output[batch_idx][index].data()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: would 'const vec_t *' make sense here and where this pattern appears, read only ? may not be idiomatic cuda, in c++ it's quite typical in some codebases to be a little strict around that
// if (l < end_iter) { | ||
{ | ||
for (; l < end_iter; l += step) { | ||
for (int jj = 0; jj < kBlockSizeQ; jj++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that I wrote the same in the previous PR related to that @fmassa, but it feels crazy that this is the best cuda has to offer to init s_delta.. I remember you wrote that was the gist of it, not really a question here, just thinking out loud
|
||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
||
constexpr int WARP_SIZE = 4; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, but this looks like it's the number of warps, and warp_size is typically used (even in this PR, see next file) to describe the number of threads in a warp, right ? not super important but for consistency's sake
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed. I'll clean this up in a follow-up PR, I have some changes layed out that I'll be pushing soon
@lucidrains Just let me know the time and place. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great, minor cosmetic or understanding questions, but that's mostly for my sake..
The cutlass part feels very different, almost like another language and hard to review for me. I do like that it seems to be one level up in terms of abstraction, but I would be hard pressed to find bugs in there.. Would it help to have other eyes on it ? (@ngimel ?)
Thanks for all this work @fmassa @danthe3rd (and @tridao of course), I hope and think that it can be super impactful, game changer for the attention mechanism
@MarkusRabe shoot me an email! your old email at Saarland no longer works (tried to email you some time ago) |
hi, @fmassa I meet precise loss during training swin-t model. Do you test it? |
Hi, can you describe more precisely your setup? (GPU, head dimension, data type, options like causality) |
GPU: A100 |
* Fast again on V100 * Fix correctness - missing syncthreads * Get rid of AttentionInfo Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
Hi @lw921014 For Swin-Transformer, you need to add a relative positional embedding, which requires adding an attention bias in the attention matrix, and this configuration is not exposed yet in the setup you mentioned. If you remove the relative positional embedding, you'll indeed see a loss in accuracy |
What does this PR do?
This (massive) PR adds a number of improvements to memory-efficient attention that have been developed over the last few months.
It contains:
A100
For the configurations below and the forward pass on fp16, CUTLASS-based kernels is on average 31% faster than vanilla PyTorch (10% faster on median), and 5% slower than FlashAttention on average (with median being 1% faster than FlashAttention).
For the backward pass, there is still room for improvement for the CUTLASS-based kernels, with it being 15% slower than vanilla PyTorch on average (7% slower on median), and 55% slower than FlashAttention on average and median.
The breakdown of the details can be found below
CUTLASS-based kernels
FlashAttention-based kernels
FlashAttention and lower triangular
V100
FlashAttention is not supported on V100, so in this case we only compare against the baseline PyTorch implementation, on both fp16 and fp32.
For the configurations below and fp16 on the forward pass, the CUTLASS-based implementation is 25% faster on average compared to vanilla implementation (5% slower on median).
For fp32, it's 13% faster on average (4% slower on median).
For the backward and fp16, CUTLASS-based implementation is 19% slower on average (15% slower on median).
For fp32, it's 27% slower on average (30% slower on median).
CUTLASS-based kernels
P100
For the configurations below, as before we only compare against a vanilla PyTorch implementation as FlashAttention doesn't support P100s.
For the forward pass, on fp16, the CUTLASS-based kernels are 18% slower on average (22% slower on median), while for fp32 they are 10% slower on average (13% slower on median).
For the backward pass, on fp16, the CUTLASS-based kernels are 40% slower on average (45% slower on median), while on fp32 they are 33% slower on average (38% slower on median)
CUTLASS-based kernels
cc @blefaudeux @danthe3rd @tridao