Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor SDPA optimizations #16566

Merged
merged 10 commits into from
Jan 17, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
#include "compute_kernel_api/reduce.h"

namespace NAMESPACE {
template <uint32_t in0, uint32_t in1, uint32_t num_tiles>
void max_block_inplace() {
template <uint32_t num_tiles>
void max_block_inplace(uint32_t in0, uint32_t in1) {
// inputs come in full, outputs go out full
copy_tile_to_dst_init_short(in0);
max_tile_init();
Expand All @@ -40,15 +40,8 @@ void max_block_inplace() {
}
}

template <
PoolType pool_type,
ReduceDim reduce_dim,
uint32_t in0_cb,
uint32_t scale_cb,
uint32_t out_cb,
uint32_t rows,
uint32_t cols>
void reduce_c() {
template <PoolType pool_type, ReduceDim reduce_dim, uint32_t in0_cb, uint32_t scale_cb, uint32_t rows, uint32_t cols>
void reduce_c(uint32_t out_cb) {
// Precondition: in0_cb has rows*cols produced. in0_cb has tiles in row-major order
// Precondition: scale_cb has 1 produced
// Precondition: out_cb has rows free
Expand All @@ -58,7 +51,7 @@ void reduce_c() {

reduce_init_delta<false, pool_type, reduce_dim>(in0_cb, scale_cb, out_cb);

const uint32_t num_tiles = rows * cols;
constexpr uint32_t num_tiles = rows * cols;
cb_wait_front(scale_cb, 1);
cb_wait_front(in0_cb, num_tiles);
cb_reserve_back(out_cb, rows);
Expand All @@ -85,6 +78,8 @@ void recip_block_inplace(uint32_t in_cb, uint32_t num_tiles) {
// Postcondition: in_cb has num_tiles produced
copy_tile_to_dst_init_short(in_cb);
recip_tile_init();
reconfig_data_format_srca(in_cb);
pack_reconfig_data_format(in_cb);

cb_wait_front(in_cb, num_tiles);
for (uint32_t i = 0; i < num_tiles; ++i) {
Expand All @@ -99,13 +94,12 @@ void recip_block_inplace(uint32_t in_cb, uint32_t num_tiles) {
}
}

template <uint32_t in0_cb, uint32_t in1_cb, uint32_t rows, uint32_t cols>
void sub_exp_block_bcast_cols_inplace() {
template <uint32_t in0_cb, uint32_t rows, uint32_t cols>
void sub_exp_block_bcast_cols_inplace(uint32_t in1_cb) {
// Precondition: in0_cb has rows*cols produced
// Precondition: in1_cb has rows produced
// Postcondition: in0_cb has rows*cols produced
// Postcondition: in1_cb has rows produced

sub_bcast_cols_init_short(in0_cb, in1_cb);
exp_tile_init<true>();
cb_wait_front(in0_cb, rows * cols);
Expand Down Expand Up @@ -133,25 +127,34 @@ void sub_exp_block_bcast_cols_inplace() {
}
}

void mul_block_bcast_cols_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t rows, uint32_t cols) {
template <uint32_t rows, uint32_t cols>
void mul_block_bcast_cols_inplace(uint32_t in0_cb, uint32_t in1_cb) {
// Precondition: in0_cb has rows*cols produced
// Precondition: in1_cb has rows produced
// Postcondition: in0_cb has rows*cols produced
// Postcondition: in1_cb has rows consumed

uint32_t num_tiles = rows * cols;
constexpr uint32_t num_tiles = rows * cols;
constexpr uint32_t dst_tiles = DHT_GRANULARITY;
constexpr uint32_t granularity = cols >> LOG2_DHT_GRANULARITY;
mul_bcast_cols_init_short(in0_cb, in1_cb);
cb_wait_front(in0_cb, num_tiles);
cb_wait_front(in1_cb, rows);
for (uint32_t i = 0; i < rows; ++i) {
for (uint32_t j = 0; j < cols; ++j) {
acquire_dst();
mul_tiles_bcast_cols(in0_cb, in1_cb, 0, i, 0);
cb_pop_front(in0_cb, 1);
cb_reserve_back(in0_cb, 1);
pack_tile(0, in0_cb);
cb_push_back(in0_cb, 1);
release_dst();
for (uint32_t u = 0; u < granularity; ++u) {
tile_regs_acquire();
for (uint32_t j = 0; j < dst_tiles; ++j) {
mul_tiles_bcast_cols(in0_cb, in1_cb, j, i, j);
}
tile_regs_commit();
cb_pop_front(in0_cb, dst_tiles);
cb_reserve_back(in0_cb, dst_tiles);
tile_regs_wait();
for (uint32_t j = 0; j < dst_tiles; ++j) {
pack_tile(j, in0_cb);
}
cb_push_back(in0_cb, dst_tiles);
tile_regs_release();
}
}
cb_pop_front(in1_cb, rows);
Expand Down Expand Up @@ -290,17 +293,16 @@ void matmul_blocks(
// preconditino: in1_cb has K*N produced
// postcondition: in0_cb is full, in1_cb is empty
// postcondition: out_cb has M*N produced

mm_block_init_short(
in0_cb, in1_cb, transpose /*transpose*/, subblock_w /*ct_dim*/, subblock_h /*rt_dim*/, in0_block_w /*kt_dim*/);

reconfig_data_format(in1_cb, in0_cb);
cb_wait_front(in1_cb, K * N);

uint32_t output_num_tiles = M * N;
uint32_t out_subblock_num_tiles = subblock_h * subblock_w;
uint32_t in0_index_offset = 0;

reconfig_data_format(in1_cb, in0_cb);
cb_wait_front(in1_cb, K * N);

for (uint32_t in0_subblock = 0; in0_subblock < in0_num_subblocks; ++in0_subblock) {
uint32_t in1_index_offset = 0;
for (uint32_t in1_subblock = 0; in1_subblock < in1_num_subblocks; ++in1_subblock) {
Expand Down Expand Up @@ -386,12 +388,12 @@ void MAIN {
constexpr uint32_t cb_identity_scale_in = tt::CBIndex::c_5;

constexpr uint32_t cb_qk_im = tt::CBIndex::c_24;
constexpr uint32_t cb_out_im = tt::CBIndex::c_25;
constexpr uint32_t cb_out_accumulate_im = tt::CBIndex::c_26;
constexpr uint32_t cb_cur_max = tt::CBIndex::c_27;
constexpr uint32_t cb_prev_max = tt::CBIndex::c_28;
constexpr uint32_t cb_cur_sum = tt::CBIndex::c_29;
constexpr uint32_t cb_prev_sum = tt::CBIndex::c_30;
constexpr uint32_t cb_out_im_A = tt::CBIndex::c_25;
constexpr uint32_t cb_out_im_B = tt::CBIndex::c_26;
constexpr uint32_t cb_max_A = tt::CBIndex::c_27;
constexpr uint32_t cb_max_B = tt::CBIndex::c_28;
constexpr uint32_t cb_sum_A = tt::CBIndex::c_29;
constexpr uint32_t cb_sum_B = tt::CBIndex::c_30;
constexpr uint32_t cb_exp_max_diff = tt::CBIndex::c_31;

constexpr uint32_t cb_out = tt::CBIndex::c_16;
Expand Down Expand Up @@ -426,8 +428,16 @@ void MAIN {
} else {
q_high_idx = Skt;
}
cb_wait_front(cb_q_in, q_chunk_tiles);

// Set up ping pong buffers
uint32_t alias_prev_sum = cb_sum_A;
uint32_t alias_cur_sum = cb_sum_B;
uint32_t alias_prev_max = cb_max_A;
uint32_t alias_cur_max = cb_max_B;
uint32_t alias_mm2_prev_out = cb_out_im_A;
uint32_t alias_mm2_cur_out = cb_out_im_B;

cb_wait_front(cb_q_in, q_chunk_tiles);
// loop while k_low < q_high
for (uint32_t k_chunk = 0; (k_chunk * Sk_chunk_t) < q_high_idx; ++k_chunk) {
const uint32_t k_low_idx = k_chunk * Sk_chunk_t;
Expand Down Expand Up @@ -478,33 +488,31 @@ void MAIN {
ReduceDim::REDUCE_ROW,
cb_qk_im,
cb_identity_scale_in,
cb_cur_max,
Sq_chunk_t,
Sk_chunk_t>();
Sk_chunk_t>(alias_cur_max);

if (k_chunk > 0) {
max_block_inplace<cb_cur_max, cb_prev_max, Sq_chunk_t>();
max_block_inplace<Sq_chunk_t>(alias_cur_max, alias_prev_max);
}

/* QK -= cb_cur_max */
/* QK = exp(QK)*/
sub_exp_block_bcast_cols_inplace<cb_qk_im, cb_cur_max, Sq_chunk_t, Sk_chunk_t>();
sub_exp_block_bcast_cols_inplace<cb_qk_im, Sq_chunk_t, Sk_chunk_t>(alias_cur_max);

/* cb_cur_sum = sum(cb_qk_im, dim=-1) */
reduce_c<
PoolType::SUM,
ReduceDim::REDUCE_ROW,
cb_qk_im,
cb_identity_scale_in,
cb_cur_sum,
Sq_chunk_t,
Sk_chunk_t>();
Sk_chunk_t>(alias_cur_sum);

/* OUT_IM = QK @ V_CHUNK */
matmul_blocks(
cb_qk_im,
cb_v_in,
cb_out_im,
alias_mm2_cur_out,
Sq_chunk_t,
DHt,
Sk_chunk_t,
Expand All @@ -516,48 +524,43 @@ void MAIN {
out_subblock_w,
false /*transpose*/);

reconfig_data_format_srca(cb_out_im);
cb_pop_front(cb_qk_im, qk_chunk_tiles);
reconfig_data_format(alias_prev_max, alias_cur_max);

/* OUT_ACC += OUT_IM */
if (k_chunk == 0) {
copy_block(cb_out_im, cb_out_accumulate_im, out_chunk_tiles);
} else {
if (k_chunk > 0) {
/* cb_exp_max_diff = torch.exp(cb_prev_max - cb_cur_max) */
sub_exp_block(cb_prev_max, cb_cur_max, cb_exp_max_diff, Sq_chunk_t);
cb_pop_front(cb_prev_max, Sq_chunk_t);
sub_exp_block(alias_prev_max, alias_cur_max, cb_exp_max_diff, Sq_chunk_t);
cb_pop_front(alias_prev_max, Sq_chunk_t);

/* cb_prev_sum *= cb_exp_max_diff */
mul_block_inplace(cb_prev_sum, cb_exp_max_diff, Sq_chunk_t);

/* cb_out_accumulate_im *= cb_exp_max_diff */
mul_block_bcast_cols_inplace(cb_out_accumulate_im, cb_exp_max_diff, Sq_chunk_t, DHt);

mul_block_inplace(alias_prev_sum, cb_exp_max_diff, Sq_chunk_t);
/* cb_cur_sum += cb_prev_sum */
add_block_inplace(cb_cur_sum, cb_prev_sum, Sq_chunk_t);
add_block_inplace(alias_cur_sum, alias_prev_sum, Sq_chunk_t);

/* cb_out_accumulate_im += cb_out_im */
add_block_inplace(cb_out_accumulate_im, cb_out_im, out_chunk_tiles);
/* cb_out_accumulate_im *= cb_exp_max_diff */
mul_block_bcast_cols_inplace<Sq_chunk_t, DHt>(alias_mm2_prev_out, cb_exp_max_diff);
add_block_inplace(alias_mm2_cur_out, alias_mm2_prev_out, out_chunk_tiles);
}

// Set cb_prev_sum and cb_prev_max
copy_block(cb_cur_max, cb_prev_max, Sq_chunk_t);
copy_block(cb_cur_sum, cb_prev_sum, Sq_chunk_t);
// Swap alias_prev_sum and alias_cur_sum
std::swap(alias_prev_sum, alias_cur_sum);
std::swap(alias_mm2_prev_out, alias_mm2_cur_out);
std::swap(alias_prev_max, alias_cur_max);
}

/* cb_cur_sum = 1.0 / cb_cur_sum */
cb_push_back(cb_cur_sum, Sq_chunk_t);
recip_block_inplace(cb_cur_sum, Sq_chunk_t);
recip_block_inplace(alias_prev_sum, Sq_chunk_t);

/* cb_out_accumulate_im *= cb_cur_sum */
mul_block_bcast_cols_inplace(cb_out_accumulate_im, cb_cur_sum, Sq_chunk_t, DHt);
// NOTE: PCC bug if we modify below function to directy output to cb_out.
mul_block_bcast_cols_inplace<Sq_chunk_t, DHt>(alias_mm2_prev_out, alias_prev_sum);
pack_reconfig_data_format(cb_out);
copy_block(cb_out_accumulate_im, cb_out, out_chunk_tiles);
copy_block(alias_mm2_prev_out, cb_out, out_chunk_tiles);

cb_pop_front(cb_q_in, q_chunk_tiles);
// free up cb_prev_max after K chunks
cb_pop_front(cb_prev_max, Sq_chunk_t);
cb_pop_front(cb_prev_sum, Sq_chunk_t);
cb_pop_front(alias_prev_max, Sq_chunk_t);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,22 @@ operation::ProgramWithCallbacks sdpa_multi_core(
uint32_t num_cores = grid_size.x * grid_size.y;

TT_FATAL(
num_cores <= device->compute_with_storage_grid_size().x * device->compute_with_storage_grid_size().y, "Error");
num_cores <= device->compute_with_storage_grid_size().x * device->compute_with_storage_grid_size().y,
"Provided grid must not contain more cores than the device. Got {} cores, expected at most {} cores.",
num_cores,
device->compute_with_storage_grid_size().x * device->compute_with_storage_grid_size().y);

// Parallelization scheme
// We will choose parallelization factors for batch, num_heads, and q_seq_len in that order
uint32_t batch_parallel_factor = std::min(B, num_cores);
uint32_t nh_parallel_factor = std::min(num_cores / batch_parallel_factor, NQH);
uint32_t q_parallel_factor = std::min(num_cores / (batch_parallel_factor * nh_parallel_factor), q_num_chunks);

TT_FATAL(batch_parallel_factor * nh_parallel_factor * q_parallel_factor <= num_cores, "Error");
TT_FATAL(
batch_parallel_factor * nh_parallel_factor * q_parallel_factor <= num_cores,
"Parallelism must not exceed number of cores. Got {}, expected at most {}.",
batch_parallel_factor * nh_parallel_factor * q_parallel_factor,
num_cores);

tt::log_debug("Parallelization scheme:");
tt::log_debug("batch_parallel_factor: {}", batch_parallel_factor);
Expand Down Expand Up @@ -230,18 +237,36 @@ operation::ProgramWithCallbacks sdpa_multi_core(
// Find log2 of stats_granularity using std
const uint32_t log2_stats_granularity = std::log2(stats_granularity);
// Assert that this is a power of 2
TT_FATAL(stats_granularity == (1 << log2_stats_granularity), "Error");
TT_FATAL(
stats_granularity == (1 << log2_stats_granularity),
"stats_granularity must be a power of 2. Got {}.",
stats_granularity);

const uint32_t sub_exp_granularity = std::min(Sk_chunk_t, dst_size);
const uint32_t log2_sub_exp_granularity = std::log2(sub_exp_granularity);
TT_FATAL(sub_exp_granularity == (1 << log2_sub_exp_granularity), "Error");
TT_FATAL(
sub_exp_granularity == (1 << log2_sub_exp_granularity),
"sub_exp_granularity must be a power of 2. Got {}.",
sub_exp_granularity);

const uint32_t mul_bcast_granularity = std::min(Sq_chunk_t * Sk_chunk_t, dst_size);
const uint32_t log2_mul_bcast_granularity = std::log2(mul_bcast_granularity);
TT_FATAL(mul_bcast_granularity == (1 << log2_mul_bcast_granularity), "Error");

const uint32_t dht_granularity = std::min(DHt, dst_size);
const uint32_t log2_dht_granularity = std::log2(dht_granularity);
TT_FATAL(
mul_bcast_granularity == (1 << log2_mul_bcast_granularity),
"mul_bcast_granularity must be a power of 2. Got {}.",
mul_bcast_granularity);

uint32_t dht_granularity = std::min(DHt, dst_size);
uint32_t log2_dht_granularity = std::log2(dht_granularity);
// Sometimes DHt is not a power of 2, so granularity should be 1
if (dht_granularity != (1 << log2_dht_granularity)) {
dht_granularity = 1;
log2_dht_granularity = 0;
}
TT_FATAL(
dht_granularity == (1 << log2_dht_granularity),
"dht_granularity must be a power of 2. Got {}.",
dht_granularity);

// Log these
tt::log_debug("stats_granularity: {}", stats_granularity);
Expand Down Expand Up @@ -417,10 +442,13 @@ operation::ProgramWithCallbacks sdpa_multi_core(
.set_page_size(tt::CBIndex::c_2, v_tile_size);
auto cb_in2_id = CreateCircularBuffer(program, core_grid, c_in2_config);

// attn_mask input
auto c_in3_config = CircularBufferConfig(mask_tiles * mask_tile_size, {{tt::CBIndex::c_3, mask_df}})
.set_page_size(tt::CBIndex::c_3, mask_tile_size);
auto cb_in3_id = CreateCircularBuffer(program, core_grid, c_in3_config);
// Only create mask buffer if it's going to be used
if (use_provided_mask or is_causal) {
// attn_mask input
auto c_in3_config = CircularBufferConfig(mask_tiles * mask_tile_size, {{tt::CB::c_in3, mask_df}})
.set_page_size(tt::CB::c_in3, mask_tile_size);
auto cb_in3_id = CreateCircularBuffer(program, core_grid, c_in3_config);
}

// scale input
auto c_in4_config = CircularBufferConfig(scale_tiles * scalar_tile_size, {{tt::CBIndex::c_4, scalar_df}})
Expand Down
Loading