Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export fp8 te nemo to trt-llm #10096

Merged
merged 27 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f966d05
initial commit
Aug 14, 2024
5087268
PR draft
Aug 14, 2024
61d0f47
fixed scaling weights
Aug 14, 2024
542d843
Apply isort and black reformatting
Aug 14, 2024
042d325
Apply isort and black reformatting
Aug 14, 2024
76535b4
fixed zarr loading, added flags, refactor
Aug 16, 2024
63e8faa
Merge branch 'main' into export_fp8_te_nemo_to_trtllm
Laplasjan107 Aug 16, 2024
7a1d042
Apply isort and black reformatting
Laplasjan107 Aug 16, 2024
7d087dd
fix expert key mapping
Aug 16, 2024
f782f6b
Merge branch 'main' into export_fp8_te_nemo_to_trtllm
Laplasjan107 Aug 19, 2024
f5ff40e
refactor
Aug 21, 2024
a11bc2f
Apply isort and black reformatting
Laplasjan107 Aug 21, 2024
7d150d7
Merge branch 'main' into export_fp8_te_nemo_to_trtllm
Laplasjan107 Aug 21, 2024
ec14cb4
fix: failed test was finishing with exit code 0
Aug 21, 2024
078c88b
Merge branch 'export_fp8_te_nemo_to_trtllm' of https://github.com/Lap…
Aug 21, 2024
157f444
Merge branch 'main' into export_fp8_te_nemo_to_trtllm
Laplasjan107 Aug 21, 2024
73d9261
test commit -- rerun github checks
Aug 21, 2024
84a5e5e
bugfix: naming
Aug 21, 2024
250525e
bugfix v2: naming
Aug 21, 2024
69b4f69
apply code review changes
Aug 23, 2024
487edd0
Apply isort and black reformatting
Laplasjan107 Aug 23, 2024
e2a3139
fix TensorRTLLM build (fp8 still not supported)
Aug 27, 2024
19c8662
Apply isort and black reformatting
Laplasjan107 Aug 27, 2024
b01fdba
undo refactor
Aug 27, 2024
a3449d2
Merge branch 'export_fp8_te_nemo_to_trtllm' of https://github.com/Lap…
Aug 28, 2024
0c922b7
bugfix: arguments to dist_convert
Aug 28, 2024
bcf85e4
Apply isort and black reformatting
Laplasjan107 Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion nemo/export/tarutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import zarr.storage


class TarPath:
class TarPath(os.PathLike):
"""
A class that represents a path inside a TAR archive and behaves like pathlib.Path.

Expand Down Expand Up @@ -58,6 +58,9 @@ def __truediv__(self, key) -> 'TarPath':
def __str__(self) -> str:
return os.path.join(self._tar.name, self._relpath)

def __fspath__(self):
return str(self)

@property
def tarobject(self):
return self._tar
Expand Down
4 changes: 4 additions & 0 deletions nemo/export/tensorrt_llm.py
terrykong marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def export(
multiple_profiles: bool = False,
gpt_attention_plugin: str = "auto",
gemm_plugin: str = "auto",
fp8_quantized: bool = False,
Laplasjan107 marked this conversation as resolved.
Show resolved Hide resolved
fp8_kvcache: bool = False,
):
"""
Exports nemo checkpoints to TensorRT-LLM.
Expand Down Expand Up @@ -324,6 +326,8 @@ def export(
gpus_per_node=gpus_per_node,
use_parallel_embedding=use_parallel_embedding,
use_embedding_sharing=use_embedding_sharing,
fp8_quantized=fp8_quantized,
fp8_kvcache=fp8_kvcache,
)

for weight_dict, model_config in zip(weights_dicts, model_configs):
Expand Down
75 changes: 49 additions & 26 deletions nemo/export/trt_llm/converter/model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np
import tensorrt_llm
import torch
from tensorrt_llm._utils import pad_vocab_size
from tensorrt_llm.functional import non_gated_version
from tensorrt_llm.layers import MoeConfig
Expand All @@ -38,10 +39,11 @@
def get_config(decoder_type, config):
if decoder_type == "llama":
return LLaMAConfig(**config)
elif decoder_type == "gpt" or decoder_type == "gptnext":

if decoder_type in ["gpt", "gptnext"]:
Laplasjan107 marked this conversation as resolved.
Show resolved Hide resolved
return GPTConfig(**config)
else:
return PretrainedConfig(**config)

return PretrainedConfig(**config)


def prompt_convert(prompt_config, prompt_weights):
Expand Down Expand Up @@ -78,6 +80,21 @@ def prompt_convert(prompt_config, prompt_weights):
return vtokens_embeddings


def create_common_export_config(nemo_model_config, decoder_type, fp8_quantized=False, fp8_kvcache=False):
is_mcore = nemo_model_config.get("mcore_gpt", False)
Laplasjan107 marked this conversation as resolved.
Show resolved Hide resolved
return {
"apply_layernorm_1p": nemo_model_config.get("normalization", "") == "layernorm1p",
"split_gated_activation": nemo_model_config.get("activation", "gelu")
in ["swiglu", "geglu", "fast-swiglu", "fast-geglu"]
and (decoder_type == "gptnext" or is_mcore),
"num_attention_heads": nemo_model_config["num_attention_heads"],
"use_attention_nemo_shape": True,
"transpose_weights": True,
"fp8_quantized": fp8_quantized,
"fp8_kvcache": fp8_kvcache,
}


def model_to_trtllm_ckpt(
model,
nemo_model_config,
Expand All @@ -91,15 +108,17 @@ def model_to_trtllm_ckpt(
use_embedding_sharing: bool = False,
use_distributed_convert: bool = False,
model_parallel_rank: int = None,
vocab_size: int = None,
vocab_size: int | None = None,
fp8_quantized: bool = False,
fp8_kvcache: bool = False,
) -> Tuple[List[Dict], List[PretrainedConfig]]:

if nemo_model_config.get("share_embeddings_and_output_weights", False) and not use_embedding_sharing:
LOGGER.info(
"Found share_embeddings_and_output_weights is True in NeMo config, set use_embedding_sharing = True"
)
use_embedding_sharing = True

export_config = create_common_export_config(nemo_model_config, decoder_type, fp8_quantized, fp8_kvcache)
# If the model has been sharded with model parallelism, convert the model in a gpu-distributed manner
if use_distributed_convert:
weights_dict = dist_model_to_trt_llm_ckpt(
Expand All @@ -108,9 +127,12 @@ def model_to_trtllm_ckpt(
inference_tp_size=tensor_parallel_size,
inference_pp_size=pipeline_parallel_size,
tokenizer_vocab_size=vocab_size,
export_config=export_config,
)
vocab_size_padded = vocab_size
else:
vocab_embedding_key = "transformer.vocab_embedding.weight"

weights_dict = convert_model_to_trt_llm_ckpt(
model=model,
nemo_model_config=nemo_model_config,
Expand All @@ -119,19 +141,23 @@ def model_to_trtllm_ckpt(
processes=1,
storage_type=dtype,
use_parallel_embedding=use_parallel_embedding,
decoder_type=decoder_type,
export_config=export_config,
)

if vocab_size is None:
vocab_size = weights_dict[vocab_embedding_key].shape[0]

has_lm_head = "lm_head.weight" in weights_dict
vocab_size_padded = pad_vocab_size(vocab_size, tensor_parallel_size) if has_lm_head else vocab_size
padding = (0, 0, 0, vocab_size_padded - vocab_size)
if has_lm_head:
lm_head_weight = weights_dict["lm_head.weight"]
if vocab_size is None:
vocab_size = weights_dict["transformer.vocab_embedding.weight"].shape[0]
vocab_size_padded = pad_vocab_size(vocab_size, tensor_parallel_size) if has_lm_head else vocab_size
lm_head_weight = torch.nn.functional.pad(lm_head_weight, padding, "constant", 0)

if has_lm_head and vocab_size_padded != vocab_size:
pad_width = vocab_size_padded - vocab_size
lm_head_weight = np.pad(lm_head_weight, ((0, pad_width), (0, 0)), "constant", constant_values=0)
if vocab_embedding_key in weights_dict:
weights_dict[vocab_embedding_key] = torch.nn.functional.pad(
weights_dict[vocab_embedding_key], padding, "constant", 0
)

world_size = tensor_parallel_size * pipeline_parallel_size
hidden_act = nemo_model_config.get('activation')
Expand Down Expand Up @@ -159,8 +185,8 @@ def model_to_trtllm_ckpt(
'embedding_sharding_dim': 0,
'share_embedding_table': use_embedding_sharing,
'quantization': {
'quant_algo': None,
'kv_cache_quant_algo': None,
'quant_algo': "FP8" if fp8_quantized else None,
'kv_cache_quant_algo': "FP8" if fp8_kvcache else None,
},
'bias': nemo_model_config.get('bias'),
'apply_query_key_layer_scaling': False,
Expand Down Expand Up @@ -203,7 +229,7 @@ def model_to_trtllm_ckpt(
return weights_dicts, model_configs

pp_key = {
"transformer.vocab_embedding.weight",
vocab_embedding_key,
JimmyZhang12 marked this conversation as resolved.
Show resolved Hide resolved
"transformer.position_embedding.weight",
"lm_head.weight",
"transformer.ln_f.weight",
Expand All @@ -228,10 +254,9 @@ def model_to_trtllm_ckpt(
continue
new_key = k
if new_key.endswith(".bin"): # TP split
if new_key.endswith(f"{mapping.tp_rank}.bin"):
new_key = new_key.replace(f".{mapping.tp_rank}.bin", "")
else:
if not new_key.endswith(f"{mapping.tp_rank}.bin"):
continue
new_key = new_key.replace(f".{mapping.tp_rank}.bin", "")
if "layers" in new_key: # PP
layer_num = int(new_key.split(".")[2])
if layer_num in layers_range:
Expand All @@ -242,14 +267,12 @@ def model_to_trtllm_ckpt(

if mapping.is_first_pp_rank():
embedding_weight = (
np.ascontiguousarray(
split(weights_dict["transformer.vocab_embedding.weight"], mapping.tp_size, mapping.tp_rank)
)
np.ascontiguousarray(split(weights_dict[vocab_embedding_key], mapping.tp_size, mapping.tp_rank))
Laplasjan107 marked this conversation as resolved.
Show resolved Hide resolved
if use_parallel_embedding
else weights_dict["transformer.vocab_embedding.weight"]
else weights_dict[vocab_embedding_key]
)

weights_dict_local["transformer.vocab_embedding.weight"] = embedding_weight
weights_dict_local[vocab_embedding_key] = embedding_weight

pos_embedding_weight = weights_dict.get("transformer.position_embedding.weight")
if pos_embedding_weight is not None:
Expand All @@ -261,9 +284,9 @@ def model_to_trtllm_ckpt(

if mapping.is_last_pp_rank():
if has_lm_head:
weights_dict_local["lm_head.weight"] = np.ascontiguousarray(
split(lm_head_weight, mapping.tp_size, mapping.tp_rank)
)
weights_dict_local["lm_head.weight"] = split(
lm_head_weight, mapping.tp_size, mapping.tp_rank
).contiguous()
weights_dict_local["transformer.ln_f.weight"] = weights_dict["transformer.ln_f.weight"]

ln_f_bias = weights_dict.get("transformer.ln_f.bias")
Expand Down
Loading
Loading