Skip to content

Commit

Permalink
Add fp8 support to reshape_and_cache_flash (vllm-project#6667)
Browse files Browse the repository at this point in the history
Signed-off-by: Alvant <alvasian@yandex.ru>
  • Loading branch information
Yard1 authored and Alvant committed Oct 26, 2024
1 parent ac235a6 commit 3b8c08f
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 43 deletions.
3 changes: 2 additions & 1 deletion csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);
const std::string& kv_cache_dtype,
const double k_scale, const double v_scale);

// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
Expand Down
75 changes: 45 additions & 30 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -203,17 +203,18 @@ __global__ void reshape_and_cache_kernel(
}
}

template <typename scalar_t>
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads,
cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads,
// head_size]
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads,
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
// head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, const int key_stride, const int value_stride,
const int num_heads, const int head_size, const int block_size) {
const int num_heads, const int head_size, const int block_size,
const float k_scale, const float v_scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
Expand All @@ -228,11 +229,20 @@ __global__ void reshape_and_cache_flash_kernel(
const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int64_t tgt_value_idx = block_idx * block_stride +
block_offset * num_heads * head_size +
head_idx * head_size + head_offset;
k_cache[tgt_value_idx] = key[src_key_idx];
v_cache[tgt_value_idx] = value[src_value_idx];
const int64_t tgt_key_value_idx = block_idx * block_stride +
block_offset * num_heads * head_size +
head_idx * head_size + head_offset;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_value_idx] = tgt_key;
value_cache[tgt_key_value_idx] = tgt_value;
} else {
key_cache[tgt_key_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
value_cache[tgt_key_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
}
}
}
} // namespace vllm
Expand Down Expand Up @@ -278,40 +288,45 @@ void reshape_and_cache(
CALL_RESHAPE_AND_CACHE)
}

// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
value_stride, num_heads, head_size, block_size, k_scale, v_scale);

void reshape_and_cache_flash(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor&
value_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype) {
// FIXME: only support auto datatype, does not support fp8
if (kv_cache_dtype != "auto") {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
const std::string& kv_cache_dtype, const double k_scale,
const double v_scale) {
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = k_cache.size(1);
int block_size = key_cache.size(1);

int key_stride = key.stride(0);
int value_stride = value.stride(0);
int block_stride = k_cache.stride(0);
TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));
int block_stride = key_cache.stride(0);
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(), "reshape_and_cache_flash", [&] {
vllm::reshape_and_cache_flash_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
value_stride, num_heads, head_size, block_size);
});

DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE_FLASH);
}

namespace vllm {
Expand Down
3 changes: 2 additions & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor! key_cache,"
" Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype) -> ()");
" str kv_cache_dtype,"
" float k_scale, float v_scale) -> ()");
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
&reshape_and_cache_flash);

Expand Down
42 changes: 34 additions & 8 deletions tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,6 @@ def test_reshape_and_cache_flash(
device: str,
kv_cache_dtype: str,
) -> None:
if kv_cache_dtype == "fp8":
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
Expand Down Expand Up @@ -248,15 +246,33 @@ def test_reshape_and_cache_flash(
dtype,
device=device,
)
key_cache, value_cache = key_caches[0], value_caches[0]
key_cache, value_cache = key_caches[0].contiguous(
), value_caches[0].contiguous()
del key_caches
del value_caches

# Clone the KV caches.
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(cloned_key_cache, key_cache)
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache)
else:
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()

# Using default kv_scale
k_scale = v_scale = 1.0

# Call the reshape_and_cache kernel.
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype)
slot_mapping, kv_cache_dtype, k_scale, v_scale)

if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(result_key_cache, key_cache)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(result_value_cache, value_cache)

# Run the reference implementation.
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
Expand All @@ -269,8 +285,18 @@ def test_reshape_and_cache_flash(
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
cloned_value_cache[block_idx, block_offset, :, :] = value[i]

assert torch.allclose(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache)
if kv_cache_dtype == "fp8":
assert torch.allclose(result_key_cache,
cloned_key_cache,
atol=0.001,
rtol=0.1)
assert torch.allclose(result_value_cache,
cloned_value_cache,
atol=0.001,
rtol=0.1)
else:
assert torch.allclose(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache)


@pytest.mark.parametrize("direction", COPYING_DIRECTION)
Expand Down
5 changes: 4 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,10 +426,13 @@ def reshape_and_cache_flash(
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
) -> None:
torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
value_cache, slot_mapping,
kv_cache_dtype)
kv_cache_dtype, k_scale,
v_scale)


def copy_blocks(key_caches: List[torch.Tensor],
Expand Down
2 changes: 2 additions & 0 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,8 @@ def forward(
value_cache,
attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype,
k_scale,
v_scale,
)

num_prefill_tokens = attn_metadata.num_prefill_tokens
Expand Down
2 changes: 2 additions & 0 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,8 @@ def forward(
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype,
k_scale,
v_scale,
)

query = query.contiguous(
Expand Down
9 changes: 7 additions & 2 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,6 @@ def create_kv_caches_with_random_flash(
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
assert cache_dtype != "fp8"
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
Expand All @@ -507,7 +506,13 @@ def create_kv_caches_with_random_flash(
key_value_cache = torch.empty(size=key_value_cache_shape,
dtype=torch_dtype,
device=device)
key_value_cache.uniform_(-scale, scale)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
key_value_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8':
_generate_random_fp8(key_value_cache, -scale, scale)
else:
raise ValueError(
f"Does not support key cache of type {cache_dtype}")
key_caches.append(key_value_cache[:, 0])
value_caches.append(key_value_cache[:, 1])
return key_caches, value_caches
Expand Down

0 comments on commit 3b8c08f

Please sign in to comment.