From 96c3bbbb83a5d6566a27b0cbe4544859edffc6c9 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Sat, 13 Jul 2024 01:24:59 +0000 Subject: [PATCH] Add mask to merge_state_in_place --- include/flashinfer/attention/cascade.cuh | 12 ++++- python/csrc/cascade.cu | 11 ++++- python/csrc/flashinfer_ops.h | 2 +- python/flashinfer/cascade.py | 13 ++++- python/tests/test_shared_prefix_kernels.py | 57 ++++++++++++++++++++++ 5 files changed, 88 insertions(+), 7 deletions(-) diff --git a/include/flashinfer/attention/cascade.cuh b/include/flashinfer/attention/cascade.cuh index 474f97a0..af96129a 100644 --- a/include/flashinfer/attention/cascade.cuh +++ b/include/flashinfer/attention/cascade.cuh @@ -81,6 +81,7 @@ __global__ void MergeStateKernel(DTypeIn* __restrict__ v_a, float* __restrict__ * \param s The logsumexp value to be updated in-place. (n, h) * \param v_other The other v to be merged. (n, h, d) * \param s_other The other logsumexp value to be merged. (n, h) + * \param mask Optional mask of whether to merge given sequences or not. (n) * \param num_heads The number of heads of v and v_other. * \param head_dim The dimension of each head. * \note Both s and s_other are logsumexp values with base 2. @@ -88,9 +89,14 @@ __global__ void MergeStateKernel(DTypeIn* __restrict__ v_a, float* __restrict__ template __global__ void MergeStateInPlaceKernel(DType* __restrict__ v, float* __restrict__ s, DType* __restrict__ v_other, float* __restrict__ s_other, + uint8_t* __restrict__ mask, uint32_t num_heads, uint32_t head_dim) { - uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t pos = blockIdx.x; + + if (mask != nullptr && mask[pos] == 0) + return; + + uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t head_idx = ty; float s_val = s[pos * num_heads + head_idx]; @@ -383,6 +389,7 @@ cudaError_t MergeState(DTypeIn* v_a, float* s_a, DTypeIn* v_b, float* s_b, DType * \param seq_len The sequence length. * \param num_heads The number of heads of v and v_other. * \param head_dim The dimension of each head. + * \param mask Optional mask of whether to merge given sequences or not. (n) * \param stream The CUDA stream to execute the kernel. * \return status Indicates whether CUDA calls are successful * \note Both s and s_other are logsumexp values with base 2. @@ -390,6 +397,7 @@ cudaError_t MergeState(DTypeIn* v_a, float* s_a, DTypeIn* v_b, float* s_b, DType template cudaError_t MergeStateInPlace(DType* v, float* s, DType* v_other, float* s_other, uint32_t seq_len, uint32_t num_heads, uint32_t head_dim, + uint8_t* mask = nullptr, cudaStream_t stream = nullptr) { DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { constexpr uint32_t vec_size = std::max(16U / sizeof(DType), HEAD_DIM / 32U); @@ -398,7 +406,7 @@ cudaError_t MergeStateInPlace(DType* v, float* s, DType* v_other, float* s_other dim3 nblks(seq_len); dim3 nthrs(bdx, bdy); auto kernel = MergeStateInPlaceKernel; - void* args[] = {&v, &s, &v_other, &s_other, &num_heads, &head_dim}; + void* args[] = {&v, &s, &v_other, &s_other, &mask, &num_heads, &head_dim}; FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); }); return cudaSuccess; diff --git a/python/csrc/cascade.cu b/python/csrc/cascade.cu index ec2e2caa..e5198f89 100644 --- a/python/csrc/cascade.cu +++ b/python/csrc/cascade.cu @@ -63,7 +63,7 @@ std::vector merge_state(torch::Tensor v_a, torch::Tensor s_a, tor } void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_other, - torch::Tensor s_other) { + torch::Tensor s_other, std::optional mask) { CHECK_INPUT(v); CHECK_INPUT(s); CHECK_INPUT(v_other); @@ -82,6 +82,13 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe CHECK_EQ(v.size(1), s.size(1)); CHECK_EQ(s.scalar_type(), torch::kFloat32); CHECK_EQ(s_other.scalar_type(), torch::kFloat32); + uint8_t* mask_ptr = nullptr; + if (mask.has_value()) { + CHECK_DIM(1, mask.value()); + CHECK_EQ(v.size(0), mask.value().size(0)); + CHECK_EQ(mask.value().device(), device); + mask_ptr = static_cast(mask.value().data_ptr()); + } unsigned int seq_len = v.size(0); unsigned int num_heads = v.size(1); unsigned int head_dim = v.size(2); @@ -91,7 +98,7 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe cudaError_t status = MergeStateInPlace( static_cast(v.data_ptr()), static_cast(s.data_ptr()), static_cast(v_other.data_ptr()), static_cast(s_other.data_ptr()), seq_len, - num_heads, head_dim, torch_current_stream); + num_heads, head_dim, mask_ptr, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "MergeStateInPlace kernel launch failed: ", cudaGetErrorString(status)); return true; diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index ca4971ee..3b45a05d 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -46,7 +46,7 @@ std::vector merge_state(torch::Tensor v_a, torch::Tensor s_a, tor torch::Tensor s_b); void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_other, - torch::Tensor s_other); + torch::Tensor s_other, std::optional mask = std::nullopt); std::vector merge_states(torch::Tensor v, torch::Tensor s); diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index f26913a5..55feadba 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -98,7 +98,11 @@ def merge_state( def merge_state_in_place( - v: torch.Tensor, s: torch.Tensor, v_other: torch.Tensor, s_other: torch.Tensor + v: torch.Tensor, + s: torch.Tensor, + v_other: torch.Tensor, + s_other: torch.Tensor, + mask: Optional[torch.Tensor] = None, ): r"""Merge the self-attention state ``(v, s)`` with another state ``(v_other, s_other)`` in-place. @@ -117,6 +121,11 @@ def merge_state_in_place( s_other : torch.Tensor The other logsumexp value to be merged, expected to be a float32 tensor, shape: ``(seq_len, num_heads)``. + mask : Optional[torch.Tensor] + The boolean mask tensor for whether to merge the state for a corresponding sequence + or not. Useful for CUDA graphs. If not specified (default), will merge states for + all sequences. + shape: ``[seq_len]`` Example ------- @@ -131,7 +140,7 @@ def merge_state_in_place( >>> s_other = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") >>> flashinfer.merge_state_in_place(v, s, v_other, s_other) """ - _kernels.merge_state_in_place(v, s, v_other, s_other) + _kernels.merge_state_in_place(v, s, v_other, s_other, mask) def merge_states(v: torch.Tensor, s: torch.Tensor): diff --git a/python/tests/test_shared_prefix_kernels.py b/python/tests/test_shared_prefix_kernels.py index 3911f696..036df9b8 100644 --- a/python/tests/test_shared_prefix_kernels.py +++ b/python/tests/test_shared_prefix_kernels.py @@ -199,6 +199,63 @@ def test_batch_attention_with_shared_prefix_paged_kv_cache( o_baseline.cpu().numpy(), o_cascade.cpu().numpy(), rtol=1e-3, atol=1e-3 ) +@pytest.mark.parametrize("seed", [0]) +@pytest.mark.parametrize("num_tries", [50]) +def test_merge_state_in_place_with_mask(seed, num_tries): + seq_len = 512 + num_heads = 32 + head_dim = 128 + va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0") + sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") + vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0") + sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") + va_orginal = va.clone() + sa_original = sa.clone() + + # No mask. + flashinfer.merge_state_in_place(va, sa, vb, sb) + va_merged_ref = va.clone() + sa_merged_ref = sa.clone() + assert not torch.allclose(va_merged_ref, va_orginal) + assert not torch.allclose(sa_merged_ref, sa_original) + + # Mask with all 1s. Should be identical to no mask. + mask = torch.ones(seq_len, dtype=torch.bool).to("cuda:0") + va = va_orginal.clone() + sa = sa_original.clone() + flashinfer.merge_state_in_place(va, sa, vb, sb, mask=mask) + va_merged = va + sa_merged = sa + numpy.testing.assert_allclose(va_merged.cpu().numpy(), va_merged_ref.cpu().numpy(), rtol=1e-3, atol=1e-3) + numpy.testing.assert_allclose(sa_merged.cpu().numpy(), sa_merged_ref.cpu().numpy(), rtol=1e-3, atol=1e-3) + + # Mask with all zeros. Input and output should be identical. + mask = torch.zeros(seq_len, dtype=torch.bool).to("cuda:0") + va = va_orginal.clone() + sa = sa_original.clone() + flashinfer.merge_state_in_place(va, sa, vb, sb, mask=mask) + va_merged = va + sa_merged = sa + numpy.testing.assert_allclose(va_merged.cpu().numpy(), va_orginal.cpu().numpy(), rtol=1e-3, atol=1e-3) + numpy.testing.assert_allclose(sa_merged.cpu().numpy(), sa_original.cpu().numpy(), rtol=1e-3, atol=1e-3) + + # Test some random masks. + randgen = torch.Generator(device="cuda:0") + randgen.manual_seed(seed) + for _ in range(num_tries): + rand_mask = (torch.rand(seq_len, generator=randgen, dtype=torch.float32, device="cuda:0") > 0.5).to(dtype=torch.bool) + true_indices = rand_mask.nonzero() + false_indices = (rand_mask==0).nonzero() + va = va_orginal.clone() + sa = sa_original.clone() + flashinfer.merge_state_in_place(va, sa, vb, sb, mask=rand_mask) + va_merged = va + sa_merged = sa + + numpy.testing.assert_allclose(va_merged[false_indices].cpu().numpy(), va_orginal[false_indices].cpu().numpy(), rtol=1e-3, atol=1e-3) + numpy.testing.assert_allclose(sa_merged[false_indices].cpu().numpy(), sa_original[false_indices].cpu().numpy(), rtol=1e-3, atol=1e-3) + numpy.testing.assert_allclose(va_merged[true_indices].cpu().numpy(), va_merged_ref[true_indices].cpu().numpy(), rtol=1e-3, atol=1e-3) + numpy.testing.assert_allclose(sa_merged[true_indices].cpu().numpy(), sa_merged_ref[true_indices].cpu().numpy(), rtol=1e-3, atol=1e-3) if __name__ == "__main__": test_batch_attention_with_shared_prefix_paged_kv_cache(