Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory-efficient attention - forward pass #267

Merged
merged 39 commits into from
Apr 12, 2022
Merged

Memory-efficient attention - forward pass #267

merged 39 commits into from
Apr 12, 2022

Conversation

fmassa
Copy link
Contributor

@fmassa fmassa commented Apr 11, 2022

What does this PR do?

This PR implements the memory-efficient attention mechanism from https://arxiv.org/pdf/2112.05682v2.pdf, with both CPU and CUDA kernels. For now, only fp32 is supported.

The CPU implementation is fairly naive and I haven't focused on optimizing it (yet). So you should expect it to be quite a bit slower than a baseline CPU implementation in PyTorch. But it is generic and should support all cases

For the CUDA implementation, the performance is quite competitive with a baseline pytorch implementation for fp32 in terms of runtime (within 10% for most cases), while the memory savings are quite significant (10x+).

Here are some numbers (run on a P100 GPU):

Speed / memory improvements on the CUDA case
Optimized Memory used: 0.0634765625 MB
Vanilla Memory used: 0.171875 MB
===== (1, 128, 16) =====
Optimized Memory used: 0.0322265625 MB
Vanilla Memory used: 0.1494140625 MB
===== (1, 128, 32) =====
Optimized Memory used: 0.0634765625 MB
Vanilla Memory used: 0.1728515625 MB
===== (1, 512, 16) =====
Optimized Memory used: 0.1259765625 MB
Vanilla Memory used: 2.0947265625 MB
===== (1, 512, 32) =====
Optimized Memory used: 0.2509765625 MB
Vanilla Memory used: 2.1884765625 MB
===== (1, 513, 16) =====
Optimized Memory used: 0.1279296875 MB
Vanilla Memory used: 2.10498046875 MB
===== (1, 513, 32) =====
Optimized Memory used: 0.2529296875 MB
Vanilla Memory used: 2.19873046875 MB
===== (1, 1023, 16) =====
Optimized Memory used: 0.2509765625 MB
Vanilla Memory used: 8.173828125 MB
===== (1, 1023, 32) =====
Optimized Memory used: 0.5009765625 MB
Vanilla Memory used: 8.361328125 MB
===== (1, 1024, 16) =====
Optimized Memory used: 0.2509765625 MB
Vanilla Memory used: 8.1884765625 MB
===== (1, 1024, 32) =====
Optimized Memory used: 0.5009765625 MB
Vanilla Memory used: 8.3759765625 MB
===== (8, 127, 16) =====
Optimized Memory used: 0.2490234375 MB
Vanilla Memory used: 1.17236328125 MB
===== (8, 127, 32) =====
Optimized Memory used: 0.4970703125 MB
Vanilla Memory used: 1.3583984375 MB
===== (8, 128, 16) =====
Optimized Memory used: 0.2509765625 MB
Vanilla Memory used: 1.1884765625 MB
===== (8, 128, 32) =====
Optimized Memory used: 0.5009765625 MB
Vanilla Memory used: 1.3759765625 MB
===== (8, 512, 16) =====
Optimized Memory used: 1.0009765625 MB
Vanilla Memory used: 16.7509765625 MB
===== (8, 512, 32) =====
Optimized Memory used: 2.0009765625 MB
Vanilla Memory used: 17.5009765625 MB
===== (8, 513, 16) =====
Optimized Memory used: 1.0029296875 MB
Vanilla Memory used: 16.81591796875 MB
===== (8, 513, 32) =====
Optimized Memory used: 2.0048828125 MB
Vanilla Memory used: 17.5673828125 MB
===== (8, 1023, 16) =====
Optimized Memory used: 1.9990234375 MB
Vanilla Memory used: 65.49951171875 MB
===== (8, 1023, 32) =====
Optimized Memory used: 3.9970703125 MB
Vanilla Memory used: 66.998046875 MB
===== (8, 1024, 16) =====
Optimized Memory used: 2.0009765625 MB
Vanilla Memory used: 65.5009765625 MB
===== (8, 1024, 32) =====
Optimized Memory used: 4.0009765625 MB
Vanilla Memory used: 67.0009765625 MB
===== (32, 127, 16) =====
Optimized Memory used: 0.9931640625 MB
Vanilla Memory used: 4.68359375 MB
===== (32, 127, 32) =====
Optimized Memory used: 1.9853515625 MB
Vanilla Memory used: 5.427734375 MB
===== (32, 128, 16) =====
Optimized Memory used: 1.0009765625 MB
Vanilla Memory used: 4.7509765625 MB
===== (32, 128, 32) =====
Optimized Memory used: 2.0009765625 MB
Vanilla Memory used: 5.5009765625 MB
===== (32, 512, 16) =====
Optimized Memory used: 4.0009765625 MB
Vanilla Memory used: 67.0009765625 MB
===== (32, 512, 32) =====
Optimized Memory used: 8.0009765625 MB
Vanilla Memory used: 70.0009765625 MB
===== (32, 513, 16) =====
Optimized Memory used: 5.87939453125 MB
Vanilla Memory used: 69.12841796875 MB
===== (32, 513, 32) =====
Optimized Memory used: 8.0166015625 MB
Vanilla Memory used: 70.263671875 MB
===== (32, 1023, 16) =====
Optimized Memory used: 8.0068359375 MB
Vanilla Memory used: 262.0087890625 MB
===== (32, 1023, 32) =====
Optimized Memory used: 15.9931640625 MB
Vanilla Memory used: 267.9970703125 MB
===== (32, 1024, 16) =====
Optimized Memory used: 8.0048828125 MB
Vanilla Memory used: 262.0048828125 MB
===== (32, 1024, 32) =====
Optimized Memory used: 16.0087890625 MB
Vanilla Memory used: 268.0087890625 MB
===== (256, 127, 16) =====
Optimized Memory used: 7.9892578125 MB
Vanilla Memory used: 38.0048828125 MB
===== (256, 127, 32) =====
Optimized Memory used: 15.9462890625 MB
Vanilla Memory used: 43.9423828125 MB
===== (256, 128, 16) =====
Optimized Memory used: 8.0556640625 MB
Vanilla Memory used: 38.0556640625 MB
===== (256, 128, 32) =====
Optimized Memory used: 16.0009765625 MB
Vanilla Memory used: 44.0009765625 MB
===== (256, 512, 16) =====
Optimized Memory used: 32.0009765625 MB
Vanilla Memory used: 536.0009765625 MB
===== (256, 512, 32) =====
Optimized Memory used: 64.0009765625 MB
Vanilla Memory used: 560.0009765625 MB
===== (256, 513, 16) =====
Optimized Memory used: 32.0634765625 MB
Vanilla Memory used: 540.0478515625 MB
===== (256, 513, 32) =====
Optimized Memory used: 64.1259765625 MB
Vanilla Memory used: 564.0947265625 MB
===== (256, 1023, 16) =====
Optimized Memory used: 63.9384765625 MB
Vanilla Memory used: 2091.9560546875 MB
===== (256, 1023, 32) =====
Optimized Memory used: 127.9384765625 MB
Vanilla Memory used: 2139.9716796875 MB
===== (256, 1024, 16) =====
Optimized Memory used: 64.0009765625 MB
Vanilla Memory used: 2096.0009765625 MB
===== (256, 1024, 32) =====
Optimized Memory used: 128.0009765625 MB
Vanilla Memory used: 2144.0009765625 MB
[------------------- attention -------------------]
                           |  optimized  |  vanilla
1 threads: ----------------------------------------
      B=1, M=127, K=16     |      24.8   |     43.1
      B=1, M=127, K=32     |      38.3   |     42.8
      B=1, M=128, K=16     |      22.3   |     42.7
      B=1, M=128, K=32     |      40.6   |     42.6
      B=1, M=512, K=16     |      68.0   |     43.3
      B=1, M=512, K=32     |     124.9   |     43.0
      B=1, M=513, K=16     |      69.3   |     43.4
      B=1, M=513, K=32     |     127.7   |     44.2
      B=1, M=1023, K=16    |     129.5   |     65.6
      B=1, M=1023, K=32    |     237.6   |     77.6
      B=1, M=1024, K=16    |     127.0   |     60.7
      B=1, M=1024, K=32    |     236.3   |     67.6
      B=8, M=127, K=16     |      24.6   |     44.6
      B=8, M=127, K=32     |      38.4   |     44.6
      B=8, M=128, K=16     |      22.0   |     44.9
      B=8, M=128, K=32     |      40.4   |     44.4
      B=8, M=512, K=16     |     104.9   |    161.9
      B=8, M=512, K=32     |     191.1   |    167.8
      B=8, M=513, K=16     |     107.7   |    176.2
      B=8, M=513, K=32     |     195.2   |    184.1
      B=8, M=1023, K=16    |     269.7   |    504.8
      B=8, M=1023, K=32    |     570.8   |    526.6
      B=8, M=1024, K=16    |     262.0   |    490.4
      B=8, M=1024, K=32    |     555.5   |    517.5
      B=32, M=127, K=16    |      32.8   |     67.9
      B=32, M=127, K=32    |      58.3   |     68.8
      B=32, M=128, K=16    |      31.7   |     67.7
      B=32, M=128, K=32    |      55.5   |     68.5
      B=32, M=512, K=16    |     343.2   |    495.0
      B=32, M=512, K=32    |     829.6   |    523.0
      B=32, M=513, K=16    |     396.6   |    565.0
      B=32, M=513, K=32    |     791.6   |    606.9
      B=32, M=1023, K=16   |    1216.4   |   1883.1
      B=32, M=1023, K=32   |    2761.4   |   1978.6
      B=32, M=1024, K=16   |    1173.2   |   1826.1
      B=32, M=1024, K=32   |    2471.5   |   1936.5
      B=256, M=127, K=16   |     266.8   |    259.7
      B=256, M=127, K=32   |     708.9   |    275.0
      B=256, M=128, K=16   |     207.5   |    250.3
      B=256, M=128, K=32   |     442.1   |    267.8
      B=256, M=512, K=16   |    2039.8   |   3428.5
      B=256, M=512, K=32   |    4461.3   |   3673.7
      B=256, M=513, K=16   |    2056.5   |   3921.6
      B=256, M=513, K=32   |    4236.2   |   4294.6
      B=256, M=1023, K=16  |    7886.6   |  13886.5
      B=256, M=1023, K=32  |   16808.8   |  14857.4
      B=256, M=1024, K=16  |    7540.8   |  13333.9
      B=256, M=1024, K=32  |   15666.1   |  14408.2

Times are in microseconds (us).

You can see up to 20x memory savings for larger configurations, while the runtime is in the order of 10% slower than the baseline (which leverages CUBLAS internally).

Next steps

This PR has some assumptions on the dimensionality of K (the feature map after splitting in heads). For now, it should be:

  • if a multiple of 4, K / 4 <= 8
  • else if a multiple of 2, K / 2 <= 8
  • if none of the above, K <= 8

This can be fixed in the future if needed.

Fixes #161.

fmassa added 30 commits March 16, 2022 08:51
It's for now 1000x slower than the baseline
Now we are *only* 6x slower than baseline
Now we are only 50% slower than baseline
Need to fix the buffer size, which is hard-coded for now
THe use of Dot makes it 2.5% faster already
Still need to make it generic wrt query size, and allow further values of K that go beyond the buffer limit
This is commented out for now as it brings a slowdown to the implementation
@fmassa fmassa requested review from blefaudeux and dianaml0 April 11, 2022 13:28
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 11, 2022
@blefaudeux
Copy link
Contributor

@fmassa the doc build issue should disappear after a rebase, this was fixed on main a week ago or so. Else this is really great, having a deeper look !

out = torch.ops.xformers.efficient_attention(query, key, value)
ref = ref_attention(query, key, value)

assert torch.allclose(out, ref, atol=2e-4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to get an idea, do you have a gut feeling on where the small differences come from ? softmax renormalization being a little different with the paper's method seems like an easy explanation, but is there something else ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good question, and my best explanation so far is indeed that we accumulate a bit more errors because we don't know ahead of time the max value over a row, so we need to renormalize (introducing a bit more rounding errors)

at::TensorAccessor<scalar_t, 3> buffer //,
// at::TensorAccessor<int64_t, 2> mask
) {
constexpr int64_t BLOCK = 1; // 8;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing that there's some speed to gain here, to remove some reads / reuse them across a couple of rows ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually this is the size of the fetch over N, my bad, I'm guessing (hoping) that the compiler groups them automatically

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly. Moving from 1 -> 8 brought a significant speedup. I moved it back to 1 before sending the PR because I was a bit lazy and didn't want to bother handling the remainder cases like I do in the GPU code.

That being said, now that I performed unrolling of both dimensions in the CUDA kernel, I could probably copy-paste the CUDA code and change a few things in the hope of making the CPU code faster.

But given that CPU is normally used for prototyping most of the time, I didn't prioritize it further.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, as it is the memory accesses are a little too granular, but not a big issue. I think it's fine not to get it super duper optimized right now, but we could add a comment about it for the future, out of context, us ?

int64_t M = query.size(1);
int64_t N = key.size(1);
int64_t grain_size = 1;
at::parallel_for(0, B, grain_size, [&](int64_t start, int64_t end) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me, I kind of recognize the pattern from the triton version I think, it's definitely a bit more complicated to follow I believe but that works ! I'm not super familiar with at::Tensor but it feels a little strange that one has to call .data() all the time ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the .data_ptr<scalar_t>() call is to get the raw pointers to the tensor. I originally used the TensorAcessor as it adds extra robustness to different strides, but it also adds an extra overhead of striding so I removed it from now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah interesting for the stride part, not a zero cost abstraction.. just a free question, thanks for the context

vec_t k_i = keys[k + K / kVecSize * k_item_idx];
#pragma unroll
for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) {
sputnik::VectorCompute<vec_t>::Dot(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tgale96 pulling you in, just in case there's something here that we can do better (black box for me, does VectorCompute imply tensor cores for instance or normal float ops ?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to the best of my knowledge we don't use any tensor cores explicitly here.

This was a refactoring that I did in cec04e9 to simplify a few things, and it turned out that sputnik already implemented some of the things I needed, so I just took those functions

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alrighty, I think that A100 may show a little worse perf (ping @suchenzang if you're interested in giving this a go ?), but in fp32 it will be fine perf wise up to V100 included (personal guess), not a big problem. The rest of the model can still be fp16 with the torch amp guards, so all in all I think that it's probably still very fast and useful

template <
typename scalar_t,
typename vec_t = float4,
int kBlockSizeK = 32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, I was wondering about these sizes for most of the code read, makes sense now. I'm wondering how architecture dependent the ideal values would be, depending on the shared memory and memory bandwidth could be that the optimal set moves a little ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, those values worked well for P100s, but might not be ideal for other architectures, and should definitely be better tuned for different systems

if ((K % 4) == 0) {
TORCH_CHECK(
K / 4 <= BUFFER_SIZE,
"For now only a certain number of K values are supported. Let us know if you hit this and we will fix it");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<3

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nice error message I meant

@blefaudeux
Copy link
Contributor

looks great to me, it's been a while since I went through this much cuda but kind of works :) Couple of nits here and there, curious for your feedback/opinion really, good to go as far as I'm concerned

@fmassa
Copy link
Contributor Author

fmassa commented Apr 11, 2022

Here are the results on a V100: there seems like indeed we might want to do some extra tuning which is device-dependent to get better performances on V100s. The current hyperparameters are not too bad, but could probably be improved

V100 results
[------------------- attention -------------------]
                           |  optimized  |  vanilla
1 threads: ----------------------------------------
      B=1, M=127, K=16     |      25.9   |     69.1
      B=1, M=127, K=32     |      29.0   |     84.8
      B=1, M=128, K=16     |      25.5   |    107.7
      B=1, M=128, K=32     |      25.8   |    107.1
      B=1, M=512, K=16     |      44.4   |     92.7
      B=1, M=512, K=32     |      79.3   |     98.4
      B=1, M=513, K=16     |      46.6   |    135.2
      B=1, M=513, K=32     |      79.8   |    141.8
      B=1, M=1023, K=16    |     102.1   |     94.0
      B=1, M=1023, K=32    |     188.1   |    100.5
      B=1, M=1024, K=16    |      99.4   |    141.7
      B=1, M=1024, K=32    |     178.1   |    140.9
      B=8, M=127, K=16     |      25.9   |     80.3
      B=8, M=127, K=32     |      32.9   |     82.3
      B=8, M=128, K=16     |      25.6   |     82.5
      B=8, M=128, K=32     |      28.3   |     81.7
      B=8, M=512, K=16     |      78.8   |     97.5
      B=8, M=512, K=32     |     148.6   |    102.5
      B=8, M=513, K=16     |      78.7   |    117.1
      B=8, M=513, K=32     |     150.4   |    122.9
      B=8, M=1023, K=16    |     201.8   |    357.8
      B=8, M=1023, K=32    |     395.0   |    362.0
      B=8, M=1024, K=16    |     194.9   |    315.5
      B=8, M=1024, K=32    |     378.8   |    328.5
      B=32, M=127, K=16    |      29.2   |     82.3
      B=32, M=127, K=32    |      52.1   |     83.0
      B=32, M=128, K=16    |      37.1   |     82.6
      B=32, M=128, K=32    |      44.7   |     92.1
      B=32, M=512, K=16    |     218.9   |    236.8
      B=32, M=512, K=32    |     408.9   |    248.9
      B=32, M=513, K=16    |     239.9   |    291.8
      B=32, M=513, K=32    |     453.6   |    298.9
      B=32, M=1023, K=16   |     759.2   |   1107.0
      B=32, M=1023, K=32   |    1453.7   |   1121.0
      B=32, M=1024, K=16   |     704.6   |    747.8
      B=32, M=1024, K=32   |    1383.6   |    807.2
      B=256, M=127, K=16   |     123.6   |    141.5
      B=256, M=127, K=32   |     264.9   |    153.3
      B=256, M=128, K=16   |     105.1   |    112.9
      B=256, M=128, K=32   |     205.5   |    125.4
      B=256, M=512, K=16   |    1264.4   |   1728.8
      B=256, M=512, K=32   |    2499.5   |   1875.4
      B=256, M=513, K=16   |    1320.2   |   2273.2
      B=256, M=513, K=32   |    2556.9   |   2322.5
      B=256, M=1023, K=16  |    5250.0   |   9485.3
      B=256, M=1023, K=32  |    9909.8   |   9807.8
      B=256, M=1024, K=16  |    4830.7   |   6888.3
      B=256, M=1024, K=32  |    9498.7   |   7564.2

Times are in microseconds (us).

@fmassa
Copy link
Contributor Author

fmassa commented Apr 12, 2022

@blefaudeux I've added an user-facing function (which just dispatches to the kernel implementation for now), plus the benchmark script, I added some more comments and also added the division by sqrt(K) so that it can be a straight replacement for scaled_dot_product_attention on inference mode (for the configurations of K that it supports).

Copy link
Contributor

@blefaudeux blefaudeux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, thanks @fmassa !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[feat] Add a fast implementation of Rabe and Staats algorigthm (mem efficient attention) on GPU
3 participants