From e36084d377c07c907c1751f974fcf4782e794faf Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 10 Jul 2024 06:45:31 +0000 Subject: [PATCH 1/2] upd --- python/tests/test_decode_prefill_lse.py | 76 ++++++++++++++++++++++++ python/tests/test_tensor_cores_decode.py | 8 +-- 2 files changed, 80 insertions(+), 4 deletions(-) create mode 100644 python/tests/test_decode_prefill_lse.py diff --git a/python/tests/test_decode_prefill_lse.py b/python/tests/test_decode_prefill_lse.py new file mode 100644 index 00000000..99783fe7 --- /dev/null +++ b/python/tests/test_decode_prefill_lse.py @@ -0,0 +1,76 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import flashinfer +import numpy as np +import torch +import pytest + + +def test_mlc_failed_case(): + kv_layout = "HND" + num_pages = 12 + kv_indptr_1 = torch.tensor([0, 0, 9]).int().to(0) + kv_indices_1 = torch.tensor([3, 4, 5, 6, 7, 8, 9, 10, 11]).int().to(0) + kv_last_page_len_1 = torch.tensor([0, 1]).int().to(0) + num_qo_heads = 32 + num_kv_heads = 32 + page_size = 16 + head_dim = 128 + q = torch.randn(2, num_qo_heads, head_dim).to(0).half() + kv_data = torch.randn(12, 2, num_kv_heads, page_size, head_dim).to(0).half() + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout) + wrapper.begin_forward( + kv_indptr_1, + kv_indices_1, + kv_last_page_len_1, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode="NONE", + data_type=torch.float16, + q_data_type=torch.float16, + ) + o_1, lse_1 = wrapper.forward_return_lse(q, kv_data) + + wrapper_tensor_cores = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, kv_layout, use_tensor_cores=True + ) + wrapper_tensor_cores.begin_forward( + kv_indptr_1, + kv_indices_1, + kv_last_page_len_1, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode="NONE", + data_type=torch.float16, + q_data_type=torch.float16, + ) + o_1_tc, lse_1_tc = wrapper_tensor_cores.forward_return_lse(q, kv_data) + + np.testing.assert_allclose( + lse_1.cpu().numpy(), lse_1_tc.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + print(lse_1[1], lse_1_tc[1]) + + +if __name__ == "__main__": + test_mlc_failed_case() diff --git a/python/tests/test_tensor_cores_decode.py b/python/tests/test_tensor_cores_decode.py index 4cac6be7..b49c522d 100644 --- a/python/tests/test_tensor_cores_decode.py +++ b/python/tests/test_tensor_cores_decode.py @@ -104,7 +104,7 @@ def test_batch_decode_tensor_cores( num_kv_heads, head_dim, page_size, - "NONE", + pos_encoding_mode=pos_encoding_mode, data_type=torch.float16, q_data_type=torch.float16, ) @@ -121,7 +121,7 @@ def test_batch_decode_tensor_cores( num_kv_heads, head_dim, page_size, - "NONE", + pos_encoding_mode=pos_encoding_mode, data_type=torch.float16, q_data_type=torch.float16, ) @@ -187,7 +187,7 @@ def test_batch_decode_tensor_cores_cuda_graph( num_kv_heads, head_dim, page_size, - "NONE", + pos_encoding_mode=pos_encoding_mode, data_type=torch.float16, q_data_type=torch.float16, ) @@ -226,7 +226,7 @@ def test_batch_decode_tensor_cores_cuda_graph( num_kv_heads, head_dim, page_size, - "NONE", + pos_encoding_mode=pos_encoding_mode, data_type=torch.float16, q_data_type=torch.float16, ) From 585c136f973631248710f99a9f21e0ac248cd3fb Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 10 Jul 2024 07:01:17 +0000 Subject: [PATCH 2/2] upd --- include/flashinfer/attention/handler.cuh | 2 +- python/tests/test_decode_prefill_lse.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 2a3d4495..632ccb85 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -238,7 +238,7 @@ cudaError_t PartitionPagedKVCacheComputeAuxiliaryInfo( for (uint32_t batch_idx = 0; batch_idx < old_batch_size; batch_idx++) { uint32_t num_chunks = ceil_div(old_indptr_h[batch_idx + 1] - old_indptr_h[batch_idx], max_num_pages_per_batch); - chunk_indptr_vec.push_back(chunk_indptr_vec.back() + num_chunks); + chunk_indptr_vec.push_back(chunk_indptr_vec.back() + std::max(num_chunks, 1U)); if (num_chunks == 0) { new_page_indptr_vec.push_back(old_indptr_h[batch_idx]); new_last_page_len_vec.push_back(0); diff --git a/python/tests/test_decode_prefill_lse.py b/python/tests/test_decode_prefill_lse.py index 99783fe7..f79bdac8 100644 --- a/python/tests/test_decode_prefill_lse.py +++ b/python/tests/test_decode_prefill_lse.py @@ -64,13 +64,16 @@ def test_mlc_failed_case(): data_type=torch.float16, q_data_type=torch.float16, ) - o_1_tc, lse_1_tc = wrapper_tensor_cores.forward_return_lse(q, kv_data) + o_1_tc, lse_1_tc = wrapper_tensor_cores.forward_return_lse( + q, kv_data + ) np.testing.assert_allclose( lse_1.cpu().numpy(), lse_1_tc.cpu().numpy(), rtol=1e-3, atol=1e-3 ) - print(lse_1[1], lse_1_tc[1]) - + np.testing.assert_allclose( + o_1.cpu().numpy(), o_1_tc.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) if __name__ == "__main__": test_mlc_failed_case()