From 69b4f69c0150735f9a2c617f0ae5f36dc203b40d Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Fri, 23 Aug 2024 02:51:05 -0700 Subject: [PATCH] apply code review changes Signed-off-by: Piotr Kaminski --- nemo/export/tensorrt_llm.py | 8 ++-- .../trt_llm/converter/model_converter.py | 22 ++++++--- scripts/export/export_to_trt_llm.py | 46 +++++++++++++------ tests/export/nemo_export.py | 17 ++++--- 4 files changed, 62 insertions(+), 31 deletions(-) diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 7b7c4e07e225..06a876c2b833 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -167,8 +167,8 @@ def export( multiple_profiles: bool = False, gpt_attention_plugin: str = "auto", gemm_plugin: str = "auto", - fp8_quantized: bool = False, - fp8_kvcache: bool = False, + fp8_quantized: Optional[bool] = None, + fp8_kvcache: Optional[bool] = None, ): """ Exports nemo checkpoints to TensorRT-LLM. @@ -204,8 +204,8 @@ def export( multiple_profiles: (bool): enables multiple profiles feature of TRT-LLM. Default = False gpt_attention_plugin (str): enable the gpt attention plugin. Default = "auto" gemm_plugin (str): enable the gpt plugin. Default = "auto" - fp8_quantized (bool): enables exporting to FP8 TRT-LLM checkpoints - fp8_kvcache (bool): enables FP8 KV-cache quantization + fp8_quantized (Optional[bool]): enables exporting to FP8 TRT-LLM checkpoints. If not set, autodetects the type. + fp8_kvcache (Optional[bool]): enables FP8 KV-cache quantization. If not set, autodetects the type. """ if n_gpus is not None: diff --git a/nemo/export/trt_llm/converter/model_converter.py b/nemo/export/trt_llm/converter/model_converter.py index 9eac3acfa708..b7c959d3b5b5 100755 --- a/nemo/export/trt_llm/converter/model_converter.py +++ b/nemo/export/trt_llm/converter/model_converter.py @@ -39,11 +39,10 @@ def get_config(decoder_type, config): if decoder_type == "llama": return LLaMAConfig(**config) - - if decoder_type in ["gpt", "gptnext"]: + elif decoder_type == "gpt" or decoder_type == "gptnext": return GPTConfig(**config) - - return PretrainedConfig(**config) + else: + return PretrainedConfig(**config) def prompt_convert(prompt_config, prompt_weights): @@ -95,6 +94,16 @@ def create_common_export_config(nemo_model_config, decoder_type, fp8_quantized=F } +def determine_quantization_settings(nemo_model_config, fp8_quantized: Optional[bool] = None, fp8_kvcache: Optional[bool] = None) -> Tuple[bool, bool]: + is_nemo_quantized = nemo_model_config.get('fp8', False) + if fp8_quantized is None: + fp8_quantized = is_nemo_quantized + if fp8_kvcache is None: + fp8_kvcache = is_nemo_quantized + + return fp8_quantized, fp8_kvcache + + def model_to_trtllm_ckpt( model, nemo_model_config, @@ -109,8 +118,8 @@ def model_to_trtllm_ckpt( use_distributed_convert: bool = False, model_parallel_rank: int = None, vocab_size: Optional[int] = None, - fp8_quantized: bool = False, - fp8_kvcache: bool = False, + fp8_quantized: Optional[bool] = None, + fp8_kvcache: Optional[bool] = None, ) -> Tuple[List[Dict], List[PretrainedConfig]]: if nemo_model_config.get("share_embeddings_and_output_weights", False) and not use_embedding_sharing: LOGGER.info( @@ -118,6 +127,7 @@ def model_to_trtllm_ckpt( ) use_embedding_sharing = True + fp8_quantized, fp8_kvcache = determine_quantization_settings(nemo_model_config, fp8_quantized, fp8_kvcache) export_config = create_common_export_config(nemo_model_config, decoder_type, fp8_quantized, fp8_kvcache) # If the model has been sharded with model parallelism, convert the model in a gpu-distributed manner if use_distributed_convert: diff --git a/scripts/export/export_to_trt_llm.py b/scripts/export/export_to_trt_llm.py index 7a240a6c4e6d..06193b06aee7 100644 --- a/scripts/export/export_to_trt_llm.py +++ b/scripts/export/export_to_trt_llm.py @@ -15,11 +15,14 @@ import argparse import logging import sys +from typing import Optional from nemo.export.tensorrt_llm import TensorRTLLM LOGGER = logging.getLogger("NeMo") +class UsageError(Exception): + pass def get_args(argv): parser = argparse.ArgumentParser( @@ -50,20 +53,6 @@ def get_args(argv): type=str, help="dtype of the model on TensorRT-LLM", ) - parser.add_argument( - "-fp8", - "--export_fp8_quantized", - default=False, - type=bool, - help="Enables exporting to a FP8-quantized TRT LLM checkpoint", - ) - parser.add_argument( - "-kv_fp8", - "--use_fp8_kv_cache", - default=False, - type=bool, - help="Enables exporting with FP8-quantizatized KV-cache", - ) parser.add_argument("-mil", "--max_input_len", default=256, type=int, help="Max input length of the model") parser.add_argument("-mol", "--max_output_len", default=256, type=int, help="Max output length of the model") parser.add_argument("-mbs", "--max_batch_size", default=8, type=int, help="Max batch size of the model") @@ -121,8 +110,37 @@ def get_args(argv): 'It is used to compute the workspace size of lora plugin.', ) parser.add_argument("-dm", "--debug_mode", default=False, action='store_true', help="Enable debug mode") + parser.add_argument( + "-fp8", + "--export_fp8_quantized", + default="auto", + type=str, + help="Enables exporting to a FP8-quantized TRT LLM checkpoint", + ) + parser.add_argument( + "-kv_fp8", + "--use_fp8_kv_cache", + default="auto", + type=str, + help="Enables exporting with FP8-quantizatized KV-cache", + ) args = parser.parse_args(argv) + + def str_to_bool(name: str, s: str, optional: bool = False) -> Optional[bool]: + s = s.lower() + true_strings = ["true", "1"] + false_strings = ["false", "0"] + if s in true_strings: + return True + if s in false_strings: + return False + if optional and s == 'auto': + return None + raise UsageError(f"Invalid boolean value for argument --{name}: '{s}'") + + args.export_fp8_quantized = str_to_bool("export_fp8_quantized", args.export_fp8_quantized, optional=True) + args.use_fp8_kv_cache = str_to_bool("use_fp8_kv_cache", args.use_fp8_kv_cache, optional=True) return args diff --git a/tests/export/nemo_export.py b/tests/export/nemo_export.py index 7fdfd73e232f..ecaf198a0c07 100644 --- a/tests/export/nemo_export.py +++ b/tests/export/nemo_export.py @@ -759,27 +759,30 @@ def get_args(): parser.add_argument( "-fp8", "--export_fp8_quantized", - default="False", + default="auto", type=str, help="Enables exporting to a FP8-quantized TRT LLM checkpoint", ) parser.add_argument( "-kv_fp8", "--use_fp8_kv_cache", - default="False", + default="auto", type=str, help="Enables exporting with FP8-quantizatized KV-cache", ) args = parser.parse_args() - def str_to_bool(name: str, s: str) -> bool: + def str_to_bool(name: str, s: str, optional: bool = False) -> Optional[bool]: + s = s.lower() true_strings = ["true", "1"] false_strings = ["false", "0"] - if s.lower() in true_strings: + if s in true_strings: return True - if s.lower() in false_strings: + if s in false_strings: return False + if optional and s == 'auto': + return None raise UsageError(f"Invalid boolean value for argument --{name}: '{s}'") args.test_cpp_runtime = str_to_bool("test_cpp_runtime", args.test_cpp_runtime) @@ -790,8 +793,8 @@ def str_to_bool(name: str, s: str) -> bool: args.use_vllm = str_to_bool("use_vllm", args.use_vllm) args.use_parallel_embedding = str_to_bool("use_parallel_embedding", args.use_parallel_embedding) args.in_framework = str_to_bool("in_framework", args.in_framework) - args.export_fp8_quantized = str_to_bool("export_fp8_quantized", args.export_fp8_quantized) - args.use_fp8_kv_cache = str_to_bool("use_fp8_kv_cache", args.use_fp8_kv_cache) + args.export_fp8_quantized = str_to_bool("export_fp8_quantized", args.export_fp8_quantized, optional=True) + args.use_fp8_kv_cache = str_to_bool("use_fp8_kv_cache", args.use_fp8_kv_cache, optional=True) return args