Skip to content

Commit

Permalink
Memory-efficient attention - forward pass (#267)
Browse files Browse the repository at this point in the history
* Add initial CPU version of memory efficient attention

* Try adding naive vectorization

* Cleanups

* Initial vectorization

* Factorize into unfoldable loops

* Minor cleanup

* Add super naive CUDA implementation

It's for now 1000x slower than the baseline

* 10x speedup on CUDA kernel

* Another 20x speedup on CUDA kernel

Now we are *only* 6x slower than baseline

* Make it 13% faster with larger BLOCK size

* Get extra 25% speedup

* Make it 2x faster

Now we are only 50% slower than baseline

* Minor code cleanups

* Make it 30% faster

Need to fix the buffer size, which is hard-coded for now

* Remove hard-coded constants and use some sputnik helpers

THe use of Dot makes it 2.5% faster already

* Start doing some refactoring

* Further cleanups

* More cleanup

* Run clang-format

* Use vec_t + cleanups

* Rename some variables

* clang-format

* Some more variable renaming

* clang-format

* More refactoring

* Make implementation generic wrt key size and feature dim

Still need to make it generic wrt query size, and allow further values of K that go beyond the buffer limit

* Speedup by almost 2x when key size is not a multiple of 32

* Make kernel generic wrt query size

* Statically unroll the different cases

This is commented out for now as it brings a slowdown to the implementation

* More cleanups

* Add tests + bugfixes

* Add more checks and cleanups

* clang-format

* Address reviewer comments

Improve code comments

* Add benchmark script

* Divide by sqrt(K) and add user-facing API

* Appease mypy

* Add copyright notice
  • Loading branch information
fmassa authored Apr 12, 2022
1 parent fb7bbcb commit f78ba0a
Show file tree
Hide file tree
Showing 6 changed files with 921 additions and 0 deletions.
51 changes: 51 additions & 0 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

import xformers.ops

_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]


def ref_attention(q, k, v):
q = q * (1 / q.shape[-1] ** 0.5)
return (q @ k.transpose(-2, -1)).softmax(-1) @ v


@pytest.mark.parametrize("k_len", [5, 6, 32])
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("kv_len", [3, 15, 32, 33])
@pytest.mark.parametrize("q_len", [2, 3, 5])
@pytest.mark.parametrize("device", _devices)
def test_memory_efficient_attention(device, q_len, kv_len, batch_size, k_len):
scale = 3
query = torch.randn((batch_size, q_len, k_len), device=device) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device) * scale

out = xformers.ops.memory_efficient_attention(query, key, value)
ref = ref_attention(query, key, value)

assert torch.allclose(out, ref, atol=2e-4)


@pytest.mark.parametrize("k_len", [5, 6, 32])
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("kv_len", [128, 512])
@pytest.mark.parametrize("q_len", [128, 512])
@pytest.mark.parametrize("device", _devices)
def test_key_query_all_ones(device, q_len, kv_len, batch_size, k_len):
scale = 3
query = torch.ones((batch_size, q_len, k_len), device=device)
key = torch.ones((batch_size, kv_len, k_len), device=device)
value = torch.randn((batch_size, kv_len, k_len), device=device) * scale

out = xformers.ops.memory_efficient_attention(query, key, value)
# this should be equivalent to the average over value
ref = value.mean(1, keepdim=True).expand_as(query)

assert torch.allclose(out, ref, atol=1e-5)
95 changes: 95 additions & 0 deletions xformers/benchmarks/benchmark_mem_eff_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import itertools
import pprint
from typing import Dict

import torch
from torch.utils import benchmark

import xformers.ops


def ref_attention(q, k, v):
q = q * (1.0 / q.shape[-1] ** 0.5)
return (q @ k.transpose(-2, -1)).softmax(-1) @ v


min_run_time = 2
device = torch.device("cuda")

NUM_THREADS = [1] if device.type == "cuda" else [1, 40]
SHAPES = list(
itertools.product([1, 8, 32, 256], [127, 128, 512, 513, 1023, 1024], [16, 32])
)

results = []
mem_use: Dict[str, Dict[str, float]] = dict(optimized={}, vanilla={})

print(f"Processing {len(SHAPES)} cases")
for num_threads in NUM_THREADS:
for shape in SHAPES:
print(f"===== {shape} =====")
B, M, K = shape
q = torch.rand(shape, device=device)
sub_label = f"B={B}, M={M}, K={K}"

if True:
r = xformers.ops.memory_efficient_attention(q, q, q)

rr = ref_attention(q, q, q)
assert (r - rr).abs().max() < 1e-5

torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
results.append(
benchmark.Timer(
stmt="fn(q, q, q)",
globals={
"q": q,
"fn": torch.ops.xformers.efficient_attention,
},
label="attention",
description="optimized",
sub_label=sub_label,
num_threads=num_threads,
).blocked_autorange(min_run_time=min_run_time)
)
torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated() / 2 ** 20
mem_use["optimized"][sub_label] = memory
memory_str = f"Memory used: {memory} MB"

print("Optimized", memory_str)

torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
results.append(
benchmark.Timer(
stmt="fn(q, q, q)",
globals={
"q": q,
"fn": ref_attention,
},
label="attention",
description="vanilla",
sub_label=sub_label,
num_threads=num_threads,
).blocked_autorange(min_run_time=min_run_time)
)

torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated() / 2 ** 20
mem_use["vanilla"][sub_label] = memory
memory_str = f"Memory used: {memory} MB"
print("Vanilla", memory_str)


compare = benchmark.Compare(results)
compare.print()

pprint.pprint(mem_use)
6 changes: 6 additions & 0 deletions xformers/components/attention/csrc/attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#include <torch/types.h>

TORCH_LIBRARY_FRAGMENT(xformers, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::efficient_attention(Tensor query, Tensor key, Tensor value) -> Tensor"));
}
156 changes: 156 additions & 0 deletions xformers/components/attention/csrc/cpu/attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <torch/library.h>
#include <cmath>
#include <vector>

#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>

namespace {

template <typename scalar_t>
void fill_zero(scalar_t* buf, int64_t K) {
for (int64_t k = 0; k < K; k++) {
buf[k] = 0;
}
}

template <typename scalar_t, int K>
scalar_t max(scalar_t* buf) {
scalar_t m = buf[0];
for (int64_t k = 1; k < K; k++) {
m = buf[k] > m ? buf[k] : m;
}
return m;
}

template <typename scalar_t>
void attention_kernel(
at::TensorAccessor<scalar_t, 3> output,
at::TensorAccessor<scalar_t, 3> query,
at::TensorAccessor<scalar_t, 3> key,
at::TensorAccessor<scalar_t, 3> value,
at::TensorAccessor<scalar_t, 3> buffer //,
// at::TensorAccessor<int64_t, 2> mask
) {
// TODO: optimize the code by adding blocking
// over multiple dimensions. Doing this allows
// the compiler to group reads and operations
// for vectorization
constexpr int64_t BLOCK = 1; // 8;
int64_t K = query.size(2);
int64_t B = query.size(0);
int64_t M = query.size(1);
int64_t N = key.size(1);
int64_t grain_size = 1;
scalar_t scale = 1.0 / std::sqrt(scalar_t(K));
at::parallel_for(0, B, grain_size, [&](int64_t start, int64_t end) {
auto buf = buffer[at::get_thread_num()][0].data();
for (int64_t i = start; i < end; i++) {
for (int64_t j = 0; j < M; j++) {
fill_zero<scalar_t>(buf, K);
auto aar = query[i][j].data();
scalar_t s_prime = 0;
scalar_t m_prime = -std::numeric_limits<scalar_t>::infinity();
for (int64_t l = 0; l < N; l += BLOCK) {
auto bar = key[i][l].data();
scalar_t si[BLOCK] = {0};
for (int64_t k = 0; k < K; k++) {
auto aaar = aar[k] * scale;
for (int64_t rr = 0; rr < BLOCK; rr++)
si[rr] += aaar * bar[k + K * rr];
}
scalar_t m_i = si[0] > m_prime ? si[0] : m_prime;
for (int64_t rr = 1; rr < BLOCK; rr++) {
m_i = si[rr] > m_i ? si[rr] : m_i;
}

auto vi = value[i][l].data();

scalar_t m_delta;
scalar_t s_delta[BLOCK];
m_delta = std::exp(m_prime - m_i);

for (int64_t rr = 0; rr < BLOCK; rr++)
s_delta[rr] = std::exp(si[rr] - m_i);

for (int64_t k = 0; k < K; k++) {
buf[k] = buf[k] * m_delta;
for (int64_t rr = 0; rr < BLOCK; rr++)
buf[k] += vi[k + K * rr] * s_delta[rr];
}
s_prime = s_prime * m_delta;
for (int64_t rr = 0; rr < BLOCK; rr++)
s_prime += s_delta[rr];

m_prime = m_i;
}
auto oo = output[i][j].data();
for (int64_t k = 0; k < K; k++) {
oo[k] = buf[k] / s_prime;
}
}
}
});
}

at::Tensor attention(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value
// const at::Tensor& mask
) {
TORCH_CHECK(query.dim() == key.dim());
TORCH_CHECK(query.dim() == value.dim());
// TORCH_CHECK(query.dim() == mask.dim());
TORCH_CHECK(query.dim() == 3);
TORCH_CHECK(query.size(2) == key.size(2));
TORCH_CHECK(query.size(0) == key.size(0));

TORCH_CHECK(query.size(0) == value.size(0));
TORCH_CHECK(key.size(1) == value.size(1));
TORCH_CHECK(
query.size(2) ==
value.size(2)); // TODO: drop this limitation in the future

TORCH_CHECK(!query.is_cuda(), "query must be a CPU tensor");
TORCH_CHECK(!key.is_cuda(), "key must be a CPU tensor");
TORCH_CHECK(!value.is_cuda(), "value must be a CPU tensor");

TORCH_CHECK(!query.is_sparse(), "query must be a dense tensor");
TORCH_CHECK(!key.is_sparse(), "key must be a dense tensor");
TORCH_CHECK(!value.is_sparse(), "value must be a dense tensor");

TORCH_CHECK(query.is_contiguous());
TORCH_CHECK(key.is_contiguous());
TORCH_CHECK(value.is_contiguous());

int64_t B = query.size(0);
int64_t M = query.size(1);
int64_t N = key.size(1);
int64_t K = query.size(2);

at::Tensor res = at::empty({B, M, K}, query.options());

at::Tensor buffer = at::empty({at::get_num_threads(), 1, K}, query.options());

AT_DISPATCH_FLOATING_TYPES(query.scalar_type(), "attention_kernel", [&] {
attention_kernel<scalar_t>(
res.accessor<scalar_t, 3>(),
query.accessor<scalar_t, 3>(),
key.accessor<scalar_t, 3>(),
value.accessor<scalar_t, 3>(),
buffer.accessor<scalar_t, 3>());
});

return res;
}

} // namespace

TORCH_LIBRARY_IMPL(xformers, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("xformers::efficient_attention"),
TORCH_FN(attention));
}
Loading

0 comments on commit f78ba0a

Please sign in to comment.