Skip to content

Commit

Permalink
[TRITON][OPS] add Flash Attention v2 to Ops (#1970)
Browse files Browse the repository at this point in the history
I also dropped the do_scaled as it is no longer needed (no scaling done
to the do in v2).

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
  • Loading branch information
IzzyPutterman and ptillet authored Jul 23, 2023
1 parent c9ab448 commit de6f053
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 129 deletions.
24 changes: 12 additions & 12 deletions python/test/regression/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,27 +155,27 @@ def test_elementwise(N, dtype_str):

flash_attention_data = {
"a100": {
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.433,
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.392,
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.106,
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.532,
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.471,
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.150,
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.204,
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.202,
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.089,
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.242,
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.220,
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.069,
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.298,
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.263,
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.095,
(4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.136,
(4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135,
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.052,
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.432,
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.392,
(4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.107,
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.525,
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.471,
(4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.150,
(4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.265,
(4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.257,
(4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.128,
(4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.251,
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.220,
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.069,
(4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.297,
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.263,
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.095,
(4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.159,
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.138,
(4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.076,
Expand Down
Loading

0 comments on commit de6f053

Please sign in to comment.