Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Separate Q and KV dtypes for decode #286

Merged
merged 10 commits into from
Jun 13, 2024

Conversation

Yard1
Copy link
Contributor

@Yard1 Yard1 commented Jun 5, 2024

Closes #285

Modified unit tests pass. May need some extra validation.

@Yard1
Copy link
Contributor Author

Yard1 commented Jun 5, 2024

@yzh119 Please let me know if this is on the right track! I couldn't see anything directly related to the dtype of the query in the kernels, so my assumption is this should "just work", but I don't know if this will not affect eg. q_vec loading. I am compiling it to test it right now.

@yzh119
Copy link
Collaborator

yzh119 commented Jun 5, 2024

Yes I do think you are on the right track, thank you!

but I don't know if this will not affect eg. q_vec loading.

I don't think so.

@Yard1 Yard1 marked this pull request as ready for review June 11, 2024 21:31
@Yard1
Copy link
Contributor Author

Yard1 commented Jun 11, 2024

@yzh119 The modified unit test passes for me, can you review and validate?

@Yard1 Yard1 changed the title [WIP] Separate Q and KV dtypes for decode Separate Q and KV dtypes for decode Jun 11, 2024
@Yard1 Yard1 changed the title Separate Q and KV dtypes for decode feat: Separate Q and KV dtypes for decode Jun 11, 2024
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Yard1 , thanks so much for doing this and it look good to me in general.

I beg some other changes, mainly around BeginForward functions because it seems you assume we are using the same data type for q and kv and it might affect some resource estimation.

I left some suggested changes, besides them, you also need to separate qtype and kvtype in this function (pass the qtype also as an empty tensor):

def begin_forward(
self,
indptr: torch.Tensor,
indices: torch.Tensor,
last_page_len: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
page_size: int,
pos_encoding_mode: str = "NONE",
data_type: Union[str, torch.dtype] = "float16",
):
r"""Create auxiliary data structures for batch decode for multiple forward calls
within the same decode step.
Parameters
----------
indptr : torch.Tensor
The indptr of the paged kv cache, shape: ``[batch_size + 1]``
indices : torch.Tensor
The page indices of the paged kv cache, shape: ``[qo_indptr[-1]]``
last_page_len : torch.Tensor
The number of entries in the last page of each request in the paged kv
cache, shape: ``[batch_size]``
num_qo_heads : int
The number of query/output heads
num_kv_heads : int
The number of key/value heads
head_dim : int
The dimension of the heads
page_size : int
The page size of the paged kv cache
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
data_type : Union[str, torch.dtype]
The data type of the paged kv cache
Note
----
The :meth:`begin_forward` method should be called before any :meth:`forward` or
:meth:`forward_return_lse` calls, auxiliary data structures will be created
during this call and cached for multiple forward calls.
The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads``
is not equal to ``num_kv_heads``, the function will use
`grouped query attention <https://arxiv.org/abs/2305.13245>`_.
"""
batch_size = len(last_page_len)
if self.is_cuda_graph_enabled:
if batch_size != self._fixed_batch_size:
raise ValueError(
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
" mismatches the batch size set during initialization {}".format(
batch_size, self._fixed_batch_size
)
)
if len(indices) > len(self._paged_kv_indices_buf):
raise ValueError(
"The size of indices should be less than or equal to the allocated buffer"
)
self._paged_kv_indptr_buf.copy_(indptr)
self._paged_kv_indices_buf[: len(indices)] = indices
self._paged_kv_last_page_len_buf.copy_(last_page_len)
else:
self._paged_kv_indptr_buf = indptr
self._paged_kv_indices_buf = indices
self._paged_kv_last_page_len_buf = last_page_len
# NOTE(Zihao): the following tensor acts as placeholder to pass dtype info
empty_data = torch.empty(
0,
dtype=(
getattr(torch, data_type) if isinstance(data_type, str) else data_type
),
)
self._wrapper.begin_forward(
self._workspace_buffer,
indptr,
last_page_len,
batch_size,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
PosEncodingMode[pos_encoding_mode].value,
empty_data,
)

and update

void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr,
torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size,
unsigned int pos_encoding_mode, torch::Tensor empty_data);

void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
torch::Tensor workspace_buffer, torch::Tensor indptr, torch::Tensor last_page_len,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int head_dim, unsigned int page_size, unsigned int pos_encoding_mode,
torch::Tensor empty_data) {
// NOTE(zihao): not necessary to be CUDA tensor
CHECK_CONTIGUOUS(indptr);
CHECK_CONTIGUOUS(last_page_len);
CHECK_CONTIGUOUS(workspace_buffer);
CHECK_DIM(1, indptr);
CHECK_DIM(1, last_page_len);
CHECK_DIM(1, workspace_buffer);
CHECK_EQ(indptr.scalar_type(), torch::kInt32);
CHECK_EQ(indptr.scalar_type(), torch::kInt32);
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
handler_->SetCUDAStream(torch_current_stream);
if (is_float8_tensor(empty_data)) {
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(empty_data.scalar_type(), c_type, [&] {
return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] {
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] {
return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] {
return DISPATCH_pos_encoding_mode(
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
cudaError_t status =
handler_->BeginForwardDispatched<GROUP_SIZE, HEAD_DIM, PageStorage::kIndices,
KV_LAYOUT, POS_ENCODING_MODE, c_type,
nv_half, int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
page_size);
TORCH_CHECK(status == cudaSuccess,
"BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
});
});
});
});
});
} else {
DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] {
return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] {
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] {
return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] {
return DISPATCH_pos_encoding_mode(
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
cudaError_t status =
handler_->BeginForwardDispatched<GROUP_SIZE, HEAD_DIM, PageStorage::kIndices,
KV_LAYOUT, POS_ENCODING_MODE, c_type, c_type,
int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
page_size);
TORCH_CHECK(status == cudaSuccess,
"BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
});
});
});
});
});
}
}

accordingly.

include/flashinfer/attention/handler.cuh Outdated Show resolved Hide resolved
include/flashinfer/attention/handler.cuh Outdated Show resolved Hide resolved
include/flashinfer/attention/handler.cuh Outdated Show resolved Hide resolved
include/flashinfer/attention/handler.cuh Outdated Show resolved Hide resolved
src/flashinfer_ops.cuh Outdated Show resolved Hide resolved
src/flashinfer_ops.cuh Outdated Show resolved Hide resolved
@Yard1
Copy link
Contributor Author

Yard1 commented Jun 12, 2024

@yzh119 correct, I wanted to avoid having to modify the public API. I don't think the information about the query dtype will be used in resource estimation, but please correct me if that's not the case - happy to do the change then

@yzh119
Copy link
Collaborator

yzh119 commented Jun 12, 2024

Hi @Yard1 , I'm a little bit conservative here because this section of code

auto partition_kv_kernel = BatchDecodeWithPagedKVCacheKernel<
/*partition_kv=*/true, POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx, vec_size, bdx,
bdy, bdz, page_storage, kv_layout, DTypeIn, DTypeOut, IdType>;
int num_blocks_per_sm = 0;
int num_sm = 0;
int dev_id = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm, partition_kv_kernel, num_threads, smem_size));

might produce different num_blocks_per_sm because of the difference of qtype in the kernel.

@Yard1
Copy link
Contributor Author

Yard1 commented Jun 12, 2024

Ok sounds good! Let me make the change.

@Yard1 Yard1 requested a review from yzh119 June 13, 2024 23:10
@Yard1
Copy link
Contributor Author

Yard1 commented Jun 13, 2024

@yzh119 Updated, ptal!

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, Thank you @Yard1 !

@yzh119 yzh119 merged commit 5602659 into flashinfer-ai:main Jun 13, 2024
@Yard1 Yard1 deleted the separate_q_kv_dtype_decode branch June 13, 2024 23:51
yzh119 added a commit that referenced this pull request Jun 20, 2024
🤖 I have created a release *beep* *boop*
---


##
[0.1.0](v0.0.4...v0.1.0)
(2024-06-20)

### Highlights

* Support any GQA group size support for tensor-cores kernels.
* Support any page size support for tensor-cores kernels.
* Support CUDA-Graph for prefill/decode APIs.
* Add an option to accelerate decode kernels with Tensor Cores.
* Support custom attention mask.
(https://docs.flashinfer.ai/tutorials/kv_layout.html#mask-layout-2d-ragged-tensor)
* Support logits cap in Grok-1 models.
* Fused GPU-sampling kernels: top-p, top-k, speculative verification.
(https://docs.flashinfer.ai/api/python/sampling.html)
* PyTorch wrapper of group-gemm cutlass kernels.
(https://docs.flashinfer.ai/api/python/sampling.html)

### Acknowledgement

We thank [@ibsidorenko](https://github.com/ibsidorenko),
[@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU),
[@Yard1](https://github.com/Yard1)
[@AgrawalAmey](https://github.com/AgrawalAmey),
[@xuzhenqi](https://github.com/xuzhenqi),
[@mgerstgrasser](https://github.com/mgerstgrasser),
[@esmeetu](https://github.com/esmeetu),
[@yz-tang](https://github.com/yz-tang),
[@HSQ79815](https://github.com/HSQ79815),
[@Qubitium](https://github.com/Qubitium),
[@shreygupta2809](https://github.com/shreygupta2809),
[@sighingnow](https://github.com/sighingnow),
[@vinx13](https://github.com/vinx13),
[@tqchen](https://github.com/tqchen),
[@merrymercy](https://github.com/merrymercy),
[@comaniac](https://github.com/comaniac) and many others for their
contributions and helpful discussions for 0.0.5 release.

### Refactor

* support any GQA group size for tensor-cores kernels
([#301](#301))
([c111ca](c111ca6))
* support any page size for tensor-cores kernels
([#306](#306))
([82fd8c](82fd8c7))


### Features

* add `use_tensor_cores` option to decode kernels to accelerate GQA
([#317](#317))
([3b50dd5](3b50dd5))
* add group gemm operators
([#282](#282))
([e08ba42](e08ba42))
* initial support of distributed operators
([#289](#289))
([03553da](03553da))
* initial support of logits hook
([#298](#298))
([ab1e2ad](ab1e2ad))
* Separate Q and KV dtypes for decode
([#286](#286))
([5602659](5602659))
* support cuda graph for batched multi-query(prefill/append) attention
([#275](#275))
([83ceb67](83ceb67))
* support cuda graph for batched multi-query(prefill/append) attention
([#277](#277))
([24cc583](24cc583))
* support custom attention mask in prefill/append attention kernels
([#266](#266))
([7304282](7304282))
* fused speculative sampilng kernels
([#259](#259))
([cea2bb](cea2bb9))
* expose sampling APIs in pytorch
([#238](#238))
([092902](0929023))


### Performance Improvements

* initial cuda graph support
([#256](#256))
([7e9cc7f](7e9cc7f))
* split kv-cache for prefill/append kernels
([#310](#310))
([f0bb0a3](f0bb0a3))
* use packed bit array for attention mask
([#308](#308))
([3d43dc9](3d43dc9))

---
This PR was generated with [Release
Please](https://github.com/googleapis/release-please). See
[documentation](https://github.com/googleapis/release-please#release-please).

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Zihao Ye <expye@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Q&A] Any palns for different dtypes for Q (query) and KV (kv-cache)?
2 participants