diff --git a/src/kernels/attention/flash_infer/attention_kernel.h b/src/kernels/attention/flash_infer/attention_kernel.h index 2a90c7da..7755906c 100644 --- a/src/kernels/attention/flash_infer/attention_kernel.h +++ b/src/kernels/attention/flash_infer/attention_kernel.h @@ -789,17 +789,11 @@ template -__device__ __forceinline__ void write_o_reg_gmem( + typename DTypeOut, + SwizzleMode swizzle_mode> +__device__ __forceinline__ void write_o_reg_smem( float (*o_frag)[num_iters_k][8], - smem_t* o_smem, - DTypeOut* o_ptr_base, - const uint32_t o_packed_idx_base, - const uint32_t qo_upper_bound, - const uint32_t o_stride_n, - const uint32_t o_stride_h, - const uint_fastdiv group_size) { + smem_t* o_smem) { constexpr uint32_t head_dim = num_iters_k * 16; constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); @@ -808,7 +802,7 @@ __device__ __forceinline__ void write_o_reg_gmem( const uint32_t lane_idx = threadIdx.x; // o_frag: [num_iters_m, num_iters_k, 8] - if (get_warp_idx_z() == 0) { + if (warp_idx_z == 0) { // write o from register to shared memory // why not every thread writes to shared memory? #pragma unroll @@ -841,6 +835,28 @@ __device__ __forceinline__ void write_o_reg_gmem( } } } +} + +template +__device__ __forceinline__ void write_o_smem_gmem( + smem_t* o_smem, + DTypeOut* o_ptr_base, + const uint32_t o_packed_idx_base, + const uint32_t qo_upper_bound, + const uint32_t o_stride_n, + const uint32_t o_stride_h, + const uint_fastdiv group_size) { + constexpr uint32_t head_dim = num_iters_k * 16; + constexpr uint32_t channel_size_128b_out = + head_dim / num_elems_per_128b(); + const uint32_t warp_idx_x = get_warp_idx_x(); + const uint32_t warp_idx_z = get_warp_idx_z(); + const uint32_t lane_idx = threadIdx.x; // write o from shared memory to global memory // o_smem: [num_warps_m, num_iters_m, 16, head_dim] @@ -1285,6 +1301,13 @@ __launch_bounds__(num_warps_m* num_warps_n* warp_size) void attention_kernel( // normalize d normalize_d(o_frag, m, d); + // write o from register to shared memory + write_o_reg_smem( + o_frag, &qo_smem); + + block.sync(); + + // write o from shared memory to global memory const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size; @@ -1306,9 +1329,7 @@ __launch_bounds__(num_warps_m* num_warps_n* warp_size) void attention_kernel( num_elems_per_128b(), num_qo_heads * head_dim, head_dim); - // write_back - write_o_reg_gmem( - o_frag, + write_o_smem_gmem( &qo_smem, o_ptr_base, qo_packed_idx_base, diff --git a/tests/kernels/attention/flash_infer_test.py b/tests/kernels/attention/flash_infer_test.py index 42150e5c..cb7a55d9 100644 --- a/tests/kernels/attention/flash_infer_test.py +++ b/tests/kernels/attention/flash_infer_test.py @@ -7,7 +7,7 @@ import scalellm._C.kernels as kernels # type: ignore -@pytest.mark.parametrize("seq_lens", [[(1, 100), (15, 15), (111, 234), (1000, 10000)]]) +@pytest.mark.parametrize("seq_lens", [[(1, 100)], [(100, 100)], [(1, 100), (15, 15), (111, 234), (1000, 10000)]]) @pytest.mark.parametrize("num_heads", [(8, 8), (8, 4), (8, 2), (8, 1)]) @pytest.mark.parametrize("head_size", [64, 128, 256]) @pytest.mark.parametrize("n_blocks", [100]) @@ -128,4 +128,5 @@ def test_flashinfer_varlen_masked_self_attention( if __name__ == "__main__": pytest.main([__file__]) - # trigger package build and test + + # test_flashinfer_varlen_masked_self_attention([(1, 100)], (8, 8), 128, torch.float16, 100, 4, 0.0, -1, False)