Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: Laplasjan107 <Laplasjan107@users.noreply.github.com>
  • Loading branch information
Piotr Kaminski committed Aug 14, 2024
1 parent b316c9f commit 4816a6e
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 14 deletions.
10 changes: 5 additions & 5 deletions nemo/export/trt_llm/converter/model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def model_to_trtllm_ckpt(
use_distributed_convert: bool = False,
model_parallel_rank: int = None,
vocab_size: int = None,
quantize_kv_cache: bool = False
quantize_kv_cache: bool = False,
) -> Tuple[List[Dict], List[PretrainedConfig]]:
nemo_model_config['kv_cache'] = quantize_kv_cache
if nemo_model_config.get("share_embeddings_and_output_weights", False) and not use_embedding_sharing:
Expand Down Expand Up @@ -139,7 +139,9 @@ def model_to_trtllm_ckpt(
if has_lm_head:
lm_head_weight = torch.nn.functional.pad(lm_head_weight, padding, "constant", 0)
if vocab_embedding_key in weights_dict:
weights_dict[vocab_embedding_key] = torch.nn.functional.pad(weights_dict[vocab_embedding_key], padding, "constant", 0)
weights_dict[vocab_embedding_key] = torch.nn.functional.pad(
weights_dict[vocab_embedding_key], padding, "constant", 0
)

world_size = tensor_parallel_size * pipeline_parallel_size
hidden_act = nemo_model_config.get('activation')
Expand Down Expand Up @@ -249,9 +251,7 @@ def model_to_trtllm_ckpt(

if mapping.is_first_pp_rank():
embedding_weight = (
np.ascontiguousarray(
split(weights_dict[vocab_embedding_key], mapping.tp_size, mapping.tp_rank)
)
np.ascontiguousarray(split(weights_dict[vocab_embedding_key], mapping.tp_size, mapping.tp_rank))
if use_parallel_embedding
else weights_dict[vocab_embedding_key]
)
Expand Down
4 changes: 2 additions & 2 deletions nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tqdm import tqdm

from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision
from nemo.export.trt_llm.converter.utils import save_val, split_and_save_weight, load_scaling_factor, weights_dict
from nemo.export.trt_llm.converter.utils import load_scaling_factor, save_val, split_and_save_weight, weights_dict

LOGGER = logging.getLogger("NeMo")

Expand Down Expand Up @@ -212,7 +212,7 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int):
# Since the state dict value has the full layers, let's select the ith layer weights/biases here.
layer_vals = [(l, val[l]) for l in range(num_layers)] if len(val.size()) != 1 else [(0, val)]

for (l, v) in layer_vals:
for l, v in layer_vals:
k = rename_key_dist_ckpt(key, l)
starmap_args.append(
(tp_rank, out_dir, split_factor, k, [v], storage_type, None, export_config, scaling_factors)
Expand Down
35 changes: 28 additions & 7 deletions nemo/export/trt_llm/converter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
"falcon": 'FalconForCausalLM',
}

post_layernorm_keys = ["post_attention_layernorm.weight", "post_attention_layernorm.bias", "post_self_attn_layernorm.weight"]
post_layernorm_keys = [
"post_attention_layernorm.weight",
"post_attention_layernorm.bias",
"post_self_attn_layernorm.weight",
]
mlp_proj_bias_keys = ["mlp.linear_fc2.bias", "mlp.dense_4h_to_h.bias"]
attention_dense_bias_keys = ["attention.linear_proj.bias", "attention.dense.bias"]
input_layernorm_keys = ["input_layernorm.weight", "input_layernorm.bias"]
Expand All @@ -46,7 +50,12 @@
mlp_proj_experts_keys = ["experts.linear_fc2.weight"]
final_layernorm_keys = ["final_layernorm.weight", "final_layernorm.bias"]
mlp_dense_2_keys = ["mlp.dense_h_to_4h_2.weight", "mlp.dense_h_to_4h_2.bias"]
attention_not_mapped_keys = ["attention.query.weight", "attention.query.bias", "attention.key_value.weight", "attention.key_value.bias"]
attention_not_mapped_keys = [
"attention.query.weight",
"attention.query.bias",
"attention.key_value.weight",
"attention.key_value.bias",
]


def save_val(val, dir, key, tp_num=None):
Expand Down Expand Up @@ -207,12 +216,13 @@ def any_word_in_key(key, words):


def sequential_key_map(key, mapping):
for (keywords, mapped) in mapping:
for keywords, mapped in mapping:
if any_word_in_key(key, keywords):
return mapped

return None


def get_trt_llm_infix(key):
mapping = [
(post_layernorm_keys, '.post_layernorm'),
Expand All @@ -226,7 +236,7 @@ def get_trt_llm_infix(key):
(attention_qkv_bias_keys + attention_qkv_weight_keys, '.attention.qkv'),
(mlp_router_keys, '.mlp.router'),
(mlp_fc_keys, '.mlp.fc'),
(mlp_proj_experts_keys, '.mlp.proj')
(mlp_proj_experts_keys, '.mlp.proj'),
]
return sequential_key_map(key, mapping)

Expand Down Expand Up @@ -254,6 +264,8 @@ def get_scaling_factor_keys(key):


first = True


def load_scaling_factor(key, val, dir, config):
global weights_dict
if not is_scaling_factor(key):
Expand Down Expand Up @@ -291,7 +303,9 @@ def cast_val_datatype(vals, trt_llm_key, storage_type, is_fp8_model, scaling_fac
return [val.to(storage_type) for val in vals]

fp8_storage_type = torch.float8_e4m3fn
quantized_keys = [ k.split('.weights_scaling_factor')[0] for k in scaling_factors.keys() if '.weights_scaling_factor' in k ]
quantized_keys = [
k.split('.weights_scaling_factor')[0] for k in scaling_factors.keys() if '.weights_scaling_factor' in k
]
for k in quantized_keys:
if k in trt_llm_key:
storage_type = fp8_storage_type
Expand Down Expand Up @@ -344,8 +358,15 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t
elif torch.is_tensor(vals[0]):
vals = [torch_to_numpy(val.cpu()) for val in vals]

if (any_word_in_key(key, input_layernorm_keys + pre_layernorm_keys + attention_dense_bias_keys + post_layernorm_keys + mlp_proj_bias_keys + final_layernorm_keys)
and (tp_rank == 0 or convert_on_device)):
if any_word_in_key(
key,
input_layernorm_keys
+ pre_layernorm_keys
+ attention_dense_bias_keys
+ post_layernorm_keys
+ mlp_proj_bias_keys
+ final_layernorm_keys,
) and (tp_rank == 0 or convert_on_device):
# shared weights, only need to convert the weights of rank 0
save_val(vals[0], saved_dir, trt_llm_key)

Expand Down

0 comments on commit 4816a6e

Please sign in to comment.