-
Notifications
You must be signed in to change notification settings - Fork 163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Separate Q and KV dtypes for decode #286
Conversation
@yzh119 Please let me know if this is on the right track! I couldn't see anything directly related to the dtype of the query in the kernels, so my assumption is this should "just work", but I don't know if this will not affect eg. |
Yes I do think you are on the right track, thank you!
I don't think so. |
@yzh119 The modified unit test passes for me, can you review and validate? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Yard1 , thanks so much for doing this and it look good to me in general.
I beg some other changes, mainly around BeginForward
functions because it seems you assume we are using the same data type for q and kv and it might affect some resource estimation.
I left some suggested changes, besides them, you also need to separate qtype and kvtype in this function (pass the qtype also as an empty tensor):
flashinfer/python/flashinfer/decode.py
Lines 532 to 620 in 1250b68
def begin_forward( | |
self, | |
indptr: torch.Tensor, | |
indices: torch.Tensor, | |
last_page_len: torch.Tensor, | |
num_qo_heads: int, | |
num_kv_heads: int, | |
head_dim: int, | |
page_size: int, | |
pos_encoding_mode: str = "NONE", | |
data_type: Union[str, torch.dtype] = "float16", | |
): | |
r"""Create auxiliary data structures for batch decode for multiple forward calls | |
within the same decode step. | |
Parameters | |
---------- | |
indptr : torch.Tensor | |
The indptr of the paged kv cache, shape: ``[batch_size + 1]`` | |
indices : torch.Tensor | |
The page indices of the paged kv cache, shape: ``[qo_indptr[-1]]`` | |
last_page_len : torch.Tensor | |
The number of entries in the last page of each request in the paged kv | |
cache, shape: ``[batch_size]`` | |
num_qo_heads : int | |
The number of query/output heads | |
num_kv_heads : int | |
The number of key/value heads | |
head_dim : int | |
The dimension of the heads | |
page_size : int | |
The page size of the paged kv cache | |
pos_encoding_mode : str | |
Whether to apply RoPE on-the-fly inside attention kernels, could be | |
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. | |
data_type : Union[str, torch.dtype] | |
The data type of the paged kv cache | |
Note | |
---- | |
The :meth:`begin_forward` method should be called before any :meth:`forward` or | |
:meth:`forward_return_lse` calls, auxiliary data structures will be created | |
during this call and cached for multiple forward calls. | |
The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` | |
is not equal to ``num_kv_heads``, the function will use | |
`grouped query attention <https://arxiv.org/abs/2305.13245>`_. | |
""" | |
batch_size = len(last_page_len) | |
if self.is_cuda_graph_enabled: | |
if batch_size != self._fixed_batch_size: | |
raise ValueError( | |
"The batch size should be fixed in cudagraph mode, the runtime batch size {} " | |
" mismatches the batch size set during initialization {}".format( | |
batch_size, self._fixed_batch_size | |
) | |
) | |
if len(indices) > len(self._paged_kv_indices_buf): | |
raise ValueError( | |
"The size of indices should be less than or equal to the allocated buffer" | |
) | |
self._paged_kv_indptr_buf.copy_(indptr) | |
self._paged_kv_indices_buf[: len(indices)] = indices | |
self._paged_kv_last_page_len_buf.copy_(last_page_len) | |
else: | |
self._paged_kv_indptr_buf = indptr | |
self._paged_kv_indices_buf = indices | |
self._paged_kv_last_page_len_buf = last_page_len | |
# NOTE(Zihao): the following tensor acts as placeholder to pass dtype info | |
empty_data = torch.empty( | |
0, | |
dtype=( | |
getattr(torch, data_type) if isinstance(data_type, str) else data_type | |
), | |
) | |
self._wrapper.begin_forward( | |
self._workspace_buffer, | |
indptr, | |
last_page_len, | |
batch_size, | |
num_qo_heads, | |
num_kv_heads, | |
head_dim, | |
page_size, | |
PosEncodingMode[pos_encoding_mode].value, | |
empty_data, | |
) |
and update
flashinfer/python/csrc/flashinfer_ops.h
Lines 77 to 80 in 1250b68
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr, | |
torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads, | |
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size, | |
unsigned int pos_encoding_mode, torch::Tensor empty_data); |
flashinfer/python/csrc/batch_decode.cu
Lines 120 to 188 in 1250b68
void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( | |
torch::Tensor workspace_buffer, torch::Tensor indptr, torch::Tensor last_page_len, | |
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, | |
unsigned int head_dim, unsigned int page_size, unsigned int pos_encoding_mode, | |
torch::Tensor empty_data) { | |
// NOTE(zihao): not necessary to be CUDA tensor | |
CHECK_CONTIGUOUS(indptr); | |
CHECK_CONTIGUOUS(last_page_len); | |
CHECK_CONTIGUOUS(workspace_buffer); | |
CHECK_DIM(1, indptr); | |
CHECK_DIM(1, last_page_len); | |
CHECK_DIM(1, workspace_buffer); | |
CHECK_EQ(indptr.scalar_type(), torch::kInt32); | |
CHECK_EQ(indptr.scalar_type(), torch::kInt32); | |
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); | |
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); | |
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); | |
handler_->SetCUDAStream(torch_current_stream); | |
if (is_float8_tensor(empty_data)) { | |
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(empty_data.scalar_type(), c_type, [&] { | |
return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { | |
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { | |
return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { | |
return DISPATCH_pos_encoding_mode( | |
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { | |
cudaError_t status = | |
handler_->BeginForwardDispatched<GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, | |
KV_LAYOUT, POS_ENCODING_MODE, c_type, | |
nv_half, int32_t>( | |
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes, | |
static_cast<int32_t*>(indptr.data_ptr()), | |
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads, | |
page_size); | |
TORCH_CHECK(status == cudaSuccess, | |
"BatchDecodeWithPagedKVCache failed with error ", | |
cudaGetErrorString(status)); | |
return true; | |
}); | |
}); | |
}); | |
}); | |
}); | |
} else { | |
DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] { | |
return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { | |
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { | |
return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { | |
return DISPATCH_pos_encoding_mode( | |
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { | |
cudaError_t status = | |
handler_->BeginForwardDispatched<GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, | |
KV_LAYOUT, POS_ENCODING_MODE, c_type, c_type, | |
int32_t>( | |
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes, | |
static_cast<int32_t*>(indptr.data_ptr()), | |
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads, | |
page_size); | |
TORCH_CHECK(status == cudaSuccess, | |
"BatchDecodeWithPagedKVCache failed with error ", | |
cudaGetErrorString(status)); | |
return true; | |
}); | |
}); | |
}); | |
}); | |
}); | |
} | |
} |
accordingly.
@yzh119 correct, I wanted to avoid having to modify the public API. I don't think the information about the query dtype will be used in resource estimation, but please correct me if that's not the case - happy to do the change then |
Hi @Yard1 , I'm a little bit conservative here because this section of code flashinfer/include/flashinfer/attention/handler.cuh Lines 121 to 130 in 1250b68
might produce different |
Ok sounds good! Let me make the change. |
@yzh119 Updated, ptal! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Thank you @Yard1 !
🤖 I have created a release *beep* *boop* --- ## [0.1.0](v0.0.4...v0.1.0) (2024-06-20) ### Highlights * Support any GQA group size support for tensor-cores kernels. * Support any page size support for tensor-cores kernels. * Support CUDA-Graph for prefill/decode APIs. * Add an option to accelerate decode kernels with Tensor Cores. * Support custom attention mask. (https://docs.flashinfer.ai/tutorials/kv_layout.html#mask-layout-2d-ragged-tensor) * Support logits cap in Grok-1 models. * Fused GPU-sampling kernels: top-p, top-k, speculative verification. (https://docs.flashinfer.ai/api/python/sampling.html) * PyTorch wrapper of group-gemm cutlass kernels. (https://docs.flashinfer.ai/api/python/sampling.html) ### Acknowledgement We thank [@ibsidorenko](https://github.com/ibsidorenko), [@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU), [@Yard1](https://github.com/Yard1) [@AgrawalAmey](https://github.com/AgrawalAmey), [@xuzhenqi](https://github.com/xuzhenqi), [@mgerstgrasser](https://github.com/mgerstgrasser), [@esmeetu](https://github.com/esmeetu), [@yz-tang](https://github.com/yz-tang), [@HSQ79815](https://github.com/HSQ79815), [@Qubitium](https://github.com/Qubitium), [@shreygupta2809](https://github.com/shreygupta2809), [@sighingnow](https://github.com/sighingnow), [@vinx13](https://github.com/vinx13), [@tqchen](https://github.com/tqchen), [@merrymercy](https://github.com/merrymercy), [@comaniac](https://github.com/comaniac) and many others for their contributions and helpful discussions for 0.0.5 release. ### Refactor * support any GQA group size for tensor-cores kernels ([#301](#301)) ([c111ca](c111ca6)) * support any page size for tensor-cores kernels ([#306](#306)) ([82fd8c](82fd8c7)) ### Features * add `use_tensor_cores` option to decode kernels to accelerate GQA ([#317](#317)) ([3b50dd5](3b50dd5)) * add group gemm operators ([#282](#282)) ([e08ba42](e08ba42)) * initial support of distributed operators ([#289](#289)) ([03553da](03553da)) * initial support of logits hook ([#298](#298)) ([ab1e2ad](ab1e2ad)) * Separate Q and KV dtypes for decode ([#286](#286)) ([5602659](5602659)) * support cuda graph for batched multi-query(prefill/append) attention ([#275](#275)) ([83ceb67](83ceb67)) * support cuda graph for batched multi-query(prefill/append) attention ([#277](#277)) ([24cc583](24cc583)) * support custom attention mask in prefill/append attention kernels ([#266](#266)) ([7304282](7304282)) * fused speculative sampilng kernels ([#259](#259)) ([cea2bb](cea2bb9)) * expose sampling APIs in pytorch ([#238](#238)) ([092902](0929023)) ### Performance Improvements * initial cuda graph support ([#256](#256)) ([7e9cc7f](7e9cc7f)) * split kv-cache for prefill/append kernels ([#310](#310)) ([f0bb0a3](f0bb0a3)) * use packed bit array for attention mask ([#308](#308)) ([3d43dc9](3d43dc9)) --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please). --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Zihao Ye <expye@outlook.com>
Closes #285
Modified unit tests pass. May need some extra validation.