Skip to content

Commit

Permalink
Revert #7509 (#7887)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Aug 27, 2024
1 parent 64cc644 commit 9606c71
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ def _get_decode_wrapper(self):
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = (num_qo_heads // num_kv_heads) not in \
(1, 2, 4, 8)
use_tensor_cores = num_qo_heads // num_kv_heads > 4
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
"NHD",
Expand Down Expand Up @@ -172,8 +171,7 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = (num_qo_heads // num_kv_heads) not in \
(1, 2, 4, 8)
use_tensor_cores = num_qo_heads // num_kv_heads > 4
self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer,
Expand Down

0 comments on commit 9606c71

Please sign in to comment.