-
Notifications
You must be signed in to change notification settings - Fork 96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Minor SDPA optimizations #16566
Minor SDPA optimizations #16566
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good!
dht_granularity = 1; | ||
log2_dht_granularity = 0; | ||
} | ||
TT_FATAL(dht_granularity == (1 << log2_dht_granularity), "Error"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe better error messaging?
…lock by aliasing mm2 output cb.
c293e7c
to
00486eb
Compare
I found that one of the optimizations in this branch, using mul_block_bcast_cols to write directly to cb_out, leads to inexplicable PCC issues in Llama tests. I was able to reproduce this in a chunked prefill unit test, but it's unclear why this optimization leads to different outputs from before. |
00486eb
to
ece57f8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Ticket
Subtask of #16557
Problem description
SDPA has quite a few unnecessary operations which make it inefficient, especially as sequence length grows.
What's changed
std::swap
mul_block_bcast_cols_accumulate
dhead=96
For the following test case, we get a nice 1.084x speedup.
tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py::test_sdpa_tt_large_seq[1-8-1-131072-128-k128-q128-bf16]
Checklist