-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory efficient attention - backward pass #281
Changes from all commits
df60515
77535fe
bd6d1f2
54988fe
9ab94a2
aeb49e8
643baaa
788e756
ade14e7
f6427d7
70bfeda
7628805
bd749c9
86e87f9
d3e2140
6cc5768
8d493a7
40b9f43
ca2bb72
63bd286
f1e7c7c
b6b0cfc
c83ebb3
2497f96
24ed9bb
33f0c71
253b3eb
e94d0cd
5706777
69d1aa8
220e046
0b38bf1
a7d1eac
99a6418
5bf4431
222c136
011b2cd
2552da0
a8c1f4b
726d3c5
3f8f954
a43d72b
45ed14c
05ea687
8e7bbb9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,66 +30,147 @@ def ref_attention(q, k, v): | |
results = [] | ||
mem_use: Dict[str, Dict[str, float]] = dict(optimized={}, vanilla={}) | ||
|
||
print(f"Processing {len(SHAPES)} cases") | ||
for num_threads in NUM_THREADS: | ||
for shape in SHAPES: | ||
print(f"===== {shape} =====") | ||
B, M, K = shape | ||
q = torch.rand(shape, device=device) | ||
sub_label = f"B={B}, M={M}, K={K}" | ||
|
||
if True: | ||
r = xformers.ops.memory_efficient_attention(q, q, q) | ||
|
||
rr = ref_attention(q, q, q) | ||
assert (r - rr).abs().max() < 1e-5 | ||
|
||
torch.cuda.reset_peak_memory_stats() | ||
torch.cuda.synchronize() | ||
results.append( | ||
benchmark.Timer( | ||
stmt="fn(q, q, q)", | ||
globals={ | ||
"q": q, | ||
"fn": torch.ops.xformers.efficient_attention, | ||
}, | ||
label="attention", | ||
description="optimized", | ||
sub_label=sub_label, | ||
num_threads=num_threads, | ||
).blocked_autorange(min_run_time=min_run_time) | ||
) | ||
torch.cuda.synchronize() | ||
memory = torch.cuda.max_memory_allocated() / 2 ** 20 | ||
mem_use["optimized"][sub_label] = memory | ||
memory_str = f"Memory used: {memory} MB" | ||
|
||
print("Optimized", memory_str) | ||
|
||
torch.cuda.reset_peak_memory_stats() | ||
torch.cuda.synchronize() | ||
results.append( | ||
benchmark.Timer( | ||
stmt="fn(q, q, q)", | ||
globals={ | ||
"q": q, | ||
"fn": ref_attention, | ||
}, | ||
label="attention", | ||
description="vanilla", | ||
sub_label=sub_label, | ||
num_threads=num_threads, | ||
).blocked_autorange(min_run_time=min_run_time) | ||
) | ||
|
||
torch.cuda.synchronize() | ||
memory = torch.cuda.max_memory_allocated() / 2 ** 20 | ||
mem_use["vanilla"][sub_label] = memory | ||
memory_str = f"Memory used: {memory} MB" | ||
print("Vanilla", memory_str) | ||
|
||
|
||
compare = benchmark.Compare(results) | ||
compare.print() | ||
|
||
pprint.pprint(mem_use) | ||
|
||
def benchmark_forward(): | ||
print(f"Processing {len(SHAPES)} cases") | ||
print("Forward") | ||
for num_threads in NUM_THREADS: | ||
for shape in SHAPES: | ||
print(f"===== {shape} =====") | ||
B, M, K = shape | ||
q = torch.rand(shape, device=device) | ||
sub_label = f"B={B}, M={M}, K={K}" | ||
|
||
if True: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. debug ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it's a debug flag that is sometimes helpful: sometimes I "break" the kernel by removing some parts of the computation and see what speedup I would get. But doing so means that the computation won't be correct anymore, so it was useful to just disable correctness checks. I can remove this in if you want, but as I expect to still do some more performance tuning, I'd like to keep this around for a bit longer if it's ok with you? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's totally ok, flagging it just in case but understood, no worries |
||
r = xformers.ops.memory_efficient_attention(q, q, q) | ||
|
||
rr = ref_attention(q, q, q) | ||
assert (r - rr).abs().max() < 1e-5 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this does not pass on a 3080 / cuda 11.6, could be interesting to test with the T4s on CircleCi, could well be because of TF32 (you would need to switch the torch flag forcing fp32 computations). Implicitly this probably means that the torch implementation switched to tensor cores I think, which changes the time difference in between the two implementations (but not a fundamental issue) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, I should probably change those defaults, or just disable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would switch tf32 off here, I think that it's the best correctness check : you assume fp32 in the kernel, let's check correctness against fp32 ? ( |
||
|
||
torch.cuda.reset_peak_memory_stats() | ||
torch.cuda.synchronize() | ||
results.append( | ||
benchmark.Timer( | ||
stmt="fn(q, q, q)", | ||
globals={ | ||
"q": q, | ||
"fn": xformers.ops.memory_efficient_attention, | ||
}, | ||
label="attention", | ||
description="optimized", | ||
sub_label=sub_label, | ||
num_threads=num_threads, | ||
).blocked_autorange(min_run_time=min_run_time) | ||
) | ||
torch.cuda.synchronize() | ||
memory = torch.cuda.max_memory_allocated() / 2 ** 20 | ||
mem_use["optimized"][sub_label] = memory | ||
memory_str = f"Memory used: {memory} MB" | ||
|
||
print("Optimized", memory_str) | ||
|
||
torch.cuda.reset_peak_memory_stats() | ||
torch.cuda.synchronize() | ||
results.append( | ||
benchmark.Timer( | ||
stmt="fn(q, q, q)", | ||
globals={ | ||
"q": q, | ||
"fn": ref_attention, | ||
}, | ||
label="attention", | ||
description="vanilla", | ||
sub_label=sub_label, | ||
num_threads=num_threads, | ||
).blocked_autorange(min_run_time=min_run_time) | ||
) | ||
|
||
torch.cuda.synchronize() | ||
memory = torch.cuda.max_memory_allocated() / 2 ** 20 | ||
mem_use["vanilla"][sub_label] = memory | ||
memory_str = f"Memory used: {memory} MB" | ||
print("Vanilla", memory_str) | ||
|
||
compare = benchmark.Compare(results) | ||
compare.print() | ||
|
||
pprint.pprint(mem_use) | ||
|
||
|
||
def benchmark_backward(): | ||
print(f"Processing {len(SHAPES)} cases") | ||
print("Backward") | ||
for num_threads in NUM_THREADS: | ||
for shape in SHAPES: | ||
print(f"===== {shape} =====") | ||
B, M, K = shape | ||
q = torch.rand(shape, device=device, requires_grad=True) | ||
sub_label = f"B={B}, M={M}, K={K}" | ||
|
||
if True: | ||
r = xformers.ops.memory_efficient_attention(q, q, q) | ||
r.backward(torch.ones_like(q)) | ||
|
||
grad = q.grad | ||
q.grad = None | ||
|
||
rr = ref_attention(q, q, q) | ||
rr.backward(torch.ones_like(q)) | ||
assert (grad - q.grad).abs().max() < 1e-5 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above, this does not pass on a 3080, guess is because of tf32 vs. float32 (would be the same with a A100, not sure about tf32 on a V100) |
||
|
||
out = xformers.ops.memory_efficient_attention(q, q, q) | ||
grad = torch.ones_like(q) | ||
|
||
torch.cuda.reset_peak_memory_stats() | ||
torch.cuda.synchronize() | ||
results.append( | ||
benchmark.Timer( | ||
stmt="out.backward(grad, retain_graph=True)", | ||
globals={ | ||
"out": out, | ||
"grad": grad, | ||
}, | ||
label="attention", | ||
description="optimized", | ||
sub_label=sub_label, | ||
num_threads=num_threads, | ||
).blocked_autorange(min_run_time=min_run_time) | ||
) | ||
torch.cuda.synchronize() | ||
memory = torch.cuda.max_memory_allocated() / 2 ** 20 | ||
mem_use["optimized"][sub_label] = memory | ||
memory_str = f"Memory used: {memory} MB" | ||
|
||
print("Optimized", memory_str) | ||
|
||
out = ref_attention(q, q, q) | ||
torch.cuda.reset_peak_memory_stats() | ||
torch.cuda.synchronize() | ||
results.append( | ||
benchmark.Timer( | ||
stmt="out.backward(grad, retain_graph=True)", | ||
globals={ | ||
"out": out, | ||
"grad": grad, | ||
}, | ||
label="attention", | ||
description="vanilla", | ||
sub_label=sub_label, | ||
num_threads=num_threads, | ||
).blocked_autorange(min_run_time=min_run_time) | ||
) | ||
|
||
torch.cuda.synchronize() | ||
memory = torch.cuda.max_memory_allocated() / 2 ** 20 | ||
mem_use["vanilla"][sub_label] = memory | ||
memory_str = f"Memory used: {memory} MB" | ||
print("Vanilla", memory_str) | ||
|
||
compare = benchmark.Compare(results) | ||
compare.print() | ||
|
||
pprint.pprint(mem_use) | ||
|
||
|
||
benchmark_forward() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nit] maybe possible to factorize the two, but not super important, good tool to have already ! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, it's totally possible. I've also added in a separate branch a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not urgent and not blocking, same as above more of a mental note, sounds good |
||
benchmark_backward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the comment, helpful !