Skip to content

Commit

Permalink
Re-organize SLL ops, pt 4 (pytorch#3644)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#720

Pull Request resolved: pytorch#3644

- Re-organize `jagged_flash_attention_basic`, `jagged_softmax`, and `jagged2_softmax`

Differential Revision: D68924000
  • Loading branch information
q10 authored and facebook-github-bot committed Jan 31, 2025
1 parent e7312a3 commit f601628
Show file tree
Hide file tree
Showing 5 changed files with 1,156 additions and 1,121 deletions.
15 changes: 0 additions & 15 deletions fbgemm_gpu/fbgemm_gpu/sll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,11 @@
from fbgemm_gpu.sll.triton_sll import ( # noqa F401
array_jagged_bmm_jagged_out,
dense_jagged_cat_jagged_out,
jagged2_softmax,
jagged2_to_padded_dense,
jagged_dense_bmm,
jagged_dense_elementwise_mul_jagged_out,
jagged_flash_attention_basic,
jagged_jagged_bmm,
jagged_jagged_bmm_jagged_out,
jagged_softmax,
triton_jagged_self_substraction_jagged_out,
)

Expand Down Expand Up @@ -295,14 +292,6 @@
"CUDA": jagged_dense_elementwise_mul_jagged_out,
"AutogradCUDA": jagged_dense_elementwise_mul_jagged_out,
},
"sll_jagged_softmax": {
"CUDA": jagged_softmax,
"AutogradCUDA": jagged_softmax,
},
"sll_jagged2_softmax": {
"CUDA": jagged2_softmax,
"AutogradCUDA": jagged2_softmax,
},
"sll_array_jagged_bmm_jagged_out": {
"CUDA": array_jagged_bmm_jagged_out,
"AutogradCUDA": array_jagged_bmm_jagged_out,
Expand All @@ -311,10 +300,6 @@
"CUDA": jagged_jagged_bmm_jagged_out,
"AutogradCUDA": jagged_jagged_bmm_jagged_out,
},
"sll_jagged_flash_attention_basic": {
"CUDA": jagged_flash_attention_basic,
"AutogradCUDA": jagged_flash_attention_basic,
},
}

for op_name, dispatches in sll_cpu_registrations.items():
Expand Down
26 changes: 25 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,43 @@

# pyre-strict


from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_add import ( # noqa F401
jagged_dense_elementwise_add,
JaggedDenseAdd, # noqa F401
)

from fbgemm_gpu.sll.triton.triton_jagged_dense_flash_attention import ( # noqa F401
jagged_dense_flash_attention,
JaggedDenseFlashAttention, # noqa F401
)

from fbgemm_gpu.sll.triton.triton_jagged_flash_attention_basic import ( # noqa F401
jagged_flash_attention_basic,
JaggedFlashAttentionBasic, # noqa F401
)

from fbgemm_gpu.sll.triton.triton_jagged_softmax import ( # noqa F401
jagged2_softmax,
Jagged2Softmax, # noqa F401
jagged_softmax,
JaggedSoftmax, # noqa F401
)

from fbgemm_gpu.sll.triton.triton_multi_head_jagged_flash_attention import ( # noqa F401
multi_head_jagged_flash_attention,
MultiHeadJaggedFlashAttention, # noqa F401
)

# pyre-ignore[5]
op_registrations = {
"sll_jagged_softmax": {
"CUDA": jagged_softmax,
"AutogradCUDA": jagged_softmax,
},
"sll_jagged2_softmax": {
"CUDA": jagged2_softmax,
"AutogradCUDA": jagged2_softmax,
},
"sll_jagged_dense_elementwise_add": {
"CUDA": jagged_dense_elementwise_add,
"AutogradCUDA": jagged_dense_elementwise_add,
Expand All @@ -32,6 +52,10 @@
"CUDA": jagged_dense_flash_attention,
"AutogradCUDA": jagged_dense_flash_attention,
},
"sll_jagged_flash_attention_basic": {
"CUDA": jagged_flash_attention_basic,
"AutogradCUDA": jagged_flash_attention_basic,
},
"sll_multi_head_jagged_flash_attention": {
"CUDA": multi_head_jagged_flash_attention,
"AutogradCUDA": multi_head_jagged_flash_attention,
Expand Down
Loading

0 comments on commit f601628

Please sign in to comment.