Skip to content

Commit

Permalink
apply code review changes
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Kaminski <pikaminski@nvidia.com>
  • Loading branch information
Piotr Kaminski committed Aug 23, 2024
1 parent 250525e commit 69b4f69
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 31 deletions.
8 changes: 4 additions & 4 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ def export(
multiple_profiles: bool = False,
gpt_attention_plugin: str = "auto",
gemm_plugin: str = "auto",
fp8_quantized: bool = False,
fp8_kvcache: bool = False,
fp8_quantized: Optional[bool] = None,
fp8_kvcache: Optional[bool] = None,
):
"""
Exports nemo checkpoints to TensorRT-LLM.
Expand Down Expand Up @@ -204,8 +204,8 @@ def export(
multiple_profiles: (bool): enables multiple profiles feature of TRT-LLM. Default = False
gpt_attention_plugin (str): enable the gpt attention plugin. Default = "auto"
gemm_plugin (str): enable the gpt plugin. Default = "auto"
fp8_quantized (bool): enables exporting to FP8 TRT-LLM checkpoints
fp8_kvcache (bool): enables FP8 KV-cache quantization
fp8_quantized (Optional[bool]): enables exporting to FP8 TRT-LLM checkpoints. If not set, autodetects the type.
fp8_kvcache (Optional[bool]): enables FP8 KV-cache quantization. If not set, autodetects the type.
"""

if n_gpus is not None:
Expand Down
22 changes: 16 additions & 6 deletions nemo/export/trt_llm/converter/model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,10 @@
def get_config(decoder_type, config):
if decoder_type == "llama":
return LLaMAConfig(**config)

if decoder_type in ["gpt", "gptnext"]:
elif decoder_type == "gpt" or decoder_type == "gptnext":
return GPTConfig(**config)

return PretrainedConfig(**config)
else:
return PretrainedConfig(**config)


def prompt_convert(prompt_config, prompt_weights):
Expand Down Expand Up @@ -95,6 +94,16 @@ def create_common_export_config(nemo_model_config, decoder_type, fp8_quantized=F
}


def determine_quantization_settings(nemo_model_config, fp8_quantized: Optional[bool] = None, fp8_kvcache: Optional[bool] = None) -> Tuple[bool, bool]:
is_nemo_quantized = nemo_model_config.get('fp8', False)
if fp8_quantized is None:
fp8_quantized = is_nemo_quantized
if fp8_kvcache is None:
fp8_kvcache = is_nemo_quantized

return fp8_quantized, fp8_kvcache


def model_to_trtllm_ckpt(
model,
nemo_model_config,
Expand All @@ -109,15 +118,16 @@ def model_to_trtllm_ckpt(
use_distributed_convert: bool = False,
model_parallel_rank: int = None,
vocab_size: Optional[int] = None,
fp8_quantized: bool = False,
fp8_kvcache: bool = False,
fp8_quantized: Optional[bool] = None,
fp8_kvcache: Optional[bool] = None,
) -> Tuple[List[Dict], List[PretrainedConfig]]:
if nemo_model_config.get("share_embeddings_and_output_weights", False) and not use_embedding_sharing:
LOGGER.info(
"Found share_embeddings_and_output_weights is True in NeMo config, set use_embedding_sharing = True"
)
use_embedding_sharing = True

fp8_quantized, fp8_kvcache = determine_quantization_settings(nemo_model_config, fp8_quantized, fp8_kvcache)
export_config = create_common_export_config(nemo_model_config, decoder_type, fp8_quantized, fp8_kvcache)
# If the model has been sharded with model parallelism, convert the model in a gpu-distributed manner
if use_distributed_convert:
Expand Down
46 changes: 32 additions & 14 deletions scripts/export/export_to_trt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
import argparse
import logging
import sys
from typing import Optional

from nemo.export.tensorrt_llm import TensorRTLLM

LOGGER = logging.getLogger("NeMo")

class UsageError(Exception):
pass

def get_args(argv):
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -50,20 +53,6 @@ def get_args(argv):
type=str,
help="dtype of the model on TensorRT-LLM",
)
parser.add_argument(
"-fp8",
"--export_fp8_quantized",
default=False,
type=bool,
help="Enables exporting to a FP8-quantized TRT LLM checkpoint",
)
parser.add_argument(
"-kv_fp8",
"--use_fp8_kv_cache",
default=False,
type=bool,
help="Enables exporting with FP8-quantizatized KV-cache",
)
parser.add_argument("-mil", "--max_input_len", default=256, type=int, help="Max input length of the model")
parser.add_argument("-mol", "--max_output_len", default=256, type=int, help="Max output length of the model")
parser.add_argument("-mbs", "--max_batch_size", default=8, type=int, help="Max batch size of the model")
Expand Down Expand Up @@ -121,8 +110,37 @@ def get_args(argv):
'It is used to compute the workspace size of lora plugin.',
)
parser.add_argument("-dm", "--debug_mode", default=False, action='store_true', help="Enable debug mode")
parser.add_argument(
"-fp8",
"--export_fp8_quantized",
default="auto",
type=str,
help="Enables exporting to a FP8-quantized TRT LLM checkpoint",
)
parser.add_argument(
"-kv_fp8",
"--use_fp8_kv_cache",
default="auto",
type=str,
help="Enables exporting with FP8-quantizatized KV-cache",
)

args = parser.parse_args(argv)

def str_to_bool(name: str, s: str, optional: bool = False) -> Optional[bool]:
s = s.lower()
true_strings = ["true", "1"]
false_strings = ["false", "0"]
if s in true_strings:
return True
if s in false_strings:
return False
if optional and s == 'auto':
return None
raise UsageError(f"Invalid boolean value for argument --{name}: '{s}'")

args.export_fp8_quantized = str_to_bool("export_fp8_quantized", args.export_fp8_quantized, optional=True)
args.use_fp8_kv_cache = str_to_bool("use_fp8_kv_cache", args.use_fp8_kv_cache, optional=True)
return args


Expand Down
17 changes: 10 additions & 7 deletions tests/export/nemo_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,27 +759,30 @@ def get_args():
parser.add_argument(
"-fp8",
"--export_fp8_quantized",
default="False",
default="auto",
type=str,
help="Enables exporting to a FP8-quantized TRT LLM checkpoint",
)
parser.add_argument(
"-kv_fp8",
"--use_fp8_kv_cache",
default="False",
default="auto",
type=str,
help="Enables exporting with FP8-quantizatized KV-cache",
)

args = parser.parse_args()

def str_to_bool(name: str, s: str) -> bool:
def str_to_bool(name: str, s: str, optional: bool = False) -> Optional[bool]:
s = s.lower()
true_strings = ["true", "1"]
false_strings = ["false", "0"]
if s.lower() in true_strings:
if s in true_strings:
return True
if s.lower() in false_strings:
if s in false_strings:
return False
if optional and s == 'auto':
return None
raise UsageError(f"Invalid boolean value for argument --{name}: '{s}'")

args.test_cpp_runtime = str_to_bool("test_cpp_runtime", args.test_cpp_runtime)
Expand All @@ -790,8 +793,8 @@ def str_to_bool(name: str, s: str) -> bool:
args.use_vllm = str_to_bool("use_vllm", args.use_vllm)
args.use_parallel_embedding = str_to_bool("use_parallel_embedding", args.use_parallel_embedding)
args.in_framework = str_to_bool("in_framework", args.in_framework)
args.export_fp8_quantized = str_to_bool("export_fp8_quantized", args.export_fp8_quantized)
args.use_fp8_kv_cache = str_to_bool("use_fp8_kv_cache", args.use_fp8_kv_cache)
args.export_fp8_quantized = str_to_bool("export_fp8_quantized", args.export_fp8_quantized, optional=True)
args.use_fp8_kv_cache = str_to_bool("use_fp8_kv_cache", args.use_fp8_kv_cache, optional=True)

return args

Expand Down

0 comments on commit 69b4f69

Please sign in to comment.