Skip to content

Commit

Permalink
add sync before shared memory access
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Oct 17, 2024
1 parent cfa76d4 commit 6b7d3d1
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 16 deletions.
49 changes: 35 additions & 14 deletions src/kernels/attention/flash_infer/attention_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -789,17 +789,11 @@ template <uint32_t num_warps_m,
uint32_t num_warps_n,
uint32_t num_iters_m,
uint32_t num_iters_k,
SwizzleMode swizzle_mode,
typename DTypeOut>
__device__ __forceinline__ void write_o_reg_gmem(
typename DTypeOut,
SwizzleMode swizzle_mode>
__device__ __forceinline__ void write_o_reg_smem(
float (*o_frag)[num_iters_k][8],
smem_t<swizzle_mode>* o_smem,
DTypeOut* o_ptr_base,
const uint32_t o_packed_idx_base,
const uint32_t qo_upper_bound,
const uint32_t o_stride_n,
const uint32_t o_stride_h,
const uint_fastdiv group_size) {
smem_t<swizzle_mode>* o_smem) {
constexpr uint32_t head_dim = num_iters_k * 16;
constexpr uint32_t channel_size_128b_out =
head_dim / num_elems_per_128b<DTypeOut>();
Expand All @@ -808,7 +802,7 @@ __device__ __forceinline__ void write_o_reg_gmem(
const uint32_t lane_idx = threadIdx.x;

// o_frag: [num_iters_m, num_iters_k, 8]
if (get_warp_idx_z<num_warps_m, num_warps_n>() == 0) {
if (warp_idx_z == 0) {
// write o from register to shared memory
// why not every thread writes to shared memory?
#pragma unroll
Expand Down Expand Up @@ -841,6 +835,28 @@ __device__ __forceinline__ void write_o_reg_gmem(
}
}
}
}

template <uint32_t num_warps_m,
uint32_t num_warps_n,
uint32_t num_iters_m,
uint32_t num_iters_k,
SwizzleMode swizzle_mode,
typename DTypeOut>
__device__ __forceinline__ void write_o_smem_gmem(
smem_t<swizzle_mode>* o_smem,
DTypeOut* o_ptr_base,
const uint32_t o_packed_idx_base,
const uint32_t qo_upper_bound,
const uint32_t o_stride_n,
const uint32_t o_stride_h,
const uint_fastdiv group_size) {
constexpr uint32_t head_dim = num_iters_k * 16;
constexpr uint32_t channel_size_128b_out =
head_dim / num_elems_per_128b<DTypeOut>();
const uint32_t warp_idx_x = get_warp_idx_x<num_warps_m, num_warps_n>();
const uint32_t warp_idx_z = get_warp_idx_z<num_warps_m, num_warps_n>();
const uint32_t lane_idx = threadIdx.x;

// write o from shared memory to global memory
// o_smem: [num_warps_m, num_iters_m, 16, head_dim]
Expand Down Expand Up @@ -1285,6 +1301,13 @@ __launch_bounds__(num_warps_m* num_warps_n* warp_size) void attention_kernel(
// normalize d
normalize_d<num_iters_m, num_iters_k>(o_frag, m, d);

// write o from register to shared memory
write_o_reg_smem<num_warps_m, num_warps_n, num_iters_m, num_iters_k, DTypeOut>(
o_frag, &qo_smem);

block.sync();

// write o from shared memory to global memory
const uint32_t num_kv_chunks =
(kv_len_safe + kv_chunk_size - 1) / kv_chunk_size;

Expand All @@ -1306,9 +1329,7 @@ __launch_bounds__(num_warps_m* num_warps_n* warp_size) void attention_kernel(
num_elems_per_128b<DTypeOut>(),
num_qo_heads * head_dim,
head_dim);
// write_back
write_o_reg_gmem<num_warps_m, num_warps_n, num_iters_m, num_iters_k>(
o_frag,
write_o_smem_gmem<num_warps_m, num_warps_n, num_iters_m, num_iters_k>(
&qo_smem,
o_ptr_base,
qo_packed_idx_base,
Expand Down
5 changes: 3 additions & 2 deletions tests/kernels/attention/flash_infer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import scalellm._C.kernels as kernels # type: ignore


@pytest.mark.parametrize("seq_lens", [[(1, 100), (15, 15), (111, 234), (1000, 10000)]])
@pytest.mark.parametrize("seq_lens", [[(1, 100)], [(100, 100)], [(1, 100), (15, 15), (111, 234), (1000, 10000)]])
@pytest.mark.parametrize("num_heads", [(8, 8), (8, 4), (8, 2), (8, 1)])
@pytest.mark.parametrize("head_size", [64, 128, 256])
@pytest.mark.parametrize("n_blocks", [100])
Expand Down Expand Up @@ -128,4 +128,5 @@ def test_flashinfer_varlen_masked_self_attention(

if __name__ == "__main__":
pytest.main([__file__])
# trigger package build and test

# test_flashinfer_varlen_masked_self_attention([(1, 100)], (8, 8), 128, torch.float16, 100, 4, 0.0, -1, False)

0 comments on commit 6b7d3d1

Please sign in to comment.