Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

small block_m for sm7.x #2626

Merged
merged 2 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/kernels/cuda/alibi_pagedattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

assert triton.__version__ >= '2.1.0'

LOG2 = math.log(2)
LOG2: tl.constexpr = math.log(2)


@triton.jit
Expand Down
9 changes: 7 additions & 2 deletions lmdeploy/pytorch/kernels/cuda/pagedattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ def convert_pv(p, v):


_convert_pv = None
_nv_cap = None


# TODO: how to support inplace autotune?
Expand Down Expand Up @@ -1099,9 +1100,10 @@ def paged_attention_fwd(
max_seqlen (int): The max input length.
BLOCK (int): The kernel block size.
"""
global _convert_pv
global _convert_pv, _nv_cap
if _convert_pv is None:
nv_cap = torch.cuda.get_device_capability()
_nv_cap = nv_cap
_convert_pv = _get_convert_pv(nv_cap)

if kv_layout == 'bshd':
Expand Down Expand Up @@ -1150,7 +1152,10 @@ def _get_block_d(Lk):
is_decoding = q.shape[-3] == q_seqlens.size(0)
if not is_decoding:
BLOCK_DMODEL, BLOCK_DMODEL1, BLOCK_DV = _get_block_d(Lq)
BLOCK_M = max(16, min(BLOCK, 16384 // BLOCK_DMODEL))
if _nv_cap[0] < 8:
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
BLOCK_M = max(16, min(BLOCK, 8192 // BLOCK_DMODEL))
else:
BLOCK_M = max(16, min(BLOCK, 16384 // BLOCK_DMODEL))
num_warps = 4
num_stages = 2
kv_head = k.shape[h_dim]
Expand Down
Loading