-
Notifications
You must be signed in to change notification settings - Fork 635
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Memory-efficient attention - forward pass (#267)
* Add initial CPU version of memory efficient attention * Try adding naive vectorization * Cleanups * Initial vectorization * Factorize into unfoldable loops * Minor cleanup * Add super naive CUDA implementation It's for now 1000x slower than the baseline * 10x speedup on CUDA kernel * Another 20x speedup on CUDA kernel Now we are *only* 6x slower than baseline * Make it 13% faster with larger BLOCK size * Get extra 25% speedup * Make it 2x faster Now we are only 50% slower than baseline * Minor code cleanups * Make it 30% faster Need to fix the buffer size, which is hard-coded for now * Remove hard-coded constants and use some sputnik helpers THe use of Dot makes it 2.5% faster already * Start doing some refactoring * Further cleanups * More cleanup * Run clang-format * Use vec_t + cleanups * Rename some variables * clang-format * Some more variable renaming * clang-format * More refactoring * Make implementation generic wrt key size and feature dim Still need to make it generic wrt query size, and allow further values of K that go beyond the buffer limit * Speedup by almost 2x when key size is not a multiple of 32 * Make kernel generic wrt query size * Statically unroll the different cases This is commented out for now as it brings a slowdown to the implementation * More cleanups * Add tests + bugfixes * Add more checks and cleanups * clang-format * Address reviewer comments Improve code comments * Add benchmark script * Divide by sqrt(K) and add user-facing API * Appease mypy * Add copyright notice
- Loading branch information
Showing
6 changed files
with
921 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pytest | ||
import torch | ||
|
||
import xformers.ops | ||
|
||
_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] | ||
|
||
|
||
def ref_attention(q, k, v): | ||
q = q * (1 / q.shape[-1] ** 0.5) | ||
return (q @ k.transpose(-2, -1)).softmax(-1) @ v | ||
|
||
|
||
@pytest.mark.parametrize("k_len", [5, 6, 32]) | ||
@pytest.mark.parametrize("batch_size", [1, 4]) | ||
@pytest.mark.parametrize("kv_len", [3, 15, 32, 33]) | ||
@pytest.mark.parametrize("q_len", [2, 3, 5]) | ||
@pytest.mark.parametrize("device", _devices) | ||
def test_memory_efficient_attention(device, q_len, kv_len, batch_size, k_len): | ||
scale = 3 | ||
query = torch.randn((batch_size, q_len, k_len), device=device) * scale | ||
key = torch.randn((batch_size, kv_len, k_len), device=device) * scale | ||
value = torch.randn((batch_size, kv_len, k_len), device=device) * scale | ||
|
||
out = xformers.ops.memory_efficient_attention(query, key, value) | ||
ref = ref_attention(query, key, value) | ||
|
||
assert torch.allclose(out, ref, atol=2e-4) | ||
|
||
|
||
@pytest.mark.parametrize("k_len", [5, 6, 32]) | ||
@pytest.mark.parametrize("batch_size", [1, 4]) | ||
@pytest.mark.parametrize("kv_len", [128, 512]) | ||
@pytest.mark.parametrize("q_len", [128, 512]) | ||
@pytest.mark.parametrize("device", _devices) | ||
def test_key_query_all_ones(device, q_len, kv_len, batch_size, k_len): | ||
scale = 3 | ||
query = torch.ones((batch_size, q_len, k_len), device=device) | ||
key = torch.ones((batch_size, kv_len, k_len), device=device) | ||
value = torch.randn((batch_size, kv_len, k_len), device=device) * scale | ||
|
||
out = xformers.ops.memory_efficient_attention(query, key, value) | ||
# this should be equivalent to the average over value | ||
ref = value.mean(1, keepdim=True).expand_as(query) | ||
|
||
assert torch.allclose(out, ref, atol=1e-5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import itertools | ||
import pprint | ||
from typing import Dict | ||
|
||
import torch | ||
from torch.utils import benchmark | ||
|
||
import xformers.ops | ||
|
||
|
||
def ref_attention(q, k, v): | ||
q = q * (1.0 / q.shape[-1] ** 0.5) | ||
return (q @ k.transpose(-2, -1)).softmax(-1) @ v | ||
|
||
|
||
min_run_time = 2 | ||
device = torch.device("cuda") | ||
|
||
NUM_THREADS = [1] if device.type == "cuda" else [1, 40] | ||
SHAPES = list( | ||
itertools.product([1, 8, 32, 256], [127, 128, 512, 513, 1023, 1024], [16, 32]) | ||
) | ||
|
||
results = [] | ||
mem_use: Dict[str, Dict[str, float]] = dict(optimized={}, vanilla={}) | ||
|
||
print(f"Processing {len(SHAPES)} cases") | ||
for num_threads in NUM_THREADS: | ||
for shape in SHAPES: | ||
print(f"===== {shape} =====") | ||
B, M, K = shape | ||
q = torch.rand(shape, device=device) | ||
sub_label = f"B={B}, M={M}, K={K}" | ||
|
||
if True: | ||
r = xformers.ops.memory_efficient_attention(q, q, q) | ||
|
||
rr = ref_attention(q, q, q) | ||
assert (r - rr).abs().max() < 1e-5 | ||
|
||
torch.cuda.reset_peak_memory_stats() | ||
torch.cuda.synchronize() | ||
results.append( | ||
benchmark.Timer( | ||
stmt="fn(q, q, q)", | ||
globals={ | ||
"q": q, | ||
"fn": torch.ops.xformers.efficient_attention, | ||
}, | ||
label="attention", | ||
description="optimized", | ||
sub_label=sub_label, | ||
num_threads=num_threads, | ||
).blocked_autorange(min_run_time=min_run_time) | ||
) | ||
torch.cuda.synchronize() | ||
memory = torch.cuda.max_memory_allocated() / 2 ** 20 | ||
mem_use["optimized"][sub_label] = memory | ||
memory_str = f"Memory used: {memory} MB" | ||
|
||
print("Optimized", memory_str) | ||
|
||
torch.cuda.reset_peak_memory_stats() | ||
torch.cuda.synchronize() | ||
results.append( | ||
benchmark.Timer( | ||
stmt="fn(q, q, q)", | ||
globals={ | ||
"q": q, | ||
"fn": ref_attention, | ||
}, | ||
label="attention", | ||
description="vanilla", | ||
sub_label=sub_label, | ||
num_threads=num_threads, | ||
).blocked_autorange(min_run_time=min_run_time) | ||
) | ||
|
||
torch.cuda.synchronize() | ||
memory = torch.cuda.max_memory_allocated() / 2 ** 20 | ||
mem_use["vanilla"][sub_label] = memory | ||
memory_str = f"Memory used: {memory} MB" | ||
print("Vanilla", memory_str) | ||
|
||
|
||
compare = benchmark.Compare(results) | ||
compare.print() | ||
|
||
pprint.pprint(mem_use) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#include <torch/types.h> | ||
|
||
TORCH_LIBRARY_FRAGMENT(xformers, m) { | ||
m.def(TORCH_SELECTIVE_SCHEMA( | ||
"xformers::efficient_attention(Tensor query, Tensor key, Tensor value) -> Tensor")); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/Parallel.h> | ||
#include <torch/library.h> | ||
#include <cmath> | ||
#include <vector> | ||
|
||
#include <ATen/cpu/vec/functional.h> | ||
#include <ATen/cpu/vec/vec.h> | ||
|
||
namespace { | ||
|
||
template <typename scalar_t> | ||
void fill_zero(scalar_t* buf, int64_t K) { | ||
for (int64_t k = 0; k < K; k++) { | ||
buf[k] = 0; | ||
} | ||
} | ||
|
||
template <typename scalar_t, int K> | ||
scalar_t max(scalar_t* buf) { | ||
scalar_t m = buf[0]; | ||
for (int64_t k = 1; k < K; k++) { | ||
m = buf[k] > m ? buf[k] : m; | ||
} | ||
return m; | ||
} | ||
|
||
template <typename scalar_t> | ||
void attention_kernel( | ||
at::TensorAccessor<scalar_t, 3> output, | ||
at::TensorAccessor<scalar_t, 3> query, | ||
at::TensorAccessor<scalar_t, 3> key, | ||
at::TensorAccessor<scalar_t, 3> value, | ||
at::TensorAccessor<scalar_t, 3> buffer //, | ||
// at::TensorAccessor<int64_t, 2> mask | ||
) { | ||
// TODO: optimize the code by adding blocking | ||
// over multiple dimensions. Doing this allows | ||
// the compiler to group reads and operations | ||
// for vectorization | ||
constexpr int64_t BLOCK = 1; // 8; | ||
int64_t K = query.size(2); | ||
int64_t B = query.size(0); | ||
int64_t M = query.size(1); | ||
int64_t N = key.size(1); | ||
int64_t grain_size = 1; | ||
scalar_t scale = 1.0 / std::sqrt(scalar_t(K)); | ||
at::parallel_for(0, B, grain_size, [&](int64_t start, int64_t end) { | ||
auto buf = buffer[at::get_thread_num()][0].data(); | ||
for (int64_t i = start; i < end; i++) { | ||
for (int64_t j = 0; j < M; j++) { | ||
fill_zero<scalar_t>(buf, K); | ||
auto aar = query[i][j].data(); | ||
scalar_t s_prime = 0; | ||
scalar_t m_prime = -std::numeric_limits<scalar_t>::infinity(); | ||
for (int64_t l = 0; l < N; l += BLOCK) { | ||
auto bar = key[i][l].data(); | ||
scalar_t si[BLOCK] = {0}; | ||
for (int64_t k = 0; k < K; k++) { | ||
auto aaar = aar[k] * scale; | ||
for (int64_t rr = 0; rr < BLOCK; rr++) | ||
si[rr] += aaar * bar[k + K * rr]; | ||
} | ||
scalar_t m_i = si[0] > m_prime ? si[0] : m_prime; | ||
for (int64_t rr = 1; rr < BLOCK; rr++) { | ||
m_i = si[rr] > m_i ? si[rr] : m_i; | ||
} | ||
|
||
auto vi = value[i][l].data(); | ||
|
||
scalar_t m_delta; | ||
scalar_t s_delta[BLOCK]; | ||
m_delta = std::exp(m_prime - m_i); | ||
|
||
for (int64_t rr = 0; rr < BLOCK; rr++) | ||
s_delta[rr] = std::exp(si[rr] - m_i); | ||
|
||
for (int64_t k = 0; k < K; k++) { | ||
buf[k] = buf[k] * m_delta; | ||
for (int64_t rr = 0; rr < BLOCK; rr++) | ||
buf[k] += vi[k + K * rr] * s_delta[rr]; | ||
} | ||
s_prime = s_prime * m_delta; | ||
for (int64_t rr = 0; rr < BLOCK; rr++) | ||
s_prime += s_delta[rr]; | ||
|
||
m_prime = m_i; | ||
} | ||
auto oo = output[i][j].data(); | ||
for (int64_t k = 0; k < K; k++) { | ||
oo[k] = buf[k] / s_prime; | ||
} | ||
} | ||
} | ||
}); | ||
} | ||
|
||
at::Tensor attention( | ||
const at::Tensor& query, | ||
const at::Tensor& key, | ||
const at::Tensor& value | ||
// const at::Tensor& mask | ||
) { | ||
TORCH_CHECK(query.dim() == key.dim()); | ||
TORCH_CHECK(query.dim() == value.dim()); | ||
// TORCH_CHECK(query.dim() == mask.dim()); | ||
TORCH_CHECK(query.dim() == 3); | ||
TORCH_CHECK(query.size(2) == key.size(2)); | ||
TORCH_CHECK(query.size(0) == key.size(0)); | ||
|
||
TORCH_CHECK(query.size(0) == value.size(0)); | ||
TORCH_CHECK(key.size(1) == value.size(1)); | ||
TORCH_CHECK( | ||
query.size(2) == | ||
value.size(2)); // TODO: drop this limitation in the future | ||
|
||
TORCH_CHECK(!query.is_cuda(), "query must be a CPU tensor"); | ||
TORCH_CHECK(!key.is_cuda(), "key must be a CPU tensor"); | ||
TORCH_CHECK(!value.is_cuda(), "value must be a CPU tensor"); | ||
|
||
TORCH_CHECK(!query.is_sparse(), "query must be a dense tensor"); | ||
TORCH_CHECK(!key.is_sparse(), "key must be a dense tensor"); | ||
TORCH_CHECK(!value.is_sparse(), "value must be a dense tensor"); | ||
|
||
TORCH_CHECK(query.is_contiguous()); | ||
TORCH_CHECK(key.is_contiguous()); | ||
TORCH_CHECK(value.is_contiguous()); | ||
|
||
int64_t B = query.size(0); | ||
int64_t M = query.size(1); | ||
int64_t N = key.size(1); | ||
int64_t K = query.size(2); | ||
|
||
at::Tensor res = at::empty({B, M, K}, query.options()); | ||
|
||
at::Tensor buffer = at::empty({at::get_num_threads(), 1, K}, query.options()); | ||
|
||
AT_DISPATCH_FLOATING_TYPES(query.scalar_type(), "attention_kernel", [&] { | ||
attention_kernel<scalar_t>( | ||
res.accessor<scalar_t, 3>(), | ||
query.accessor<scalar_t, 3>(), | ||
key.accessor<scalar_t, 3>(), | ||
value.accessor<scalar_t, 3>(), | ||
buffer.accessor<scalar_t, 3>()); | ||
}); | ||
|
||
return res; | ||
} | ||
|
||
} // namespace | ||
|
||
TORCH_LIBRARY_IMPL(xformers, CPU, m) { | ||
m.impl( | ||
TORCH_SELECTIVE_NAME("xformers::efficient_attention"), | ||
TORCH_FN(attention)); | ||
} |
Oops, something went wrong.