Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

automatically set max_batch_size according to the device when it is not specified #2434

Merged
merged 5 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.utils import get_max_batch_size

from .cli import CLI
from .utils import (ArgumentHelper, DefaultsAndTypesHelpFormatter,
convert_args, get_chat_template, get_lora_adapters)
Expand Down Expand Up @@ -202,6 +204,8 @@ def gradio(args):
from lmdeploy.messages import (PytorchEngineConfig,
TurbomindEngineConfig)
from lmdeploy.serve.gradio.app import run
max_batch_size = args.max_batch_size if args.max_batch_size \
else get_max_batch_size(args.device)
backend = args.backend

if backend != 'pytorch' and ':' not in args.model_path_or_server:
Expand All @@ -210,7 +214,7 @@ def gradio(args):
if backend == 'pytorch':
backend_config = PytorchEngineConfig(
tp=args.tp,
max_batch_size=args.max_batch_size,
max_batch_size=max_batch_size,
cache_max_entry_count=args.cache_max_entry_count,
block_size=args.cache_block_seq_len,
session_len=args.session_len,
Expand All @@ -220,7 +224,7 @@ def gradio(args):
else:
backend_config = TurbomindEngineConfig(
tp=args.tp,
max_batch_size=args.max_batch_size,
max_batch_size=max_batch_size,
session_len=args.session_len,
model_format=args.model_format,
quant_policy=args.quant_policy,
Expand All @@ -243,6 +247,8 @@ def api_server(args):
"""Serve LLMs with restful api using fastapi."""
from lmdeploy.archs import autoget_backend
from lmdeploy.serve.openai.api_server import serve as run_api_server
max_batch_size = args.max_batch_size if args.max_batch_size \
else get_max_batch_size(args.device)
backend = args.backend
if backend != 'pytorch':
# set auto backend mode
Expand All @@ -253,7 +259,7 @@ def api_server(args):
adapters = get_lora_adapters(args.adapters)
backend_config = PytorchEngineConfig(
tp=args.tp,
max_batch_size=args.max_batch_size,
max_batch_size=max_batch_size,
cache_max_entry_count=args.cache_max_entry_count,
block_size=args.cache_block_seq_len,
session_len=args.session_len,
Expand All @@ -265,7 +271,7 @@ def api_server(args):
from lmdeploy.messages import TurbomindEngineConfig
backend_config = TurbomindEngineConfig(
tp=args.tp,
max_batch_size=args.max_batch_size,
max_batch_size=max_batch_size,
session_len=args.session_len,
model_format=args.model_format,
quant_policy=args.quant_policy,
Expand Down
10 changes: 6 additions & 4 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,12 @@ def session_len(parser, default: int = None):
def max_batch_size(parser):
"""Add argument max_batch_size to parser."""

return parser.add_argument('--max-batch-size',
type=int,
default=128,
help='Maximum batch size')
return parser.add_argument(
'--max-batch-size',
type=int,
default=None,
help='Maximum batch size. If not specified, the engine will '
'automatically set it according to the device')

@staticmethod
def quant_policy(parser, default: int = 0):
Expand Down
12 changes: 6 additions & 6 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ class TurbomindEngineConfig:
If it is not specified, i.e. None, it will be extracted from the input model
tp (int): the number of GPU cards used in tensor parallelism, default to 1
session_len (int): the max session length of a sequence, default to None
max_batch_size (int): the max batch size during inference, default to 128
max_batch_size (int): the max batch size during inference. If it is not specified,
the engine will automatically set it according to the device
cache_max_entry_count (float): the percentage of gpu memory occupied by the k/v cache.
For versions of lmdeploy between `v0.2.0` and `v0.2.1`, it defaults to 0.5, depicting the percentage of TOTAL GPU memory to be allocated to the k/v cache.
For lmdeploy versions greater than `v0.2.1`, it defaults to 0.8, signifying the percentage of FREE GPU memory to be reserved for the k/v cache
Expand All @@ -135,7 +136,7 @@ class TurbomindEngineConfig:
model_format: Optional[str] = None
tp: int = 1
session_len: Optional[int] = None
max_batch_size: int = 128
max_batch_size: int = None
cache_max_entry_count: float = 0.8
cache_chunk_size: int = -1
cache_block_seq_len: int = 64
Expand All @@ -152,7 +153,6 @@ class TurbomindEngineConfig:
def __post_init__(self):
"""Check input validation."""
assert self.tp >= 1, 'tp must be a positive integer'
assert self.max_batch_size >= 1, 'max_batch_size must be a positive integer' # noqa
assert self.cache_max_entry_count > 0 and self.cache_max_entry_count < 1, 'invalid cache_max_entry_count' # noqa
assert self.quant_policy in (0, 4, 8), 'invalid quant_policy'
assert self.rope_scaling_factor >= 0, 'invalid rope_scaling_factor'
Expand All @@ -167,7 +167,8 @@ class PytorchEngineConfig:
Args:
tp (int): Tensor Parallelism. default 1.
session_len (int): Max session length. Default None.
max_batch_size (int): Max batch size. Default 128.
max_batch_size (int): Max batch size. If it is not specified,
the engine will automatically set it according to the device
cache_max_entry_count (float): the percentage of gpu memory occupied
by the k/v cache. For lmdeploy versions greater than `v0.2.1`,
it defaults to 0.8, signifying the percentage of FREE GPU memory
Expand All @@ -192,7 +193,7 @@ class PytorchEngineConfig:
"""
tp: int = 1
session_len: int = None
max_batch_size: int = 128
max_batch_size: int = None
cache_max_entry_count: float = 0.8
prefill_interval: int = 16
block_size: int = 64
Expand All @@ -209,7 +210,6 @@ class PytorchEngineConfig:
def __post_init__(self):
"""Check input validation."""
assert self.tp >= 1, 'invalid tp'
assert self.max_batch_size >= 1, 'invalid max_batch_size'
assert self.cache_max_entry_count > 0 and self.cache_max_entry_count < 1, 'invalid cache_max_entry_count' # noqa
assert self.num_cpu_blocks >= 0, 'invalid num_cpu_blocks'
assert self.max_prefill_token_num >= 0, 'invalid max_prefill_token_num'
Expand Down
15 changes: 10 additions & 5 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import copy
import os
from dataclasses import dataclass
from typing import Any, Dict, List
Expand All @@ -9,7 +10,8 @@

from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig,
ResponseType)
from lmdeploy.utils import get_logger, get_model, logging_timer
from lmdeploy.utils import (get_logger, get_max_batch_size, get_model,
logging_timer)

from ..adapter.adapter import AdapterManager, SchedulerAdapter
from ..check_env import check_adapters, check_env, check_model
Expand Down Expand Up @@ -116,15 +118,18 @@ def __init__(self,
trust_remote_code: bool = True) -> None:
if engine_config is None:
engine_config = PytorchEngineConfig()
else:
engine_config = copy.deepcopy(engine_config)
check_env(engine_config.device_type)
check_model(model_path, trust_remote_code)
if engine_config.max_batch_size is None:
engine_config.max_batch_size = get_max_batch_size(
engine_config.device_type)
if engine_config.adapters is not None:
check_adapters(list(engine_config.adapters.values()))

self.engine_config = engine_config
tp = engine_config.tp

self.tp = tp
self.tp = engine_config.tp

self.device_context = DeviceContext(
device_type=engine_config.device_type)
Expand Down Expand Up @@ -156,7 +161,7 @@ def __init__(self,
cache_config=cache_config,
trust_remote_code=trust_remote_code,
adapters=adapters,
tp=tp)
tp=self.tp)

cache_config = self.model_agent.cache_config
self.adapter_manager = self._build_adapter_manager(adapters)
Expand Down
14 changes: 2 additions & 12 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,10 @@ def _build_turbomind(
PytorchEngineConfig]] = None,
**kwargs):
"""Innter build method for turbomind backend."""
if backend_config is None:
backend_config = TurbomindEngineConfig()
assert isinstance(backend_config, TurbomindEngineConfig), 'Please'\
' use TurbomindEngineConfig imported from lmdeploy.messages for ' \
'turbomind backend'
from lmdeploy import turbomind as tm
self.engine = tm.TurboMind.from_pretrained(
model_path, engine_config=backend_config, **kwargs)
self.backend_config = backend_config
self.backend_config = self.engine.engine_config
self.hf_tm_cfg = self.engine.config

def _build_pytorch(
Expand All @@ -207,14 +202,9 @@ def _build_pytorch(
**kwargs):
"""Innter build method for pytorch backend."""
from lmdeploy.pytorch.engine import Engine
if backend_config is None:
backend_config = PytorchEngineConfig()
assert isinstance(backend_config, PytorchEngineConfig), 'Please '\
'use PytorchEngineConfig imported from lmdeploy.messages for ' \
'pytorch backend'
self.engine = Engine(model_path=model_path,
engine_config=backend_config)
self.backend_config = backend_config
self.backend_config = self.engine.engine_config
self.hf_tm_cfg = getattr(self.engine.model_config, 'hf_config', None)

def __call__(self,
Expand Down
37 changes: 22 additions & 15 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import copy
import json
import os.path as osp
import sys
Expand All @@ -18,7 +19,7 @@
from lmdeploy.messages import (EngineOutput, GenerationConfig, ResponseType,
TurbomindEngineConfig)
from lmdeploy.tokenizer import Tokenizer
from lmdeploy.utils import get_logger, get_model
from lmdeploy.utils import get_logger, get_max_batch_size, get_model

from .deploy.config import TurbomindModelConfig
from .supported_models import is_supported
Expand Down Expand Up @@ -67,12 +68,14 @@ class TurboMind:

Args:
model_path (str): the path of turbomind's model
model_source (int): model source
model_format (str): needed when model_path is a hf model and not
managed by lmdeploy
group_size (int): needed when model_path is a hf model and not
managed by lmdeploy
tp (int): tensor parallel
mode_name (str): the name of the served model
chat_template_name (str): the name of the chat template, which is
supposed to be a builtin chat template defined in
`lmdeploy/model.py`
engine_config (TurbomindEngineConfig): the config of the inference
engine
model_source (int): the source of the model, which is either
turbomind model, or a transformers model
"""

def __init__(self,
Expand All @@ -85,24 +88,28 @@ def __init__(self,
self.model_name = model_name
self.chat_template_name = chat_template_name

tp = 1 if engine_config is None else engine_config.tp
assert ((tp & (tp - 1) == 0) and tp != 0), 'tp should be 2^n'
self.gpu_count = tp
_engine_config = copy.deepcopy(engine_config)
if _engine_config is None:
_engine_config = TurbomindEngineConfig()
if _engine_config.max_batch_size is None:
_engine_config.max_batch_size = get_max_batch_size('cuda')

self.gpu_count = _engine_config.tp

if model_source == ModelSource.WORKSPACE:
tokenizer_model_path = osp.join(model_path, 'triton_models',
'tokenizer')
self.tokenizer = Tokenizer(tokenizer_model_path)
self.model_comm = self._from_workspace(model_path=model_path,
engine_config=engine_config)
self.model_comm = self._from_workspace(
model_path=model_path, engine_config=_engine_config)
else:
if not osp.exists(model_path):
model_path = get_model(model_path, engine_config.download_dir,
engine_config.revision)
model_path = get_model(model_path, _engine_config.download_dir,
_engine_config.revision)
self.tokenizer = Tokenizer(model_path)
self.model_comm = self._from_hf(model_source=model_source,
model_path=model_path,
engine_config=engine_config)
engine_config=_engine_config)

with ThreadPoolExecutor(max_workers=self.gpu_count) as e:
ranks = [
Expand Down
27 changes: 27 additions & 0 deletions lmdeploy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,30 @@ def _get_and_verify_max_len(
f'({max_len_key}={derived_max_model_len} or model_max_length='
f"{model_max_length} in model's config.json).")
return int(max_model_len)


def get_max_batch_size(device_type: str):
"""Get the max inference batch size for LLM models according to the device
type.
Args:
device_type (str): the type of device
"""
assert device_type in ['cuda', 'ascend']
if device_type == 'cuda':
max_batch_size_map = {
'a100': 256,
'a800': 256,
'h100': 512,
'h800': 512
}
import torch
device_name = torch.cuda.get_device_name(0).lower()
for name, size in max_batch_size_map.items():
if name in device_name:
return size
# for devices that are not in `max_batch_size_map`, set
# the max_batch_size 128
return 128
elif device_type == 'ascend':
return 16
Loading