Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor SDPA optimizations #16566

Merged
merged 10 commits into from
Jan 17, 2025
Merged

Minor SDPA optimizations #16566

merged 10 commits into from
Jan 17, 2025

Conversation

cglagovichTT
Copy link
Contributor

@cglagovichTT cglagovichTT commented Jan 9, 2025

Ticket

Subtask of #16557

Problem description

SDPA has quite a few unnecessary operations which make it inefficient, especially as sequence length grows.

What's changed

  • Remove all block copies by efficiently ping-ponging buffers with aliases and std::swap
  • Use L1 accumulation to update intermediate output to avoid an extra unpack/add/pack
  • Reuse DST in mul_block_bcast_cols_accumulate
  • Fix bug with DST reuse and dhead=96
  • Don't allocate CB for mask if a mask isn't used. This reduces memory waste and enables larger chunk sizes

For the following test case, we get a nice 1.084x speedup.
tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py::test_sdpa_tt_large_seq[1-8-1-131072-128-k128-q128-bf16]

ID  Total %  Bound  OP Code                    Device Time   Op-to-Op Gap  Cores  DRAM  DRAM %  FLOPs  FLOPs %  Math Fidelity
----------------------------------------------------------------------------------------------------------------------------------
2  100.0 %         OLD ScaledDotProductAttention  1,949,831 us                   64                                BF16, BF16 => BF16
2  100.0 %         NEW ScaledDotProductAttention  1,798,710 us                   64                                BF16, BF16 => BF16

Checklist

Copy link
Contributor

@caixunshiren caixunshiren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good!

dht_granularity = 1;
log2_dht_granularity = 0;
}
TT_FATAL(dht_granularity == (1 << log2_dht_granularity), "Error");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better error messaging?

@cglagovichTT cglagovichTT requested a review from a team as a code owner January 16, 2025 20:01
@cglagovichTT
Copy link
Contributor Author

I found that one of the optimizations in this branch, using mul_block_bcast_cols to write directly to cb_out, leads to inexplicable PCC issues in Llama tests. I was able to reproduce this in a chunked prefill unit test, but it's unclear why this optimization leads to different outputs from before.

Copy link
Contributor

@caixunshiren caixunshiren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@cglagovichTT cglagovichTT merged commit 9a3766d into main Jan 17, 2025
219 of 223 checks passed
@cglagovichTT cglagovichTT deleted the cglagovich/sdpa_opt branch January 17, 2025 18:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants