From ef99a78760896316dd05f96683b8d8176bfacd7a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 28 Aug 2024 21:27:06 -0700 Subject: [PATCH] Revert "[Core][Kernels] Use FlashInfer backend for FP8 KV Cache when available." (#7982) --- tests/kernels/test_flashinfer.py | 228 +------------------------- vllm/attention/backends/flashinfer.py | 29 +--- vllm/attention/selector.py | 4 - 3 files changed, 12 insertions(+), 249 deletions(-) diff --git a/tests/kernels/test_flashinfer.py b/tests/kernels/test_flashinfer.py index 67f12cf1ee08e..f109792ad251b 100644 --- a/tests/kernels/test_flashinfer.py +++ b/tests/kernels/test_flashinfer.py @@ -73,14 +73,11 @@ def ref_paged_attn( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @torch.inference_mode -def test_flashinfer_decode_with_paged_kv( - kv_lens: List[int], - num_heads: Tuple[int, int], - head_size: int, - dtype: torch.dtype, - block_size: int, - soft_cap: Optional[float], -) -> None: +def test_flashinfer_decode_with_paged_kv(kv_lens: List[int], + num_heads: Tuple[int, + int], head_size: int, + dtype: torch.dtype, block_size: int, + soft_cap: Optional[float]) -> None: torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) num_seqs = len(kv_lens) @@ -91,7 +88,6 @@ def test_flashinfer_decode_with_paged_kv( scale = head_size**-0.5 query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - key_value_cache = torch.randn(NUM_BLOCKS, 2, block_size, @@ -129,7 +125,7 @@ def test_flashinfer_decode_with_paged_kv( wrapper = flashinfer.\ BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", use_tensor_cores=( - (num_query_heads//num_kv_heads) > 4) + (num_query_heads//num_kv_heads) not in (1, 2, 4, 8)) ) wrapper.begin_forward(kv_indptr, kv_indices, @@ -253,215 +249,3 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], soft_cap=soft_cap) torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" - - -@pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]]) -@pytest.mark.parametrize("num_heads", [(32, 8), (6, 1)]) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) -def test_flashinfer_prefill_with_paged_fp8_kv( - seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int], - head_size: int, dtype: torch.dtype, block_size: int, - soft_cap: Optional[float]) -> None: - torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) - num_seqs = len(seq_lens) - query_lens = [x[0] for x in seq_lens] - kv_lens = [x[1] for x in seq_lens] - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - max_kv_len = max(kv_lens) - scale = head_size**-0.5 - - kv_cache_dtype = torch.float8_e4m3fn - - query = torch.randn(sum(query_lens), - num_query_heads, - head_size, - dtype=dtype) - NUM_BLOCKS_FP8 = 2048 - key_value_cache = torch.randn(NUM_BLOCKS_FP8, - 2, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1) - key_cache /= head_size**0.5 - value_cache /= head_size**0.5 - - k_scale = key_cache.amax().item() / 448.0 - v_scale = value_cache.amax().item() / 448.0 - - kv_cache_fp8 = torch.cat([key_cache / k_scale, value_cache / v_scale], - dim=1).to(kv_cache_dtype) - - assert (kv_cache_fp8.shape == key_value_cache.shape) - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS_FP8, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) - - qo_indptr = [0] - kv_indptr = [0] - kv_indices = [] - kv_last_page_lens = [] - for i in range(num_seqs): - seq_len = kv_lens[i] - assert seq_len > 0 - num_blocks = (seq_len + block_size - 1) // block_size - kv_indices.extend(block_tables[i, :num_blocks]) - kv_indptr.append(kv_indptr[-1] + num_blocks) - kv_last_page_len = seq_len % block_size - if kv_last_page_len == 0: - kv_last_page_len = block_size - kv_last_page_lens.append(kv_last_page_len) - qo_indptr.append(qo_indptr[-1] + query_lens[i]) - - qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32) - kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) - kv_indices = torch.tensor(kv_indices, dtype=torch.int32) - kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, "NHD") - wrapper.begin_forward( - qo_indptr, - kv_indptr, - kv_indices, - kv_last_page_lens, - num_query_heads, - num_kv_heads, - head_size, - block_size, - ) - - output = wrapper.forward(query, - kv_cache_fp8, - logits_soft_cap=soft_cap, - k_scale=k_scale, - v_scale=v_scale) - - ref_output = ref_paged_attn(query=query, - key_cache=key_cache.squeeze(1), - value_cache=value_cache.squeeze(1), - query_lens=query_lens, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap) - del query - del block_tables - # verify prefill fp8 - torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" - - -@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) -@pytest.mark.parametrize("num_heads", [(32, 8), (64, 8), (6, 1)]) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) -@torch.inference_mode -def test_flashinfer_decode_with_paged_fp8_kv( - kv_lens: List[int], - num_heads: Tuple[int, int], - head_size: int, - dtype: torch.dtype, - block_size: int, - soft_cap: Optional[float], -) -> None: - # test doesn't work for num_heads = (16,16) - torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) - num_seqs = len(kv_lens) - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - max_kv_len = max(kv_lens) - scale = head_size**-0.5 - use_tensor_cores = (num_query_heads // num_kv_heads) > 4 - kv_cache_dtype = torch.float8_e4m3fn - - query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - NUM_BLOCKS_FP8 = 2048 - key_value_cache = torch.randn(NUM_BLOCKS_FP8, - 2, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1) - key_cache /= head_size**0.5 - value_cache /= head_size**0.5 - - k_scale = key_cache.amax().item() / 448.0 - v_scale = value_cache.amax().item() / 448.0 - - key_cache_fp8 = (key_cache / k_scale).to(kv_cache_dtype) - value_cache_fp8 = (value_cache / v_scale).to(kv_cache_dtype) - assert (key_cache_fp8.shape[1] == 1 and value_cache_fp8.shape[1] == 1) - kv_cache_fp8 = torch.cat([key_cache_fp8, value_cache_fp8], dim=1) - - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS_FP8, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) - - kv_indptr = [0] - kv_indices = [] - kv_last_page_lens = [] - for i in range(num_seqs): - seq_len = kv_lens[i] - assert seq_len > 0 - num_blocks = (seq_len + block_size - 1) // block_size - kv_indices.extend(block_tables[i, :num_blocks]) - kv_indptr.append(kv_indptr[-1] + num_blocks) - kv_last_page_len = seq_len % block_size - if kv_last_page_len == 0: - kv_last_page_len = block_size - kv_last_page_lens.append(kv_last_page_len) - - kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) - kv_indices = torch.tensor(kv_indices, dtype=torch.int32) - kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.\ - BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", - use_tensor_cores=use_tensor_cores) - wrapper.begin_forward(kv_indptr, - kv_indices, - kv_last_page_lens, - num_query_heads, - num_kv_heads, - head_size, - block_size, - "NONE", - data_type=dtype) - output = wrapper.forward(query, - kv_cache_fp8, - logits_soft_cap=soft_cap, - k_scale=k_scale, - v_scale=v_scale) - key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) - value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) - - ref_output = ref_paged_attn(query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap) - # Temporary fix: Increasing the tolerance. Seems like a flashinfer issue - torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index ca42f77f51cd4..a8d76b79ff204 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -83,15 +83,6 @@ def copy_blocks( def get_supported_head_sizes() -> List[int]: return [64, 128, 256] - @staticmethod - def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: - if kv_cache_dtype in ("fp8", "fp8_e4m3"): - return torch.float8_e4m3fn - elif kv_cache_dtype == "fp8_e5m2": - return torch.float8_e5m2 - else: - return ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") - class FlashInferState(AttentionState): @@ -186,9 +177,9 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): self._graph_decode_workspace_buffer, _indptr_buffer, self._graph_indices_buffer, _last_page_len_buffer, "NHD", use_tensor_cores) + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) - kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - self.runner.kv_cache_dtype) paged_kv_indptr_tensor_host = torch.arange(0, batch_size + 1, dtype=torch.int32) @@ -349,7 +340,7 @@ def begin_forward(self): self.page_size, # Disable flashinfer's pos encoding and use vllm's rope. pos_encoding_mode="NONE", - ) + data_type=self.data_type) def asdict_zerocopy(self, skip_fields: Optional[Set[str]] = None @@ -375,8 +366,7 @@ def prefill_metadata(self) -> Optional["FlashInferMetadata"]: def decode_metadata(self) -> Optional["FlashInferMetadata"]: # Currently chunked prefill is not supported if self.num_prefills > 0: - assert self.num_decode_tokens == 0, ( - "Chunked prefill is not supported with flashinfer yet.") + assert self.num_decode_tokens == 0 return None return self @@ -588,7 +578,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], kv_cache_dtype = get_kv_cache_torch_dtype( self.runner.kv_cache_dtype, self.runner.model_config.dtype) - return FlashInferMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, @@ -672,6 +661,7 @@ def forward( if attn_metadata.num_decode_tokens > 0: assert attn_metadata.num_prefill_tokens == 0, ( "Chunked prefill is not supported with flashinfer yet.") + if kv_cache is not None: # Use the same reshape and cache kernel as flash attention. ops.reshape_and_cache_flash( @@ -684,11 +674,6 @@ def forward( k_scale, v_scale, ) - # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 - # to process the cache in fp8 - torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - self.kv_cache_dtype) - kv_cache = kv_cache.view(torch_dtype) query = query.contiguous( ) # Flashinfer requires query to be contiguous @@ -726,7 +711,5 @@ def forward( query, kv_cache, sm_scale=self.scale, - logits_soft_cap=self.logits_soft_cap, - k_scale=k_scale, - v_scale=v_scale) + logits_soft_cap=self.logits_soft_cap) return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index c0e592c8b12a0..54558fc2d7e53 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -226,10 +226,6 @@ def which_attn_to_use( elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): logger.info( "Cannot use FlashAttention-2 backend for FP8 KV cache.") - logger.warning( - "Please use FlashInfer backend with FP8 KV Cache for " - "better performance by set environment " - "VLLM_ATTENTION_BACKEND=FLASHINFER") selected_backend = _Backend.XFORMERS elif block_size % 16 != 0: logger.info(