From 47359a661b171cb1698e186fb7f2052a34cfbc3a Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Thu, 9 Jan 2025 08:17:15 -0800 Subject: [PATCH 01/10] #0: Don't allocate CB for mask if it won't be used --- .../transformer/sdpa/device/kernels/compute/sdpa.cpp | 4 +--- .../transformer/sdpa/device/sdpa_program_factory.cpp | 11 +++++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp index 1fa196dc220..f5d2321de67 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp @@ -58,7 +58,7 @@ void reduce_c() { reduce_init_delta(in0_cb, scale_cb, out_cb); - const uint32_t num_tiles = rows * cols; + constexpr uint32_t num_tiles = rows * cols; cb_wait_front(scale_cb, 1); cb_wait_front(in0_cb, num_tiles); cb_reserve_back(out_cb, rows); @@ -105,7 +105,6 @@ void sub_exp_block_bcast_cols_inplace() { // Precondition: in1_cb has rows produced // Postcondition: in0_cb has rows*cols produced // Postcondition: in1_cb has rows produced - sub_bcast_cols_init_short(in0_cb, in1_cb); exp_tile_init(); cb_wait_front(in0_cb, rows * cols); @@ -290,7 +289,6 @@ void matmul_blocks( // preconditino: in1_cb has K*N produced // postcondition: in0_cb is full, in1_cb is empty // postcondition: out_cb has M*N produced - mm_block_init_short( in0_cb, in1_cb, transpose /*transpose*/, subblock_w /*ct_dim*/, subblock_h /*rt_dim*/, in0_block_w /*kt_dim*/); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp index cc6d481809a..4a9232573e5 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp @@ -417,10 +417,13 @@ operation::ProgramWithCallbacks sdpa_multi_core( .set_page_size(tt::CBIndex::c_2, v_tile_size); auto cb_in2_id = CreateCircularBuffer(program, core_grid, c_in2_config); - // attn_mask input - auto c_in3_config = CircularBufferConfig(mask_tiles * mask_tile_size, {{tt::CBIndex::c_3, mask_df}}) - .set_page_size(tt::CBIndex::c_3, mask_tile_size); - auto cb_in3_id = CreateCircularBuffer(program, core_grid, c_in3_config); + // Only create mask buffer if it's going to be used + if (use_provided_mask or is_causal) { + // attn_mask input + auto c_in3_config = CircularBufferConfig(mask_tiles * mask_tile_size, {{tt::CB::c_in3, mask_df}}) + .set_page_size(tt::CB::c_in3, mask_tile_size); + auto cb_in3_id = CreateCircularBuffer(program, core_grid, c_in3_config); + } // scale input auto c_in4_config = CircularBufferConfig(scale_tiles * scalar_tile_size, {{tt::CBIndex::c_4, scalar_df}}) From 3dce46e4a27a73e330248efcbc6275ffaed9acc7 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Thu, 9 Jan 2025 08:19:39 -0800 Subject: [PATCH 02/10] #0: remove unnecessary copy_block on sum --- .../transformer/sdpa/device/kernels/compute/sdpa.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp index f5d2321de67..5827f49eebb 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp @@ -520,6 +520,8 @@ void MAIN { /* OUT_ACC += OUT_IM */ if (k_chunk == 0) { copy_block(cb_out_im, cb_out_accumulate_im, out_chunk_tiles); + + copy_block(cb_cur_sum, cb_prev_sum, Sq_chunk_t); } else { /* cb_exp_max_diff = torch.exp(cb_prev_max - cb_cur_max) */ sub_exp_block(cb_prev_max, cb_cur_max, cb_exp_max_diff, Sq_chunk_t); @@ -532,7 +534,7 @@ void MAIN { mul_block_bcast_cols_inplace(cb_out_accumulate_im, cb_exp_max_diff, Sq_chunk_t, DHt); /* cb_cur_sum += cb_prev_sum */ - add_block_inplace(cb_cur_sum, cb_prev_sum, Sq_chunk_t); + add_block_inplace(cb_prev_sum, cb_cur_sum, Sq_chunk_t); /* cb_out_accumulate_im += cb_out_im */ add_block_inplace(cb_out_accumulate_im, cb_out_im, out_chunk_tiles); @@ -540,22 +542,19 @@ void MAIN { // Set cb_prev_sum and cb_prev_max copy_block(cb_cur_max, cb_prev_max, Sq_chunk_t); - copy_block(cb_cur_sum, cb_prev_sum, Sq_chunk_t); } /* cb_cur_sum = 1.0 / cb_cur_sum */ - cb_push_back(cb_cur_sum, Sq_chunk_t); - recip_block_inplace(cb_cur_sum, Sq_chunk_t); + recip_block_inplace(cb_prev_sum, Sq_chunk_t); /* cb_out_accumulate_im *= cb_cur_sum */ - mul_block_bcast_cols_inplace(cb_out_accumulate_im, cb_cur_sum, Sq_chunk_t, DHt); + mul_block_bcast_cols_inplace(cb_out_accumulate_im, cb_prev_sum, Sq_chunk_t, DHt); pack_reconfig_data_format(cb_out); copy_block(cb_out_accumulate_im, cb_out, out_chunk_tiles); cb_pop_front(cb_q_in, q_chunk_tiles); // free up cb_prev_max after K chunks cb_pop_front(cb_prev_max, Sq_chunk_t); - cb_pop_front(cb_prev_sum, Sq_chunk_t); } } } From d30d9a8218abfbe74cd83decae58f4a0da01ac1d Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Thu, 9 Jan 2025 08:43:29 -0800 Subject: [PATCH 03/10] #0: v1 Get DST reuse in mul_block_bcast_cols_inplace. Remove a copy_block by aliasing mm2 output cb. --- .../sdpa/device/kernels/compute/sdpa.cpp | 43 ++++++++++++------- .../sdpa/device/sdpa_program_factory.cpp | 1 + 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp index 5827f49eebb..e71ceb1a392 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp @@ -132,25 +132,34 @@ void sub_exp_block_bcast_cols_inplace() { } } -void mul_block_bcast_cols_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t rows, uint32_t cols) { +template +void mul_block_bcast_cols_inplace(uint32_t in0_cb, uint32_t in1_cb) { // Precondition: in0_cb has rows*cols produced // Precondition: in1_cb has rows produced // Postcondition: in0_cb has rows*cols produced // Postcondition: in1_cb has rows consumed - uint32_t num_tiles = rows * cols; + constexpr uint32_t num_tiles = rows * cols; + constexpr uint32_t dst_tiles = DHT_GRANULARITY; + constexpr uint32_t granularity = cols >> LOG2_DHT_GRANULARITY; mul_bcast_cols_init_short(in0_cb, in1_cb); cb_wait_front(in0_cb, num_tiles); cb_wait_front(in1_cb, rows); for (uint32_t i = 0; i < rows; ++i) { - for (uint32_t j = 0; j < cols; ++j) { - acquire_dst(); - mul_tiles_bcast_cols(in0_cb, in1_cb, 0, i, 0); - cb_pop_front(in0_cb, 1); - cb_reserve_back(in0_cb, 1); - pack_tile(0, in0_cb); - cb_push_back(in0_cb, 1); - release_dst(); + for (uint32_t u = 0; u < granularity; ++u) { + tile_regs_acquire(); + for (uint32_t j = 0; j < dst_tiles; ++j) { + mul_tiles_bcast_cols(in0_cb, in1_cb, j, i, j); + } + tile_regs_commit(); + cb_pop_front(in0_cb, dst_tiles); + cb_reserve_back(in0_cb, dst_tiles); + tile_regs_wait(); + for (uint32_t j = 0; j < dst_tiles; ++j) { + pack_tile(j, in0_cb); + } + cb_push_back(in0_cb, dst_tiles); + tile_regs_release(); } } cb_pop_front(in1_cb, rows); @@ -426,6 +435,10 @@ void MAIN { } cb_wait_front(cb_q_in, q_chunk_tiles); + // On (k_chunk == 0), mm2 should store directly in cb_out_accumulate_im + uint32_t alias_mm2_out = cb_out_accumulate_im; + // TODO: alias cur_sum to prev_sum in first iteration + // loop while k_low < q_high for (uint32_t k_chunk = 0; (k_chunk * Sk_chunk_t) < q_high_idx; ++k_chunk) { const uint32_t k_low_idx = k_chunk * Sk_chunk_t; @@ -502,7 +515,7 @@ void MAIN { matmul_blocks( cb_qk_im, cb_v_in, - cb_out_im, + alias_mm2_out, Sq_chunk_t, DHt, Sk_chunk_t, @@ -515,12 +528,12 @@ void MAIN { false /*transpose*/); reconfig_data_format_srca(cb_out_im); + // After first iteration, mm2 should store in cb_out_im + alias_mm2_out = cb_out_im; cb_pop_front(cb_qk_im, qk_chunk_tiles); /* OUT_ACC += OUT_IM */ if (k_chunk == 0) { - copy_block(cb_out_im, cb_out_accumulate_im, out_chunk_tiles); - copy_block(cb_cur_sum, cb_prev_sum, Sq_chunk_t); } else { /* cb_exp_max_diff = torch.exp(cb_prev_max - cb_cur_max) */ @@ -531,7 +544,7 @@ void MAIN { mul_block_inplace(cb_prev_sum, cb_exp_max_diff, Sq_chunk_t); /* cb_out_accumulate_im *= cb_exp_max_diff */ - mul_block_bcast_cols_inplace(cb_out_accumulate_im, cb_exp_max_diff, Sq_chunk_t, DHt); + mul_block_bcast_cols_inplace(cb_out_accumulate_im, cb_exp_max_diff); /* cb_cur_sum += cb_prev_sum */ add_block_inplace(cb_prev_sum, cb_cur_sum, Sq_chunk_t); @@ -548,7 +561,7 @@ void MAIN { recip_block_inplace(cb_prev_sum, Sq_chunk_t); /* cb_out_accumulate_im *= cb_cur_sum */ - mul_block_bcast_cols_inplace(cb_out_accumulate_im, cb_prev_sum, Sq_chunk_t, DHt); + mul_block_bcast_cols_inplace(cb_out_accumulate_im, cb_prev_sum); pack_reconfig_data_format(cb_out); copy_block(cb_out_accumulate_im, cb_out, out_chunk_tiles); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp index 4a9232573e5..a902d490a45 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp @@ -242,6 +242,7 @@ operation::ProgramWithCallbacks sdpa_multi_core( const uint32_t dht_granularity = std::min(DHt, dst_size); const uint32_t log2_dht_granularity = std::log2(dht_granularity); + TT_FATAL(dht_granularity == (1 << log2_dht_granularity), "Error"); // Log these tt::log_debug("stats_granularity: {}", stats_granularity); From 0255f50ff85d45d225ef4a20a7d09fc34320824b Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Thu, 9 Jan 2025 08:47:04 -0800 Subject: [PATCH 04/10] #0: ping pong buffer cur_sum and prev_sum --- .../sdpa/device/kernels/compute/sdpa.cpp | 43 +++++++++---------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp index e71ceb1a392..59d3ca70216 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp @@ -40,15 +40,8 @@ void max_block_inplace() { } } -template < - PoolType pool_type, - ReduceDim reduce_dim, - uint32_t in0_cb, - uint32_t scale_cb, - uint32_t out_cb, - uint32_t rows, - uint32_t cols> -void reduce_c() { +template +void reduce_c(uint32_t out_cb) { // Precondition: in0_cb has rows*cols produced. in0_cb has tiles in row-major order // Precondition: scale_cb has 1 produced // Precondition: out_cb has rows free @@ -397,10 +390,14 @@ void MAIN { constexpr uint32_t cb_out_accumulate_im = tt::CBIndex::c_26; constexpr uint32_t cb_cur_max = tt::CBIndex::c_27; constexpr uint32_t cb_prev_max = tt::CBIndex::c_28; - constexpr uint32_t cb_cur_sum = tt::CBIndex::c_29; - constexpr uint32_t cb_prev_sum = tt::CBIndex::c_30; + constexpr uint32_t cb_sum_A = tt::CBIndex::c_29; + constexpr uint32_t cb_sum_B = tt::CBIndex::c_30; constexpr uint32_t cb_exp_max_diff = tt::CBIndex::c_31; + // Set up ping pong buffers for sum + uint32_t alias_prev_sum = cb_sum_A; + uint32_t alias_cur_sum = cb_sum_B; + constexpr uint32_t cb_out = tt::CBIndex::c_16; mm_init(); @@ -489,9 +486,8 @@ void MAIN { ReduceDim::REDUCE_ROW, cb_qk_im, cb_identity_scale_in, - cb_cur_max, Sq_chunk_t, - Sk_chunk_t>(); + Sk_chunk_t>(cb_cur_max); if (k_chunk > 0) { max_block_inplace(); @@ -507,9 +503,8 @@ void MAIN { ReduceDim::REDUCE_ROW, cb_qk_im, cb_identity_scale_in, - cb_cur_sum, Sq_chunk_t, - Sk_chunk_t>(); + Sk_chunk_t>(alias_cur_sum); /* OUT_IM = QK @ V_CHUNK */ matmul_blocks( @@ -534,22 +529,24 @@ void MAIN { /* OUT_ACC += OUT_IM */ if (k_chunk == 0) { - copy_block(cb_cur_sum, cb_prev_sum, Sq_chunk_t); + // Instead of copy, swap. + std::swap(alias_cur_sum, alias_prev_sum); } else { /* cb_exp_max_diff = torch.exp(cb_prev_max - cb_cur_max) */ sub_exp_block(cb_prev_max, cb_cur_max, cb_exp_max_diff, Sq_chunk_t); cb_pop_front(cb_prev_max, Sq_chunk_t); /* cb_prev_sum *= cb_exp_max_diff */ - mul_block_inplace(cb_prev_sum, cb_exp_max_diff, Sq_chunk_t); + mul_block_inplace(alias_prev_sum, cb_exp_max_diff, Sq_chunk_t); + /* cb_cur_sum += cb_prev_sum */ + add_block_inplace(alias_cur_sum, alias_prev_sum, Sq_chunk_t); + + // Swap alias_prev_sum and alias_cur_sum + std::swap(alias_prev_sum, alias_cur_sum); /* cb_out_accumulate_im *= cb_exp_max_diff */ mul_block_bcast_cols_inplace(cb_out_accumulate_im, cb_exp_max_diff); - /* cb_cur_sum += cb_prev_sum */ - add_block_inplace(cb_prev_sum, cb_cur_sum, Sq_chunk_t); - - /* cb_out_accumulate_im += cb_out_im */ add_block_inplace(cb_out_accumulate_im, cb_out_im, out_chunk_tiles); } @@ -558,10 +555,10 @@ void MAIN { } /* cb_cur_sum = 1.0 / cb_cur_sum */ - recip_block_inplace(cb_prev_sum, Sq_chunk_t); + recip_block_inplace(alias_prev_sum, Sq_chunk_t); /* cb_out_accumulate_im *= cb_cur_sum */ - mul_block_bcast_cols_inplace(cb_out_accumulate_im, cb_prev_sum); + mul_block_bcast_cols_inplace(cb_out_accumulate_im, alias_prev_sum); pack_reconfig_data_format(cb_out); copy_block(cb_out_accumulate_im, cb_out, out_chunk_tiles); From 47b7fc3a2a6cea7c0932b93eb44fa46b98f9b51e Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Thu, 9 Jan 2025 09:44:37 -0800 Subject: [PATCH 05/10] #0: Remove special case for k_chunk 0, move swap after branch. --- .../transformer/sdpa/device/kernels/compute/sdpa.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp index 59d3ca70216..6e19d520423 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp @@ -528,10 +528,7 @@ void MAIN { cb_pop_front(cb_qk_im, qk_chunk_tiles); /* OUT_ACC += OUT_IM */ - if (k_chunk == 0) { - // Instead of copy, swap. - std::swap(alias_cur_sum, alias_prev_sum); - } else { + if (k_chunk > 0) { /* cb_exp_max_diff = torch.exp(cb_prev_max - cb_cur_max) */ sub_exp_block(cb_prev_max, cb_cur_max, cb_exp_max_diff, Sq_chunk_t); cb_pop_front(cb_prev_max, Sq_chunk_t); @@ -541,15 +538,14 @@ void MAIN { /* cb_cur_sum += cb_prev_sum */ add_block_inplace(alias_cur_sum, alias_prev_sum, Sq_chunk_t); - // Swap alias_prev_sum and alias_cur_sum - std::swap(alias_prev_sum, alias_cur_sum); - /* cb_out_accumulate_im *= cb_exp_max_diff */ mul_block_bcast_cols_inplace(cb_out_accumulate_im, cb_exp_max_diff); add_block_inplace(cb_out_accumulate_im, cb_out_im, out_chunk_tiles); } + // Swap alias_prev_sum and alias_cur_sum + std::swap(alias_prev_sum, alias_cur_sum); // Set cb_prev_sum and cb_prev_max copy_block(cb_cur_max, cb_prev_max, Sq_chunk_t); } From 9d665aa6ed4b09b872236286cb41bab025319aac Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Thu, 9 Jan 2025 09:52:14 -0800 Subject: [PATCH 06/10] #0: ping pong mm2 out as well --- .../sdpa/device/kernels/compute/sdpa.cpp | 34 ++++++++----------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp index 6e19d520423..753bf2cfe04 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp @@ -386,18 +386,14 @@ void MAIN { constexpr uint32_t cb_identity_scale_in = tt::CBIndex::c_5; constexpr uint32_t cb_qk_im = tt::CBIndex::c_24; - constexpr uint32_t cb_out_im = tt::CBIndex::c_25; - constexpr uint32_t cb_out_accumulate_im = tt::CBIndex::c_26; + constexpr uint32_t cb_out_im_A = tt::CBIndex::c_25; + constexpr uint32_t cb_out_im_B = tt::CBIndex::c_26; constexpr uint32_t cb_cur_max = tt::CBIndex::c_27; constexpr uint32_t cb_prev_max = tt::CBIndex::c_28; constexpr uint32_t cb_sum_A = tt::CBIndex::c_29; constexpr uint32_t cb_sum_B = tt::CBIndex::c_30; constexpr uint32_t cb_exp_max_diff = tt::CBIndex::c_31; - // Set up ping pong buffers for sum - uint32_t alias_prev_sum = cb_sum_A; - uint32_t alias_cur_sum = cb_sum_B; - constexpr uint32_t cb_out = tt::CBIndex::c_16; mm_init(); @@ -430,12 +426,14 @@ void MAIN { } else { q_high_idx = Skt; } - cb_wait_front(cb_q_in, q_chunk_tiles); - // On (k_chunk == 0), mm2 should store directly in cb_out_accumulate_im - uint32_t alias_mm2_out = cb_out_accumulate_im; - // TODO: alias cur_sum to prev_sum in first iteration + // Set up ping pong buffers + uint32_t alias_prev_sum = cb_sum_A; + uint32_t alias_cur_sum = cb_sum_B; + uint32_t alias_mm2_prev_out = cb_out_im_A; + uint32_t alias_mm2_cur_out = cb_out_im_B; + cb_wait_front(cb_q_in, q_chunk_tiles); // loop while k_low < q_high for (uint32_t k_chunk = 0; (k_chunk * Sk_chunk_t) < q_high_idx; ++k_chunk) { const uint32_t k_low_idx = k_chunk * Sk_chunk_t; @@ -510,7 +508,7 @@ void MAIN { matmul_blocks( cb_qk_im, cb_v_in, - alias_mm2_out, + alias_mm2_cur_out, Sq_chunk_t, DHt, Sk_chunk_t, @@ -522,9 +520,7 @@ void MAIN { out_subblock_w, false /*transpose*/); - reconfig_data_format_srca(cb_out_im); - // After first iteration, mm2 should store in cb_out_im - alias_mm2_out = cb_out_im; + reconfig_data_format_srca(alias_mm2_cur_out); cb_pop_front(cb_qk_im, qk_chunk_tiles); /* OUT_ACC += OUT_IM */ @@ -539,13 +535,13 @@ void MAIN { add_block_inplace(alias_cur_sum, alias_prev_sum, Sq_chunk_t); /* cb_out_accumulate_im *= cb_exp_max_diff */ - mul_block_bcast_cols_inplace(cb_out_accumulate_im, cb_exp_max_diff); - - add_block_inplace(cb_out_accumulate_im, cb_out_im, out_chunk_tiles); + mul_block_bcast_cols_inplace(alias_mm2_prev_out, cb_exp_max_diff); + add_block_inplace(alias_mm2_cur_out, alias_mm2_prev_out, out_chunk_tiles); } // Swap alias_prev_sum and alias_cur_sum std::swap(alias_prev_sum, alias_cur_sum); + std::swap(alias_mm2_prev_out, alias_mm2_cur_out); // Set cb_prev_sum and cb_prev_max copy_block(cb_cur_max, cb_prev_max, Sq_chunk_t); } @@ -554,9 +550,9 @@ void MAIN { recip_block_inplace(alias_prev_sum, Sq_chunk_t); /* cb_out_accumulate_im *= cb_cur_sum */ - mul_block_bcast_cols_inplace(cb_out_accumulate_im, alias_prev_sum); + mul_block_bcast_cols_inplace(alias_mm2_prev_out, alias_prev_sum); pack_reconfig_data_format(cb_out); - copy_block(cb_out_accumulate_im, cb_out, out_chunk_tiles); + copy_block(alias_mm2_prev_out, cb_out, out_chunk_tiles); cb_pop_front(cb_q_in, q_chunk_tiles); // free up cb_prev_max after K chunks From 0e4ad8e70bfdcf11f97e33c255c2003158a58080 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Fri, 17 Jan 2025 06:40:09 -0800 Subject: [PATCH 07/10] #0: Remove copy for prev max. Keep copy for cb_out due to PCC bug if mul_block_bcast_cols writes direclty to cb_out --- .../sdpa/device/kernels/compute/sdpa.cpp | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp index 753bf2cfe04..5efa7645e72 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp @@ -17,8 +17,8 @@ #include "compute_kernel_api/reduce.h" namespace NAMESPACE { -template -void max_block_inplace() { +template +void max_block_inplace(uint32_t in0, uint32_t in1) { // inputs come in full, outputs go out full copy_tile_to_dst_init_short(in0); max_tile_init(); @@ -92,8 +92,8 @@ void recip_block_inplace(uint32_t in_cb, uint32_t num_tiles) { } } -template -void sub_exp_block_bcast_cols_inplace() { +template +void sub_exp_block_bcast_cols_inplace(uint32_t in1_cb) { // Precondition: in0_cb has rows*cols produced // Precondition: in1_cb has rows produced // Postcondition: in0_cb has rows*cols produced @@ -388,8 +388,8 @@ void MAIN { constexpr uint32_t cb_qk_im = tt::CBIndex::c_24; constexpr uint32_t cb_out_im_A = tt::CBIndex::c_25; constexpr uint32_t cb_out_im_B = tt::CBIndex::c_26; - constexpr uint32_t cb_cur_max = tt::CBIndex::c_27; - constexpr uint32_t cb_prev_max = tt::CBIndex::c_28; + constexpr uint32_t cb_max_A = tt::CBIndex::c_27; + constexpr uint32_t cb_max_B = tt::CBIndex::c_28; constexpr uint32_t cb_sum_A = tt::CBIndex::c_29; constexpr uint32_t cb_sum_B = tt::CBIndex::c_30; constexpr uint32_t cb_exp_max_diff = tt::CBIndex::c_31; @@ -430,6 +430,8 @@ void MAIN { // Set up ping pong buffers uint32_t alias_prev_sum = cb_sum_A; uint32_t alias_cur_sum = cb_sum_B; + uint32_t alias_prev_max = cb_max_A; + uint32_t alias_cur_max = cb_max_B; uint32_t alias_mm2_prev_out = cb_out_im_A; uint32_t alias_mm2_cur_out = cb_out_im_B; @@ -485,15 +487,15 @@ void MAIN { cb_qk_im, cb_identity_scale_in, Sq_chunk_t, - Sk_chunk_t>(cb_cur_max); + Sk_chunk_t>(alias_cur_max); if (k_chunk > 0) { - max_block_inplace(); + max_block_inplace(alias_cur_max, alias_prev_max); } /* QK -= cb_cur_max */ /* QK = exp(QK)*/ - sub_exp_block_bcast_cols_inplace(); + sub_exp_block_bcast_cols_inplace(alias_cur_max); /* cb_cur_sum = sum(cb_qk_im, dim=-1) */ reduce_c< @@ -526,8 +528,8 @@ void MAIN { /* OUT_ACC += OUT_IM */ if (k_chunk > 0) { /* cb_exp_max_diff = torch.exp(cb_prev_max - cb_cur_max) */ - sub_exp_block(cb_prev_max, cb_cur_max, cb_exp_max_diff, Sq_chunk_t); - cb_pop_front(cb_prev_max, Sq_chunk_t); + sub_exp_block(alias_prev_max, alias_cur_max, cb_exp_max_diff, Sq_chunk_t); + cb_pop_front(alias_prev_max, Sq_chunk_t); /* cb_prev_sum *= cb_exp_max_diff */ mul_block_inplace(alias_prev_sum, cb_exp_max_diff, Sq_chunk_t); @@ -542,21 +544,21 @@ void MAIN { // Swap alias_prev_sum and alias_cur_sum std::swap(alias_prev_sum, alias_cur_sum); std::swap(alias_mm2_prev_out, alias_mm2_cur_out); - // Set cb_prev_sum and cb_prev_max - copy_block(cb_cur_max, cb_prev_max, Sq_chunk_t); + std::swap(alias_prev_max, alias_cur_max); } /* cb_cur_sum = 1.0 / cb_cur_sum */ recip_block_inplace(alias_prev_sum, Sq_chunk_t); /* cb_out_accumulate_im *= cb_cur_sum */ + // NOTE: PCC bug if we modify below function to directy output to cb_out. mul_block_bcast_cols_inplace(alias_mm2_prev_out, alias_prev_sum); pack_reconfig_data_format(cb_out); copy_block(alias_mm2_prev_out, cb_out, out_chunk_tiles); cb_pop_front(cb_q_in, q_chunk_tiles); // free up cb_prev_max after K chunks - cb_pop_front(cb_prev_max, Sq_chunk_t); + cb_pop_front(alias_prev_max, Sq_chunk_t); } } } From 4aae4ae57380e4029f6a430ae2834370811ce860 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Thu, 9 Jan 2025 10:11:33 -0800 Subject: [PATCH 08/10] #0: Reorder a few lines. (cherry picked from commit 705f06493ea99ce4215a9cdb021f3ea37802f272) --- .../transformer/sdpa/device/kernels/compute/sdpa.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp index 5efa7645e72..98b10de21f8 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp @@ -294,13 +294,13 @@ void matmul_blocks( mm_block_init_short( in0_cb, in1_cb, transpose /*transpose*/, subblock_w /*ct_dim*/, subblock_h /*rt_dim*/, in0_block_w /*kt_dim*/); - reconfig_data_format(in1_cb, in0_cb); - cb_wait_front(in1_cb, K * N); - uint32_t output_num_tiles = M * N; uint32_t out_subblock_num_tiles = subblock_h * subblock_w; uint32_t in0_index_offset = 0; + reconfig_data_format(in1_cb, in0_cb); + cb_wait_front(in1_cb, K * N); + for (uint32_t in0_subblock = 0; in0_subblock < in0_num_subblocks; ++in0_subblock) { uint32_t in1_index_offset = 0; for (uint32_t in1_subblock = 0; in1_subblock < in1_num_subblocks; ++in1_subblock) { From f144ccfa5c1e372246d26e47ea7dbbefac0c8b36 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Thu, 9 Jan 2025 10:24:38 -0800 Subject: [PATCH 09/10] #0: Fix reconfig DF, fix granularity calc when not power of 2 (cherry picked from commit 94b6e96d179a096e30565eac87df5b0dd9cc4485) --- .../transformer/sdpa/device/kernels/compute/sdpa.cpp | 4 +++- .../transformer/sdpa/device/sdpa_program_factory.cpp | 9 +++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp index 98b10de21f8..76465f534e5 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp @@ -78,6 +78,8 @@ void recip_block_inplace(uint32_t in_cb, uint32_t num_tiles) { // Postcondition: in_cb has num_tiles produced copy_tile_to_dst_init_short(in_cb); recip_tile_init(); + reconfig_data_format_srca(in_cb); + pack_reconfig_data_format(in_cb); cb_wait_front(in_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; ++i) { @@ -522,8 +524,8 @@ void MAIN { out_subblock_w, false /*transpose*/); - reconfig_data_format_srca(alias_mm2_cur_out); cb_pop_front(cb_qk_im, qk_chunk_tiles); + reconfig_data_format(alias_prev_max, alias_cur_max); /* OUT_ACC += OUT_IM */ if (k_chunk > 0) { diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp index a902d490a45..6e7b45af947 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp @@ -240,8 +240,13 @@ operation::ProgramWithCallbacks sdpa_multi_core( const uint32_t log2_mul_bcast_granularity = std::log2(mul_bcast_granularity); TT_FATAL(mul_bcast_granularity == (1 << log2_mul_bcast_granularity), "Error"); - const uint32_t dht_granularity = std::min(DHt, dst_size); - const uint32_t log2_dht_granularity = std::log2(dht_granularity); + uint32_t dht_granularity = std::min(DHt, dst_size); + uint32_t log2_dht_granularity = std::log2(dht_granularity); + // Sometimes DHt is not a power of 2, so granularity should be 1 + if (dht_granularity != (1 << log2_dht_granularity)) { + dht_granularity = 1; + log2_dht_granularity = 0; + } TT_FATAL(dht_granularity == (1 << log2_dht_granularity), "Error"); // Log these From ece57f849770707c0b8e4f2d5f6437b147cd942f Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Thu, 16 Jan 2025 12:00:52 -0800 Subject: [PATCH 10/10] #0: enhance error messages (cherry picked from commit 00486eb0ec1e0f45efbfaf65b64a231baab0bccc) --- .../sdpa/device/sdpa_program_factory.cpp | 31 +++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp index 6e7b45af947..c1c61b13d39 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp @@ -141,7 +141,10 @@ operation::ProgramWithCallbacks sdpa_multi_core( uint32_t num_cores = grid_size.x * grid_size.y; TT_FATAL( - num_cores <= device->compute_with_storage_grid_size().x * device->compute_with_storage_grid_size().y, "Error"); + num_cores <= device->compute_with_storage_grid_size().x * device->compute_with_storage_grid_size().y, + "Provided grid must not contain more cores than the device. Got {} cores, expected at most {} cores.", + num_cores, + device->compute_with_storage_grid_size().x * device->compute_with_storage_grid_size().y); // Parallelization scheme // We will choose parallelization factors for batch, num_heads, and q_seq_len in that order @@ -149,7 +152,11 @@ operation::ProgramWithCallbacks sdpa_multi_core( uint32_t nh_parallel_factor = std::min(num_cores / batch_parallel_factor, NQH); uint32_t q_parallel_factor = std::min(num_cores / (batch_parallel_factor * nh_parallel_factor), q_num_chunks); - TT_FATAL(batch_parallel_factor * nh_parallel_factor * q_parallel_factor <= num_cores, "Error"); + TT_FATAL( + batch_parallel_factor * nh_parallel_factor * q_parallel_factor <= num_cores, + "Parallelism must not exceed number of cores. Got {}, expected at most {}.", + batch_parallel_factor * nh_parallel_factor * q_parallel_factor, + num_cores); tt::log_debug("Parallelization scheme:"); tt::log_debug("batch_parallel_factor: {}", batch_parallel_factor); @@ -230,15 +237,24 @@ operation::ProgramWithCallbacks sdpa_multi_core( // Find log2 of stats_granularity using std const uint32_t log2_stats_granularity = std::log2(stats_granularity); // Assert that this is a power of 2 - TT_FATAL(stats_granularity == (1 << log2_stats_granularity), "Error"); + TT_FATAL( + stats_granularity == (1 << log2_stats_granularity), + "stats_granularity must be a power of 2. Got {}.", + stats_granularity); const uint32_t sub_exp_granularity = std::min(Sk_chunk_t, dst_size); const uint32_t log2_sub_exp_granularity = std::log2(sub_exp_granularity); - TT_FATAL(sub_exp_granularity == (1 << log2_sub_exp_granularity), "Error"); + TT_FATAL( + sub_exp_granularity == (1 << log2_sub_exp_granularity), + "sub_exp_granularity must be a power of 2. Got {}.", + sub_exp_granularity); const uint32_t mul_bcast_granularity = std::min(Sq_chunk_t * Sk_chunk_t, dst_size); const uint32_t log2_mul_bcast_granularity = std::log2(mul_bcast_granularity); - TT_FATAL(mul_bcast_granularity == (1 << log2_mul_bcast_granularity), "Error"); + TT_FATAL( + mul_bcast_granularity == (1 << log2_mul_bcast_granularity), + "mul_bcast_granularity must be a power of 2. Got {}.", + mul_bcast_granularity); uint32_t dht_granularity = std::min(DHt, dst_size); uint32_t log2_dht_granularity = std::log2(dht_granularity); @@ -247,7 +263,10 @@ operation::ProgramWithCallbacks sdpa_multi_core( dht_granularity = 1; log2_dht_granularity = 0; } - TT_FATAL(dht_granularity == (1 << log2_dht_granularity), "Error"); + TT_FATAL( + dht_granularity == (1 << log2_dht_granularity), + "dht_granularity must be a power of 2. Got {}.", + dht_granularity); // Log these tt::log_debug("stats_granularity: {}", stats_granularity);