Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Kaminski <pikaminski@nvidia.com>
  • Loading branch information
Piotr Kaminski committed Aug 14, 2024
1 parent 61d0f47 commit 542d843
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 162 deletions.
47 changes: 26 additions & 21 deletions nemo/export/trt_llm/converter/model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from typing import Dict, List, Tuple

import numpy as np
import torch
import tensorrt_llm
import torch
from tensorrt_llm._utils import pad_vocab_size
from tensorrt_llm.functional import non_gated_version
from tensorrt_llm.layers import MoeConfig
Expand All @@ -39,10 +39,11 @@
def get_config(decoder_type, config):
if decoder_type == "llama":
return LLaMAConfig(**config)
elif decoder_type == "gpt" or decoder_type == "gptnext":

if decoder_type in ["gpt", "gptnext"]:
return GPTConfig(**config)
else:
return PretrainedConfig(**config)

return PretrainedConfig(**config)


def prompt_convert(prompt_config, prompt_weights):
Expand Down Expand Up @@ -93,8 +94,9 @@ def model_to_trtllm_ckpt(
use_distributed_convert: bool = False,
model_parallel_rank: int = None,
vocab_size: int = None,
quantize_kv_cache: bool = False
) -> Tuple[List[Dict], List[PretrainedConfig]]:

nemo_model_config['kv_cache'] = quantize_kv_cache
if nemo_model_config.get("share_embeddings_and_output_weights", False) and not use_embedding_sharing:
LOGGER.info(
"Found share_embeddings_and_output_weights is True in NeMo config, set use_embedding_sharing = True"
Expand All @@ -112,6 +114,8 @@ def model_to_trtllm_ckpt(
)
vocab_size_padded = vocab_size
else:
vocab_embedding_key = "transformer.vocab_embedding.weight"

weights_dict = convert_model_to_trt_llm_ckpt(
model=model,
nemo_model_config=nemo_model_config,
Expand All @@ -127,15 +131,15 @@ def model_to_trtllm_ckpt(
if has_lm_head:
lm_head_weight = weights_dict["lm_head.weight"]
if vocab_size is None:
vocab_size = weights_dict["transformer.vocab_embedding.weight"].shape[0]
vocab_size_padded = pad_vocab_size(vocab_size, tensor_parallel_size) if has_lm_head else vocab_size
vocab_size = weights_dict[vocab_embedding_key].shape[0]

if has_lm_head and vocab_size_padded != vocab_size:
vocab_size_padded = pad_vocab_size(vocab_size, tensor_parallel_size) if has_lm_head else vocab_size
if vocab_size_padded != vocab_size:
padding = (0, 0, 0, vocab_size_padded - vocab_size)
embedding_key = "transformer.vocab_embedding.weight"
lm_head_weight = torch.nn.functional.pad(lm_head_weight, padding, "constant", 0)
weights_dict[embedding_key] = torch.nn.functional.pad(weights_dict[embedding_key], padding, "constant", 0)

if has_lm_head:
lm_head_weight = torch.nn.functional.pad(lm_head_weight, padding, "constant", 0)
if vocab_embedding_key in weights_dict:
weights_dict[vocab_embedding_key] = torch.nn.functional.pad(weights_dict[vocab_embedding_key], padding, "constant", 0)

world_size = tensor_parallel_size * pipeline_parallel_size
hidden_act = nemo_model_config.get('activation')
Expand Down Expand Up @@ -164,7 +168,7 @@ def model_to_trtllm_ckpt(
'share_embedding_table': use_embedding_sharing,
'quantization': {
'quant_algo': "FP8" if nemo_model_config.get('fp8', False) else None,
'kv_cache_quant_algo': None, # TODO maybe "FP8",
'kv_cache_quant_algo': "FP8" if quantize_kv_cache else None,
},
'bias': nemo_model_config.get('bias'),
'apply_query_key_layer_scaling': False,
Expand Down Expand Up @@ -207,7 +211,7 @@ def model_to_trtllm_ckpt(
return weights_dicts, model_configs

pp_key = {
"transformer.vocab_embedding.weight",
vocab_embedding_key,
"transformer.position_embedding.weight",
"lm_head.weight",
"transformer.ln_f.weight",
Expand All @@ -232,10 +236,9 @@ def model_to_trtllm_ckpt(
continue
new_key = k
if new_key.endswith(".bin"): # TP split
if new_key.endswith(f"{mapping.tp_rank}.bin"):
new_key = new_key.replace(f".{mapping.tp_rank}.bin", "")
else:
if not new_key.endswith(f"{mapping.tp_rank}.bin"):
continue
new_key = new_key.replace(f".{mapping.tp_rank}.bin", "")
if "layers" in new_key: # PP
layer_num = int(new_key.split(".")[2])
if layer_num in layers_range:
Expand All @@ -247,13 +250,13 @@ def model_to_trtllm_ckpt(
if mapping.is_first_pp_rank():
embedding_weight = (
np.ascontiguousarray(
split(weights_dict["transformer.vocab_embedding.weight"], mapping.tp_size, mapping.tp_rank)
split(weights_dict[vocab_embedding_key], mapping.tp_size, mapping.tp_rank)
)
if use_parallel_embedding
else weights_dict["transformer.vocab_embedding.weight"]
else weights_dict[vocab_embedding_key]
)

weights_dict_local["transformer.vocab_embedding.weight"] = embedding_weight
weights_dict_local[vocab_embedding_key] = embedding_weight

pos_embedding_weight = weights_dict.get("transformer.position_embedding.weight")
if pos_embedding_weight is not None:
Expand All @@ -265,7 +268,9 @@ def model_to_trtllm_ckpt(

if mapping.is_last_pp_rank():
if has_lm_head:
weights_dict_local["lm_head.weight"] = split(lm_head_weight, mapping.tp_size, mapping.tp_rank).contiguous()
weights_dict_local["lm_head.weight"] = split(
lm_head_weight, mapping.tp_size, mapping.tp_rank
).contiguous()
weights_dict_local["transformer.ln_f.weight"] = weights_dict["transformer.ln_f.weight"]

ln_f_bias = weights_dict.get("transformer.ln_f.bias")
Expand Down
34 changes: 10 additions & 24 deletions nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tqdm import tqdm

from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision
from nemo.export.trt_llm.converter.utils import save_val, split_and_save_weight, weights_dict
from nemo.export.trt_llm.converter.utils import save_val, split_and_save_weight, load_scaling_factor, weights_dict

LOGGER = logging.getLogger("NeMo")

Expand Down Expand Up @@ -94,29 +94,18 @@ def rename_key_dist_ckpt(old_key: str, layer: int):
return rename_key(new_key)


def load_scaling_factors(model, num_layers, tp_rank, out_dir, split_factor, storage_type, export_config):
def load_scaling_factors(model, num_layers, out_dir, export_config):
starmap_args = []
for key, val in model.items():
if 'extra_state' not in key:
continue

for i in range(num_layers):
starmap_args.append(
(
tp_rank,
out_dir,
split_factor,
rename_key_dist_ckpt(key, i),
[val[i]],
storage_type,
None,
export_config,
{},
)
)
for layer in range(num_layers):
args = (rename_key_dist_ckpt(key, layer), val[layer], out_dir, export_config)
starmap_args.append(args)

for starmap_arg in starmap_args:
scaling_factors = split_and_save_weight(*starmap_arg)
scaling_factors = load_scaling_factor(*starmap_arg)

return scaling_factors

Expand All @@ -132,7 +121,6 @@ def convert_model_to_trt_llm_ckpt(
use_parallel_embedding,
processes,
):

# if checkpoints files could be found - start preparing output dir
out_dir = create_export_dir(nemo_export_dir)
storage_type = str_dtype_to_torch(storage_type)
Expand Down Expand Up @@ -213,7 +201,7 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int):
handle_model_level_weights(model, 0, 0)
model = extract_layers_with_prefix(model, transformer_layer_prefix)

scaling_factors = load_scaling_factors(model, num_layers, tp_rank, out_dir, split_factor, storage_type, export_config)
scaling_factors = load_scaling_factors(model, num_layers, out_dir, export_config)

starmap_args = []
for key, val in model.items():
Expand All @@ -222,12 +210,10 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int):

# Let's rename/map the key to the old layer name previously.
# Since the state dict value has the full layers, let's select the ith layer weights/biases here.
if len(val.size()) == 1:
key_vals = [(rename_key_dist_ckpt(key, 0), val)]
else:
key_vals = [(rename_key_dist_ckpt(key, i), val[i]) for i in range(num_layers)]
layer_vals = [(l, val[l]) for l in range(num_layers)] if len(val.size()) != 1 else [(0, val)]

for (k, v) in key_vals:
for (l, v) in layer_vals:
k = rename_key_dist_ckpt(key, l)
starmap_args.append(
(tp_rank, out_dir, split_factor, k, [v], storage_type, None, export_config, scaling_factors)
)
Expand Down
Loading

0 comments on commit 542d843

Please sign in to comment.