Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export fp8 te nemo to trt-llm #10096

Merged
merged 27 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f966d05
initial commit
Aug 14, 2024
5087268
PR draft
Aug 14, 2024
61d0f47
fixed scaling weights
Aug 14, 2024
542d843
Apply isort and black reformatting
Aug 14, 2024
042d325
Apply isort and black reformatting
Aug 14, 2024
76535b4
fixed zarr loading, added flags, refactor
Aug 16, 2024
63e8faa
Merge branch 'main' into export_fp8_te_nemo_to_trtllm
Laplasjan107 Aug 16, 2024
7a1d042
Apply isort and black reformatting
Laplasjan107 Aug 16, 2024
7d087dd
fix expert key mapping
Aug 16, 2024
f782f6b
Merge branch 'main' into export_fp8_te_nemo_to_trtllm
Laplasjan107 Aug 19, 2024
f5ff40e
refactor
Aug 21, 2024
a11bc2f
Apply isort and black reformatting
Laplasjan107 Aug 21, 2024
7d150d7
Merge branch 'main' into export_fp8_te_nemo_to_trtllm
Laplasjan107 Aug 21, 2024
ec14cb4
fix: failed test was finishing with exit code 0
Aug 21, 2024
078c88b
Merge branch 'export_fp8_te_nemo_to_trtllm' of https://github.com/Lap…
Aug 21, 2024
157f444
Merge branch 'main' into export_fp8_te_nemo_to_trtllm
Laplasjan107 Aug 21, 2024
73d9261
test commit -- rerun github checks
Aug 21, 2024
84a5e5e
bugfix: naming
Aug 21, 2024
250525e
bugfix v2: naming
Aug 21, 2024
69b4f69
apply code review changes
Aug 23, 2024
487edd0
Apply isort and black reformatting
Laplasjan107 Aug 23, 2024
e2a3139
fix TensorRTLLM build (fp8 still not supported)
Aug 27, 2024
19c8662
Apply isort and black reformatting
Laplasjan107 Aug 27, 2024
b01fdba
undo refactor
Aug 27, 2024
a3449d2
Merge branch 'export_fp8_te_nemo_to_trtllm' of https://github.com/Lap…
Aug 28, 2024
0c922b7
bugfix: arguments to dist_convert
Aug 28, 2024
bcf85e4
Apply isort and black reformatting
Laplasjan107 Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions nemo/export/tensorrt_llm.py
terrykong marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def export(
multiple_profiles: bool = False,
gpt_attention_plugin: str = "auto",
gemm_plugin: str = "auto",
fp8_quantized: Optional[bool] = None,
fp8_kvcache: Optional[bool] = None,
):
"""
Exports nemo checkpoints to TensorRT-LLM.
Expand Down Expand Up @@ -202,6 +204,8 @@ def export(
multiple_profiles: (bool): enables multiple profiles feature of TRT-LLM. Default = False
gpt_attention_plugin (str): enable the gpt attention plugin. Default = "auto"
gemm_plugin (str): enable the gpt plugin. Default = "auto"
fp8_quantized (Optional[bool]): enables exporting to FP8 TRT-LLM checkpoints. If not set, autodetects the type.
fp8_kvcache (Optional[bool]): enables FP8 KV-cache quantization. If not set, autodetects the type.
"""

if n_gpus is not None:
Expand Down Expand Up @@ -324,6 +328,8 @@ def export(
gpus_per_node=gpus_per_node,
use_parallel_embedding=use_parallel_embedding,
use_embedding_sharing=use_embedding_sharing,
fp8_quantized=fp8_quantized,
fp8_kvcache=fp8_kvcache,
)

for weight_dict, model_config in zip(weights_dicts, model_configs):
Expand Down
35 changes: 27 additions & 8 deletions nemo/export/trt_llm/converter/model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@

import csv
import logging
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple

import numpy as np
import tensorrt_llm
import torch
from tensorrt_llm._utils import pad_vocab_size
from tensorrt_llm.functional import non_gated_version
from tensorrt_llm.layers import MoeConfig
Expand Down Expand Up @@ -78,6 +79,18 @@ def prompt_convert(prompt_config, prompt_weights):
return vtokens_embeddings


def determine_quantization_settings(
nemo_model_config, fp8_quantized: Optional[bool] = None, fp8_kvcache: Optional[bool] = None
) -> Tuple[bool, bool]:
is_nemo_quantized = nemo_model_config.get('fp8', False)
if fp8_quantized is None:
fp8_quantized = is_nemo_quantized
if fp8_kvcache is None:
fp8_kvcache = is_nemo_quantized

return fp8_quantized, fp8_kvcache


def model_to_trtllm_ckpt(
model,
nemo_model_config,
Expand All @@ -91,15 +104,17 @@ def model_to_trtllm_ckpt(
use_embedding_sharing: bool = False,
use_distributed_convert: bool = False,
model_parallel_rank: int = None,
vocab_size: int = None,
vocab_size: Optional[int] = None,
fp8_quantized: Optional[bool] = None,
fp8_kvcache: Optional[bool] = None,
) -> Tuple[List[Dict], List[PretrainedConfig]]:

if nemo_model_config.get("share_embeddings_and_output_weights", False) and not use_embedding_sharing:
LOGGER.info(
"Found share_embeddings_and_output_weights is True in NeMo config, set use_embedding_sharing = True"
)
use_embedding_sharing = True

fp8_quantized, fp8_kvcache = determine_quantization_settings(nemo_model_config, fp8_quantized, fp8_kvcache)
# If the model has been sharded with model parallelism, convert the model in a gpu-distributed manner
if use_distributed_convert:
weights_dict = dist_model_to_trt_llm_ckpt(
Expand All @@ -108,6 +123,8 @@ def model_to_trtllm_ckpt(
inference_tp_size=tensor_parallel_size,
inference_pp_size=pipeline_parallel_size,
tokenizer_vocab_size=vocab_size,
fp8_quantized=fp8_quantized,
fp8_kvcache=fp8_kvcache,
)
vocab_size_padded = vocab_size
else:
Expand All @@ -120,6 +137,8 @@ def model_to_trtllm_ckpt(
storage_type=dtype,
use_parallel_embedding=use_parallel_embedding,
decoder_type=decoder_type,
fp8_quantized=fp8_quantized,
fp8_kvcache=fp8_kvcache,
)

has_lm_head = "lm_head.weight" in weights_dict
Expand Down Expand Up @@ -159,8 +178,8 @@ def model_to_trtllm_ckpt(
'embedding_sharding_dim': 0,
'share_embedding_table': use_embedding_sharing,
'quantization': {
'quant_algo': None,
'kv_cache_quant_algo': None,
'quant_algo': "FP8" if fp8_quantized else None,
'kv_cache_quant_algo': "FP8" if fp8_kvcache else None,
},
'bias': nemo_model_config.get('bias'),
'apply_query_key_layer_scaling': False,
Expand Down Expand Up @@ -261,9 +280,9 @@ def model_to_trtllm_ckpt(

if mapping.is_last_pp_rank():
if has_lm_head:
weights_dict_local["lm_head.weight"] = np.ascontiguousarray(
split(lm_head_weight, mapping.tp_size, mapping.tp_rank)
)
weights_dict_local["lm_head.weight"] = split(
lm_head_weight, mapping.tp_size, mapping.tp_rank
).contiguous()
weights_dict_local["transformer.ln_f.weight"] = weights_dict["transformer.ln_f.weight"]

ln_f_bias = weights_dict.get("transformer.ln_f.bias")
Expand Down
41 changes: 35 additions & 6 deletions nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py
terrykong marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tqdm import tqdm

from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision
from nemo.export.trt_llm.converter.utils import save_val, split_and_save_weight, weights_dict
from nemo.export.trt_llm.converter.utils import save_scaling_factor, save_val, split_and_save_weight, weights_dict

LOGGER = logging.getLogger("NeMo")

Expand Down Expand Up @@ -94,6 +94,24 @@ def rename_key_dist_ckpt(old_key: str, layer: int):
return rename_key(new_key)


def is_scaling_factor(key: str) -> bool:
return "extra_state" in key


def load_scaling_factors(model: dict, num_layers: int, export_config: dict) -> dict:
if not export_config.get('fp8_quantized', False):
return {}

scaling_factors = {}
for key, val in model.items():
if is_scaling_factor(key):
for layer in range(num_layers):
renamed_key = rename_key_dist_ckpt(key, layer)
scaling_factors = save_scaling_factor(scaling_factors, renamed_key, val[layer], export_config)

return scaling_factors


@torch.no_grad()
def convert_model_to_trt_llm_ckpt(
nemo_model_config,
Expand All @@ -104,6 +122,8 @@ def convert_model_to_trt_llm_ckpt(
decoder_type,
use_parallel_embedding,
processes,
fp8_quantized=False,
fp8_kvcache=False,
):

# if checkpoints files could be found - start preparing output dir
Expand Down Expand Up @@ -148,6 +168,8 @@ def convert_model_to_trt_llm_ckpt(
"use_attention_nemo_shape": True,
"transpose_weights": True,
"use_parallel_embedding": use_parallel_embedding,
"fp8_quantized": fp8_quantized,
"fp8_kvcache": fp8_kvcache,
}

# split_factor: in how many parts a TP training node is split
Expand All @@ -158,7 +180,7 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int):
if tp_idx == 0 and pp_idx == 0:
if has_position_embedding:
val = model[get_layer_name("position_embedding", prefix)]
val = torch_to_numpy(val.to(storage_type).cpu())
val = val.to(storage_type).cpu()
model_level_weights["transformer.position_embedding.weight"].append(val)
if pp_idx == 0:
val = model.get("state_dict", model)[get_layer_name("word_embedding", prefix)]
Expand All @@ -171,19 +193,19 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int):
pad_width = vocab_size_padded - vocab_size
val = torch.nn.functional.pad(val, (0, 0, 0, pad_width), value=0)

val = torch_to_numpy(val.to(storage_type).cpu())
val = val.to(storage_type).cpu()
model_level_weights["transformer.vocab_embedding.weight"].append(val)
if has_lm_head and pp_idx == training_pp_size - 1:
val = model.get("state_dict", model)[get_layer_name("output_layer", prefix)]
val = torch_to_numpy(val.to(storage_type).cpu())
val = val.to(storage_type).cpu()
model_level_weights["lm_head.weight"].append(val)

weights_dict = {}

tp_rank = 0

handle_model_level_weights(model, 0, 0)
model = extract_layers_with_prefix(model, transformer_layer_prefix)
scaling_factors = load_scaling_factors(model, num_layers, export_config)

starmap_args = []
for key, val in model.items():
Expand All @@ -202,6 +224,7 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int):
storage_type,
None,
export_config,
scaling_factors,
)
)
else:
Expand All @@ -219,6 +242,7 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int):
storage_type,
None,
export_config,
scaling_factors,
)
)

Expand All @@ -236,9 +260,10 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int):
weights_dict.update(weights_dict_local)

for key, values in model_level_weights.items():
model_level_weights[key] = np.concatenate(values, axis=0)
model_level_weights[key] = torch.concatenate(values, axis=0)
weights_dict[key] = model_level_weights[key]

weights_dict.update(scaling_factors)
return weights_dict


Expand Down Expand Up @@ -269,6 +294,8 @@ def dist_model_to_trt_llm_ckpt(
inference_tp_size,
inference_pp_size,
tokenizer_vocab_size,
fp8_quantized=False,
fp8_kvcache=False,
):
from megatron.core import parallel_state
from megatron.core.tensor_parallel.utils import VocabUtility
Expand Down Expand Up @@ -314,6 +341,8 @@ def dist_model_to_trt_llm_ckpt(
"convert_on_device": True,
"use_attention_nemo_shape": True,
"transpose_weights": True,
"fp8_quantized": fp8_quantized,
"fp8_kvcache": fp8_kvcache,
}

starmap_config = {
Expand Down
Loading
Loading