Skip to content

Commit

Permalink
#0: Fix reconfig DF, fix granularity calc when not power of 2
Browse files Browse the repository at this point in the history
  • Loading branch information
cglagovichTT committed Jan 16, 2025
1 parent 705f064 commit 94b6e96
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ void recip_block_inplace(uint32_t in_cb, uint32_t num_tiles) {
// Postcondition: in_cb has num_tiles produced
copy_tile_to_dst_init_short(in_cb);
recip_tile_init();
reconfig_data_format_srca(in_cb);
pack_reconfig_data_format(in_cb);

cb_wait_front(in_cb, num_tiles);
for (uint32_t i = 0; i < num_tiles; ++i) {
Expand Down Expand Up @@ -505,8 +507,8 @@ void MAIN {
out_subblock_w,
false /*transpose*/);

reconfig_data_format_srca(alias_mm2_cur_out);
cb_pop_front(cb_qk_im, qk_chunk_tiles);
reconfig_data_format(alias_prev_max, alias_cur_max);

/* OUT_ACC += OUT_IM */
if (k_chunk > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,13 @@ operation::ProgramWithCallbacks sdpa_multi_core(
const uint32_t log2_mul_bcast_granularity = std::log2(mul_bcast_granularity);
TT_FATAL(mul_bcast_granularity == (1 << log2_mul_bcast_granularity), "Error");

const uint32_t dht_granularity = std::min(DHt, dst_size);
const uint32_t log2_dht_granularity = std::log2(dht_granularity);
uint32_t dht_granularity = std::min(DHt, dst_size);
uint32_t log2_dht_granularity = std::log2(dht_granularity);
// Sometimes DHt is not a power of 2, so granularity should be 1
if (dht_granularity != (1 << log2_dht_granularity)) {
dht_granularity = 1;
log2_dht_granularity = 0;
}
TT_FATAL(dht_granularity == (1 << log2_dht_granularity), "Error");

// Log these
Expand Down

0 comments on commit 94b6e96

Please sign in to comment.