Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add mask to merge_state_in_place #372

Merged
merged 1 commit into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions include/flashinfer/attention/cascade.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,22 @@ __global__ void MergeStateKernel(DTypeIn* __restrict__ v_a, float* __restrict__
* \param s The logsumexp value to be updated in-place. (n, h)
* \param v_other The other v to be merged. (n, h, d)
* \param s_other The other logsumexp value to be merged. (n, h)
* \param mask Optional mask of whether to merge given sequences or not. (n)
* \param num_heads The number of heads of v and v_other.
* \param head_dim The dimension of each head.
* \note Both s and s_other are logsumexp values with base 2.
*/
template <uint32_t vec_size, typename DType>
__global__ void MergeStateInPlaceKernel(DType* __restrict__ v, float* __restrict__ s,
DType* __restrict__ v_other, float* __restrict__ s_other,
uint8_t* __restrict__ mask,
uint32_t num_heads, uint32_t head_dim) {
uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t pos = blockIdx.x;

if (mask != nullptr && mask[pos] == 0)
return;

uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t head_idx = ty;

float s_val = s[pos * num_heads + head_idx];
Expand Down Expand Up @@ -383,13 +389,15 @@ cudaError_t MergeState(DTypeIn* v_a, float* s_a, DTypeIn* v_b, float* s_b, DType
* \param seq_len The sequence length.
* \param num_heads The number of heads of v and v_other.
* \param head_dim The dimension of each head.
* \param mask Optional mask of whether to merge given sequences or not. (n)
* \param stream The CUDA stream to execute the kernel.
* \return status Indicates whether CUDA calls are successful
* \note Both s and s_other are logsumexp values with base 2.
*/
template <typename DType>
cudaError_t MergeStateInPlace(DType* v, float* s, DType* v_other, float* s_other, uint32_t seq_len,
uint32_t num_heads, uint32_t head_dim,
uint8_t* mask = nullptr,
cudaStream_t stream = nullptr) {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16U / sizeof(DType), HEAD_DIM / 32U);
Expand All @@ -398,7 +406,7 @@ cudaError_t MergeStateInPlace(DType* v, float* s, DType* v_other, float* s_other
dim3 nblks(seq_len);
dim3 nthrs(bdx, bdy);
auto kernel = MergeStateInPlaceKernel<vec_size, DType>;
void* args[] = {&v, &s, &v_other, &s_other, &num_heads, &head_dim};
void* args[] = {&v, &s, &v_other, &s_other, &mask, &num_heads, &head_dim};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
});
return cudaSuccess;
Expand Down
11 changes: 9 additions & 2 deletions python/csrc/cascade.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ std::vector<torch::Tensor> merge_state(torch::Tensor v_a, torch::Tensor s_a, tor
}

void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_other,
torch::Tensor s_other) {
torch::Tensor s_other, std::optional<torch::Tensor> mask) {
CHECK_INPUT(v);
CHECK_INPUT(s);
CHECK_INPUT(v_other);
Expand All @@ -82,6 +82,13 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe
CHECK_EQ(v.size(1), s.size(1));
CHECK_EQ(s.scalar_type(), torch::kFloat32);
CHECK_EQ(s_other.scalar_type(), torch::kFloat32);
uint8_t* mask_ptr = nullptr;
if (mask.has_value()) {
CHECK_DIM(1, mask.value());
CHECK_EQ(v.size(0), mask.value().size(0));
CHECK_EQ(mask.value().device(), device);
mask_ptr = static_cast<uint8_t*>(mask.value().data_ptr());
}
unsigned int seq_len = v.size(0);
unsigned int num_heads = v.size(1);
unsigned int head_dim = v.size(2);
Expand All @@ -91,7 +98,7 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe
cudaError_t status = MergeStateInPlace(
static_cast<c_type*>(v.data_ptr()), static_cast<float*>(s.data_ptr()),
static_cast<c_type*>(v_other.data_ptr()), static_cast<float*>(s_other.data_ptr()), seq_len,
num_heads, head_dim, torch_current_stream);
num_heads, head_dim, mask_ptr, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"MergeStateInPlace kernel launch failed: ", cudaGetErrorString(status));
return true;
Expand Down
2 changes: 1 addition & 1 deletion python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ std::vector<torch::Tensor> merge_state(torch::Tensor v_a, torch::Tensor s_a, tor
torch::Tensor s_b);

void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_other,
torch::Tensor s_other);
torch::Tensor s_other, std::optional<torch::Tensor> mask = std::nullopt);

std::vector<torch::Tensor> merge_states(torch::Tensor v, torch::Tensor s);

Expand Down
13 changes: 11 additions & 2 deletions python/flashinfer/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ def merge_state(


def merge_state_in_place(
v: torch.Tensor, s: torch.Tensor, v_other: torch.Tensor, s_other: torch.Tensor
v: torch.Tensor,
s: torch.Tensor,
v_other: torch.Tensor,
s_other: torch.Tensor,
mask: Optional[torch.Tensor] = None,
):
r"""Merge the self-attention state ``(v, s)`` with another state
``(v_other, s_other)`` in-place.
Expand All @@ -117,6 +121,11 @@ def merge_state_in_place(
s_other : torch.Tensor
The other logsumexp value to be merged, expected to be a float32 tensor,
shape: ``(seq_len, num_heads)``.
mask : Optional[torch.Tensor]
The boolean mask tensor for whether to merge the state for a corresponding sequence
or not. Useful for CUDA graphs. If not specified (default), will merge states for
all sequences.
shape: ``[seq_len]``

Example
-------
Expand All @@ -131,7 +140,7 @@ def merge_state_in_place(
>>> s_other = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
>>> flashinfer.merge_state_in_place(v, s, v_other, s_other)
"""
_kernels.merge_state_in_place(v, s, v_other, s_other)
_kernels.merge_state_in_place(v, s, v_other, s_other, mask)


def merge_states(v: torch.Tensor, s: torch.Tensor):
Expand Down
57 changes: 57 additions & 0 deletions python/tests/test_shared_prefix_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,63 @@ def test_batch_attention_with_shared_prefix_paged_kv_cache(
o_baseline.cpu().numpy(), o_cascade.cpu().numpy(), rtol=1e-3, atol=1e-3
)

@pytest.mark.parametrize("seed", [0])
@pytest.mark.parametrize("num_tries", [50])
def test_merge_state_in_place_with_mask(seed, num_tries):
seq_len = 512
num_heads = 32
head_dim = 128
va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
va_orginal = va.clone()
sa_original = sa.clone()

# No mask.
flashinfer.merge_state_in_place(va, sa, vb, sb)
va_merged_ref = va.clone()
sa_merged_ref = sa.clone()
assert not torch.allclose(va_merged_ref, va_orginal)
assert not torch.allclose(sa_merged_ref, sa_original)

# Mask with all 1s. Should be identical to no mask.
mask = torch.ones(seq_len, dtype=torch.bool).to("cuda:0")
va = va_orginal.clone()
sa = sa_original.clone()
flashinfer.merge_state_in_place(va, sa, vb, sb, mask=mask)
va_merged = va
sa_merged = sa
numpy.testing.assert_allclose(va_merged.cpu().numpy(), va_merged_ref.cpu().numpy(), rtol=1e-3, atol=1e-3)
numpy.testing.assert_allclose(sa_merged.cpu().numpy(), sa_merged_ref.cpu().numpy(), rtol=1e-3, atol=1e-3)

# Mask with all zeros. Input and output should be identical.
mask = torch.zeros(seq_len, dtype=torch.bool).to("cuda:0")
va = va_orginal.clone()
sa = sa_original.clone()
flashinfer.merge_state_in_place(va, sa, vb, sb, mask=mask)
va_merged = va
sa_merged = sa
numpy.testing.assert_allclose(va_merged.cpu().numpy(), va_orginal.cpu().numpy(), rtol=1e-3, atol=1e-3)
numpy.testing.assert_allclose(sa_merged.cpu().numpy(), sa_original.cpu().numpy(), rtol=1e-3, atol=1e-3)

# Test some random masks.
randgen = torch.Generator(device="cuda:0")
randgen.manual_seed(seed)
for _ in range(num_tries):
rand_mask = (torch.rand(seq_len, generator=randgen, dtype=torch.float32, device="cuda:0") > 0.5).to(dtype=torch.bool)
true_indices = rand_mask.nonzero()
false_indices = (rand_mask==0).nonzero()
va = va_orginal.clone()
sa = sa_original.clone()
flashinfer.merge_state_in_place(va, sa, vb, sb, mask=rand_mask)
va_merged = va
sa_merged = sa

numpy.testing.assert_allclose(va_merged[false_indices].cpu().numpy(), va_orginal[false_indices].cpu().numpy(), rtol=1e-3, atol=1e-3)
numpy.testing.assert_allclose(sa_merged[false_indices].cpu().numpy(), sa_original[false_indices].cpu().numpy(), rtol=1e-3, atol=1e-3)
numpy.testing.assert_allclose(va_merged[true_indices].cpu().numpy(), va_merged_ref[true_indices].cpu().numpy(), rtol=1e-3, atol=1e-3)
numpy.testing.assert_allclose(sa_merged[true_indices].cpu().numpy(), sa_merged_ref[true_indices].cpu().numpy(), rtol=1e-3, atol=1e-3)

if __name__ == "__main__":
test_batch_attention_with_shared_prefix_paged_kv_cache(
Expand Down