Skip to content

Commit

Permalink
use all warps to write o
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Oct 17, 2024
1 parent 01a884a commit cfa76d4
Showing 1 changed file with 43 additions and 36 deletions.
79 changes: 43 additions & 36 deletions src/kernels/attention/flash_infer/attention_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ __device__ __forceinline__ void load_q_global_smem(
// | t24 | t25 | t26 | t27 | t28 | t29 | t30 | t31 |
//

// q_smem: [num_iters_m, num_warps_m, 16, head_dim]
// q_smem: [num_warps_m, num_iters_m, 16, head_dim]
uint32_t q_smem_x = warp_idx_x * num_iters_m * 16 + lane_idx / 8;

#pragma unroll
Expand All @@ -212,9 +212,9 @@ __device__ __forceinline__ void load_q_global_smem(
packed_offset + fx * 16 + lane_idx / 8 + j * 4;
uint32_t q, r;
group_size.divmod(packed_q_idx, q, r);
// q_idx = packed_q_idx / group_size
// h_idx = packed_q_idx % group_size
const uint32_t q_idx = q;
if (q >= qo_upper_bound) {
continue;
}
// q_ptr_base: [n_tokens, n_heads, head_dim]
// q_ptr for given header: [head_dim]
DTypeQ* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h +
Expand All @@ -223,24 +223,18 @@ __device__ __forceinline__ void load_q_global_smem(
// load head_dim from global memory to shared memory using num_warps_n
// warps echa warp loads 8 columns once
uint32_t q_smem_y = warp_idx_z * 8 + lane_idx % 8;
#pragma unroll
// for (uint32_t fyo = 0; fyo < num_iters_k / 4 / num_warps_n ; ++fyo) {
while(q_smem_y * num_elems_per_128b<DTypeQ>() < head_dim) {
while (q_smem_y * num_elems_per_128b<DTypeQ>() < head_dim) {
const uint32_t q_smem_offset_w =
q_smem->template get_permuted_offset<channel_size_128b_q>(q_smem_x,
q_smem_y);

// load q fragment from gmem to smem
q_smem->load_128b_async<SharedMemFillMode::kNoFill>(
q_smem_offset_w, q_ptr, q_idx < qo_upper_bound);
// move ahead by 8 int128_t for each warp
q_smem->load_128b_async(q_smem_offset_w, q_ptr);
// move ahead by 8*int128_t for each warp
q_smem_y += (8 * num_warps_n);

// move ahead by 8 * 16 bytes
q_ptr += (8 * num_elems_per_128b<DTypeQ>() * num_warps_n);
}

// adjust offset for next iteration
// move ahead by 4 rows
q_smem_x += 4;
}
Expand Down Expand Up @@ -810,9 +804,13 @@ __device__ __forceinline__ void write_o_reg_gmem(
constexpr uint32_t channel_size_128b_out =
head_dim / num_elems_per_128b<DTypeOut>();
const uint32_t warp_idx_x = get_warp_idx_x<num_warps_m, num_warps_n>();
const uint32_t warp_idx_z = get_warp_idx_z<num_warps_m, num_warps_n>();
const uint32_t lane_idx = threadIdx.x;

// o_frag: [num_iters_m, num_iters_k, 8]
if (get_warp_idx_z<num_warps_m, num_warps_n>() == 0) {
// write o from register to shared memory
// why not every thread writes to shared memory?
#pragma unroll
for (uint32_t fx = 0; fx < num_iters_m; ++fx) {
#pragma unroll
Expand All @@ -832,6 +830,7 @@ __device__ __forceinline__ void write_o_reg_gmem(
(warp_idx_x * num_iters_m + fx) * 16 + lane_idx / 4, fy * 2);
((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] =
o_frag_f16[0];
// TODO: avoid manipulating permuted offset directly
((uint32_t*)(o_smem->base + o_smem_offset_w +
8 * channel_size_128b_out))[lane_idx % 4] = o_frag_f16[1];
((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] =
Expand All @@ -841,34 +840,42 @@ __device__ __forceinline__ void write_o_reg_gmem(
#endif
}
}
}

uint32_t o_smem_offset_w =
o_smem->get_permuted_offset<channel_size_128b_out>(
warp_idx_x * num_iters_m * 16 + lane_idx / 8, lane_idx % 8);
// write o from shared memory to global memory
// o_smem: [num_warps_m, num_iters_m, 16, head_dim]
uint32_t o_smem_x = warp_idx_x * num_iters_m * 16 + lane_idx / 8;

#pragma unroll
for (uint32_t fx = 0; fx < num_iters_m; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 4; ++j) {
uint32_t q, r;
group_size.divmod(
o_packed_idx_base + lane_idx / 8 + fx * 16 + j * 4, q, r);
const uint32_t o_idx = q;
DTypeOut* o_ptr = o_ptr_base + q * o_stride_n + r * o_stride_h;
for (uint32_t fx = 0; fx < num_iters_m; ++fx) {
// each wrap writes 4 rows, 16 rows needs 4(16/4) iters
#pragma unroll
for (uint32_t fyo = 0; fyo < num_iters_k / 4; ++fyo) {
if (o_idx < qo_upper_bound) {
o_smem->store_128b(o_smem_offset_w, o_ptr);
}
o_ptr += 8 * num_elems_per_128b<DTypeOut>();
o_smem_offset_w = o_smem->template advance_offset_by_column<8>(
o_smem_offset_w, fyo);
}
o_smem_offset_w =
o_smem->template advance_offset_by_row<4, channel_size_128b_out>(
o_smem_offset_w) -
2 * num_iters_k;
for (uint32_t j = 0; j < 4; ++j) {
const uint32_t packed_o_idx =
o_packed_idx_base + lane_idx / 8 + fx * 16 + j * 4;
uint32_t q, r;
group_size.divmod(packed_o_idx, q, r);
// skip if out of boundary
if (q >= qo_upper_bound) {
continue;
}

DTypeOut* o_ptr = o_ptr_base + q * o_stride_n + r * o_stride_h +
warp_idx_z * 8 * num_elems_per_128b<DTypeOut>();
uint32_t o_smem_y = warp_idx_z * 8 + lane_idx % 8;
// write head_dim from shared memory to global memory
while (o_smem_y * num_elems_per_128b<DTypeOut>() < head_dim) {
const uint32_t o_smem_offset_w =
o_smem->template get_permuted_offset<channel_size_128b_out>(
o_smem_x, o_smem_y);
o_smem->store_128b(o_smem_offset_w, o_ptr);

// move ahead by 8 * int128_t for each warp
o_smem_y += (8 * num_warps_n);
o_ptr += 8 * num_elems_per_128b<DTypeOut>() * num_warps_n;
}
// move row by 4
o_smem_x += 4;
}
}
}
Expand Down

0 comments on commit cfa76d4

Please sign in to comment.