Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pytorch engine kv int4/int8 quantization #2438

Merged
merged 18 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion benchmark/profile_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def parse_args():
cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group)
cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group)
prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)
quant_policy_act = ArgumentHelper.quant_policy(pt_group, default=0)

# turbomind engine args
tb_group = parser.add_argument_group('TurboMind engine argument')
Expand All @@ -294,8 +295,8 @@ def parse_args():
tb_group._group_actions.append(cache_count_act)
tb_group._group_actions.append(cache_block_seq_len_act)
tb_group._group_actions.append(prefix_caching_act)
tb_group._group_actions.append(quant_policy_act)
ArgumentHelper.model_format(tb_group, default='hf')
ArgumentHelper.quant_policy(tb_group, default=0)
ArgumentHelper.num_tokens_per_iter(tb_group)
ArgumentHelper.max_prefill_iters(tb_group)

Expand Down Expand Up @@ -328,6 +329,7 @@ def main():
tp=args.tp,
thread_safe=True,
enable_prefix_caching=args.enable_prefix_caching,
quant_policy=args.quant_policy,
)

engine = Engine(args.model_path, engine_config, csv=args.csv)
Expand Down
6 changes: 4 additions & 2 deletions lmdeploy/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def add_parser_chat():
session_len_act = ArgumentHelper.session_len(pt_group)
cache_max_entry_act = ArgumentHelper.cache_max_entry_count(pt_group)
prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)
quant_policy = ArgumentHelper.quant_policy(pt_group)

# turbomind args
tb_group = parser.add_argument_group('TurboMind engine arguments')
Expand All @@ -137,8 +138,8 @@ def add_parser_chat():
tb_group._group_actions.append(session_len_act)
tb_group._group_actions.append(cache_max_entry_act)
tb_group._group_actions.append(prefix_caching_act)
tb_group._group_actions.append(quant_policy)
ArgumentHelper.model_format(tb_group)
ArgumentHelper.quant_policy(tb_group)
ArgumentHelper.rope_scaling_factor(tb_group)

@staticmethod
Expand Down Expand Up @@ -263,7 +264,8 @@ def chat(args):
cache_max_entry_count=args.cache_max_entry_count,
adapters=adapters,
enable_prefix_caching=args.enable_prefix_caching,
device_type=args.device)
device_type=args.device,
quant_policy=args.quant_policy)
run_chat(args.model_path,
engine_config,
chat_template_config=chat_template_config)
Expand Down
5 changes: 4 additions & 1 deletion lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def add_parser_api_server():
prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)
max_prefill_token_num_act = ArgumentHelper.max_prefill_token_num(
pt_group)
quant_policy = ArgumentHelper.quant_policy(pt_group)
# turbomind args
tb_group = parser.add_argument_group('TurboMind engine arguments')
# common engine args
Expand All @@ -179,8 +180,8 @@ def add_parser_api_server():
tb_group._group_actions.append(cache_block_seq_len_act)
tb_group._group_actions.append(prefix_caching_act)
tb_group._group_actions.append(max_prefill_token_num_act)
tb_group._group_actions.append(quant_policy)
ArgumentHelper.model_format(tb_group)
ArgumentHelper.quant_policy(tb_group)
ArgumentHelper.rope_scaling_factor(tb_group)
ArgumentHelper.num_tokens_per_iter(tb_group)
ArgumentHelper.max_prefill_iters(tb_group)
Expand Down Expand Up @@ -258,6 +259,7 @@ def gradio(args):
session_len=args.session_len,
enable_prefix_caching=args.enable_prefix_caching,
device_type=args.device,
quant_policy=args.quant_policy,
max_prefill_token_num=args.max_prefill_token_num)
else:
backend_config = TurbomindEngineConfig(
Expand Down Expand Up @@ -307,6 +309,7 @@ def api_server(args):
adapters=adapters,
enable_prefix_caching=args.enable_prefix_caching,
device_type=args.device,
quant_policy=args.quant_policy,
max_prefill_token_num=args.max_prefill_token_num)
else:
from lmdeploy.messages import TurbomindEngineConfig
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ class PytorchEngineConfig:
revision (str): The specific model version to use.
It can be a branch name, a tag name, or a commit id.
If unspecified, will use the default version.
quant_policy (int): default to 0. When k/v is quantized into 4 or 8
bit, set it to 4 or 8, respectively
"""
dtype: str = 'auto'
tp: int = 1
Expand All @@ -275,6 +277,7 @@ class PytorchEngineConfig:
custom_module_map: Dict[str, str] = None
download_dir: str = None
revision: str = None
quant_policy: Literal[0, 4, 8] = 0
grimoire marked this conversation as resolved.
Show resolved Hide resolved

def __post_init__(self):
"""Check input validation."""
Expand All @@ -286,6 +289,7 @@ def __post_init__(self):
assert self.max_prefill_token_num >= 0, \
'invalid max_prefill_token_num'
assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks'
assert self.quant_policy in (0, 4, 8), 'invalid quant_policy'
assert self.device_type in [
'cuda', 'ascend'
], (f'invalid device_type: {self.device_type}')
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/backends/attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, TypeVar
from typing import Generic, Literal, TypeVar

import torch

Expand All @@ -14,6 +14,7 @@ class AttentionMetadata:
q_start_loc: torch.Tensor = None
q_seqlens: torch.Tensor = None
kv_seqlens: torch.Tensor = None
quant_policy: Literal[0, 4, 8] = 0


T = TypeVar('T', bound=AttentionMetadata)
Expand Down
12 changes: 12 additions & 0 deletions lmdeploy/pytorch/backends/cuda/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def forward(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
attn_metadata: TritonAttentionMetadata,
k_scales_zeros: torch.Tensor = None,
v_scales_zeros: torch.Tensor = None,
inplace: bool = True,
) -> torch.Tensor:
"""forward."""
Expand All @@ -77,6 +79,7 @@ def forward(
q_start_loc = attn_metadata.q_start_loc
q_seqlens = attn_metadata.q_seqlens
kv_seqlens = attn_metadata.kv_seqlens
quant_policy = attn_metadata.quant_policy
max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))

# fill kv cache
Expand All @@ -90,6 +93,9 @@ def forward(
kv_seq_length=kv_seqlens,
max_q_seq_length=max_q_seqlen,
block_offsets=block_offsets,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
quant_policy=quant_policy,
)

if inplace:
Expand All @@ -110,6 +116,9 @@ def forward(
q_seqlens=q_seqlens,
kv_seqlens=kv_seqlens,
max_seqlen=max_q_seqlen,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
quant_policy=quant_policy,
window_size=self.sliding_window,
sm_scale=self.scale,
logit_softcapping=self.logit_softcapping,
Expand All @@ -127,6 +136,9 @@ def forward(
max_input_len=max_q_seqlen,
head_offset=self.alibi_head_offset,
num_heads=self.alibi_num_heads,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
quant_policy=quant_policy,
)

return attn_output
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/backends/cuda/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def update_step_context(cls, step_context):
q_start_loc=q_start_loc,
q_seqlens=q_seqlens,
kv_seqlens=step_context.kv_seqlens,
quant_policy=step_context.kv_quant_policy,
)

step_context.attn_metadata = attn_metadata
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
from typing import Any, Dict, List
from typing import Any, Dict, List, Literal

import torch

Expand Down Expand Up @@ -76,6 +76,7 @@ class CacheConfig:
cache_max_entry_count: float = 0.8
max_prefill_token_num: int = 4096
enable_prefix_caching: bool = False
quant_policy: Literal[0, 4, 8] = 0

def __post_init__(self):
"""post init."""
Expand Down
93 changes: 82 additions & 11 deletions lmdeploy/pytorch/engine/cache_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/vllm-project/vllm
from typing import Dict, List, Tuple
from typing import Dict, List, Literal, Tuple

import torch

Expand Down Expand Up @@ -43,6 +43,8 @@ def __init__(
self.block_size = cache_config.block_size
self.num_layers = model_config.num_layers
self.kv_cache_dtype = model_config.dtype
if cache_config.quant_policy > 0:
self.kv_cache_dtype = torch.uint8

# Initialize the cache.
self.local_gpu_cache = self.allocate_gpu_cache()
Expand Down Expand Up @@ -84,6 +86,7 @@ def _get_key_block_shape_impl(cls,
block_size: int,
head_size: int,
world_size: int = 1,
quant_policy: Literal[0, 4, 8] = 0,
local: bool = True):
"""get single block shape."""
attn_backend = get_backend()
Expand All @@ -93,6 +96,10 @@ def _get_key_block_shape_impl(cls,
assert num_heads % world_size == 0, \
f'num_heads: {num_heads}, world_size: {world_size}'
num_heads = num_heads // world_size
if quant_policy == 4: # pack head_dim to uint8
assert head_size % 2 == 0, \
f'head_size: {head_size}, quant_policy: {quant_policy}'
head_size = head_size // 2
return attn_backend.get_k_block_shape(block_size, num_heads, head_size,
dtype)

Expand All @@ -102,6 +109,7 @@ def _get_value_block_shape_impl(cls,
block_size: int,
head_size: int,
world_size: int = 1,
quant_policy: Literal[0, 4, 8] = 0,
local: bool = True):
"""get single block shape."""
attn_backend = get_backend()
Expand All @@ -111,6 +119,11 @@ def _get_value_block_shape_impl(cls,
assert num_heads % world_size == 0, \
f'num_heads: {num_heads}, world_size: {world_size}'
num_heads = num_heads // world_size
if quant_policy == 4: # pack head_dim to uint8
assert head_size % 2 == 0, \
f'head_size: {head_size}, quant_policy: {quant_policy}'
head_size = head_size // 2

return attn_backend.get_v_block_shape(block_size, num_heads, head_size,
dtype)

Expand All @@ -124,6 +137,7 @@ def get_key_block_shape(self, local: bool = False) -> Tuple[int, int, int]:
block_size=self.block_size,
head_size=head_size,
world_size=self.world_size,
quant_policy=self.cache_config.quant_policy,
local=local,
)

Expand All @@ -138,6 +152,7 @@ def get_value_block_shape(self,
block_size=self.block_size,
head_size=head_size,
world_size=self.world_size,
quant_policy=self.cache_config.quant_policy,
local=local,
)

Expand All @@ -158,7 +173,21 @@ def allocate_gpu_cache(self):
dtype=self.kv_cache_dtype,
device='cuda',
)
gpu_cache.append((key_blocks, value_blocks))
if self.cache_config.quant_policy in (4, 8):
key_scales_zeros = torch.empty(
size=(self.num_gpu_blocks, *key_block_shape[:-1], 2),
dtype=self.model_config.dtype,
device='cuda',
)
value_scales_zeros = torch.empty(
size=(self.num_gpu_blocks, *value_block_shape[:-1], 2),
dtype=self.model_config.dtype,
device='cuda',
)
gpu_cache.append((key_blocks, value_blocks, key_scales_zeros,
value_scales_zeros))
else:
gpu_cache.append((key_blocks, value_blocks))

return gpu_cache

Expand All @@ -182,7 +211,21 @@ def allocate_cpu_cache(self):
dtype=self.kv_cache_dtype,
pin_memory=pin_memory,
)
cpu_cache.append((key_blocks, value_blocks))
if self.cache_config.quant_policy in (4, 8):
key_scales_zeros = torch.empty(
size=(self.num_cpu_blocks, *key_block_shape[:-1], 2),
dtype=self.model_config.dtype,
pin_memory=pin_memory,
)
value_scales_zeros = torch.empty(
size=(self.num_cpu_blocks, *value_block_shape[:-1], 2),
dtype=self.model_config.dtype,
pin_memory=pin_memory,
)
cpu_cache.append((key_blocks, value_blocks, key_scales_zeros,
value_scales_zeros))
else:
cpu_cache.append((key_blocks, value_blocks))
return cpu_cache

@torch.inference_mode()
Expand All @@ -201,8 +244,9 @@ def _swap(self, src: List[KVCache], dst: List[KVCache],
dst_key_cache, dst_value_cache = dst[i]

for src_id, dst_id in src_to_dst.items():
dst_key_cache[dst_id].copy_(src_key_cache[src_id])
dst_value_cache[dst_id].copy_(src_value_cache[src_id])
if isinstance(dst_key_cache[dst_id], torch.Tensor):
dst_key_cache[dst_id].copy_(src_key_cache[src_id])
dst_value_cache[dst_id].copy_(src_value_cache[src_id])

event = self.events[i]
event.record(stream=self.cache_stream)
Expand All @@ -227,7 +271,8 @@ def swap_out(self, src_to_dst: Dict[int, int]) -> None:
def get_cache_block_size(cls,
block_size: int,
model_config: ModelConfig,
world_size: int = 1) -> int:
world_size: int = 1,
quant_policy: int = 0) -> int:
"""Get the required cache size of the model.

Args:
Expand All @@ -250,18 +295,44 @@ def get_cache_block_size(cls,
head_size=key_head_size,
world_size=world_size,
local=True,
quant_policy=quant_policy,
)
value_shape = cls._get_value_block_shape_impl(
model_config,
block_size=block_size,
head_size=value_head_size,
world_size=world_size,
quant_policy=quant_policy,
local=True,
)
dtype = model_config.dtype
key_block = torch.empty(key_shape, dtype=dtype, device='meta')
value_block = torch.empty(value_shape, dtype=dtype, device='meta')
mem_key_block = key_block.numel() * key_block.element_size()
mem_value_block = value_block.numel() * value_block.element_size()
if quant_policy == 0:
dtype = model_config.dtype
key_block = torch.empty(key_shape, dtype=dtype, device='meta')
value_block = torch.empty(value_shape, dtype=dtype, device='meta')
mem_key_block = key_block.numel() * key_block.element_size()
mem_value_block = value_block.numel() * value_block.element_size()
elif quant_policy in (4, 8):
key_block = torch.empty(key_shape,
dtype=torch.uint8,
device='meta')
value_block = torch.empty(value_shape,
dtype=torch.uint8,
device='meta')
key_scale_zero_block = torch.empty((*key_shape[:-1], 2),
dtype=model_config.dtype,
device='meta')
value_scale_zero_block = torch.empty((*value_shape[:-1], 2),
dtype=model_config.dtype,
device='meta')
mem_key_block = key_block.numel() * key_block.element_size(
) + key_scale_zero_block.numel(
) * key_scale_zero_block.element_size()
mem_value_block = value_block.numel() * value_block.element_size(
) + value_scale_zero_block.numel(
) * value_scale_zero_block.element_size()
else:
raise ValueError(f'unsupported quant_policy {quant_policy}')

# TODO quant 4
AllentDan marked this conversation as resolved.
Show resolved Hide resolved
total = num_layers * (mem_key_block + mem_value_block)
return total
1 change: 1 addition & 0 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(self,
cache_max_entry_count=engine_config.cache_max_entry_count,
max_prefill_token_num=engine_config.max_prefill_token_num,
enable_prefix_caching=engine_config.enable_prefix_caching,
quant_policy=engine_config.quant_policy,
)

if not os.path.exists(model_path):
Expand Down
Loading
Loading