From 042d325c62eec21d547b9f0419ca6a13335aaac3 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Wed, 14 Aug 2024 14:47:09 -0700 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: Piotr Kaminski --- .../trt_llm/converter/model_converter.py | 10 +++--- .../converter/model_to_trt_llm_ckpt.py | 4 +-- nemo/export/trt_llm/converter/utils.py | 35 +++++++++++++++---- 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/nemo/export/trt_llm/converter/model_converter.py b/nemo/export/trt_llm/converter/model_converter.py index bb9b19173509..6a7fe25ba824 100755 --- a/nemo/export/trt_llm/converter/model_converter.py +++ b/nemo/export/trt_llm/converter/model_converter.py @@ -94,7 +94,7 @@ def model_to_trtllm_ckpt( use_distributed_convert: bool = False, model_parallel_rank: int = None, vocab_size: int = None, - quantize_kv_cache: bool = False + quantize_kv_cache: bool = False, ) -> Tuple[List[Dict], List[PretrainedConfig]]: nemo_model_config['kv_cache'] = quantize_kv_cache if nemo_model_config.get("share_embeddings_and_output_weights", False) and not use_embedding_sharing: @@ -139,7 +139,9 @@ def model_to_trtllm_ckpt( if has_lm_head: lm_head_weight = torch.nn.functional.pad(lm_head_weight, padding, "constant", 0) if vocab_embedding_key in weights_dict: - weights_dict[vocab_embedding_key] = torch.nn.functional.pad(weights_dict[vocab_embedding_key], padding, "constant", 0) + weights_dict[vocab_embedding_key] = torch.nn.functional.pad( + weights_dict[vocab_embedding_key], padding, "constant", 0 + ) world_size = tensor_parallel_size * pipeline_parallel_size hidden_act = nemo_model_config.get('activation') @@ -249,9 +251,7 @@ def model_to_trtllm_ckpt( if mapping.is_first_pp_rank(): embedding_weight = ( - np.ascontiguousarray( - split(weights_dict[vocab_embedding_key], mapping.tp_size, mapping.tp_rank) - ) + np.ascontiguousarray(split(weights_dict[vocab_embedding_key], mapping.tp_size, mapping.tp_rank)) if use_parallel_embedding else weights_dict[vocab_embedding_key] ) diff --git a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py index 6cb373159e17..7886c2221566 100644 --- a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py +++ b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py @@ -25,7 +25,7 @@ from tqdm import tqdm from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision -from nemo.export.trt_llm.converter.utils import save_val, split_and_save_weight, load_scaling_factor, weights_dict +from nemo.export.trt_llm.converter.utils import load_scaling_factor, save_val, split_and_save_weight, weights_dict LOGGER = logging.getLogger("NeMo") @@ -212,7 +212,7 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): # Since the state dict value has the full layers, let's select the ith layer weights/biases here. layer_vals = [(l, val[l]) for l in range(num_layers)] if len(val.size()) != 1 else [(0, val)] - for (l, v) in layer_vals: + for l, v in layer_vals: k = rename_key_dist_ckpt(key, l) starmap_args.append( (tp_rank, out_dir, split_factor, k, [v], storage_type, None, export_config, scaling_factors) diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index 1a15e8874dc4..cfde6a359bdf 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -31,7 +31,11 @@ "falcon": 'FalconForCausalLM', } -post_layernorm_keys = ["post_attention_layernorm.weight", "post_attention_layernorm.bias", "post_self_attn_layernorm.weight"] +post_layernorm_keys = [ + "post_attention_layernorm.weight", + "post_attention_layernorm.bias", + "post_self_attn_layernorm.weight", +] mlp_proj_bias_keys = ["mlp.linear_fc2.bias", "mlp.dense_4h_to_h.bias"] attention_dense_bias_keys = ["attention.linear_proj.bias", "attention.dense.bias"] input_layernorm_keys = ["input_layernorm.weight", "input_layernorm.bias"] @@ -46,7 +50,12 @@ mlp_proj_experts_keys = ["experts.linear_fc2.weight"] final_layernorm_keys = ["final_layernorm.weight", "final_layernorm.bias"] mlp_dense_2_keys = ["mlp.dense_h_to_4h_2.weight", "mlp.dense_h_to_4h_2.bias"] -attention_not_mapped_keys = ["attention.query.weight", "attention.query.bias", "attention.key_value.weight", "attention.key_value.bias"] +attention_not_mapped_keys = [ + "attention.query.weight", + "attention.query.bias", + "attention.key_value.weight", + "attention.key_value.bias", +] def save_val(val, dir, key, tp_num=None): @@ -207,12 +216,13 @@ def any_word_in_key(key, words): def sequential_key_map(key, mapping): - for (keywords, mapped) in mapping: + for keywords, mapped in mapping: if any_word_in_key(key, keywords): return mapped return None + def get_trt_llm_infix(key): mapping = [ (post_layernorm_keys, '.post_layernorm'), @@ -226,7 +236,7 @@ def get_trt_llm_infix(key): (attention_qkv_bias_keys + attention_qkv_weight_keys, '.attention.qkv'), (mlp_router_keys, '.mlp.router'), (mlp_fc_keys, '.mlp.fc'), - (mlp_proj_experts_keys, '.mlp.proj') + (mlp_proj_experts_keys, '.mlp.proj'), ] return sequential_key_map(key, mapping) @@ -254,6 +264,8 @@ def get_scaling_factor_keys(key): first = True + + def load_scaling_factor(key, val, dir, config): global weights_dict if not is_scaling_factor(key): @@ -291,7 +303,9 @@ def cast_val_datatype(vals, trt_llm_key, storage_type, is_fp8_model, scaling_fac return [val.to(storage_type) for val in vals] fp8_storage_type = torch.float8_e4m3fn - quantized_keys = [ k.split('.weights_scaling_factor')[0] for k in scaling_factors.keys() if '.weights_scaling_factor' in k ] + quantized_keys = [ + k.split('.weights_scaling_factor')[0] for k in scaling_factors.keys() if '.weights_scaling_factor' in k + ] for k in quantized_keys: if k in trt_llm_key: storage_type = fp8_storage_type @@ -344,8 +358,15 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t elif torch.is_tensor(vals[0]): vals = [torch_to_numpy(val.cpu()) for val in vals] - if (any_word_in_key(key, input_layernorm_keys + pre_layernorm_keys + attention_dense_bias_keys + post_layernorm_keys + mlp_proj_bias_keys + final_layernorm_keys) - and (tp_rank == 0 or convert_on_device)): + if any_word_in_key( + key, + input_layernorm_keys + + pre_layernorm_keys + + attention_dense_bias_keys + + post_layernorm_keys + + mlp_proj_bias_keys + + final_layernorm_keys, + ) and (tp_rank == 0 or convert_on_device): # shared weights, only need to convert the weights of rank 0 save_val(vals[0], saved_dir, trt_llm_key)