Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory efficient attention - backward pass #281

Merged
merged 45 commits into from
Apr 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
df60515
Add naive CPU implementation for memory-efficient attention backward
fmassa Apr 13, 2022
77535fe
Optimize (at least) by a factor 2
fmassa Apr 13, 2022
bd6d1f2
More cleanups
fmassa Apr 13, 2022
54988fe
A few more comments
fmassa Apr 13, 2022
9ab94a2
Add very naive CUDA implementation
fmassa Apr 13, 2022
aeb49e8
Speedup CUDA kernel by 5x
fmassa Apr 13, 2022
643baaa
Make logsumexp an argument
fmassa Apr 13, 2022
788e756
Make it 30% faster
fmassa Apr 14, 2022
ade14e7
3.5x speedup by blocking strategy
fmassa Apr 14, 2022
f6427d7
Use vector loads and improve tile selection
fmassa Apr 14, 2022
70bfeda
Recompute attention for grad_q computation
fmassa Apr 15, 2022
7628805
Smal cleanups
fmassa Apr 15, 2022
bd749c9
clang-format
fmassa Apr 15, 2022
86e87f9
Make it 0.5% faster
fmassa Apr 15, 2022
d3e2140
Make it 1% faster by caching the loads
fmassa Apr 15, 2022
6cc5768
Make it 6% faster with better hyperparameters
fmassa Apr 15, 2022
8d493a7
Slightly better hyperparameter
fmassa Apr 15, 2022
40b9f43
axpy == FMA
fmassa Apr 17, 2022
ca2bb72
Separate grad_q into its own kernel
fmassa Apr 19, 2022
63bd286
Avoid additional global writes by recomputing grad_aatn_v in grad_k
fmassa Apr 19, 2022
f1e7c7c
Trying out new idea
fmassa Apr 20, 2022
b6b0cfc
Almost on par with my previous best implementation
fmassa Apr 20, 2022
c83ebb3
Improve perf by 5%
fmassa Apr 20, 2022
2497f96
Remove query-key from shared memory and increase tile size
fmassa Apr 20, 2022
24ed9bb
Make it 20% faster with better hyperparameters
fmassa Apr 20, 2022
33f0c71
Make it another 12% faster
fmassa Apr 20, 2022
253b3eb
Code cleanup
fmassa Apr 20, 2022
e94d0cd
Further cleanups
fmassa Apr 20, 2022
5706777
Variable rename
fmassa Apr 20, 2022
69d1aa8
clang-format
fmassa Apr 20, 2022
220e046
Add alternative implementation for grad_v
fmassa Apr 20, 2022
0b38bf1
Speed it up by 10% with better hyperparameters
fmassa Apr 20, 2022
a7d1eac
Delete old implementation
fmassa Apr 20, 2022
99a6418
Centralize all input accesses in the beginning
fmassa Apr 21, 2022
5bf4431
Bugfix
fmassa Apr 21, 2022
222c136
Make kernels generic wrt sequence length
fmassa Apr 22, 2022
011b2cd
Add template argument to skip bound checking
fmassa Apr 22, 2022
2552da0
Make it support all use-cases
fmassa Apr 22, 2022
a8c1f4b
Let logsumexp be returned by forward
fmassa Apr 22, 2022
726d3c5
clang-format
fmassa Apr 22, 2022
3f8f954
Add scaling factor
fmassa Apr 22, 2022
a43d72b
Add tests + silly bugfix
fmassa Apr 22, 2022
45ed14c
Add benchmark function for backward
fmassa Apr 22, 2022
05ea687
Add comment
fmassa Apr 22, 2022
8e7bbb9
clang-format
fmassa Apr 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,58 @@ def test_key_query_all_ones(device, q_len, kv_len, batch_size, k_len):
ref = value.mean(1, keepdim=True).expand_as(query)

assert torch.allclose(out, ref, atol=1e-5)


@pytest.mark.parametrize("k_len", [5, 6, 32])
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("kv_len", [3, 15, 32, 33])
@pytest.mark.parametrize("q_len", [2, 3, 5])
@pytest.mark.parametrize("device", _devices)
def test_logsumexp(device, q_len, kv_len, batch_size, k_len):
scale = 3
query = torch.randn((batch_size, q_len, k_len), device=device) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device) * scale

_, lse = torch.ops.xformers.efficient_attention(query, key, value, True)
ref_lse = ((query / k_len ** 0.5) @ key.transpose(-2, -1)).logsumexp(-1)

assert torch.allclose(lse, ref_lse, atol=2e-4)


@pytest.mark.parametrize("k_len", [5, 6, 32])
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("kv_len", [3, 15, 32, 33])
@pytest.mark.parametrize("q_len", [2, 3, 5])
@pytest.mark.parametrize("device", _devices)
def test_memory_efficient_attention_backward(device, q_len, kv_len, batch_size, k_len):
scale = 3
query = torch.randn((batch_size, q_len, k_len), device=device) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device) * scale

query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)

out = xformers.ops.memory_efficient_attention(query, key, value)
out.backward(torch.ones_like(query))

grad_q = query.grad
grad_k = key.grad
grad_v = value.grad

query.grad = None
key.grad = None
value.grad = None

ref = ref_attention(query, key, value)
ref.backward(torch.ones_like(query))

# there is some extra precision loss in the CPU implementation due to an
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the comment, helpful !

# extra accumulation step in grad_q, which is not present in the CUDA
# implementation
atol = 3e-4 if device == "cuda" else 4e-4
assert torch.allclose(grad_q, query.grad, atol=atol), "grad_q doesn't match"
assert torch.allclose(grad_k, key.grad, atol=atol), "grad_k doesn't match"
assert torch.allclose(grad_v, value.grad, atol=atol), "grad_v doesn't match"
207 changes: 144 additions & 63 deletions xformers/benchmarks/benchmark_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,66 +30,147 @@ def ref_attention(q, k, v):
results = []
mem_use: Dict[str, Dict[str, float]] = dict(optimized={}, vanilla={})

print(f"Processing {len(SHAPES)} cases")
for num_threads in NUM_THREADS:
for shape in SHAPES:
print(f"===== {shape} =====")
B, M, K = shape
q = torch.rand(shape, device=device)
sub_label = f"B={B}, M={M}, K={K}"

if True:
r = xformers.ops.memory_efficient_attention(q, q, q)

rr = ref_attention(q, q, q)
assert (r - rr).abs().max() < 1e-5

torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
results.append(
benchmark.Timer(
stmt="fn(q, q, q)",
globals={
"q": q,
"fn": torch.ops.xformers.efficient_attention,
},
label="attention",
description="optimized",
sub_label=sub_label,
num_threads=num_threads,
).blocked_autorange(min_run_time=min_run_time)
)
torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated() / 2 ** 20
mem_use["optimized"][sub_label] = memory
memory_str = f"Memory used: {memory} MB"

print("Optimized", memory_str)

torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
results.append(
benchmark.Timer(
stmt="fn(q, q, q)",
globals={
"q": q,
"fn": ref_attention,
},
label="attention",
description="vanilla",
sub_label=sub_label,
num_threads=num_threads,
).blocked_autorange(min_run_time=min_run_time)
)

torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated() / 2 ** 20
mem_use["vanilla"][sub_label] = memory
memory_str = f"Memory used: {memory} MB"
print("Vanilla", memory_str)


compare = benchmark.Compare(results)
compare.print()

pprint.pprint(mem_use)

def benchmark_forward():
print(f"Processing {len(SHAPES)} cases")
print("Forward")
for num_threads in NUM_THREADS:
for shape in SHAPES:
print(f"===== {shape} =====")
B, M, K = shape
q = torch.rand(shape, device=device)
sub_label = f"B={B}, M={M}, K={K}"

if True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debug ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's a debug flag that is sometimes helpful: sometimes I "break" the kernel by removing some parts of the computation and see what speedup I would get. But doing so means that the computation won't be correct anymore, so it was useful to just disable correctness checks.

I can remove this in if you want, but as I expect to still do some more performance tuning, I'd like to keep this around for a bit longer if it's ok with you?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's totally ok, flagging it just in case but understood, no worries

r = xformers.ops.memory_efficient_attention(q, q, q)

rr = ref_attention(q, q, q)
assert (r - rr).abs().max() < 1e-5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does not pass on a 3080 / cuda 11.6, could be interesting to test with the T4s on CircleCi, could well be because of TF32 (you would need to switch the torch flag forcing fp32 computations). Implicitly this probably means that the torch implementation switched to tensor cores I think, which changes the time difference in between the two implementations (but not a fundamental issue)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I should probably change those defaults, or just disable TF32 in the benchmarks (but that makes it for slower baselines), or just disable this correctness check by default. Which one would you prefer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would switch tf32 off here, I think that it's the best correctness check : you assume fp32 in the kernel, let's check correctness against fp32 ? (torch.backends.cuda.matmul.allow_tf32 = False)
Good to keep in mind in the benchmarks that the comparison is not iso-accuracy by the way, your implementation is actually more precise :)


torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
results.append(
benchmark.Timer(
stmt="fn(q, q, q)",
globals={
"q": q,
"fn": xformers.ops.memory_efficient_attention,
},
label="attention",
description="optimized",
sub_label=sub_label,
num_threads=num_threads,
).blocked_autorange(min_run_time=min_run_time)
)
torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated() / 2 ** 20
mem_use["optimized"][sub_label] = memory
memory_str = f"Memory used: {memory} MB"

print("Optimized", memory_str)

torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
results.append(
benchmark.Timer(
stmt="fn(q, q, q)",
globals={
"q": q,
"fn": ref_attention,
},
label="attention",
description="vanilla",
sub_label=sub_label,
num_threads=num_threads,
).blocked_autorange(min_run_time=min_run_time)
)

torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated() / 2 ** 20
mem_use["vanilla"][sub_label] = memory
memory_str = f"Memory used: {memory} MB"
print("Vanilla", memory_str)

compare = benchmark.Compare(results)
compare.print()

pprint.pprint(mem_use)


def benchmark_backward():
print(f"Processing {len(SHAPES)} cases")
print("Backward")
for num_threads in NUM_THREADS:
for shape in SHAPES:
print(f"===== {shape} =====")
B, M, K = shape
q = torch.rand(shape, device=device, requires_grad=True)
sub_label = f"B={B}, M={M}, K={K}"

if True:
r = xformers.ops.memory_efficient_attention(q, q, q)
r.backward(torch.ones_like(q))

grad = q.grad
q.grad = None

rr = ref_attention(q, q, q)
rr.backward(torch.ones_like(q))
assert (grad - q.grad).abs().max() < 1e-5
Copy link
Contributor

@blefaudeux blefaudeux Apr 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above, this does not pass on a 3080, guess is because of tf32 vs. float32 (would be the same with a A100, not sure about tf32 on a V100)


out = xformers.ops.memory_efficient_attention(q, q, q)
grad = torch.ones_like(q)

torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
results.append(
benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
globals={
"out": out,
"grad": grad,
},
label="attention",
description="optimized",
sub_label=sub_label,
num_threads=num_threads,
).blocked_autorange(min_run_time=min_run_time)
)
torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated() / 2 ** 20
mem_use["optimized"][sub_label] = memory
memory_str = f"Memory used: {memory} MB"

print("Optimized", memory_str)

out = ref_attention(q, q, q)
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
results.append(
benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
globals={
"out": out,
"grad": grad,
},
label="attention",
description="vanilla",
sub_label=sub_label,
num_threads=num_threads,
).blocked_autorange(min_run_time=min_run_time)
)

torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated() / 2 ** 20
mem_use["vanilla"][sub_label] = memory
memory_str = f"Memory used: {memory} MB"
print("Vanilla", memory_str)

compare = benchmark.Compare(results)
compare.print()

pprint.pprint(mem_use)


benchmark_forward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] maybe possible to factorize the two, but not super important, good tool to have already !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it's totally possible. I've also added in a separate branch a benchmark_forward_and_backward case, and it started to have quite a bit of duplication. I can look into refactoring this up in a follow-up PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not urgent and not blocking, same as above more of a mental note, sounds good

benchmark_backward()
4 changes: 3 additions & 1 deletion xformers/components/attention/csrc/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@

TORCH_LIBRARY_FRAGMENT(xformers, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::efficient_attention(Tensor query, Tensor key, Tensor value) -> Tensor"));
"xformers::efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_logsumexp) -> (Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::efficient_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor logsumexp) -> (Tensor, Tensor, Tensor)"));
}
Loading