diff --git a/lmdeploy/pytorch/kernels/cuda/alibi_pagedattention.py b/lmdeploy/pytorch/kernels/cuda/alibi_pagedattention.py index 1e54b5c13..66a844203 100644 --- a/lmdeploy/pytorch/kernels/cuda/alibi_pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/alibi_pagedattention.py @@ -12,7 +12,7 @@ assert triton.__version__ >= '2.1.0' -LOG2 = math.log(2) +LOG2: tl.constexpr = math.log(2) @triton.jit diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index d8e6ec501..7790a44b1 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -621,6 +621,7 @@ def convert_pv(p, v): _convert_pv = None +_nv_cap = None # TODO: how to support inplace autotune? @@ -1099,9 +1100,10 @@ def paged_attention_fwd( max_seqlen (int): The max input length. BLOCK (int): The kernel block size. """ - global _convert_pv + global _convert_pv, _nv_cap if _convert_pv is None: nv_cap = torch.cuda.get_device_capability() + _nv_cap = nv_cap _convert_pv = _get_convert_pv(nv_cap) if kv_layout == 'bshd': @@ -1150,7 +1152,10 @@ def _get_block_d(Lk): is_decoding = q.shape[-3] == q_seqlens.size(0) if not is_decoding: BLOCK_DMODEL, BLOCK_DMODEL1, BLOCK_DV = _get_block_d(Lq) - BLOCK_M = max(16, min(BLOCK, 16384 // BLOCK_DMODEL)) + if _nv_cap[0] < 8: + BLOCK_M = max(16, min(BLOCK, 8192 // BLOCK_DMODEL)) + else: + BLOCK_M = max(16, min(BLOCK, 16384 // BLOCK_DMODEL)) num_warps = 4 num_stages = 2 kv_head = k.shape[h_dim]