From 0b1c6ca64a2f99dab4e11ede24d12c1c90dff855 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Sat, 14 Sep 2024 17:42:27 +0800 Subject: [PATCH 1/9] support specifying data type --- lmdeploy/messages.py | 91 +++++++++++++------ lmdeploy/turbomind/deploy/converter.py | 30 ++++-- .../turbomind/deploy/target_model/base.py | 19 ++-- src/turbomind/python/bind.cpp | 4 +- .../triton_backend/llama/LlamaTritonModel.cc | 8 +- 5 files changed, 100 insertions(+), 52 deletions(-) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 021a1def6..13227824e 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -117,30 +117,58 @@ class TurbomindEngineConfig: """TurboMind Engine config. Args: - model_format (str): the layout of the deployed model. It can be one of the following values [hf, meta_llama, awq, gptq], - `hf` meaning huggingface model(.bin, .safetensors), `meta_llama` being meta llama's format(.pth), - `awq` and `gptq` meaning the quantized model by AWQ and GPTQ, respectively. - If it is not specified, i.e. None, it will be extracted from the input model - tp (int): the number of GPU cards used in tensor parallelism, default to 1 - session_len (int): the max session length of a sequence, default to None - max_batch_size (int): the max batch size during inference. If it is not specified, - the engine will automatically set it according to the device - cache_max_entry_count (float): the percentage of gpu memory occupied by the k/v cache. - For versions of lmdeploy between `v0.2.0` and `v0.2.1`, it defaults to 0.5, depicting the percentage of TOTAL GPU memory to be allocated to the k/v cache. - For lmdeploy versions greater than `v0.2.1`, it defaults to 0.8, signifying the percentage of FREE GPU memory to be reserved for the k/v cache - cache_chunk_size (int): The policy to apply for KV block from the block manager, default to -1. - cache_block_seq_len (int): the length of the token sequence in a k/v block, default to 64 - enable_prefix_caching (bool): enable cache prompts for block reuse, default to False - quant_policy (int): default to 0. When k/v is quantized into 8 bit, set it to 4 - rope_scaling_factor (float): scaling factor used for dynamic ntk, default to 0. TurboMind follows the implementation of transformer LlamaAttention + dtype (str): data type for model weights and activations. It can be + one of the following values, ['auto', 'float16', 'bfloat16'] + The `auto` option will use FP16 precision for FP32 and FP16 + models, and BF16 precision for BF16 models. + model_format (str): the layout of the deployed model. It can be one + of the following values [hf, meta_llama, awq, gptq],`hf` meaning + huggingface model(.bin, .safetensors), `meta_llama` being + meta llama's format(.pth), `awq` and `gptq` meaning the quantized + model by AWQ and GPTQ, respectively. If it is not specified, + i.e. None, it will be extracted from the input model + tp (int): the number of GPU cards used in tensor parallelism, + default to 1 + session_len (int): the max session length of a sequence, default to + None + max_batch_size (int): the max batch size during inference. If it is + not specified, the engine will automatically set it according to + the device + cache_max_entry_count (float): the percentage of gpu memory occupied + by the k/v cache. + For versions of lmdeploy between `v0.2.0` and `v0.2.1`, it + defaults to 0.5, depicting the percentage of TOTAL GPU memory to + be allocated to the k/v cache. + For lmdeploy versions greater than `v0.2.1`, it defaults to 0.8, + signifying the percentage of FREE GPU memory to be reserved for + the k/v cache + cache_chunk_size (int): The policy to apply for KV block from + the block manager, default to -1. + cache_block_seq_len (int): the length of the token sequence in + a k/v block, default to 64 + enable_prefix_caching (bool): enable cache prompts for block reuse, + default to False + quant_policy (int): default to 0. When k/v is quantized into 4 or 8 + bit, set it to 4 or 8, respectively + rope_scaling_factor (float): scaling factor used for dynamic ntk, + default to 0. TurboMind follows the implementation of transformer + LlamaAttention use_logn_attn (bool): whether or not to use log attn: default to False - download_dir (str): Directory to download and load the weights, default to the default cache directory of huggingface. - revision (str): The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. - max_prefill_token_num(int): the number of tokens each iteration during prefill, default to 8192 - num_tokens_per_iter(int): the number of tokens processed in each forward pass. Working with `max_prefill_iters` enables "Dynamic SplitFuse"-like scheduling - max_prefill_iters(int): the max number of forward pass during prefill stage - """ # noqa: E501 + download_dir (str): Directory to download and load the weights, + default to the default cache directory of huggingface. + revision (str): The specific model version to use. It can be a branch + name, a tag name, or a commit id. If unspecified, will use the + default version. + max_prefill_token_num(int): the number of tokens each iteration during + prefill, default to 8192 + num_tokens_per_iter(int): the number of tokens processed in each + forward pass. Working with `max_prefill_iters` enables the + "Dynamic SplitFuse"-like scheduling + max_prefill_iters(int): the max number of forward pass during prefill + stage + """ + dtype: str = 'auto' model_format: Optional[str] = None tp: int = 1 session_len: Optional[int] = None @@ -160,11 +188,14 @@ class TurbomindEngineConfig: def __post_init__(self): """Check input validation.""" + assert self.dtype in ['auto', 'float16', 'bfloat16'] assert self.tp >= 1, 'tp must be a positive integer' - assert self.cache_max_entry_count > 0 and self.cache_max_entry_count < 1, 'invalid cache_max_entry_count' # noqa + assert 0 < self.cache_max_entry_count < 1, \ + 'invalid cache_max_entry_count' assert self.quant_policy in (0, 4, 8), 'invalid quant_policy' assert self.rope_scaling_factor >= 0, 'invalid rope_scaling_factor' - assert self.max_prefill_token_num >= 0, 'invalid max_prefill_token_num' + assert self.max_prefill_token_num >= 0, \ + 'invalid max_prefill_token_num' assert self.num_tokens_per_iter >= 0, 'invalid num_tokens_per_iter' @@ -173,6 +204,10 @@ class PytorchEngineConfig: """PyTorch Engine Config. Args: + dtype (str): data type for model weights and activations. It can be + one of the following values, ['auto', 'float16', 'bfloat16'] + The `auto` option will use FP16 precision for FP32 and FP16 + models, and BF16 precision for BF16 models. tp (int): Tensor Parallelism. default 1. session_len (int): Max session length. Default None. max_batch_size (int): Max batch size. If it is not specified, @@ -199,6 +234,7 @@ class PytorchEngineConfig: It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. """ + dtype: str = 'auto' tp: int = 1 session_len: int = None max_batch_size: int = None @@ -219,10 +255,13 @@ class PytorchEngineConfig: def __post_init__(self): """Check input validation.""" + assert self.dtype in ['auto', 'float16', 'bfloat16'] assert self.tp >= 1, 'invalid tp' - assert self.cache_max_entry_count > 0 and self.cache_max_entry_count < 1, 'invalid cache_max_entry_count' # noqa + assert 0 < self.cache_max_entry_count < 1, \ + 'invalid cache_max_entry_count' assert self.num_cpu_blocks >= 0, 'invalid num_cpu_blocks' - assert self.max_prefill_token_num >= 0, 'invalid max_prefill_token_num' + assert self.max_prefill_token_num >= 0, \ + 'invalid max_prefill_token_num' assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks' assert self.device_type in [ 'cuda', 'ascend' diff --git a/lmdeploy/turbomind/deploy/converter.py b/lmdeploy/turbomind/deploy/converter.py index bce9bbd61..442c2e0ce 100644 --- a/lmdeploy/turbomind/deploy/converter.py +++ b/lmdeploy/turbomind/deploy/converter.py @@ -85,7 +85,7 @@ def copy_tokenizer(model_path: str, tokenizer_path: str, def get_output_model_registered_name_and_config(model_path: str, - model_format: str, + model_format: str, dtype: str, group_size: int): """Get the registered name of the turbomind model and its configuration according to the input model path, format and user-input config. The name @@ -95,11 +95,12 @@ def get_output_model_registered_name_and_config(model_path: str, model_path (str): the path of the input model model_format (str): the format of the model, which can be one of ['meta_llama', 'hf', 'awq', 'gptq'] + dtype (str): the data type of the model's weights and activations group_size (int): the size of group used by awq model """ register_name = 'tm' turbomind_model_arch = 'llama' - weight_type = 'fp16' + weight_type = 'float16' config = TurbomindModelConfig.from_dict() @@ -114,16 +115,26 @@ def get_output_model_registered_name_and_config(model_path: str, group_size = 128 if group_size == 0 else group_size else: torch_dtype = getattr(model_config, 'torch_dtype', 'float16') - TORCH_DTYPE_MAP = {torch.bfloat16: 'bf16', torch.float16: 'fp16'} - weight_type = TORCH_DTYPE_MAP.get(torch_dtype, 'fp16') + TORCH_DTYPE_MAP = { + torch.bfloat16: 'bfloat16', + torch.float16: 'float16' + } + weight_type = TORCH_DTYPE_MAP.get(torch_dtype, 'float16') # Qwen-1 didn't set torch_dtype. It used bf16 as default if model_arch == 'QWenLMHeadModel': - weight_type = 'bf16' - if not torch.cuda.is_bf16_supported(): - print( - 'Device does not support bfloat16. Set float16 forcefully') - weight_type = 'fp16' + weight_type = 'bfloat16' + + if dtype == 'auto': + weight_type = weight_type if weight_type in [ + 'float16', 'bfloat16', 'int4' + ] else 'float16' + elif dtype in ['float16', 'bfloat16']: + assert weight_type != 'int4', f'the model {model_path} is a 4bit ' \ + f'weight quantized model but user specifies dtype {dtype}' + weight_type = dtype + else: + assert 0, 'unsupported specified data type {dtype}' config.model_config.model_arch = model_arch config.model_config.weight_type = weight_type @@ -251,6 +262,7 @@ def get_tm_model(model_path, get_output_model_registered_name_and_config( model_path=model_path, model_format=engine_config.model_format, + dtype=engine_config.dtype, group_size=group_size) tm_cfg.model_config.chat_template = chat_template_name diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index b2d69cdd6..b2da4b441 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -28,13 +28,10 @@ def tprint(*args, **kwargs): def _weight_dtype_map(weight_type: str, default=None): """map literal data type to torch dtype.""" - _WEIGHT_DTYPE_MAP = dict( - int4=torch.float16, - fp16=torch.float16, - fp32=torch.float16, - bf16=torch.bfloat16 - if torch.cuda.is_bf16_supported() else torch.float16, - ) + _WEIGHT_DTYPE_MAP = dict(int4=torch.float16, + float16=torch.float16, + float32=torch.float16, + bfloat16=torch.bfloat16) return _WEIGHT_DTYPE_MAP.get(weight_type, default) @@ -128,7 +125,7 @@ def _tofile(tensor, path): elif len(self.tm_params) > 0: tm_params = self.tm_params weight_type = self.model_config.weight_type - assert weight_type in ['fp16', 'fp32', 'bf16', 'int4'] + assert weight_type in ['float16', 'bfloat16', 'int4'] # currently, the tensor type should in # [torch.float, torch.half, torch.bfloat16, torch.int32] @@ -137,12 +134,12 @@ def _tofile(tensor, path): torch.int32, torch.float, torch.half, torch.bfloat16 ] if torch_tensor.dtype != torch.int32: - if weight_type in ['fp16', 'int4']: + if weight_type in ['float16', 'int4']: torch_tensor = torch_tensor.half() - elif weight_type == 'bf16': + elif weight_type == 'bfloat16': torch_tensor = torch_tensor.bfloat16() else: - torch_tensor = torch_tensor.float() + torch_tensor = torch_tensor.half() for tm_tensor in tm_params[name]: tm_tensor.copy_from(torch_tensor) tm_params.pop(name) diff --git a/src/turbomind/python/bind.cpp b/src/turbomind/python/bind.cpp index a716f8a49..62142a052 100644 --- a/src/turbomind/python/bind.cpp +++ b/src/turbomind/python/bind.cpp @@ -373,13 +373,13 @@ PYBIND11_MODULE(_turbomind, m) PyGILState_Release(state); } }; - if (data_type == "half" || data_type == "fp16" || data_type == "int4") { + if (data_type == "half" || data_type == "fp16" || data_type == 'float16' || data_type == "int4") { auto model = std::make_shared>( tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir, config); model->set_ffi_lock(gil_control); return model; } - else if (data_type == "bf16") { + else if (data_type == "bf16" || data_type == 'bfloat16') { #ifdef ENABLE_BF16 auto model = std::make_shared>( tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir, config); diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index e2a564aa4..c67177400 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -49,14 +49,14 @@ std::shared_ptr AbstractTransformerModel::createLlamaM int tensor_para_size = ft_instance_hyperparameter["tensor_para_size"].as(); std::string model_dir = ft_instance_hyperparameter["model_dir"].as(); - if (data_type == "half" || data_type == "fp16") { + if (data_type == "half" || data_type == "fp16" || data_type == 'float16') { return std::make_shared>( ft_instance_hyperparameter["tensor_para_size"].as(), ft_instance_hyperparameter["pipeline_para_size"].as(), ft_instance_hyperparameter["enable_custom_all_reduce"].as(0), model_dir); } - else if (data_type == "bf16") { + else if (data_type == "bf16" || data_type == 'bfloat16') { #ifdef ENABLE_BF16 return std::make_shared>( ft_instance_hyperparameter["tensor_para_size"].as(), @@ -278,10 +278,10 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, engines_.resize(device_count); const std::string weight_type_str = model_reader["weight_type"].as(); - if (weight_type_str == "fp16") { + if (weight_type_str == "fp16" || weight_type_str == 'float16') { weight_type_ = ft::WeightType::kFP16; } - else if (weight_type_str == "bf16") { + else if (weight_type_str == "bf16" || weight_type_str == 'bfloat16') { weight_type_ = ft::WeightType::kBF16; } else if (weight_type_str == "fp32") { From 5969cf3c15e7d9f70354e2b6f4abec56983a3614 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Sat, 14 Sep 2024 19:10:58 +0800 Subject: [PATCH 2/9] update --- src/turbomind/python/bind.cpp | 4 ++-- src/turbomind/triton_backend/llama/LlamaTritonModel.cc | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/turbomind/python/bind.cpp b/src/turbomind/python/bind.cpp index 62142a052..4eb34249f 100644 --- a/src/turbomind/python/bind.cpp +++ b/src/turbomind/python/bind.cpp @@ -373,13 +373,13 @@ PYBIND11_MODULE(_turbomind, m) PyGILState_Release(state); } }; - if (data_type == "half" || data_type == "fp16" || data_type == 'float16' || data_type == "int4") { + if (data_type == "half" || data_type == "fp16" || data_type == "float16" || data_type == "int4") { auto model = std::make_shared>( tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir, config); model->set_ffi_lock(gil_control); return model; } - else if (data_type == "bf16" || data_type == 'bfloat16') { + else if (data_type == "bf16" || data_type == "bfloat16") { #ifdef ENABLE_BF16 auto model = std::make_shared>( tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir, config); diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index c67177400..7829a4924 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -49,14 +49,14 @@ std::shared_ptr AbstractTransformerModel::createLlamaM int tensor_para_size = ft_instance_hyperparameter["tensor_para_size"].as(); std::string model_dir = ft_instance_hyperparameter["model_dir"].as(); - if (data_type == "half" || data_type == "fp16" || data_type == 'float16') { + if (data_type == "half" || data_type == "fp16" || data_type == "float16") { return std::make_shared>( ft_instance_hyperparameter["tensor_para_size"].as(), ft_instance_hyperparameter["pipeline_para_size"].as(), ft_instance_hyperparameter["enable_custom_all_reduce"].as(0), model_dir); } - else if (data_type == "bf16" || data_type == 'bfloat16') { + else if (data_type == "bf16" || data_type == "bfloat16") { #ifdef ENABLE_BF16 return std::make_shared>( ft_instance_hyperparameter["tensor_para_size"].as(), @@ -278,10 +278,10 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, engines_.resize(device_count); const std::string weight_type_str = model_reader["weight_type"].as(); - if (weight_type_str == "fp16" || weight_type_str == 'float16') { + if (weight_type_str == "fp16" || weight_type_str == "float16") { weight_type_ = ft::WeightType::kFP16; } - else if (weight_type_str == "bf16" || weight_type_str == 'bfloat16') { + else if (weight_type_str == "bf16" || weight_type_str == "bfloat16") { weight_type_ = ft::WeightType::kBF16; } else if (weight_type_str == "fp32") { From 9cf4041a5f1031d74afe19cc2679e117d542e5b4 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Sat, 14 Sep 2024 19:14:51 +0800 Subject: [PATCH 3/9] update --- lmdeploy/messages.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 13227824e..fa4219e8f 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -54,10 +54,29 @@ class GenerationConfig: in the decoding. Default to be True. logprobs (int): Number of log probabilities to return per output token. response_format (Dict): Only pytorch backend support formatting - response. Examples: `{"type": "json_schema", "json_schema": {"name":"test","schema": {"properties": {"name": {"type": "string"}}, "required": ["name"], "type": "object"}}}` - or `{"type": "regex_schema", "regex_schema": "call me [A-Za-z]{1,10}"}` + response. Examples: + { + "type": "json_schema", + "json_schema": { + "name": "test", + "schema": { + "properties": { + "name": { + "type": "string" + } + }, + "required": ["name"], + "type": "object" + } + } + } + or, + { + "type": "regex_schema", + "regex_schema": "call me [A-Za-z]{1,10}" + } logits_processors (List[Callable]): Custom logit processors. - """ # noqa + """ n: int = 1 max_new_tokens: int = 512 From a26dfd6bf8682b0d466c701b1654c70f1a044046 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Sat, 14 Sep 2024 19:42:44 +0800 Subject: [PATCH 4/9] add dtype in CLI --- lmdeploy/cli/cli.py | 13 +++++++++++++ lmdeploy/cli/serve.py | 8 ++++++++ lmdeploy/cli/utils.py | 13 +++++++++++++ lmdeploy/turbomind/chat.py | 27 +++++++++++++++++++------- lmdeploy/turbomind/deploy/converter.py | 17 ++++++++++++---- 5 files changed, 67 insertions(+), 11 deletions(-) diff --git a/lmdeploy/cli/cli.py b/lmdeploy/cli/cli.py index 59fb2f345..7eedd458f 100644 --- a/lmdeploy/cli/cli.py +++ b/lmdeploy/cli/cli.py @@ -66,6 +66,16 @@ def add_parser_convert(): default=None, help='the name of the built-in chat template, which can be ' 'overviewed by `lmdeploy list`') + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['auto', 'float16', 'bfloat16'], + help='data type for model weights and activations. ' + 'The "auto" option will use FP16 precision ' + 'for FP32 and FP16 models, and BF16 precision ' + 'for BF16 models. This option will be ignored if ' + 'the model is a quantized model') parser.set_defaults(run=CLI.convert) @staticmethod @@ -113,6 +123,7 @@ def add_parser_chat(): ArgumentHelper.adapters(pt_group) ArgumentHelper.device(pt_group) # common engine args + dtype_act = ArgumentHelper.dtype(pt_group) tp_act = ArgumentHelper.tp(pt_group) session_len_act = ArgumentHelper.session_len(pt_group) cache_max_entry_act = ArgumentHelper.cache_max_entry_count(pt_group) @@ -121,6 +132,7 @@ def add_parser_chat(): # turbomind args tb_group = parser.add_argument_group('TurboMind engine arguments') # common engine args + tb_group._group_actions.append(dtype_act) tb_group._group_actions.append(tp_act) tb_group._group_actions.append(session_len_act) tb_group._group_actions.append(cache_max_entry_act) @@ -245,6 +257,7 @@ def chat(args): adapters = get_lora_adapters(args.adapters) engine_config = PytorchEngineConfig( + dtype=args.dtype, tp=args.tp, session_len=args.session_len, cache_max_entry_count=args.cache_max_entry_count, diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 7dca6403b..bc03bb577 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -59,6 +59,7 @@ def add_parser_gradio(): pt_group = parser.add_argument_group('PyTorch engine arguments') # common engine args + dtype_act = ArgumentHelper.dtype(pt_group) tp_act = ArgumentHelper.tp(pt_group) ArgumentHelper.device(pt_group) session_len_act = ArgumentHelper.session_len(pt_group) @@ -71,6 +72,7 @@ def add_parser_gradio(): # turbomind args tb_group = parser.add_argument_group('TurboMind engine arguments') # common engine args + tb_group._group_actions.append(dtype_act) tb_group._group_actions.append(tp_act) tb_group._group_actions.append(session_len_act) tb_group._group_actions.append(max_batch_size_act) @@ -150,6 +152,7 @@ def add_parser_api_server(): ArgumentHelper.adapters(pt_group) ArgumentHelper.device(pt_group) # common engine args + dtype_act = ArgumentHelper.dtype(pt_group) tp_act = ArgumentHelper.tp(pt_group) session_len_act = ArgumentHelper.session_len(pt_group) max_batch_size_act = ArgumentHelper.max_batch_size(pt_group) @@ -161,6 +164,7 @@ def add_parser_api_server(): # turbomind args tb_group = parser.add_argument_group('TurboMind engine arguments') # common engine args + tb_group._group_actions.append(dtype_act) tb_group._group_actions.append(tp_act) tb_group._group_actions.append(session_len_act) tb_group._group_actions.append(max_batch_size_act) @@ -213,6 +217,7 @@ def gradio(args): backend = autoget_backend(args.model_path_or_server) if backend == 'pytorch': backend_config = PytorchEngineConfig( + dtype=args.dtype, tp=args.tp, max_batch_size=max_batch_size, cache_max_entry_count=args.cache_max_entry_count, @@ -223,6 +228,7 @@ def gradio(args): max_prefill_token_num=args.max_prefill_token_num) else: backend_config = TurbomindEngineConfig( + dtype=args.dtype, tp=args.tp, max_batch_size=max_batch_size, session_len=args.session_len, @@ -258,6 +264,7 @@ def api_server(args): from lmdeploy.messages import PytorchEngineConfig adapters = get_lora_adapters(args.adapters) backend_config = PytorchEngineConfig( + dtype=args.dtype, tp=args.tp, max_batch_size=max_batch_size, cache_max_entry_count=args.cache_max_entry_count, @@ -270,6 +277,7 @@ def api_server(args): else: from lmdeploy.messages import TurbomindEngineConfig backend_config = TurbomindEngineConfig( + dtype=args.dtype, tp=args.tp, max_batch_size=max_batch_size, session_len=args.session_len, diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index f4223164d..3437b864a 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -100,6 +100,19 @@ def model_name(parser): 'by the RESTful API `/v1/models`. If it is not specified, ' '`model_path` will be adopted') + @staticmethod + def dtype(parser, default: str = 'auto'): + return parser.add_argument( + '--dtype', + type=str, + default=default, + choices=['auto', 'float16', 'bfloat16'], + help='data type for model weights and activations. ' + 'The "auto" option will use FP16 precision ' + 'for FP32 and FP16 models, and BF16 precision ' + 'for BF16 models. This option will be ignored if ' + 'the model is a quantized model') + @staticmethod def model_format(parser, default: str = None): return parser.add_argument( diff --git a/lmdeploy/turbomind/chat.py b/lmdeploy/turbomind/chat.py index ade7875ce..e106beae1 100644 --- a/lmdeploy/turbomind/chat.py +++ b/lmdeploy/turbomind/chat.py @@ -35,6 +35,7 @@ def main(model_path: str, temperature: float = 0.8, repetition_penalty: float = 1.0, cap: str = 'chat', + dtype: str = 'auto', tp: int = 1, model_format: str = None, quant_policy: int = 0, @@ -57,20 +58,31 @@ def main(model_path: str, top_p (int): sampling top p. temperature (float): sampling temperature. repetition_penalty (float): parameter to penalize repetition - cap (str): the capability of a model. For example, codellama has the ability among ['completion', 'infilling', 'chat', 'python'] + cap (str): the capability of a model. For example, codellama has the + ability among ['completion', 'infilling', 'chat', 'python'] + dtype (str): data type for model weights and activations. It can be + one of the following values, ['auto', 'float16', 'bfloat16'] + The `auto` option will use FP16 precision for FP32 and FP16 + models, and BF16 precision for BF16 models. tp (int): GPU number used in tensor parallelism - model_format (str): the layout of the deployed model. It can be one of the following values [hf, llama, awq] - quant_policy (int): default to 0. When k/v is quantized into 8 bit, set it to 4 - cache_max_entry_count (float): the percentage of gpu memory occupied by the k/v cache. - cache_block_seq_len (int): the length of the token sequence in a k/v block, default to 64 - rope_scaling_factor (float): scaling factor used for dynamic ntk, default to 0. TurboMind follows the implementation of transformer LlamaAttention + model_format (str): the layout of the deployed model. It can be one + of the following values [hf, llama, awq] + quant_policy (int): default to 0. When k/v is quantized into 4 or 8 + bit, set it to 4 or 8, respectively + cache_max_entry_count (float): the percentage of gpu memory occupied + by the k/v cache. + cache_block_seq_len (int): the length of the token sequence in a k/v + block, default to 64 + rope_scaling_factor (float): scaling factor used for dynamic ntk, + default to 0. TurboMind follows the implementation of transformer + LlamaAttention enable_prefix_caching (bool): whether enable prefix caching session_len (int): the length input output tokens stream_output (bool): indicator for streaming output or not request_output_len (int): output token nums chat_template_config (ChatTemplateConfig): chat template config kwargs (dict): unused args - """ # noqa: E 501 + """ # chat template _, chat_template_name = get_names_from_model(model_path) @@ -96,6 +108,7 @@ def main(model_path: str, enable_prefix_caching=enable_prefix_caching, quant_policy=quant_policy, rope_scaling_factor=rope_scaling_factor, + dtype=dtype, tp=tp) print('engine_cfg:\n', engine_cfg, sep='', flush=True) diff --git a/lmdeploy/turbomind/deploy/converter.py b/lmdeploy/turbomind/deploy/converter.py index 442c2e0ce..cad772336 100644 --- a/lmdeploy/turbomind/deploy/converter.py +++ b/lmdeploy/turbomind/deploy/converter.py @@ -130,9 +130,11 @@ def get_output_model_registered_name_and_config(model_path: str, 'float16', 'bfloat16', 'int4' ] else 'float16' elif dtype in ['float16', 'bfloat16']: - assert weight_type != 'int4', f'the model {model_path} is a 4bit ' \ - f'weight quantized model but user specifies dtype {dtype}' - weight_type = dtype + if weight_type == 'int4': + logger.warn(f'The model {model_path} is a quantized model, so the ' + f'specified data type {dtype} is ignored') + else: + weight_type = dtype else: assert 0, 'unsupported specified data type {dtype}' @@ -281,6 +283,7 @@ def get_tm_model(model_path, def main(model_name: str, model_path: str, model_format: str = 'hf', + dtype: str = 'auto', chat_template: str = None, tokenizer_path: str = None, dst_path: str = 'workspace', @@ -299,6 +302,10 @@ def main(model_name: str, llama format, 'hf' means huggingface model, and 'awq', `gptq` means models quantized by `autoawq` and `autogptq` respectively. The default value is hf + dtype (str): data type for model weights and activations. It can be + one of the following values, ['auto', 'float16', 'bfloat16'] + The `auto` option will use FP16 precision for FP32 and FP16 + models, and BF16 precision for BF16 models. chat_template (str): the name of the built-in chat template. tokenizer_path (str): the path of tokenizer model dst_path (str): the destination path that saves outputs @@ -345,7 +352,9 @@ def main(model_name: str, tm_weight_path, tm_tokenizer_path = create_workspace(dst_path) copy_tokenizer(model_path, tokenizer_path, tm_tokenizer_path) - engine_config = TurbomindEngineConfig(tp=tp, model_format=model_format) + engine_config = TurbomindEngineConfig(tp=tp, + model_format=model_format, + dtype=dtype) tm_model = get_tm_model(model_path, model_name, chat_template, engine_config, group_size, tm_weight_path) tm_model.export() From d358e1e1fc7a15e90ec856ef02e1251b8e09ae01 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Mon, 16 Sep 2024 18:39:59 +0800 Subject: [PATCH 5/9] add user specified dtype for pytorch engine --- lmdeploy/messages.py | 6 +++- lmdeploy/pytorch/config.py | 48 +++++++++++++++++++------- lmdeploy/pytorch/engine/engine.py | 3 ++ lmdeploy/pytorch/engine/model_agent.py | 16 +++++++-- lmdeploy/turbomind/deploy/converter.py | 2 +- 5 files changed, 59 insertions(+), 16 deletions(-) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index fa4219e8f..742a48008 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -247,6 +247,10 @@ class PytorchEngineConfig: thread_safe (bool): thread safe engine instance. enable_prefix_caching (bool): Enable token match and sharing caches. device_type (str): The inference device type, options ['cuda'] + eager_mode (bool): Enable "eager" mode or not + custom_module_map (Dict): nn module map customized by users. Once + provided, the original nn modules of the model will be + substituted by the mapping ones download_dir (str): Directory to download and load the weights, default to the default cache directory of huggingface. revision (str): The specific model version to use. @@ -268,7 +272,7 @@ class PytorchEngineConfig: enable_prefix_caching: bool = False device_type: str = 'cuda' eager_mode: bool = False - custom_module_map: str = None + custom_module_map: Dict[str, str] = None download_dir: str = None revision: str = None diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index de5aeb0b1..71a041558 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -1,16 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. from dataclasses import dataclass -from typing import Any, List +from typing import Any, Dict, List import torch -def _update_torch_dtype(config: 'ModelConfig', default: str = 'float16'): +def _update_torch_dtype(config: 'ModelConfig', dtype: str): """Update the torch dtype from the model config. Args: config (ModelConfig): The input model config. - default (str): default device type. + dtype (str): user specified data type. Refer to + `PyTorchEngineConfig.dtype` for detailed info """ from lmdeploy.utils import get_logger logger = get_logger('lmdeploy') @@ -26,12 +27,20 @@ def _update_torch_dtype(config: 'ModelConfig', default: str = 'float16'): torch_dtype = getattr(config.hf_config, 'torch_dtype', None) if torch_dtype is None: + _dtype = 'float16' if dtype == 'auto' else dtype logger.warning('Model config does not have `torch_dtype`,' - f' use default: {default}') - torch_dtype = default + f' use: {_dtype}') + torch_dtype = _dtype # update hf_config as well setattr(config.hf_config, 'torch_dtype', torch_dtype) - + else: + # change to user specified data type if it is not 'auto' + if dtype == 'auto': + torch_dtype = torch_dtype if torch_dtype in [ + torch.float16, torch.bfloat16 + ] else 'float16' + else: + torch_dtype = dtype config.dtype = eval(f'torch.{torch_dtype}') return config @@ -97,7 +106,7 @@ class ModelConfig: vocab_size: int = 40000 hf_config: Any = None cogvlm_style: bool = False - custom_module_map: str = None + custom_module_map: Dict[str, setattr] = None def get_head_size(self): """get head size.""" @@ -106,8 +115,18 @@ def get_head_size(self): @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, - trust_remote_code: bool = True): - """build ModelConfig from model path or name.""" + trust_remote_code: bool = True, + dtype: str = 'auto'): + """Instantiate one of the configuration classes of the library from a + pretrained model configuration. + + Args: + pretrained_model_name_or_path (str): the pretrained model path + trust_remote_code (bool): Whether or not to allow for custom + models defined on the Hub in their own modeling files. + dtype (str): user specified data type for model weights and + activations. Refer to `PyTorchEngineConfig` for details + """ from transformers import AutoConfig hf_config = AutoConfig.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code) @@ -115,10 +134,15 @@ def from_pretrained(cls, # phi3 + trust_remote_code leads to error when tp. hf_config = AutoConfig.from_pretrained( pretrained_model_name_or_path) - return cls.from_hf_config(hf_config, pretrained_model_name_or_path) + return cls.from_hf_config(hf_config, + pretrained_model_name_or_path, + dtype=dtype) @classmethod - def from_hf_config(cls, hf_config: Any, model_path: str = None): + def from_hf_config(cls, + hf_config: Any, + model_path: str = None, + dtype: str = 'auto'): """from huggingface config.""" from lmdeploy.pytorch.configurations import AutoModelConfigBuilder @@ -132,7 +156,7 @@ def from_hf_config(cls, hf_config: Any, model_path: str = None): model_config.v_head_dim = model_config.head_dim # should after setting `hf_config` and `model_arch` attributes - model_config = _update_torch_dtype(model_config) + model_config = _update_torch_dtype(model_config, dtype) # update eos_token_id to list if isinstance(model_config.eos_token_id, int): diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 995fea6cb..b9919efb3 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -129,6 +129,8 @@ def __init__(self, engine_config.device_type) if engine_config.adapters is not None: check_adapters(list(engine_config.adapters.values())) + assert engine_config.dtype in ['auto', 'float16', 'bfloat16'], \ + f'unsupported specified data type {engine_config.dtype}' self.engine_config = engine_config self.tp = engine_config.tp @@ -171,6 +173,7 @@ def __init__(self, trust_remote_code=trust_remote_code, adapters=adapters, tp=self.tp, + dtype=engine_config.dtype, custom_module_map=engine_config.custom_module_map) cache_config = self.model_agent.cache_config diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 46d17b5c8..2a80f09b0 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -821,10 +821,22 @@ def build_model_agent(model_path: str, trust_remote_code: bool, adapters: Dict[str, str] = None, tp: int = 1, + dtype: str = 'auto', custom_module_map: str = None): - """create model agent.""" + """create model agent. + + Args: + model_path (str): the path of the input model + cache_config (CacheConfig): config of kv cache + backend_config (BackendConfig): config of backend devices + trust_remote_code (bool): To use the remote modeling code or not + adapters (Dict): lora adapters + tp (int): the number of devices to be used in tensor parallelism + dtype (str): the data type of model weights and activations + custom_module_map (str): customized nn module map + """ model_config = ModelConfig.from_pretrained( - model_path, trust_remote_code=trust_remote_code) + model_path, trust_remote_code=trust_remote_code, dtype=dtype) model_config.custom_module_map = custom_module_map if tp == 1: model_agent = BaseModelAgent(model_path, diff --git a/lmdeploy/turbomind/deploy/converter.py b/lmdeploy/turbomind/deploy/converter.py index cad772336..a444501b3 100644 --- a/lmdeploy/turbomind/deploy/converter.py +++ b/lmdeploy/turbomind/deploy/converter.py @@ -136,7 +136,7 @@ def get_output_model_registered_name_and_config(model_path: str, else: weight_type = dtype else: - assert 0, 'unsupported specified data type {dtype}' + assert 0, f'unsupported specified data type {dtype}' config.model_config.model_arch = model_arch config.model_config.weight_type = weight_type From ec93ccafedbcd1d6f8ccbc13dd54638afb2b8c3b Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Mon, 16 Sep 2024 18:42:36 +0800 Subject: [PATCH 6/9] update log --- src/turbomind/kernels/attention/attention.cu | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/turbomind/kernels/attention/attention.cu b/src/turbomind/kernels/attention/attention.cu index ab615ea23..60cc7e690 100644 --- a/src/turbomind/kernels/attention/attention.cu +++ b/src/turbomind/kernels/attention/attention.cu @@ -31,6 +31,11 @@ void dispatchAttention(const AttentionParams& params) params); } } + else { + if (params.arch < 80) { + TM_LOG_ERROR("CUDA architecture sm%d does not support data type 'bfloat16'. Please specify dtype 'float16'", params.arch); + } + } } FT_CHECK(0); } From 65c539540ae0f3be3b51f8402e1c93a7f66eeb36 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Mon, 16 Sep 2024 19:23:02 +0800 Subject: [PATCH 7/9] fix linting & ut --- src/turbomind/kernels/attention/attention.cu | 4 +- .../test_turbomind/test_converter.py | 75 ++++++++++++------- 2 files changed, 52 insertions(+), 27 deletions(-) diff --git a/src/turbomind/kernels/attention/attention.cu b/src/turbomind/kernels/attention/attention.cu index 60cc7e690..ffbad56b4 100644 --- a/src/turbomind/kernels/attention/attention.cu +++ b/src/turbomind/kernels/attention/attention.cu @@ -33,7 +33,9 @@ void dispatchAttention(const AttentionParams& params) } else { if (params.arch < 80) { - TM_LOG_ERROR("CUDA architecture sm%d does not support data type 'bfloat16'. Please specify dtype 'float16'", params.arch); + TM_LOG_ERROR( + "CUDA architecture sm%d does not support data type 'bfloat16'. Please specify dtype 'float16'", + params.arch); } } } diff --git a/tests/test_lmdeploy/test_turbomind/test_converter.py b/tests/test_lmdeploy/test_turbomind/test_converter.py index 0d125fe74..3548eac7d 100644 --- a/tests/test_lmdeploy/test_turbomind/test_converter.py +++ b/tests/test_lmdeploy/test_turbomind/test_converter.py @@ -7,30 +7,31 @@ def test_registered_models(): for model, model_format, group_size, weight_type, register_name in [ - ('internlm/internlm2-7b', 'hf', 0, 'bf16', 'tm'), - ('baichuan-inc/Baichuan-7B', 'hf', 0, 'fp16', 'tm'), - ('baichuan-inc/Baichuan2-7B-Chat', 'hf', 0, 'bf16', 'tm'), - ('baichuan-inc/Baichuan-13B-Chat', 'hf', 0, 'bf16', 'tm'), - ('baichuan-inc/Baichuan2-13B-Chat', 'hf', 0, 'bf16', 'tm'), - ('internlm/internlm-chat-7b', 'hf', 0, 'fp16', 'tm'), - ('internlm/internlm2-chat-7b', 'hf', 0, 'bf16', 'tm'), - ('internlm/internlm-xcomposer2-4khd-7b', 'hf', 0, 'bf16', 'tm'), - ('internlm/internlm-xcomposer2-vl-7b', 'hf', 0, 'bf16', 'tm'), - ('internlm/internlm-xcomposer2-7b', 'hf', 0, 'bf16', 'tm'), - ('lmsys/vicuna-7b-v1.5', 'hf', 0, 'fp16', 'tm'), - ('01-ai/Yi-1.5-9B', 'hf', 0, 'bf16', 'tm'), - ('deepseek-ai/deepseek-coder-6.7b-instruct', 'hf', 0, 'bf16', 'tm'), - ('deepseek-ai/deepseek-llm-7b-chat', 'hf', 0, 'bf16', 'tm'), - ('Qwen/Qwen-7B-Chat', 'hf', 0, 'bf16', 'tm'), - ('Qwen/Qwen1.5-7B-Chat', 'hf', 0, 'bf16', 'tm'), - ('Qwen/Qwen2-7B-Instruct', 'hf', 0, 'bf16', 'tm'), - ('Qwen/Qwen-VL-Chat', 'hf', 0, 'bf16', 'tm'), - ('liuhaotian/llava-v1.6-34b', 'hf', 0, 'bf16', 'tm'), - ('liuhaotian/llava-v1.6-mistral-7b', 'hf', 0, 'bf16', 'tm'), - ('liuhaotian/llava-v1.6-vicuna-13b', 'hf', 0, 'bf16', 'tm'), - ('OpenGVLab/InternVL-Chat-V1-5', 'hf', 0, 'bf16', 'tm'), - ('deepseek-ai/deepseek-vl-7b-chat', 'hf', 0, 'fp16', 'tm'), - ('YanweiLi/MGM-7B', 'hf', 0, 'bf16', 'tm'), + ('internlm/internlm2-7b', 'hf', 0, 'bfloat16', 'tm'), + ('baichuan-inc/Baichuan-7B', 'hf', 0, 'float16', 'tm'), + ('baichuan-inc/Baichuan2-7B-Chat', 'hf', 0, 'bfloat16', 'tm'), + ('baichuan-inc/Baichuan-13B-Chat', 'hf', 0, 'bfloat16', 'tm'), + ('baichuan-inc/Baichuan2-13B-Chat', 'hf', 0, 'bfloat16', 'tm'), + ('internlm/internlm-chat-7b', 'hf', 0, 'float16', 'tm'), + ('internlm/internlm2-chat-7b', 'hf', 0, 'bfloat16', 'tm'), + ('internlm/internlm-xcomposer2-4khd-7b', 'hf', 0, 'bfloat16', 'tm'), + ('internlm/internlm-xcomposer2-vl-7b', 'hf', 0, 'bfloat16', 'tm'), + ('internlm/internlm-xcomposer2-7b', 'hf', 0, 'bfloat16', 'tm'), + ('lmsys/vicuna-7b-v1.5', 'hf', 0, 'float16', 'tm'), + ('01-ai/Yi-1.5-9B', 'hf', 0, 'bfloat16', 'tm'), + ('deepseek-ai/deepseek-coder-6.7b-instruct', 'hf', 0, + 'bfloat16', 'tm'), + ('deepseek-ai/deepseek-llm-7b-chat', 'hf', 0, 'bfloat16', 'tm'), + ('Qwen/Qwen-7B-Chat', 'hf', 0, 'bfloat16', 'tm'), + ('Qwen/Qwen1.5-7B-Chat', 'hf', 0, 'bfloat16', 'tm'), + ('Qwen/Qwen2-7B-Instruct', 'hf', 0, 'bfloat16', 'tm'), + ('Qwen/Qwen-VL-Chat', 'hf', 0, 'bfloat16', 'tm'), + ('liuhaotian/llava-v1.6-34b', 'hf', 0, 'bfloat16', 'tm'), + ('liuhaotian/llava-v1.6-mistral-7b', 'hf', 0, 'bfloat16', 'tm'), + ('liuhaotian/llava-v1.6-vicuna-13b', 'hf', 0, 'bfloat16', 'tm'), + ('OpenGVLab/InternVL-Chat-V1-5', 'hf', 0, 'bfloat16', 'tm'), + ('deepseek-ai/deepseek-vl-7b-chat', 'hf', 0, 'float16', 'tm'), + ('YanweiLi/MGM-7B', 'hf', 0, 'bfloat16', 'tm'), ('Qwen/Qwen1.5-4B-Chat-AWQ', 'awq', 128, 'int4', 'tm'), ('solidrust/Meta-Llama-3-8B-Instruct-hf-AWQ', 'awq', 128, 'int4', 'tm'), @@ -42,7 +43,7 @@ def test_registered_models(): assert input_name in list(INPUT_MODELS.module_dict.keys()) output_name, config, _ = get_output_model_registered_name_and_config( - model, model_format=model_format, group_size=0) + model, model_format=model_format, dtype='auto', group_size=0) assert output_name == register_name assert config.model_config.group_size == group_size assert config.weight_type == weight_type @@ -53,7 +54,10 @@ def test_registered_models(): def test_update_from_engine_config(): import copy _, _config, _ = get_output_model_registered_name_and_config( - 'internlm/internlm2-chat-7b', model_format='hf', group_size=0) + 'internlm/internlm2-chat-7b', + model_format='hf', + dtype='auto', + group_size=0) config = copy.deepcopy(_config) config.update_from_engine_config(None) assert (config == _config) @@ -85,3 +89,22 @@ def test_update_from_engine_config(): engine_config.rope_scaling_factor) assert ( config.attention_config.use_logn_attn == engine_config.use_logn_attn) + + +def test_dtype(): + testsets = [('auto', 'bfloat16'), ('float16', 'float16'), + ('bfloat16', 'bfloat16')] + for specified_dtype, expected_dtype in testsets: + _, _config, _ = get_output_model_registered_name_and_config( + 'internlm/internlm2-chat-7b', + model_format='hf', + dtype=specified_dtype, + group_size=0) + assert _config.weight_type == expected_dtype + for specified_dtype in ['auto', 'float16', 'bfloat16']: + _, _config, _ = get_output_model_registered_name_and_config( + 'internlm/internlm2_5-20b-chat-4bit-awq', + model_format='awq', + dtype=specified_dtype, + group_size=128) + assert _config.weight_type == 'int4' From 15f8641d3d2db900bf73c97cc4f80e6a2317e535 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Mon, 16 Sep 2024 21:17:13 +0800 Subject: [PATCH 8/9] fix vl --- lmdeploy/turbomind/turbomind.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 4407d0fea..5e79fb856 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -471,7 +471,9 @@ def prepare_embeddings(self, if item and isinstance(item[0], np.ndarray): item = [torch.from_numpy(x).squeeze() for x in item] # convert to lookup table type - _MAP = dict(fp32=torch.float, bf16=torch.bfloat16) + _MAP = dict(float=torch.float, + bfloat16=torch.bfloat16, + float16=torch.float16) dtype = _MAP.get(self.tm_model.config.weight_type, torch.float16) item = [x.to(dtype=dtype) for x in item] item = item or [torch.zeros(0, hidden_dim, dtype=dtype)] From 1aa9bc36dd56a8fae2f9c02beeaf7f89f7fc9961 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Wed, 18 Sep 2024 12:27:43 +0800 Subject: [PATCH 9/9] minor --- lmdeploy/pytorch/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 71a041558..2625de5dc 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -38,7 +38,7 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str): if dtype == 'auto': torch_dtype = torch_dtype if torch_dtype in [ torch.float16, torch.bfloat16 - ] else 'float16' + ] else torch.float16 else: torch_dtype = dtype config.dtype = eval(f'torch.{torch_dtype}')