-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory-efficient attention - forward pass #267
Conversation
It's for now 1000x slower than the baseline
Now we are *only* 6x slower than baseline
Now we are only 50% slower than baseline
Need to fix the buffer size, which is hard-coded for now
THe use of Dot makes it 2.5% faster already
Still need to make it generic wrt query size, and allow further values of K that go beyond the buffer limit
This is commented out for now as it brings a slowdown to the implementation
@fmassa the doc build issue should disappear after a rebase, this was fixed on main a week ago or so. Else this is really great, having a deeper look ! |
out = torch.ops.xformers.efficient_attention(query, key, value) | ||
ref = ref_attention(query, key, value) | ||
|
||
assert torch.allclose(out, ref, atol=2e-4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just to get an idea, do you have a gut feeling on where the small differences come from ? softmax renormalization being a little different with the paper's method seems like an easy explanation, but is there something else ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a good question, and my best explanation so far is indeed that we accumulate a bit more errors because we don't know ahead of time the max value over a row, so we need to renormalize (introducing a bit more rounding errors)
at::TensorAccessor<scalar_t, 3> buffer //, | ||
// at::TensorAccessor<int64_t, 2> mask | ||
) { | ||
constexpr int64_t BLOCK = 1; // 8; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing that there's some speed to gain here, to remove some reads / reuse them across a couple of rows ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually this is the size of the fetch over N, my bad, I'm guessing (hoping) that the compiler groups them automatically
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly. Moving from 1 -> 8 brought a significant speedup. I moved it back to 1 before sending the PR because I was a bit lazy and didn't want to bother handling the remainder cases like I do in the GPU code.
That being said, now that I performed unrolling of both dimensions in the CUDA kernel, I could probably copy-paste the CUDA code and change a few things in the hope of making the CPU code faster.
But given that CPU is normally used for prototyping most of the time, I didn't prioritize it further.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, as it is the memory accesses are a little too granular, but not a big issue. I think it's fine not to get it super duper optimized right now, but we could add a comment about it for the future, out of context, us ?
int64_t M = query.size(1); | ||
int64_t N = key.size(1); | ||
int64_t grain_size = 1; | ||
at::parallel_for(0, B, grain_size, [&](int64_t start, int64_t end) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me, I kind of recognize the pattern from the triton version I think, it's definitely a bit more complicated to follow I believe but that works ! I'm not super familiar with at::Tensor but it feels a little strange that one has to call .data() all the time ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the .data_ptr<scalar_t>()
call is to get the raw pointers to the tensor. I originally used the TensorAcessor
as it adds extra robustness to different strides, but it also adds an extra overhead of striding so I removed it from now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah interesting for the stride part, not a zero cost abstraction.. just a free question, thanks for the context
vec_t k_i = keys[k + K / kVecSize * k_item_idx]; | ||
#pragma unroll | ||
for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { | ||
sputnik::VectorCompute<vec_t>::Dot( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tgale96 pulling you in, just in case there's something here that we can do better (black box for me, does VectorCompute imply tensor cores for instance or normal float ops ?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to the best of my knowledge we don't use any tensor cores explicitly here.
This was a refactoring that I did in cec04e9 to simplify a few things, and it turned out that sputnik already implemented some of the things I needed, so I just took those functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alrighty, I think that A100 may show a little worse perf (ping @suchenzang if you're interested in giving this a go ?), but in fp32 it will be fine perf wise up to V100 included (personal guess), not a big problem. The rest of the model can still be fp16 with the torch amp guards, so all in all I think that it's probably still very fast and useful
template < | ||
typename scalar_t, | ||
typename vec_t = float4, | ||
int kBlockSizeK = 32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting, I was wondering about these sizes for most of the code read, makes sense now. I'm wondering how architecture dependent the ideal values would be, depending on the shared memory and memory bandwidth could be that the optimal set moves a little ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, those values worked well for P100s, but might not be ideal for other architectures, and should definitely be better tuned for different systems
if ((K % 4) == 0) { | ||
TORCH_CHECK( | ||
K / 4 <= BUFFER_SIZE, | ||
"For now only a certain number of K values are supported. Let us know if you hit this and we will fix it"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
<3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super nice error message I meant
looks great to me, it's been a while since I went through this much cuda but kind of works :) Couple of nits here and there, curious for your feedback/opinion really, good to go as far as I'm concerned |
Here are the results on a V100: there seems like indeed we might want to do some extra tuning which is device-dependent to get better performances on V100s. The current hyperparameters are not too bad, but could probably be improved V100 results
|
Improve code comments
@blefaudeux I've added an user-facing function (which just dispatches to the kernel implementation for now), plus the benchmark script, I added some more comments and also added the division by |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, thanks @fmassa !
What does this PR do?
This PR implements the memory-efficient attention mechanism from https://arxiv.org/pdf/2112.05682v2.pdf, with both CPU and CUDA kernels. For now, only fp32 is supported.
The CPU implementation is fairly naive and I haven't focused on optimizing it (yet). So you should expect it to be quite a bit slower than a baseline CPU implementation in PyTorch. But it is generic and should support all cases
For the CUDA implementation, the performance is quite competitive with a baseline pytorch implementation for fp32 in terms of runtime (within 10% for most cases), while the memory savings are quite significant (10x+).
Here are some numbers (run on a P100 GPU):
Speed / memory improvements on the CUDA case
You can see up to 20x memory savings for larger configurations, while the runtime is in the order of 10% slower than the baseline (which leverages CUBLAS internally).
Next steps
This PR has some assumptions on the dimensionality of
K
(the feature map after splitting in heads). For now, it should be:K / 4 <= 8
K / 2 <= 8
K <= 8
This can be fixed in the future if needed.
Fixes #161.