From 70f0b13d5386d8f6120cd7b4d9a4ad6b50d64fff Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 3 Jul 2024 00:10:58 +0000 Subject: [PATCH 01/11] Separate kv_scale into key_scale and value_scale --- csrc/attention/attention_kernels.cu | 55 +++++++++++---------- csrc/cache.h | 2 +- csrc/cache_kernels.cu | 11 +++-- csrc/cpu/attention.cpp | 12 ++--- csrc/cpu/cache.cpp | 5 +- csrc/cpu/torch_bindings.cpp | 10 ++-- csrc/ops.h | 8 +-- csrc/torch_bindings.cpp | 10 ++-- tests/kernels/test_attention.py | 2 + tests/kernels/test_blocksparse_attention.py | 2 + tests/kernels/test_cache.py | 2 +- vllm/_custom_ops.py | 19 ++++--- vllm/_ipex_ops.py | 9 ++-- vllm/attention/backends/abstract.py | 3 +- vllm/attention/backends/blocksparse_attn.py | 9 ++-- vllm/attention/backends/flash_attn.py | 6 ++- vllm/attention/backends/flashinfer.py | 6 ++- vllm/attention/backends/ipex_attn.py | 14 ++++-- vllm/attention/backends/pallas.py | 5 +- vllm/attention/backends/rocm_flash_attn.py | 9 ++-- vllm/attention/backends/torch_sdpa.py | 11 +++-- vllm/attention/backends/xformers.py | 9 ++-- vllm/attention/layer.py | 3 +- vllm/attention/ops/ipex_attn.py | 6 ++- vllm/attention/ops/paged_attn.py | 15 ++++-- 25 files changed, 144 insertions(+), 99 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 91083481705cb..677bac38ab0ef 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -105,9 +105,9 @@ __device__ void paged_attention_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const float key_scale, const float value_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int seq_idx = blockIdx.y; const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; @@ -285,7 +285,7 @@ __device__ void paged_attention_kernel( Quant_vec k_vec_quant = *reinterpret_cast( k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_vecs[j] = fp8::scaled_convert( - k_vec_quant, kv_scale); + k_vec_quant, key_scale); } } @@ -414,8 +414,8 @@ __device__ void paged_attention_kernel( V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. - v_vec = fp8::scaled_convert(v_quant_vec, - kv_scale); + v_vec = fp8::scaled_convert( + v_quant_vec, value_scale); } if (block_idx == num_seq_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the @@ -513,15 +513,15 @@ __global__ void paged_attention_v1_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const float key_scale, const float value_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, - kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks, + kv_head_stride, key_scale, value_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step); } @@ -549,14 +549,14 @@ __global__ void paged_attention_v2_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const float key_scale, const float value_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, - kv_block_stride, kv_head_stride, kv_scale, tp_rank, + kv_block_stride, kv_head_stride, key_scale, value_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step); } @@ -682,7 +682,7 @@ __global__ void paged_attention_v2_reduce_kernel( out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ - kv_scale, tp_rank, blocksparse_local_blocks, \ + key_scale, value_scale, tp_rank, blocksparse_local_blocks, \ blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_head_sliding_step); @@ -694,8 +694,8 @@ void paged_attention_v1_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes, float kv_scale, - const int tp_rank, const int blocksparse_local_blocks, + const c10::optional& alibi_slopes, float key_scale, + float value_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); @@ -770,7 +770,7 @@ void paged_attention_v1_launcher( paged_attention_v1_launcher( \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank, \ + seq_lens, max_seq_len, alibi_slopes, key_scale, value_scale, tp_rank, \ blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); @@ -815,8 +815,8 @@ void paged_attention_v1( torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, double key_scale, double value_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); @@ -833,7 +833,7 @@ void paged_attention_v1( exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, kv_scale, tp_rank, \ + kv_block_stride, kv_head_stride, key_scale, value_scale, tp_rank, \ blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); \ vllm::paged_attention_v2_reduce_kernel& alibi_slopes, float kv_scale, - const int tp_rank, const int blocksparse_local_blocks, + const c10::optional& alibi_slopes, float key_scale, + float value_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); @@ -932,8 +932,9 @@ void paged_attention_v2_launcher( IS_BLOCK_SPARSE>( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ - kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ - blocksparse_block_size, blocksparse_head_sliding_step); + key_scale, value_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); #define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ switch (is_block_sparse) { \ @@ -980,8 +981,8 @@ void paged_attention_v2( torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, double key_scale, double value_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); diff --git a/csrc/cache.h b/csrc/cache.h index 86caa9345361d..183b12e3d2543 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -19,7 +19,7 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, const std::string& kv_cache_dtype, - const double kv_scale); + const double key_scale, const double value_scale); void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 72041076ae009..861df888db351 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -160,7 +160,7 @@ __global__ void reshape_and_cache_kernel( const int64_t* __restrict__ slot_mapping, // [num_tokens] const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, const int x, - const float kv_scale) { + const float key_scale, const float value_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { @@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel( value_cache[tgt_value_idx] = tgt_value; } else { key_cache[tgt_key_idx] = - fp8::scaled_convert(tgt_key, kv_scale); + fp8::scaled_convert(tgt_key, key_scale); value_cache[tgt_value_idx] = - fp8::scaled_convert(tgt_value, kv_scale); + fp8::scaled_convert(tgt_value, value_scale); } } } @@ -248,7 +248,7 @@ __global__ void reshape_and_cache_flash_kernel( reinterpret_cast(key_cache.data_ptr()), \ reinterpret_cast(value_cache.data_ptr()), \ slot_mapping.data_ptr(), key_stride, value_stride, \ - num_heads, head_size, block_size, x, kv_scale); + num_heads, head_size, block_size, x, key_scale, value_scale); void reshape_and_cache( torch::Tensor& key, // [num_tokens, num_heads, head_size] @@ -258,7 +258,8 @@ void reshape_and_cache( torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype, const double kv_scale) { + const std::string& kv_cache_dtype, const double key_scale, + const double value_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index 8367093325314..8cc159bda597a 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -423,11 +423,11 @@ void paged_attention_v1( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, double key_scale, double value_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(kv_scale == 1.0f); + TORCH_CHECK(key_scale == 1.0f && value_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", @@ -742,11 +742,11 @@ void paged_attention_v2( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, double key_scale, double value_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(kv_scale == 1.0f); + TORCH_CHECK(key_scale == 1.0f && value_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 2b5c3bd6ee70b..058d480cb10ad 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -107,8 +107,9 @@ void copy_blocks(std::vector const& key_caches, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, double kv_scale) { - TORCH_CHECK(kv_scale == 1.0f); + const std::string& kv_cache_dtype, double key_scale, + double value_scale) { + TORCH_CHECK(key_scale == 1.0f && value_scale == 1.0f); int num_tokens = key.size(0); int num_heads = key.size(1); diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 39e8cf3ed3c10..06d9e8bd1106a 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -16,8 +16,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float kv_scale, int tp_rank," - " int blocksparse_local_blocks," + " str kv_cache_dtype, float key_scale, float value_scale," + " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1); @@ -30,8 +30,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float kv_scale, int tp_rank," - " int blocksparse_local_blocks," + " str kv_cache_dtype, float key_scale, float value_scale," + " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2); @@ -103,7 +103,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! key_cache, Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float kv_scale) -> ()"); + " float key_scale, float value_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); } diff --git a/csrc/ops.h b/csrc/ops.h index 8a92afdc81a9b..2932521dc81a2 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -8,8 +8,8 @@ void paged_attention_v1( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, double key_scale, double value_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); @@ -19,8 +19,8 @@ void paged_attention_v2( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, double key_scale, double value_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index faf29e1f1e01e..e27defb42b430 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -27,8 +27,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float kv_scale, int tp_rank," - " int blocksparse_local_blocks," + " str kv_cache_dtype, float key_scale, float value_scale," + " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1); @@ -41,8 +41,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float kv_scale, int tp_rank," - " int blocksparse_local_blocks," + " str kv_cache_dtype, float key_scale, float value_scale," + " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); @@ -219,7 +219,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! key_cache, Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float kv_scale) -> ()"); + " float key_scale, float value_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); // Reshape the key and value tensors and cache them. diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index f848ad51c7014..1f71f7f700614 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -194,6 +194,7 @@ def test_paged_attention( alibi_slopes, kv_cache_dtype, kv_scale, + kv_scale, ) elif version == "v2": num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) @@ -225,6 +226,7 @@ def test_paged_attention( alibi_slopes, kv_cache_dtype, kv_scale, + kv_scale, ) else: raise AssertionError(f"Unknown version: {version}") diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index 402545d1980d6..1730f6edb6d97 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -232,6 +232,7 @@ def test_paged_attention( alibi_slopes, kv_cache_dtype, kv_scale, + kv_scale, tp_rank=tp_rank, blocksparse_local_blocks=blocksparse_local_blocks, blocksparse_vert_stride=blocksparse_vert_stride, @@ -268,6 +269,7 @@ def test_paged_attention( alibi_slopes, kv_cache_dtype, kv_scale, + kv_scale, tp_rank=tp_rank, blocksparse_local_blocks=blocksparse_local_blocks, blocksparse_vert_stride=blocksparse_vert_stride, diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 23b6baa60c05b..ad6164d9e388c 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -159,7 +159,7 @@ def test_reshape_and_cache( # Call the reshape_and_cache kernel. ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, - kv_cache_dtype, kv_scale) + kv_cache_dtype, kv_scale, kv_scale) if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 479ea08e49072..4ea497512b47c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -84,7 +84,8 @@ def paged_attention_v1( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - kv_scale: float, + key_scale: float, + value_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -94,8 +95,9 @@ def paged_attention_v1( torch.ops._C.paged_attention_v1( out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, - blocksparse_block_size, blocksparse_head_sliding_step) + key_scale, value_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step) def paged_attention_v2( @@ -114,7 +116,8 @@ def paged_attention_v2( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - kv_scale: float, + key_scale: float, + value_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -124,7 +127,7 @@ def paged_attention_v2( torch.ops._C.paged_attention_v2( out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, - alibi_slopes, kv_cache_dtype, kv_scale, tp_rank, + alibi_slopes, kv_cache_dtype, key_scale, value_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step) @@ -365,11 +368,13 @@ def reshape_and_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - kv_scale: float, + key_scale: float, + value_scale: float, ) -> None: torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, - kv_cache_dtype, kv_scale) + kv_cache_dtype, key_scale, + value_scale) def reshape_and_cache_flash( diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index 99a875c9b3fb7..0105d040b3613 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -59,7 +59,8 @@ def paged_attention_v1( max_context_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - kv_scale: float, + key_scale: float, + value_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -99,7 +100,8 @@ def paged_attention_v2( max_context_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - kv_scale: float, + key_scale: float, + value_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -227,7 +229,8 @@ def reshape_and_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - kv_scale: float, + key_scale: float, + value_scale: float, ) -> None: assert kv_cache_dtype == "auto" ipex.llm.modules.PagedAttention.reshape_and_cache( diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 40768532f59c2..d2eabc2434f1e 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -127,6 +127,7 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: T, - kv_scale: float = 1.0, + key_scale: float = 1.0, + value_scale: float = 1.0, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 7b4578fcd8b9d..fdfd1edfc4479 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -327,7 +327,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: BlocksparseFlashAttentionMetadata, - kv_scale: float = 1.0, + key_scale: float = 1.0, + value_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -361,7 +362,8 @@ def forward( value_cache, attn_metadata.slot_mapping, self.kv_cache_dtype, - kv_scale, + key_scale, + value_scale, ) if prefill_meta := attn_metadata.prefill_metadata: @@ -398,7 +400,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - kv_scale, + key_scale, + value_scale, tp_rank=self.tp_rank, blocksparse_local_blocks=self.local_blocks, blocksparse_vert_stride=self.vert_stride, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 8cb5c3101a804..979fb337a0834 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -256,7 +256,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, - kv_scale: float = 1.0, + key_scale: float = 1.0, + value_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -270,7 +271,8 @@ def forward( shape = [num_tokens, num_heads * head_size] """ # NOTE(woosuk): FlashAttention does not support FP8 KV cache. - assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." + assert key_scale == 1.0 and value_scale == 1.0, ( + "key/value_scale is not supported in FlashAttention.") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 4d023282fad49..a83e74ff96bae 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -221,9 +221,11 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: FlashInferMetadata, - kv_scale: float = 1.0, + key_scale: float = 1.0, + value_scale: float = 1.0, ) -> torch.Tensor: - assert kv_scale == 1.0 + assert key_scale == 1.0 and value_scale == 1.0, ( + "key/value_scale is not supported in FlashInfer.") num_tokens, hidden_size = query.shape query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 5114bfa6e1589..2a2786383584e 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -156,7 +156,8 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: IpexAttnMetadata, # type: ignore - kv_scale: float = 1.0, + key_scale: float = 1.0, + value_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. @@ -169,7 +170,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - assert kv_scale == 1.0 + assert key_scale == 1.0 and value_scale == 1.0 num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -186,7 +187,8 @@ def forward( value_cache, attn_metadata.slot_mapping.flatten(), self.kv_cache_dtype, - kv_scale, + key_scale, + value_scale, ) if attn_metadata.is_prompt: @@ -267,7 +269,8 @@ def forward( max_seq_len, self.alibi_slopes, self.kv_cache_dtype, - kv_scale, + key_scale, + value_scale, ) else: # Run PagedAttention V2. @@ -299,7 +302,8 @@ def forward( max_seq_len, self.alibi_slopes, self.kv_cache_dtype, - kv_scale, + key_scale, + value_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 22cb1a1bd1fd3..6df160c42d992 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -131,7 +131,8 @@ def forward( value: torch.Tensor, kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]], attn_metadata: PallasMetadata, - kv_scale: float = 1.0, + key_scale: float = 1.0, + value_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with Pallas attention. @@ -145,7 +146,7 @@ def forward( Returns: shape = [batch_size, seq_len, num_heads * head_size] """ - assert kv_scale == 1.0 + assert key_scale == 1.0 and value_scale == 1.0 batch_size, seq_len, hidden_size = query.shape query = query.view(batch_size, seq_len, self.num_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 81fabdbdfc83c..62c0683428190 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -265,7 +265,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: ROCmFlashAttentionMetadata, - kv_scale: float = 1.0, + key_scale: float = 1.0, + value_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -298,7 +299,8 @@ def forward( value_cache, attn_metadata.slot_mapping, self.kv_cache_dtype, - kv_scale, + key_scale, + value_scale, ) num_prefill_tokens = attn_metadata.num_prefill_tokens @@ -402,7 +404,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - kv_scale, + key_scale, + value_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 63f8466da9316..9707c4f1f5180 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -144,7 +144,8 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: TorchSDPAMetadata, # type: ignore - kv_scale: float = 1.0, + key_scale: float = 1.0, + value_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -157,7 +158,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - assert kv_scale == 1.0 + assert key_scale == 1.0 and value_scale == 1.0 num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -170,7 +171,8 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - self.kv_cache_dtype, kv_scale) + self.kv_cache_dtype, key_scale, + value_scale) if attn_metadata.is_prompt: assert attn_metadata.seq_lens is not None @@ -233,7 +235,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - kv_scale, + key_scale, + value_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index ff449c3ff74f8..7e4b29fe831bc 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -242,7 +242,8 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: "XFormersMetadata", - kv_scale: float = 1.0, + key_scale: float = 1.0, + value_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -269,7 +270,8 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - self.kv_cache_dtype, kv_scale) + self.kv_cache_dtype, key_scale, + value_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -331,7 +333,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - kv_scale, + key_scale, + value_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index dfe93be462184..580c36c289e54 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -91,8 +91,9 @@ def forward( kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: + # TODO(mgoin): Add capacity for loading separate key and value scales return self.impl.forward(query, key, value, kv_cache, attn_metadata, - self._kv_scale) + self._kv_scale, self._kv_scale) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index 5a5317b65004e..17962d0fd1f5c 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -45,7 +45,8 @@ def write_to_paged_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - kv_scale: float, + key_scale: float, + value_scale: float, *args, ) -> None: ipex_modules.PagedAttention.reshape_and_cache( @@ -64,7 +65,8 @@ def forward_decode( num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], - kv_scale: float, + key_scale: float, + value_scale: float, *args, ) -> torch.Tensor: output = torch.empty_like(query) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index a214f40d16514..576428a926b68 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -66,7 +66,8 @@ def write_to_paged_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - kv_scale: float, + key_scale: float, + value_scale: float, ) -> None: ops.reshape_and_cache( key, @@ -75,7 +76,8 @@ def write_to_paged_cache( value_cache, slot_mapping.flatten(), kv_cache_dtype, - kv_scale, + key_scale, + value_scale, ) @staticmethod @@ -90,7 +92,8 @@ def forward_decode( num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], - kv_scale: float, + key_scale: float, + value_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -135,7 +138,8 @@ def forward_decode( max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, + key_scale, + value_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, @@ -172,7 +176,8 @@ def forward_decode( max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, + key_scale, + value_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, From 95fbd2c6ac311ac9d563777452e218ccbd714896 Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 3 Jul 2024 00:17:46 +0000 Subject: [PATCH 02/11] Fix tests --- benchmarks/kernels/benchmark_paged_attention.py | 8 +++++--- tests/kernels/test_attention.py | 10 +++++----- tests/kernels/test_blocksparse_attention.py | 10 +++++----- tests/kernels/test_cache.py | 4 ++-- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 16de60477c305..a093b4e13d271 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -100,7 +100,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: start_time = time.perf_counter() # Using default kv_scale - kv_scale = 1.0 + key_scale = value_scale = 1.0 for _ in range(num_iters): if version == "v1": @@ -117,7 +117,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, + key_scale, + value_scale, ) elif version == "v2": ops.paged_attention_v2( @@ -136,7 +137,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, + key_scale, + value_scale, ) else: raise ValueError(f"Invalid version: {version}") diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 1f71f7f700614..ef2d469e0dc59 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -175,7 +175,7 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale - kv_scale = 1.0 + key_scale = value_scale = 1.0 # Call the paged attention kernel. output = torch.empty_like(query) @@ -193,8 +193,8 @@ def test_paged_attention( max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, - kv_scale, + key_scale, + value_scale, ) elif version == "v2": num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) @@ -225,8 +225,8 @@ def test_paged_attention( max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, - kv_scale, + key_scale, + value_scale, ) else: raise AssertionError(f"Unknown version: {version}") diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index 1730f6edb6d97..57b349ea1fb26 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -212,7 +212,7 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale - kv_scale = 1.0 + key_scale = value_scale = 1.0 tp_rank = 0 # Call the paged attention kernel. @@ -231,8 +231,8 @@ def test_paged_attention( max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, - kv_scale, + key_scale, + value_scale, tp_rank=tp_rank, blocksparse_local_blocks=blocksparse_local_blocks, blocksparse_vert_stride=blocksparse_vert_stride, @@ -268,8 +268,8 @@ def test_paged_attention( max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, - kv_scale, + key_scale, + value_scale, tp_rank=tp_rank, blocksparse_local_blocks=blocksparse_local_blocks, blocksparse_vert_stride=blocksparse_vert_stride, diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index ad6164d9e388c..09b918f1105fc 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -155,11 +155,11 @@ def test_reshape_and_cache( cloned_value_cache = value_cache.clone() # Using default kv_scale - kv_scale = 1.0 + key_scale = value_scale = 1.0 # Call the reshape_and_cache kernel. ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, - kv_cache_dtype, kv_scale, kv_scale) + kv_cache_dtype, key_scale, value_scale) if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) From 1050a1814ee2c5c63063aa95e42d4ce904f52eb1 Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 3 Jul 2024 01:47:38 +0000 Subject: [PATCH 03/11] Try adding scales to weight loader --- csrc/cache_kernels.cu | 8 +-- vllm/attention/layer.py | 13 +++-- .../model_executor/layers/quantization/fp8.py | 42 ++++++++++---- .../model_loader/weight_utils.py | 55 +++++++++++++++++++ vllm/model_executor/models/llama.py | 19 ++----- vllm/model_executor/models/mixtral.py | 19 ++----- vllm/model_executor/models/qwen2.py | 19 ++----- 7 files changed, 113 insertions(+), 62 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 861df888db351..39a2f971fa90f 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -319,13 +319,13 @@ namespace vllm { template __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, Tout* __restrict__ dst_cache, - const float kv_scale, + const float scale, const int64_t block_stride) { const int64_t block_idx = blockIdx.x; for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { int64_t idx = block_idx * block_stride + i; dst_cache[idx] = - fp8::scaled_convert(src_cache[idx], kv_scale); + fp8::scaled_convert(src_cache[idx], scale); } } @@ -334,11 +334,11 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, #define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ vllm::convert_fp8_kernel<<>>( \ reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst_cache.data_ptr()), kv_scale, block_stride); + reinterpret_cast(dst_cache.data_ptr()), scale, block_stride); // Only for testing. void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, - const double kv_scale, const std::string& kv_cache_dtype) { + const double scale, const std::string& kv_cache_dtype) { torch::Device src_device = src_cache.device(); torch::Device dst_device = dst_cache.device(); TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 580c36c289e54..5e72ec633280e 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -47,13 +47,14 @@ def __init__( if num_kv_heads is None: num_kv_heads = num_heads - # The default kv_scale is set to 1.0. This is ignored + # The default key/value_scale is set to 1.0. This is ignored # when kv-cache is not fp8, and should be used with # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we - # expect the pre-quantized kv_scale to be loaded along + # expect the pre-quantized key/value_scale to be loaded along # with the model weights. self.kv_cache_dtype = kv_cache_dtype - self._kv_scale = 1.0 + self._key_scale = 1.0 + self._value_scale = 1.0 quant_method = quant_config.get_quant_method( self) if quant_config else None if quant_method is not None: @@ -66,8 +67,8 @@ def __init__( "fp8 checkpoints.") # When FP8 quantization is enabled, we make a parameter # "kv_scale" so that it can be loaded from FP8 checkpoint. - # The kv_scale will then be converted back to self._kv_scale - # in a native float32 value after weight loading. + # The key/value_scale will then be converted back to + # self._kv_scale in a native float32 value after weight loading self.quant_method = quant_method self.quant_method.create_weights(self) @@ -93,7 +94,7 @@ def forward( ) -> torch.Tensor: # TODO(mgoin): Add capacity for loading separate key and value scales return self.impl.forward(query, key, value, kv_cache, attn_metadata, - self._kv_scale, self._kv_scale) + self._key_scale, self._value_scale) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5d503a2215c96..ac714fad00800 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -287,6 +287,8 @@ def create_weights(self, layer: torch.nn.Module): # If the kv_scale appears in the checkpoint, it will be # overwritten when loading weights. layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False) + layer.key_scale = Parameter(torch.tensor(-1.0), requires_grad=False) + layer.value_scale = Parameter(torch.tensor(-1.0), requires_grad=False) def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError("Fp8KVCacheMethod.apply should not be called.") @@ -295,17 +297,37 @@ def process_weights_after_loading(self, layer: Module) -> None: # If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0 # regardless whether the kv-scale is available in the checkpoint. if layer.kv_cache_dtype != "auto": - kv_scale = layer.kv_scale.to("cpu").tolist() - if not isinstance(kv_scale, float): - raise ValueError("Only support per-tensor scaling factor " - "for fp8 KV cache") - layer._kv_scale = kv_scale - if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype: - print_warning_once( - "Using KV cache scaling factor 1.0 for fp8_e4m3. This may " - "cause accuracy issues. Please make sure kv-cache scaling " - "factor is available in the fp8 checkpoint.") + # We prefer to use separate key_scale and value_scale if present + if layer.key_scale > 0.0 and layer.value_scale > 0.0: + key_scale = layer.key_scale.to("cpu").tolist() + value_scale = layer.value_scale.to("cpu").tolist() + if not isinstance(key_scale, float) or not isinstance( + value_scale, float): + raise ValueError("Only support per-tensor scaling factor " + "for fp8 KV cache") + layer._key_scale = key_scale + layer._value_scale = value_scale + if (layer._key_scale == 1.0 and layer._value_scale == 1.0 + and "e5m2" not in layer.kv_cache_dtype): + print_warning_once( + "Using KV cache scaling factor 1.0 for fp8_e4m3. This " + "may cause accuracy issues. Please make sure kv-cache " + "scaling factor is available in the fp8 checkpoint.") + else: + kv_scale = layer.kv_scale.to("cpu").tolist() + if not isinstance(kv_scale, float): + raise ValueError("Only support per-tensor scaling factor " + "for fp8 KV cache") + layer._kv_scale = kv_scale + if (layer._kv_scale == 1.0 + and "e5m2" not in layer.kv_cache_dtype): + print_warning_once( + "Using KV cache scaling factor 1.0 for fp8_e4m3. This " + "may cause accuracy issues. Please make sure kv-cache " + "scaling factor is available in the fp8 checkpoint.") del layer.kv_scale + del layer.key_scale + del layer.value_scale def per_tensor_quantize(tensor: torch.Tensor, diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 943022a3f03c7..21b0a25fcf280 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization import (QuantizationConfig, get_quantization_config) from vllm.model_executor.layers.quantization.schema import QuantParamSchema +from vllm.utils import print_warning_once logger = init_logger(__name__) @@ -457,3 +458,57 @@ def initialize_dummy_weights( param.data.copy_(tmp_param) else: param.uniform_(low, high) + + +def remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: + """Remap the name of FP8 kv-scale parameters. + + This function handles the remapping of FP8 kv-scale parameter names. + It checks if the given name ends with "kv_scale" and attempts to remap + it to the expected name format in the model. If the remapped name is not + found in the params_dict, a warning is printed and None is returned. + + Args: + name (str): The original loaded checkpoint parameter name. + params_dict (dict): Dictionary containing the model's named parameters. + + Returns: + str: The remapped parameter name if successful, or the original name + if no remapping is needed. + None: If the remapped name is not found in params_dict. + """ + if name.endswith("kv_scale"): + print_warning_once( + f"Found kv_scale in the checkpoint (e.g. {name}). This format " + "is deprecated in favor of separate key_scale and value_scale " + "tensors and will be removed in a future release.") + remapped_scale_name = name.replace(".kv_scale", ".attn.kv_scale") + if remapped_scale_name not in params_dict: + print_warning_once( + f"Found kv_scale in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_scale_name}). kv_scale is " + "not loaded.") + return None + return remapped_scale_name + elif name.endswith("key_scale"): + remapped_scale_name = name.replace(".key_scale", ".attn.key_scale") + if remapped_scale_name not in params_dict: + print_warning_once( + f"Found key_scale in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_scale_name}). key_scale is " + "not loaded.") + return None + return remapped_scale_name + elif name.endswith("value_scale"): + remapped_scale_name = name.replace(".value_scale", ".attn.value_scale") + if remapped_scale_name not in params_dict: + print_warning_once( + f"Found value_scale in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_scale_name}). value_scale is " + "not loaded.") + return None + return remapped_scale_name + return name diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 54d01701f04fb..f78a6926a4f8b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -44,10 +44,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader) + default_weight_loader, kv_cache_scales_loader, remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput -from vllm.utils import is_hip, print_warning_once +from vllm.utils import is_hip from .interfaces import SupportsLoRA @@ -425,18 +425,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. - if name.endswith("kv_scale"): - remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") - if remapped_kv_scale_name not in params_dict: - print_warning_once( - f"Found kv scale in the checkpoint (e.g. {name}), " - "but not found the expected name in the model " - f"(e.g. {remapped_kv_scale_name}). kv-scale is " - "not loaded.") - continue - else: - name = remapped_kv_scale_name + remapped_name = remap_kv_scale_name(name, params_dict) + if remapped_name is None: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index a662db6d28d00..acbc06c3d6eaa 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -48,7 +48,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import SamplerOutput @@ -619,19 +620,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. - if name.endswith("kv_scale"): - remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") - if remapped_kv_scale_name not in params_dict: - print_warning_once( - "Found kv scale in the checkpoint " - f"(e.g. {name}), but not found the expected " - f"name in the model " - f"(e.g. {remapped_kv_scale_name}). " - "kv-scale is not loaded.") - continue - else: - name = remapped_kv_scale_name + remapped_name = remap_kv_scale_name(name, params_dict) + if remapped_name is None: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index e2d725af63593..4ee2c79b329f3 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -43,10 +43,10 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput -from vllm.utils import print_warning_once from .interfaces import SupportsLoRA @@ -381,18 +381,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. - if name.endswith("kv_scale"): - remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") - if remapped_kv_scale_name not in params_dict: - print_warning_once( - f"Found kv scale in the checkpoint (e.g. {name}), " - "but not found the expected name in the model " - f"(e.g. {remapped_kv_scale_name}). kv-scale is " - "not loaded.") - continue - else: - name = remapped_kv_scale_name + remapped_name = remap_kv_scale_name(name, params_dict) + if remapped_name is None: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) From 33699508a95d4f98f807f8e3d25a8ea1890730cd Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 15 Jul 2024 19:25:22 +0000 Subject: [PATCH 04/11] Format --- vllm/attention/backends/xformers.py | 3 +-- vllm/model_executor/model_loader/weight_utils.py | 2 +- vllm/model_executor/models/llama.py | 10 ++++------ vllm/model_executor/models/mixtral.py | 8 ++++---- vllm/model_executor/models/qwen2.py | 10 ++++------ 5 files changed, 14 insertions(+), 19 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 3995ceb2af837..0048ef45eb049 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -532,8 +532,7 @@ def forward( value_cache, updated_slot_mapping, self.kv_cache_dtype, - key_scale, - value_scale) + key_scale, value_scale) if attn_type != AttentionType.ENCODER: # Decoder self-attention supports chunked prefill. diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 141a0468d9b91..0b26aceed9e42 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -465,7 +465,7 @@ def initialize_dummy_weights( param.uniform_(low, high) -def remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: +def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: """Remap the name of FP8 kv-scale parameters. This function handles the remapping of FP8 kv-scale parameter names. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index fa873880c4276..f03e34b9e7c92 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -44,10 +44,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader, remap_kv_scale_name) + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.utils import is_hip, print_warning_once +from vllm.utils import is_hip from .interfaces import SupportsLoRA from .utils import is_pp_missing_parameter, make_layers @@ -460,11 +460,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. - remapped_name = remap_kv_scale_name(name, params_dict) - if remapped_name is None: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: continue - else: - name = remapped_name if is_pp_missing_parameter(name, self): continue diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 4b4627af8b17e..e739df87cf96a 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -43,10 +43,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, remap_kv_scale_name) + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.utils import print_warning_once from .interfaces import SupportsLoRA @@ -416,9 +415,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. - remapped_name = remap_kv_scale_name(name, params_dict) - if remapped_name is None: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index df1834353ecaf..e9aa4416eded4 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -44,10 +44,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, remap_kv_scale_name) + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.utils import print_warning_once from .interfaces import SupportsLoRA @@ -383,11 +382,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. - remapped_name = remap_kv_scale_name(name, params_dict) - if remapped_name is None: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: continue - else: - name = remapped_name + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) From b1616ebd9504e4cd0d89e92c3bc741a19598b3b0 Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 15 Jul 2024 21:09:26 +0000 Subject: [PATCH 05/11] Update to use kv_scale if found --- .../model_executor/layers/quantization/fp8.py | 75 ++++++++++--------- .../model_loader/weight_utils.py | 26 ++++--- 2 files changed, 52 insertions(+), 49 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 6296cad73b512..3f2485dbe2c1a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -407,53 +407,54 @@ def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module): - """Create "weight" (aka kv_scale) for an attention layer. + """Create "weight" (aka k_scale and v_scale) for an attention layer. Args: layer: The layer that is using the QuantizeMethodBase factory. """ - # Initialize the KV cache scale to 1.0 as the default value. - # If the kv_scale appears in the checkpoint, it will be + # Initialize the KV cache scales to -1.0, which is an invalid value. + # If the k/v_scale appears in the checkpoint, it will be # overwritten when loading weights. - layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False) - layer.key_scale = Parameter(torch.tensor(-1.0), requires_grad=False) - layer.value_scale = Parameter(torch.tensor(-1.0), requires_grad=False) + layer.k_scale = Parameter(torch.tensor(-1.0), requires_grad=False) + layer.v_scale = Parameter(torch.tensor(-1.0), requires_grad=False) def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError("Fp8KVCacheMethod.apply should not be called.") def process_weights_after_loading(self, layer: Module) -> None: - # If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0 + # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 # regardless whether the kv-scale is available in the checkpoint. if layer.kv_cache_dtype != "auto": - # We prefer to use separate key_scale and value_scale if present - if layer.key_scale > 0.0 and layer.value_scale > 0.0: - key_scale = layer.key_scale.to("cpu").tolist() - value_scale = layer.value_scale.to("cpu").tolist() - if not isinstance(key_scale, float) or not isinstance( - value_scale, float): - raise ValueError("Only support per-tensor scaling factor " - "for fp8 KV cache") - layer._key_scale = key_scale - layer._value_scale = value_scale - if (layer._key_scale == 1.0 and layer._value_scale == 1.0 - and "e5m2" not in layer.kv_cache_dtype): - print_warning_once( - "Using KV cache scaling factor 1.0 for fp8_e4m3. This " - "may cause accuracy issues. Please make sure kv-cache " - "scaling factor is available in the fp8 checkpoint.") + if layer.k_scale > 0.0 and layer.v_scale > 0.0: + # We prefer to use separate k_scale and v_scale if present + k_scale = layer.k_scale.to("cpu").tolist() + v_scale = layer.v_scale.to("cpu").tolist() + elif layer.k_scale < 0.0 and layer.v_scale < 0.0: + # If no scales were loaded (both scales are invalid negative + # values), use the default value of 1.0 + k_scale = Parameter(torch.tensor(1.0), requires_grad=False) + v_scale = Parameter(torch.tensor(1.0), requires_grad=False) else: - kv_scale = layer.kv_scale.to("cpu").tolist() - if not isinstance(kv_scale, float): - raise ValueError("Only support per-tensor scaling factor " - "for fp8 KV cache") - layer._kv_scale = kv_scale - if (layer._kv_scale == 1.0 - and "e5m2" not in layer.kv_cache_dtype): - print_warning_once( - "Using KV cache scaling factor 1.0 for fp8_e4m3. This " - "may cause accuracy issues. Please make sure kv-cache " - "scaling factor is available in the fp8 checkpoint.") - del layer.kv_scale - del layer.key_scale - del layer.value_scale + # If we find a single kv_scale in the checkpoint, we remap + # kv_scale to k_scale during weight loading, and duplicate + # k_scale to v_scale here + assert layer.k_scale > 0.0 + scale_to_duplicate = max(layer.k_scale, layer.v_scale) + k_scale = scale_to_duplicate.to("cpu").tolist() + v_scale = scale_to_duplicate.to("cpu").tolist() + + if not isinstance(k_scale, float) or not isinstance( + v_scale, float): + raise ValueError("Only support per-tensor scaling factor " + "for fp8 KV cache") + layer._k_scale = k_scale + layer._v_scale = v_scale + if (layer._k_scale == 1.0 and layer._v_scale == 1.0 + and "e5m2" not in layer.kv_cache_dtype): + print_warning_once( + "Using KV cache scaling factor 1.0 for fp8_e4m3. This " + "may cause accuracy issues. Please make sure k/v_scale " + "scaling factors are available in the fp8 checkpoint.") + + del layer.k_scale + del layer.v_scale diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 0b26aceed9e42..bc09f56edb067 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -482,11 +482,13 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: if no remapping is needed. None: If the remapped name is not found in params_dict. """ - if name.endswith("kv_scale"): + if name.endswith(".kv_scale"): print_warning_once( - f"Found kv_scale in the checkpoint (e.g. {name}). This format " - "is deprecated in favor of separate key_scale and value_scale " - "tensors and will be removed in a future release.") + f"DEPRECATED. Found kv_scale in the checkpoint (e.g. {name}). " + "This format is deprecated in favor of separate k_scale and " + "v_scale tensors and will be removed in a future release. " + "Functionally for now, we will remap kv_scale to k_scale and " + "duplicate k_scale to v_scale") remapped_scale_name = name.replace(".kv_scale", ".attn.kv_scale") if remapped_scale_name not in params_dict: print_warning_once( @@ -496,23 +498,23 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: "not loaded.") return None return remapped_scale_name - elif name.endswith("key_scale"): - remapped_scale_name = name.replace(".key_scale", ".attn.key_scale") + elif name.endswith(".k_scale"): + remapped_scale_name = name.replace(".k_scale", ".attn.k_scale") if remapped_scale_name not in params_dict: print_warning_once( - f"Found key_scale in the checkpoint (e.g. {name}), " + f"Found k_scale in the checkpoint (e.g. {name}), " "but not found the expected name in the model " - f"(e.g. {remapped_scale_name}). key_scale is " + f"(e.g. {remapped_scale_name}). k_scale is " "not loaded.") return None return remapped_scale_name - elif name.endswith("value_scale"): - remapped_scale_name = name.replace(".value_scale", ".attn.value_scale") + elif name.endswith(".v_scale"): + remapped_scale_name = name.replace(".v_scale", ".attn.v_scale") if remapped_scale_name not in params_dict: print_warning_once( - f"Found value_scale in the checkpoint (e.g. {name}), " + f"Found v_scale in the checkpoint (e.g. {name}), " "but not found the expected name in the model " - f"(e.g. {remapped_scale_name}). value_scale is " + f"(e.g. {remapped_scale_name}). v_scale is " "not loaded.") return None return remapped_scale_name From 7f639067aec85f6c21780b8b7abbf10acbc75927 Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 15 Jul 2024 21:32:05 +0000 Subject: [PATCH 06/11] Format --- vllm/model_executor/layers/quantization/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 3f2485dbe2c1a..1354e5a1c44aa 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -446,7 +446,7 @@ def process_weights_after_loading(self, layer: Module) -> None: if not isinstance(k_scale, float) or not isinstance( v_scale, float): raise ValueError("Only support per-tensor scaling factor " - "for fp8 KV cache") + "for fp8 KV cache") layer._k_scale = k_scale layer._v_scale = v_scale if (layer._k_scale == 1.0 and layer._v_scale == 1.0 From ccc2a805a1a97f8cbe36c768260323e7cb8071e8 Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 15 Jul 2024 21:35:31 +0000 Subject: [PATCH 07/11] Rename --- vllm/attention/layer.py | 14 +++++++------- vllm/model_executor/layers/quantization/fp8.py | 2 ++ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 05fe6a9e8ebc1..0619bda90a2a7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -47,14 +47,14 @@ def __init__( if num_kv_heads is None: num_kv_heads = num_heads - # The default key/value_scale is set to 1.0. This is ignored + # The default k/v_scale is set to 1.0. This is ignored # when kv-cache is not fp8, and should be used with # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we - # expect the pre-quantized key/value_scale to be loaded along + # expect the pre-quantized k/v_scale to be loaded along # with the model weights. self.kv_cache_dtype = kv_cache_dtype - self._key_scale = 1.0 - self._value_scale = 1.0 + self._k_scale = 1.0 + self._v_scale = 1.0 quant_method = quant_config.get_quant_method( self) if quant_config else None if quant_method is not None: @@ -67,7 +67,7 @@ def __init__( "fp8 checkpoints.") # When FP8 quantization is enabled, we make a parameter # "kv_scale" so that it can be loaded from FP8 checkpoint. - # The key/value_scale will then be converted back to + # The k/v_scale will then be converted back to # self._kv_scale in a native float32 value after weight loading self.quant_method = quant_method self.quant_method.create_weights(self) @@ -99,8 +99,8 @@ def forward( value, kv_cache, attn_metadata, - self._key_scale, - self._value_scale, + self._k_scale, + self._v_scale, attn_type=attn_type) def extra_repr(self) -> str: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1354e5a1c44aa..cfef914ed6cf7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -447,6 +447,8 @@ def process_weights_after_loading(self, layer: Module) -> None: v_scale, float): raise ValueError("Only support per-tensor scaling factor " "for fp8 KV cache") + + # These are used in the final Attention.forward() layer._k_scale = k_scale layer._v_scale = v_scale if (layer._k_scale == 1.0 and layer._v_scale == 1.0 From 35f02a6b0e4db6c9cc5504b448dce0f0cc59a444 Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 15 Jul 2024 23:23:24 +0000 Subject: [PATCH 08/11] Poke From d26453e72798a639983130caaa271936df057cb9 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 16 Jul 2024 16:47:08 +0000 Subject: [PATCH 09/11] Review comments to cleanup maybe_remap_kv_scale_name --- .../model_loader/weight_utils.py | 50 +++++++++---------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index bc09f56edb067..3b76c2cdd0ca6 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -466,10 +466,10 @@ def initialize_dummy_weights( def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: - """Remap the name of FP8 kv-scale parameters. + """Remap the name of FP8 k/v_scale parameters. - This function handles the remapping of FP8 kv-scale parameter names. - It checks if the given name ends with "kv_scale" and attempts to remap + This function handles the remapping of FP8 k/v_scale parameter names. + It detects if the given name ends with a suffix and attempts to remap it to the expected name format in the model. If the remapped name is not found in the params_dict, a warning is printed and None is returned. @@ -489,33 +489,29 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: "v_scale tensors and will be removed in a future release. " "Functionally for now, we will remap kv_scale to k_scale and " "duplicate k_scale to v_scale") - remapped_scale_name = name.replace(".kv_scale", ".attn.kv_scale") - if remapped_scale_name not in params_dict: + # NOTE: we remap the deprecated kv_scale to k_scale + remapped_name = name.replace(".kv_scale", ".attn.k_scale") + if remapped_name not in params_dict: print_warning_once( f"Found kv_scale in the checkpoint (e.g. {name}), " "but not found the expected name in the model " - f"(e.g. {remapped_scale_name}). kv_scale is " + f"(e.g. {remapped_name}). kv_scale is " "not loaded.") return None - return remapped_scale_name - elif name.endswith(".k_scale"): - remapped_scale_name = name.replace(".k_scale", ".attn.k_scale") - if remapped_scale_name not in params_dict: - print_warning_once( - f"Found k_scale in the checkpoint (e.g. {name}), " - "but not found the expected name in the model " - f"(e.g. {remapped_scale_name}). k_scale is " - "not loaded.") - return None - return remapped_scale_name - elif name.endswith(".v_scale"): - remapped_scale_name = name.replace(".v_scale", ".attn.v_scale") - if remapped_scale_name not in params_dict: - print_warning_once( - f"Found v_scale in the checkpoint (e.g. {name}), " - "but not found the expected name in the model " - f"(e.g. {remapped_scale_name}). v_scale is " - "not loaded.") - return None - return remapped_scale_name + return remapped_name + + possible_scale_names = [".k_scale", ".v_scale"] + for scale_name in possible_scale_names: + if name.endswith(scale_name): + remapped_name = name.replace(scale_name, f".attn{scale_name}") + if remapped_name not in params_dict: + print_warning_once( + f"Found {scale_name} in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_name}). {scale_name} is " + "not loaded.") + return None + return remapped_name + + # If there were no matches, return the untouched param name return name From 728e524ed24e7f553422b8eef1c921d862ff95e1 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 16 Jul 2024 20:19:25 +0000 Subject: [PATCH 10/11] Add test and fix for loading --- tests/quantization/test_fp8.py | 40 ++++++++++++++++--- vllm/model_executor/layers/linear.py | 9 +++++ .../model_loader/weight_utils.py | 11 ++--- 3 files changed, 47 insertions(+), 13 deletions(-) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 0ed91cbb447fd..82dc775f8d812 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -7,19 +7,49 @@ from tests.quantization.utils import is_quant_method_supported from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod +from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod, + Fp8LinearMethod) MODELS = [ - "neuralmagic/Meta-Llama-3-8B-Instruct-FP8", + "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV", "nm-testing/Phi-3-mini-128k-instruct-FP8", ] @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="FP8 is not supported on this GPU type.") -@pytest.mark.parametrize("model", MODELS) -def test_model_load_and_run(vllm_runner, model: str): - with vllm_runner(model) as llm: +@pytest.mark.parametrize("model_id", MODELS) +def test_model_load_and_run(vllm_runner, model_id: str): + with vllm_runner(model_id) as llm: + # note: this does not test accuracy, just that we can run through + # see lm-eval tests for accuracy + outputs = llm.generate_greedy(prompts=["Hello my name is"], + max_tokens=10) + print(outputs[0][1]) + + +KV_CACHE_MODELS = [ + # Deprecated AutoFP8 format using .kv_scale + "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV", + # AutoFP8 format using separate .k_scale and .v_scale + "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V", +] + + +@pytest.mark.skipif(not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.") +@pytest.mark.parametrize("model_id", KV_CACHE_MODELS) +def test_kv_cache_model_load_and_run(vllm_runner, model_id: str): + with vllm_runner(model_id, kv_cache_dtype="fp8") as llm: + + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + attn = model.model.layers[0].self_attn.attn + assert isinstance(attn.quant_method, Fp8KVCacheMethod) + # NOTE: it is valid for scales to be 1.0 (default value), but we know + # these checkpoints have scales < 1.0 + assert 0.0 < attn._k_scale < 1.0 + assert 0.0 < attn._v_scale < 1.0 + # note: this does not test accuracy, just that we can run through # see lm-eval tests for accuracy outputs = llm.generate_greedy(prompts=["Hello my name is"], diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index bc07d2b831862..684e1abf7bcf7 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -196,6 +196,15 @@ def __init__(self, else: self.register_parameter("bias", None) + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + # If the weight on disk does not have a shape, give it one + # (such scales for AutoFp8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight) + def forward(self, x: torch.Tensor) -> torch.Tensor: bias = self.bias if not self.skip_bias_add else None assert self.quant_method is not None diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 3b76c2cdd0ca6..cb83f43a2a4e2 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -432,11 +432,6 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: """Default weight loader.""" - # If the weight on disk does not have a shape, give it one - # (such scales for AutoFp8). - if len(loaded_weight.shape) == 0: - loaded_weight = loaded_weight.reshape(1) - assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight) @@ -484,11 +479,11 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: """ if name.endswith(".kv_scale"): print_warning_once( - f"DEPRECATED. Found kv_scale in the checkpoint (e.g. {name}). " + "DEPRECATED. Found kv_scale in the checkpoint. " "This format is deprecated in favor of separate k_scale and " "v_scale tensors and will be removed in a future release. " - "Functionally for now, we will remap kv_scale to k_scale and " - "duplicate k_scale to v_scale") + "Functionally, we will remap kv_scale to k_scale and duplicate " + "k_scale to v_scale") # NOTE: we remap the deprecated kv_scale to k_scale remapped_name = name.replace(".kv_scale", ".attn.k_scale") if remapped_name not in params_dict: From 31e15b833dcbdc56546e6d1fd82129c685f01caa Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 16 Jul 2024 20:23:53 +0000 Subject: [PATCH 11/11] Replace key_scale and value_scale everywhere --- .../kernels/benchmark_paged_attention.py | 10 +++--- csrc/attention/attention_kernels.cu | 36 +++++++++---------- csrc/cache.h | 4 +-- csrc/cache_kernels.cu | 14 ++++---- csrc/cpu/attention.cpp | 8 ++--- csrc/cpu/cache.cpp | 6 ++-- csrc/cpu/torch_bindings.cpp | 6 ++-- csrc/ops.h | 4 +-- csrc/torch_bindings.cpp | 6 ++-- tests/kernels/test_attention.py | 10 +++--- tests/kernels/test_blocksparse_attention.py | 10 +++--- tests/kernels/test_cache.py | 4 +-- vllm/_custom_ops.py | 19 +++++----- vllm/_ipex_ops.py | 12 +++---- vllm/attention/backends/abstract.py | 4 +-- vllm/attention/backends/blocksparse_attn.py | 12 +++---- vllm/attention/backends/flash_attn.py | 8 ++--- vllm/attention/backends/flashinfer.py | 8 ++--- vllm/attention/backends/ipex_attn.py | 18 +++++----- vllm/attention/backends/pallas.py | 6 ++-- vllm/attention/backends/rocm_flash_attn.py | 12 +++---- vllm/attention/backends/torch_sdpa.py | 14 ++++---- vllm/attention/backends/xformers.py | 10 +++--- vllm/attention/ops/ipex_attn.py | 8 ++--- vllm/attention/ops/paged_attn.py | 20 +++++------ 25 files changed, 134 insertions(+), 135 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index a093b4e13d271..78cac8a555d1b 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -100,7 +100,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: start_time = time.perf_counter() # Using default kv_scale - key_scale = value_scale = 1.0 + k_scale = v_scale = 1.0 for _ in range(num_iters): if version == "v1": @@ -117,8 +117,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: max_seq_len, alibi_slopes, kv_cache_dtype, - key_scale, - value_scale, + k_scale, + v_scale, ) elif version == "v2": ops.paged_attention_v2( @@ -137,8 +137,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: max_seq_len, alibi_slopes, kv_cache_dtype, - key_scale, - value_scale, + k_scale, + v_scale, ) else: raise ValueError(f"Invalid version: {version}") diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 677bac38ab0ef..350dbce1d7ba9 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -105,7 +105,7 @@ __device__ void paged_attention_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float key_scale, const float value_scale, const int tp_rank, + const float k_scale, const float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int seq_idx = blockIdx.y; @@ -285,7 +285,7 @@ __device__ void paged_attention_kernel( Quant_vec k_vec_quant = *reinterpret_cast( k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_vecs[j] = fp8::scaled_convert( - k_vec_quant, key_scale); + k_vec_quant, k_scale); } } @@ -414,8 +414,8 @@ __device__ void paged_attention_kernel( V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. - v_vec = fp8::scaled_convert( - v_quant_vec, value_scale); + v_vec = fp8::scaled_convert(v_quant_vec, + v_scale); } if (block_idx == num_seq_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the @@ -513,7 +513,7 @@ __global__ void paged_attention_v1_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float key_scale, const float value_scale, const int tp_rank, + const float k_scale, const float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, - kv_block_stride, kv_head_stride, key_scale, value_scale, tp_rank, + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step); } @@ -682,7 +682,7 @@ __global__ void paged_attention_v2_reduce_kernel( out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ - key_scale, value_scale, tp_rank, blocksparse_local_blocks, \ + k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_head_sliding_step); @@ -694,8 +694,8 @@ void paged_attention_v1_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes, float key_scale, - float value_scale, const int tp_rank, const int blocksparse_local_blocks, + const c10::optional& alibi_slopes, float k_scale, + float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); @@ -770,7 +770,7 @@ void paged_attention_v1_launcher( paged_attention_v1_launcher( \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - seq_lens, max_seq_len, alibi_slopes, key_scale, value_scale, tp_rank, \ + seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); @@ -815,7 +815,7 @@ void paged_attention_v1( torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double key_scale, double value_scale, + const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { @@ -833,7 +833,7 @@ void paged_attention_v1( exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, key_scale, value_scale, tp_rank, \ + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); \ vllm::paged_attention_v2_reduce_kernel& alibi_slopes, float key_scale, - float value_scale, const int tp_rank, const int blocksparse_local_blocks, + const c10::optional& alibi_slopes, float k_scale, + float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); @@ -932,7 +932,7 @@ void paged_attention_v2_launcher( IS_BLOCK_SPARSE>( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ - key_scale, value_scale, tp_rank, blocksparse_local_blocks, \ + k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_head_sliding_step); @@ -981,7 +981,7 @@ void paged_attention_v2( torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double key_scale, double value_scale, + const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { diff --git a/csrc/cache.h b/csrc/cache.h index 183b12e3d2543..52177e8901a89 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -18,8 +18,8 @@ void copy_blocks(std::vector const& key_caches, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, - const double key_scale, const double value_scale); + const std::string& kv_cache_dtype, const double k_scale, + const double v_scale); void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 39a2f971fa90f..caef7f5e18630 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel( // block_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] const int key_stride, const int value_stride, const int num_heads, - const int head_size, const int block_size, const int x, - const float key_scale, const float value_scale) { + const int head_size, const int block_size, const int x, const float k_scale, + const float v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { @@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel( value_cache[tgt_value_idx] = tgt_value; } else { key_cache[tgt_key_idx] = - fp8::scaled_convert(tgt_key, key_scale); + fp8::scaled_convert(tgt_key, k_scale); value_cache[tgt_value_idx] = - fp8::scaled_convert(tgt_value, value_scale); + fp8::scaled_convert(tgt_value, v_scale); } } } @@ -248,7 +248,7 @@ __global__ void reshape_and_cache_flash_kernel( reinterpret_cast(key_cache.data_ptr()), \ reinterpret_cast(value_cache.data_ptr()), \ slot_mapping.data_ptr(), key_stride, value_stride, \ - num_heads, head_size, block_size, x, key_scale, value_scale); + num_heads, head_size, block_size, x, k_scale, v_scale); void reshape_and_cache( torch::Tensor& key, // [num_tokens, num_heads, head_size] @@ -258,8 +258,8 @@ void reshape_and_cache( torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype, const double key_scale, - const double value_scale) { + const std::string& kv_cache_dtype, const double k_scale, + const double v_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index 8cc159bda597a..abb4e3bea14bb 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -423,11 +423,11 @@ void paged_attention_v1( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double key_scale, double value_scale, + const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(key_scale == 1.0f && value_scale == 1.0f); + TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", @@ -742,11 +742,11 @@ void paged_attention_v2( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double key_scale, double value_scale, + const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(key_scale == 1.0f && value_scale == 1.0f); + TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 058d480cb10ad..31d454328b2c1 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -107,9 +107,9 @@ void copy_blocks(std::vector const& key_caches, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, double key_scale, - double value_scale) { - TORCH_CHECK(key_scale == 1.0f && value_scale == 1.0f); + const std::string& kv_cache_dtype, double k_scale, + double v_scale) { + TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); int num_tokens = key.size(0); int num_heads = key.size(1); diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 06d9e8bd1106a..5be0e9810b5b9 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -16,7 +16,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float key_scale, float value_scale," + " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -30,7 +30,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float key_scale, float value_scale," + " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -103,7 +103,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! key_cache, Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float key_scale, float value_scale) -> ()"); + " float k_scale, float v_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); } diff --git a/csrc/ops.h b/csrc/ops.h index 44b92e6226877..f9feb3deff5e4 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -8,7 +8,7 @@ void paged_attention_v1( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double key_scale, double value_scale, + const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); @@ -19,7 +19,7 @@ void paged_attention_v2( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double key_scale, double value_scale, + const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 868e76bb38d38..9dc7cefc404ca 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -27,7 +27,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float key_scale, float value_scale," + " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -41,7 +41,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float key_scale, float value_scale," + " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -223,7 +223,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! key_cache, Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float key_scale, float value_scale) -> ()"); + " float k_scale, float v_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); // Reshape the key and value tensors and cache them. diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index ef2d469e0dc59..2e6412c28958e 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -175,7 +175,7 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale - key_scale = value_scale = 1.0 + k_scale = v_scale = 1.0 # Call the paged attention kernel. output = torch.empty_like(query) @@ -193,8 +193,8 @@ def test_paged_attention( max_seq_len, alibi_slopes, kv_cache_dtype, - key_scale, - value_scale, + k_scale, + v_scale, ) elif version == "v2": num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) @@ -225,8 +225,8 @@ def test_paged_attention( max_seq_len, alibi_slopes, kv_cache_dtype, - key_scale, - value_scale, + k_scale, + v_scale, ) else: raise AssertionError(f"Unknown version: {version}") diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index 57b349ea1fb26..b3adb152949a2 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -212,7 +212,7 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale - key_scale = value_scale = 1.0 + k_scale = v_scale = 1.0 tp_rank = 0 # Call the paged attention kernel. @@ -231,8 +231,8 @@ def test_paged_attention( max_seq_len, alibi_slopes, kv_cache_dtype, - key_scale, - value_scale, + k_scale, + v_scale, tp_rank=tp_rank, blocksparse_local_blocks=blocksparse_local_blocks, blocksparse_vert_stride=blocksparse_vert_stride, @@ -268,8 +268,8 @@ def test_paged_attention( max_seq_len, alibi_slopes, kv_cache_dtype, - key_scale, - value_scale, + k_scale, + v_scale, tp_rank=tp_rank, blocksparse_local_blocks=blocksparse_local_blocks, blocksparse_vert_stride=blocksparse_vert_stride, diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 09b918f1105fc..70ae3d0c6e0c3 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -155,11 +155,11 @@ def test_reshape_and_cache( cloned_value_cache = value_cache.clone() # Using default kv_scale - key_scale = value_scale = 1.0 + k_scale = v_scale = 1.0 # Call the reshape_and_cache kernel. ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, - kv_cache_dtype, key_scale, value_scale) + kv_cache_dtype, k_scale, v_scale) if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d120e8d1f96a0..4ca67224a91b8 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -84,8 +84,8 @@ def paged_attention_v1( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - key_scale: float, - value_scale: float, + k_scale: float, + v_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -95,7 +95,7 @@ def paged_attention_v1( torch.ops._C.paged_attention_v1( out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, - key_scale, value_scale, tp_rank, blocksparse_local_blocks, + k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step) @@ -116,8 +116,8 @@ def paged_attention_v2( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - key_scale: float, - value_scale: float, + k_scale: float, + v_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -127,7 +127,7 @@ def paged_attention_v2( torch.ops._C.paged_attention_v2( out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, - alibi_slopes, kv_cache_dtype, key_scale, value_scale, tp_rank, + alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step) @@ -377,13 +377,12 @@ def reshape_and_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - key_scale: float, - value_scale: float, + k_scale: float, + v_scale: float, ) -> None: torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, - kv_cache_dtype, key_scale, - value_scale) + kv_cache_dtype, k_scale, v_scale) def reshape_and_cache_flash( diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index 0105d040b3613..b4721b4e1aedd 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -59,8 +59,8 @@ def paged_attention_v1( max_context_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - key_scale: float, - value_scale: float, + k_scale: float, + v_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -100,8 +100,8 @@ def paged_attention_v2( max_context_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - key_scale: float, - value_scale: float, + k_scale: float, + v_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -229,8 +229,8 @@ def reshape_and_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - key_scale: float, - value_scale: float, + k_scale: float, + v_scale: float, ) -> None: assert kv_cache_dtype == "auto" ipex.llm.modules.PagedAttention.reshape_and_cache( diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 5626761e59049..1310bb1679e15 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -134,8 +134,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: T, - key_scale: float = 1.0, - value_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index e5825a98f6914..6308cf07ce41e 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -327,8 +327,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: BlocksparseFlashAttentionMetadata, - key_scale: float = 1.0, - value_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -369,8 +369,8 @@ def forward( value_cache, attn_metadata.slot_mapping, self.kv_cache_dtype, - key_scale, - value_scale, + k_scale, + v_scale, ) if prefill_meta := attn_metadata.prefill_metadata: @@ -407,8 +407,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - key_scale, - value_scale, + k_scale, + v_scale, tp_rank=self.tp_rank, blocksparse_local_blocks=self.local_blocks, blocksparse_vert_stride=self.vert_stride, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 6bf8a5b220152..0b6bd21279393 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -256,8 +256,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, - key_scale: float = 1.0, - value_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -278,8 +278,8 @@ def forward( "FlashAttentionImpl") # NOTE(woosuk): FlashAttention does not support FP8 KV cache. - assert key_scale == 1.0 and value_scale == 1.0, ( - "key/value_scale is not supported in FlashAttention.") + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in FlashAttention.") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 1d2a3ed2ae7a2..a4b01c6d3b508 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -223,12 +223,12 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: FlashInferMetadata, - key_scale: float = 1.0, - value_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: - assert key_scale == 1.0 and value_scale == 1.0, ( - "key/value_scale is not supported in FlashInfer.") + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in FlashInfer.") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 3e00b597c45cf..4559dd15f600c 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -156,8 +156,8 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: IpexAttnMetadata, # type: ignore - key_scale: float = 1.0, - value_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. @@ -171,7 +171,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - assert key_scale == 1.0 and value_scale == 1.0 + assert k_scale == 1.0 and v_scale == 1.0 if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " @@ -193,8 +193,8 @@ def forward( value_cache, attn_metadata.slot_mapping.flatten(), self.kv_cache_dtype, - key_scale, - value_scale, + k_scale, + v_scale, ) if attn_metadata.is_prompt: @@ -275,8 +275,8 @@ def forward( max_seq_len, self.alibi_slopes, self.kv_cache_dtype, - key_scale, - value_scale, + k_scale, + v_scale, ) else: # Run PagedAttention V2. @@ -308,8 +308,8 @@ def forward( max_seq_len, self.alibi_slopes, self.kv_cache_dtype, - key_scale, - value_scale, + k_scale, + v_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 34ffcffcaeeee..b83a83bb177d4 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -131,8 +131,8 @@ def forward( value: torch.Tensor, kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]], attn_metadata: PallasMetadata, - key_scale: float = 1.0, - value_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with Pallas attention. @@ -147,7 +147,7 @@ def forward( Returns: shape = [batch_size, seq_len, num_heads * head_size] """ - assert key_scale == 1.0 and value_scale == 1.0 + assert k_scale == 1.0 and v_scale == 1.0 if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 839b3fcc20d05..f6ecea30da492 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -296,8 +296,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: ROCmFlashAttentionMetadata, - key_scale: float = 1.0, - value_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -337,8 +337,8 @@ def forward( value_cache, attn_metadata.slot_mapping, self.kv_cache_dtype, - key_scale, - value_scale, + k_scale, + v_scale, ) num_prefill_tokens = attn_metadata.num_prefill_tokens @@ -458,8 +458,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - key_scale, - value_scale, + k_scale, + v_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 8a8266034fa19..fe6a56123ce72 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -144,8 +144,8 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: TorchSDPAMetadata, # type: ignore - key_scale: float = 1.0, - value_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -159,7 +159,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - assert key_scale == 1.0 and value_scale == 1.0 + assert k_scale == 1.0 and v_scale == 1.0 if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " @@ -177,8 +177,8 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - self.kv_cache_dtype, key_scale, - value_scale) + self.kv_cache_dtype, k_scale, + v_scale) if attn_metadata.is_prompt: assert attn_metadata.seq_lens is not None @@ -241,8 +241,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - key_scale, - value_scale, + k_scale, + v_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 0048ef45eb049..3dd60ed5be528 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -427,8 +427,8 @@ def forward( value: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor], attn_metadata: "XFormersMetadata", - key_scale: float = 1.0, - value_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -532,7 +532,7 @@ def forward( value_cache, updated_slot_mapping, self.kv_cache_dtype, - key_scale, value_scale) + k_scale, v_scale) if attn_type != AttentionType.ENCODER: # Decoder self-attention supports chunked prefill. @@ -621,8 +621,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - key_scale, - value_scale, + k_scale, + v_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index 17962d0fd1f5c..81d308c4d4e22 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -45,8 +45,8 @@ def write_to_paged_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - key_scale: float, - value_scale: float, + k_scale: float, + v_scale: float, *args, ) -> None: ipex_modules.PagedAttention.reshape_and_cache( @@ -65,8 +65,8 @@ def forward_decode( num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], - key_scale: float, - value_scale: float, + k_scale: float, + v_scale: float, *args, ) -> torch.Tensor: output = torch.empty_like(query) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 576428a926b68..ce7b4d129779c 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -66,8 +66,8 @@ def write_to_paged_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - key_scale: float, - value_scale: float, + k_scale: float, + v_scale: float, ) -> None: ops.reshape_and_cache( key, @@ -76,8 +76,8 @@ def write_to_paged_cache( value_cache, slot_mapping.flatten(), kv_cache_dtype, - key_scale, - value_scale, + k_scale, + v_scale, ) @staticmethod @@ -92,8 +92,8 @@ def forward_decode( num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], - key_scale: float, - value_scale: float, + k_scale: float, + v_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -138,8 +138,8 @@ def forward_decode( max_seq_len, alibi_slopes, kv_cache_dtype, - key_scale, - value_scale, + k_scale, + v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, @@ -176,8 +176,8 @@ def forward_decode( max_seq_len, alibi_slopes, kv_cache_dtype, - key_scale, - value_scale, + k_scale, + v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,