Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: fix decode kernels output for empty kv cache #363

Merged
merged 2 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ cudaError_t PartitionPagedKVCacheComputeAuxiliaryInfo(
for (uint32_t batch_idx = 0; batch_idx < old_batch_size; batch_idx++) {
uint32_t num_chunks =
ceil_div(old_indptr_h[batch_idx + 1] - old_indptr_h[batch_idx], max_num_pages_per_batch);
chunk_indptr_vec.push_back(chunk_indptr_vec.back() + num_chunks);
chunk_indptr_vec.push_back(chunk_indptr_vec.back() + std::max(num_chunks, 1U));
if (num_chunks == 0) {
new_page_indptr_vec.push_back(old_indptr_h[batch_idx]);
new_last_page_len_vec.push_back(0);
Expand Down
79 changes: 79 additions & 0 deletions python/tests/test_decode_prefill_lse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Copyright (c) 2024 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import flashinfer
import numpy as np
import torch
import pytest


def test_mlc_failed_case():
kv_layout = "HND"
num_pages = 12
kv_indptr_1 = torch.tensor([0, 0, 9]).int().to(0)
kv_indices_1 = torch.tensor([3, 4, 5, 6, 7, 8, 9, 10, 11]).int().to(0)
kv_last_page_len_1 = torch.tensor([0, 1]).int().to(0)
num_qo_heads = 32
num_kv_heads = 32
page_size = 16
head_dim = 128
q = torch.randn(2, num_qo_heads, head_dim).to(0).half()
kv_data = torch.randn(12, 2, num_kv_heads, page_size, head_dim).to(0).half()

workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout)
wrapper.begin_forward(
kv_indptr_1,
kv_indices_1,
kv_last_page_len_1,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
pos_encoding_mode="NONE",
data_type=torch.float16,
q_data_type=torch.float16,
)
o_1, lse_1 = wrapper.forward_return_lse(q, kv_data)

wrapper_tensor_cores = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, kv_layout, use_tensor_cores=True
)
wrapper_tensor_cores.begin_forward(
kv_indptr_1,
kv_indices_1,
kv_last_page_len_1,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
pos_encoding_mode="NONE",
data_type=torch.float16,
q_data_type=torch.float16,
)
o_1_tc, lse_1_tc = wrapper_tensor_cores.forward_return_lse(
q, kv_data
)

np.testing.assert_allclose(
lse_1.cpu().numpy(), lse_1_tc.cpu().numpy(), rtol=1e-3, atol=1e-3
)
np.testing.assert_allclose(
o_1.cpu().numpy(), o_1_tc.cpu().numpy(), rtol=1e-3, atol=1e-3
)

if __name__ == "__main__":
test_mlc_failed_case()
8 changes: 4 additions & 4 deletions python/tests/test_tensor_cores_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_batch_decode_tensor_cores(
num_kv_heads,
head_dim,
page_size,
"NONE",
pos_encoding_mode=pos_encoding_mode,
data_type=torch.float16,
q_data_type=torch.float16,
)
Expand All @@ -121,7 +121,7 @@ def test_batch_decode_tensor_cores(
num_kv_heads,
head_dim,
page_size,
"NONE",
pos_encoding_mode=pos_encoding_mode,
data_type=torch.float16,
q_data_type=torch.float16,
)
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_batch_decode_tensor_cores_cuda_graph(
num_kv_heads,
head_dim,
page_size,
"NONE",
pos_encoding_mode=pos_encoding_mode,
data_type=torch.float16,
q_data_type=torch.float16,
)
Expand Down Expand Up @@ -226,7 +226,7 @@ def test_batch_decode_tensor_cores_cuda_graph(
num_kv_heads,
head_dim,
page_size,
"NONE",
pos_encoding_mode=pos_encoding_mode,
data_type=torch.float16,
q_data_type=torch.float16,
)
Expand Down