Skip to content

Commit

Permalink
Re-organize SLL ops, pt 2
Browse files Browse the repository at this point in the history
Summary: - Re-organize `jagged_dense_flash_attention`

Differential Revision: D68916405
  • Loading branch information
q10 authored and facebook-github-bot committed Jan 30, 2025
1 parent e2495d7 commit 85a9626
Show file tree
Hide file tree
Showing 4 changed files with 874 additions and 854 deletions.
5 changes: 0 additions & 5 deletions fbgemm_gpu/fbgemm_gpu/sll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
jagged_dense_bmm,
jagged_dense_elementwise_add,
jagged_dense_elementwise_mul_jagged_out,
jagged_dense_flash_attention,
jagged_flash_attention_basic,
jagged_jagged_bmm,
jagged_jagged_bmm_jagged_out,
Expand Down Expand Up @@ -321,10 +320,6 @@
"CUDA": jagged_dense_elementwise_add,
"AutogradCUDA": jagged_dense_elementwise_add,
},
"sll_jagged_dense_flash_attention": {
"CUDA": jagged_dense_flash_attention,
"AutogradCUDA": jagged_dense_flash_attention,
},
}

for op_name, dispatches in sll_cpu_registrations.items():
Expand Down
14 changes: 12 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,22 @@
# pyre-strict


from fbgemm_gpu.sll.triton.multi_head_jagged_flash_attention import ( # noqa F401
from fbgemm_gpu.sll.triton.jagged_dense_flash_attention import (

Check failure on line 11 in fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py

View workflow job for this annotation

GitHub Actions / run-lint (3.13)

F401 'fbgemm_gpu.sll.triton.jagged_dense_flash_attention.JaggedDenseFlashAttention' imported but unused
jagged_dense_flash_attention,
JaggedDenseFlashAttention, # noqa F401
)

from fbgemm_gpu.sll.triton.multi_head_jagged_flash_attention import (

Check failure on line 16 in fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py

View workflow job for this annotation

GitHub Actions / run-lint (3.13)

F401 'fbgemm_gpu.sll.triton.multi_head_jagged_flash_attention.MultiHeadJaggedFlashAttention' imported but unused
multi_head_jagged_flash_attention,
MultiHeadJaggedFlashAttention,
MultiHeadJaggedFlashAttention, # noqa F401
)

# pyre-ignore[5]
op_registrations = {
"sll_jagged_dense_flash_attention": {
"CUDA": jagged_dense_flash_attention,
"AutogradCUDA": jagged_dense_flash_attention,
},
"sll_multi_head_jagged_flash_attention": {
"CUDA": multi_head_jagged_flash_attention,
"AutogradCUDA": multi_head_jagged_flash_attention,
Expand Down
Loading

0 comments on commit 85a9626

Please sign in to comment.