From 542d843a6dc730699972fbb72e6e8ebc9b783fd6 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Wed, 14 Aug 2024 14:46:40 -0700 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: Piotr Kaminski --- .../trt_llm/converter/model_converter.py | 47 ++-- .../converter/model_to_trt_llm_ckpt.py | 34 +-- nemo/export/trt_llm/converter/utils.py | 217 ++++++++---------- .../trt_llm/nemo_ckpt_loader/nemo_file.py | 12 +- 4 files changed, 148 insertions(+), 162 deletions(-) diff --git a/nemo/export/trt_llm/converter/model_converter.py b/nemo/export/trt_llm/converter/model_converter.py index 7600224ff373..bb9b19173509 100755 --- a/nemo/export/trt_llm/converter/model_converter.py +++ b/nemo/export/trt_llm/converter/model_converter.py @@ -18,8 +18,8 @@ from typing import Dict, List, Tuple import numpy as np -import torch import tensorrt_llm +import torch from tensorrt_llm._utils import pad_vocab_size from tensorrt_llm.functional import non_gated_version from tensorrt_llm.layers import MoeConfig @@ -39,10 +39,11 @@ def get_config(decoder_type, config): if decoder_type == "llama": return LLaMAConfig(**config) - elif decoder_type == "gpt" or decoder_type == "gptnext": + + if decoder_type in ["gpt", "gptnext"]: return GPTConfig(**config) - else: - return PretrainedConfig(**config) + + return PretrainedConfig(**config) def prompt_convert(prompt_config, prompt_weights): @@ -93,8 +94,9 @@ def model_to_trtllm_ckpt( use_distributed_convert: bool = False, model_parallel_rank: int = None, vocab_size: int = None, + quantize_kv_cache: bool = False ) -> Tuple[List[Dict], List[PretrainedConfig]]: - + nemo_model_config['kv_cache'] = quantize_kv_cache if nemo_model_config.get("share_embeddings_and_output_weights", False) and not use_embedding_sharing: LOGGER.info( "Found share_embeddings_and_output_weights is True in NeMo config, set use_embedding_sharing = True" @@ -112,6 +114,8 @@ def model_to_trtllm_ckpt( ) vocab_size_padded = vocab_size else: + vocab_embedding_key = "transformer.vocab_embedding.weight" + weights_dict = convert_model_to_trt_llm_ckpt( model=model, nemo_model_config=nemo_model_config, @@ -127,15 +131,15 @@ def model_to_trtllm_ckpt( if has_lm_head: lm_head_weight = weights_dict["lm_head.weight"] if vocab_size is None: - vocab_size = weights_dict["transformer.vocab_embedding.weight"].shape[0] - vocab_size_padded = pad_vocab_size(vocab_size, tensor_parallel_size) if has_lm_head else vocab_size + vocab_size = weights_dict[vocab_embedding_key].shape[0] - if has_lm_head and vocab_size_padded != vocab_size: + vocab_size_padded = pad_vocab_size(vocab_size, tensor_parallel_size) if has_lm_head else vocab_size + if vocab_size_padded != vocab_size: padding = (0, 0, 0, vocab_size_padded - vocab_size) - embedding_key = "transformer.vocab_embedding.weight" - lm_head_weight = torch.nn.functional.pad(lm_head_weight, padding, "constant", 0) - weights_dict[embedding_key] = torch.nn.functional.pad(weights_dict[embedding_key], padding, "constant", 0) - + if has_lm_head: + lm_head_weight = torch.nn.functional.pad(lm_head_weight, padding, "constant", 0) + if vocab_embedding_key in weights_dict: + weights_dict[vocab_embedding_key] = torch.nn.functional.pad(weights_dict[vocab_embedding_key], padding, "constant", 0) world_size = tensor_parallel_size * pipeline_parallel_size hidden_act = nemo_model_config.get('activation') @@ -164,7 +168,7 @@ def model_to_trtllm_ckpt( 'share_embedding_table': use_embedding_sharing, 'quantization': { 'quant_algo': "FP8" if nemo_model_config.get('fp8', False) else None, - 'kv_cache_quant_algo': None, # TODO maybe "FP8", + 'kv_cache_quant_algo': "FP8" if quantize_kv_cache else None, }, 'bias': nemo_model_config.get('bias'), 'apply_query_key_layer_scaling': False, @@ -207,7 +211,7 @@ def model_to_trtllm_ckpt( return weights_dicts, model_configs pp_key = { - "transformer.vocab_embedding.weight", + vocab_embedding_key, "transformer.position_embedding.weight", "lm_head.weight", "transformer.ln_f.weight", @@ -232,10 +236,9 @@ def model_to_trtllm_ckpt( continue new_key = k if new_key.endswith(".bin"): # TP split - if new_key.endswith(f"{mapping.tp_rank}.bin"): - new_key = new_key.replace(f".{mapping.tp_rank}.bin", "") - else: + if not new_key.endswith(f"{mapping.tp_rank}.bin"): continue + new_key = new_key.replace(f".{mapping.tp_rank}.bin", "") if "layers" in new_key: # PP layer_num = int(new_key.split(".")[2]) if layer_num in layers_range: @@ -247,13 +250,13 @@ def model_to_trtllm_ckpt( if mapping.is_first_pp_rank(): embedding_weight = ( np.ascontiguousarray( - split(weights_dict["transformer.vocab_embedding.weight"], mapping.tp_size, mapping.tp_rank) + split(weights_dict[vocab_embedding_key], mapping.tp_size, mapping.tp_rank) ) if use_parallel_embedding - else weights_dict["transformer.vocab_embedding.weight"] + else weights_dict[vocab_embedding_key] ) - weights_dict_local["transformer.vocab_embedding.weight"] = embedding_weight + weights_dict_local[vocab_embedding_key] = embedding_weight pos_embedding_weight = weights_dict.get("transformer.position_embedding.weight") if pos_embedding_weight is not None: @@ -265,7 +268,9 @@ def model_to_trtllm_ckpt( if mapping.is_last_pp_rank(): if has_lm_head: - weights_dict_local["lm_head.weight"] = split(lm_head_weight, mapping.tp_size, mapping.tp_rank).contiguous() + weights_dict_local["lm_head.weight"] = split( + lm_head_weight, mapping.tp_size, mapping.tp_rank + ).contiguous() weights_dict_local["transformer.ln_f.weight"] = weights_dict["transformer.ln_f.weight"] ln_f_bias = weights_dict.get("transformer.ln_f.bias") diff --git a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py index 462f741a8bb9..6cb373159e17 100644 --- a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py +++ b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py @@ -25,7 +25,7 @@ from tqdm import tqdm from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision -from nemo.export.trt_llm.converter.utils import save_val, split_and_save_weight, weights_dict +from nemo.export.trt_llm.converter.utils import save_val, split_and_save_weight, load_scaling_factor, weights_dict LOGGER = logging.getLogger("NeMo") @@ -94,29 +94,18 @@ def rename_key_dist_ckpt(old_key: str, layer: int): return rename_key(new_key) -def load_scaling_factors(model, num_layers, tp_rank, out_dir, split_factor, storage_type, export_config): +def load_scaling_factors(model, num_layers, out_dir, export_config): starmap_args = [] for key, val in model.items(): if 'extra_state' not in key: continue - for i in range(num_layers): - starmap_args.append( - ( - tp_rank, - out_dir, - split_factor, - rename_key_dist_ckpt(key, i), - [val[i]], - storage_type, - None, - export_config, - {}, - ) - ) + for layer in range(num_layers): + args = (rename_key_dist_ckpt(key, layer), val[layer], out_dir, export_config) + starmap_args.append(args) for starmap_arg in starmap_args: - scaling_factors = split_and_save_weight(*starmap_arg) + scaling_factors = load_scaling_factor(*starmap_arg) return scaling_factors @@ -132,7 +121,6 @@ def convert_model_to_trt_llm_ckpt( use_parallel_embedding, processes, ): - # if checkpoints files could be found - start preparing output dir out_dir = create_export_dir(nemo_export_dir) storage_type = str_dtype_to_torch(storage_type) @@ -213,7 +201,7 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): handle_model_level_weights(model, 0, 0) model = extract_layers_with_prefix(model, transformer_layer_prefix) - scaling_factors = load_scaling_factors(model, num_layers, tp_rank, out_dir, split_factor, storage_type, export_config) + scaling_factors = load_scaling_factors(model, num_layers, out_dir, export_config) starmap_args = [] for key, val in model.items(): @@ -222,12 +210,10 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): # Let's rename/map the key to the old layer name previously. # Since the state dict value has the full layers, let's select the ith layer weights/biases here. - if len(val.size()) == 1: - key_vals = [(rename_key_dist_ckpt(key, 0), val)] - else: - key_vals = [(rename_key_dist_ckpt(key, i), val[i]) for i in range(num_layers)] + layer_vals = [(l, val[l]) for l in range(num_layers)] if len(val.size()) != 1 else [(0, val)] - for (k, v) in key_vals: + for (l, v) in layer_vals: + k = rename_key_dist_ckpt(key, l) starmap_args.append( (tp_rank, out_dir, split_factor, k, [v], storage_type, None, export_config, scaling_factors) ) diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index 9ce44559a30c..1a15e8874dc4 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -31,11 +31,29 @@ "falcon": 'FalconForCausalLM', } +post_layernorm_keys = ["post_attention_layernorm.weight", "post_attention_layernorm.bias", "post_self_attn_layernorm.weight"] +mlp_proj_bias_keys = ["mlp.linear_fc2.bias", "mlp.dense_4h_to_h.bias"] +attention_dense_bias_keys = ["attention.linear_proj.bias", "attention.dense.bias"] +input_layernorm_keys = ["input_layernorm.weight", "input_layernorm.bias"] +pre_layernorm_keys = ["pre_mlp_layernorm.weight", "pre_mlp_layernorm.bias"] +attention_dense_weight_keys = ["attention.linear_proj.weight", "attention.dense.weight"] +mlp_proj_weight_keys = ["mlp.linear_fc2.weight", "mlp.dense_4h_to_h.weight"] +mlp_fc_keys = ["mlp.dense_h_to_4h.weight", "mlp.dense_h_to_4h.bias", "mlp.linear_fc1.weight", "mlp.linear_fc1.bias"] +attention_qkv_bias_keys = ["attention.query_key_value.bias", "attention.linear_qkv.bias"] +attention_qkv_weight_keys = ["attention.query_key_value.weight", "attention.linear_qkv.weight"] +mlp_router_keys = ["mlp.router.weight"] +mlp_fc_expert_keys = ["experts.linear_fc1.weight"] +mlp_proj_experts_keys = ["experts.linear_fc2.weight"] +final_layernorm_keys = ["final_layernorm.weight", "final_layernorm.bias"] +mlp_dense_2_keys = ["mlp.dense_h_to_4h_2.weight", "mlp.dense_h_to_4h_2.bias"] +attention_not_mapped_keys = ["attention.query.weight", "attention.query.bias", "attention.key_value.weight", "attention.key_value.bias"] + def save_val(val, dir, key, tp_num=None): - suffix = "" if tp_num is None else f".{tp_num}.bin" - global weights_dict + if tp_num: + key += f".{tp_num}.bin" + global weights_dict # Transpose linear layer weights to the correct shape. if torch.is_tensor(val): val = val.detach().contiguous() @@ -43,14 +61,14 @@ def save_val(val, dir, key, tp_num=None): val = val.reshape(val.shape[0], -1) val = torch.transpose(val, 0, 1) if key not in weights_dict: - weights_dict[f"{key}{suffix}"] = torch.empty( + weights_dict[key] = torch.empty( val.size(), dtype=val.dtype, layout=val.layout, device="cpu", pin_memory=True ) - weights_dict[f"{key}{suffix}"].copy_(val, non_blocking=True) + weights_dict[key].copy_(val, non_blocking=True) else: if len(val.shape) >= 2: val = np.ascontiguousarray(np.transpose(val.reshape(val.shape[0], -1), [1, 0])) - weights_dict[f"{key}{suffix}"] = val + weights_dict[key] = val def save_split(split_vals, dir, key, i, split_factor): @@ -59,12 +77,13 @@ def save_split(split_vals, dir, key, i, split_factor): def save_expert_split(split_vals, dir, key, i, split_factor): + if tp_num: + key += f".{tp_num}.bin" + for j, val in enumerate(split_vals): tp_num = i * split_factor + j - suffix = "" if tp_num is None else f".{tp_num}.bin" - global weights_dict - weights_dict[f"{key}{suffix}"] = val + weights_dict[key] = val def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False): @@ -177,61 +196,47 @@ def write_int8(vals, dir, base_key, split_dim, tp_rank, split_factor, kv_cache_o def get_suffix(key): return '.' + key.split('.')[-1] + def get_trt_llm_prefix(key): layer_num = key.split(".")[1] return f'transformer.layers.{layer_num}' -def get_new_keyname(key): - layer_prefix = get_trt_llm_prefix(key) - if ("post_attention_layernorm.weight" in key - or "post_attention_layernorm.bias" in key - or "post_self_attn_layernorm.weight" in key): - return f'{layer_prefix}.post_layernorm' + get_suffix(key) +def any_word_in_key(key, words): + return any([word in key for word in words]) - if "mlp.linear_fc2.bias" in key or "mlp.dense_4h_to_h.bias" in key: - return f'{layer_prefix}.mlp.proj.bias' - if "attention.linear_proj.bias" in key or "attention.dense.bias" in key: - return f'{layer_prefix}.attention.dense.bias' - - if "final_layernorm.weight" in key or "final_layernorm.bias" in key: - return key.replace("final_layernorm", "transformer.ln_f") +def sequential_key_map(key, mapping): + for (keywords, mapped) in mapping: + if any_word_in_key(key, keywords): + return mapped - if "input_layernorm.weight" in key or "input_layernorm.bias" in key: - return f'{layer_prefix}.input_layernorm' + get_suffix(key) + return None - if "pre_mlp_layernorm.weight" in key or "pre_mlp_layernorm.bias" in key: - return f'{layer_prefix}.post_layernorm' + get_suffix(key) +def get_trt_llm_infix(key): + mapping = [ + (post_layernorm_keys, '.post_layernorm'), + (mlp_proj_bias_keys, '.mlp.proj'), + (attention_dense_bias_keys, '.attention.dense'), + (input_layernorm_keys, '.input_layernorm'), + (pre_layernorm_keys, '.post_layernorm'), + (attention_dense_weight_keys, '.attention.dense'), + (mlp_proj_weight_keys, '.mlp.proj'), + (mlp_fc_keys, '.mlp.fc'), + (attention_qkv_bias_keys + attention_qkv_weight_keys, '.attention.qkv'), + (mlp_router_keys, '.mlp.router'), + (mlp_fc_keys, '.mlp.fc'), + (mlp_proj_experts_keys, '.mlp.proj') + ] + return sequential_key_map(key, mapping) - if "attention.linear_proj.weight" in key or "attention.dense.weight" in key: - return f'{layer_prefix}.attention.dense.weight' - if "mlp.linear_fc2.weight" in key or "mlp.dense_4h_to_h.weight" in key: - return f'{layer_prefix}.mlp.proj.weight' - - if ( - "mlp.dense_h_to_4h.weight" in key - or "mlp.dense_h_to_4h.bias" in key - or "mlp.linear_fc1.weight" in key - or "mlp.linear_fc1.bias" in key - ): - return f'{layer_prefix}.mlp.fc' + get_suffix(key) - - if "attention.query_key_value.bias" in key or "attention.linear_qkv.bias" in key: - return f'{layer_prefix}.attention.qkv.bias' - - if "attention.query_key_value.weight" in key or "attention.linear_qkv.weight" in key: - return f'{layer_prefix}.attention.qkv.weight' - - if "mlp.router.weight" in key: - return f'{layer_prefix}.mlp.router.weight' - - if "experts.linear_fc1.weight" in key: - return f'{layer_prefix}.mlp.fc.weight' +def get_new_keyname(key): + if any_word_in_key(key, final_layernorm_keys): + return key.replace("final_layernorm", "transformer.ln_f") - if "experts.linear_fc2.weight" in key: - return f'{layer_prefix}.mlp.proj.weight' + if infix := get_trt_llm_infix(key): + return get_trt_llm_prefix(key) + infix + get_suffix(key) return key @@ -249,23 +254,28 @@ def get_scaling_factor_keys(key): first = True -def handle_scaling_factor(key, val, dir, split_gated_activation): - weights_key, activation_key = get_scaling_factor_keys(key) +def load_scaling_factor(key, val, dir, config): + global weights_dict + if not is_scaling_factor(key): + return weights_dict activation_factor = 1 / val[0].view(1) weights_factor = 1 / val[1].view(1) weights_factor_2 = 1 / val[2].view(1) + weights_key, activation_key = get_scaling_factor_keys(key) save_val(torch_to_numpy(activation_factor), dir, activation_key) save_val(torch_to_numpy(weights_factor), dir, weights_key) - # save_val(torch_to_numpy(weights_factor_2), dir, weights_key + '_2') + # TODO + # save_val(torch_to_numpy(weights_factor_2), dir, weights_key + '_2') # global first # if first: # first = False # for i in range(32): # save_val(torch_to_numpy(weights_factor_2), dir, f'transformer.layers.{i}.attention.kv_cache_scaling_factor') + split_gated_activation = config.get("split_gated_activation", False) if split_gated_activation and (("mlp.dense_h_to_4h" in key) or ("mlp.linear_fc1" in key)): layer_prefix = get_trt_llm_prefix(key) mapped_key = f'{layer_prefix}.mlp.gate' @@ -273,24 +283,33 @@ def handle_scaling_factor(key, val, dir, split_gated_activation): save_val(torch_to_numpy(weights_factor), dir, mapped_key + '.weights_scaling_factor') # save_val(torch_to_numpy(weights_factor_2), dir, mapped_key + '.weights_scaling_factor_2') - global weights_dict return weights_dict def cast_val_datatype(vals, trt_llm_key, storage_type, is_fp8_model, scaling_factors): - if is_fp8_model: - fp8_storage_type = torch.float8_e4m3fn - quantized_keys = [ k.split('.weights_scaling_factor')[0] for k in scaling_factors.keys() if '.weights_scaling_factor' in k] - for k in quantized_keys: - if k in trt_llm_key: - storage_type = fp8_storage_type - s = scaling_factors[k + '.weights_scaling_factor'] - vals = [val.to(torch.float32) / s for val in vals] - break + if not is_fp8_model: + return [val.to(storage_type) for val in vals] + + fp8_storage_type = torch.float8_e4m3fn + quantized_keys = [ k.split('.weights_scaling_factor')[0] for k in scaling_factors.keys() if '.weights_scaling_factor' in k ] + for k in quantized_keys: + if k in trt_llm_key: + storage_type = fp8_storage_type + scale = scaling_factors[k + '.weights_scaling_factor'] + vals = [val.to(torch.float32) / scale for val in vals] + break return [val.to(storage_type) for val in vals] +def split_val_gate(vals, convert_on_device): + if convert_on_device: + return [[n] for n in torch.chunk(vals[0], 2, axis=-1)] + + splits = [np.split(val, 2, axis=-1) for val in vals] + return list(zip(*splits)) + + # Note: in multi_query_mode, only query heads are split between multiple GPUs, while key/value head # are not split as there is only one head per key/value. @torch.no_grad() @@ -306,9 +325,6 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t convert_on_device = config.get("convert_on_device", False) is_fp8_model = config.get("fp8", False) - if is_scaling_factor(key): - return handle_scaling_factor(key, vals[0], saved_dir, split_gated_activation) - save_int8 = int8_outputs == "all" or int8_outputs == "kv_cache_only" layer_prefix = get_trt_llm_prefix(key) @@ -327,31 +343,13 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t assert torch.is_tensor(vals[0]) elif torch.is_tensor(vals[0]): vals = [torch_to_numpy(val.cpu()) for val in vals] - if ( - "input_layernorm.weight" in key - or "input_layernorm.bias" in key - or "pre_mlp_layernorm.weight" in key - or "pre_mlp_layernorm.bias" in key - or "attention.dense.bias" in key - or "attention.linear_proj.bias" in key - or "post_attention_layernorm.weight" in key - or "post_attention_layernorm.bias" in key - or "post_self_attn_layernorm.weight" in key - or "mlp.dense_4h_to_h.bias" in key - or "mlp.linear_fc2.bias" in key - or "final_layernorm.weight" in key - or "final_layernorm.bias" in key - ): + + if (any_word_in_key(key, input_layernorm_keys + pre_layernorm_keys + attention_dense_bias_keys + post_layernorm_keys + mlp_proj_bias_keys + final_layernorm_keys) + and (tp_rank == 0 or convert_on_device)): # shared weights, only need to convert the weights of rank 0 - if tp_rank == 0 or convert_on_device: - save_val(vals[0], saved_dir, trt_llm_key) + save_val(vals[0], saved_dir, trt_llm_key) - elif ( - "attention.dense.weight" in key - or "mlp.dense_4h_to_h.weight" in key - or "attention.linear_proj.weight" in key - or "mlp.linear_fc2.weight" in key - ): + elif any_word_in_key(key, attention_dense_weight_keys + mlp_proj_weight_keys): if convert_on_device: save_val(vals[0], saved_dir, trt_llm_key) else: @@ -363,20 +361,11 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t if act_range is not None and int8_outputs == "all": base_key = trt_llm_key.replace(".weight", "") vals_i8 = generate_int8(val, act_range, multi_query_mode=multi_query_mode) - write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor) # is cat dim always defined? - - elif ( - "mlp.dense_h_to_4h.weight" in key - or "mlp.dense_h_to_4h.bias" in key - or "mlp.linear_fc1.weight" in key - or "mlp.linear_fc1.bias" in key - ): + write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor) # TODO is cat dim always defined? + + elif any_word_in_key(key, mlp_fc_keys): if split_gated_activation: - if convert_on_device: - vals, gates = [[n] for n in torch.chunk(vals[0], 2, axis=-1)] - else: - splits = [np.split(val, 2, axis=-1) for val in vals] - vals, gates = list(zip(*splits)) + vals, gates = split_val_gate(vals, convert_on_device) if convert_on_device: save_val(vals[0], saved_dir, trt_llm_key) @@ -401,7 +390,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t split_vals = np.split(gate, split_factor, axis=cat_dim) save_split(split_vals, saved_dir, gate_key, tp_rank, split_factor) - elif "mlp.dense_h_to_4h_2.weight" in key or "mlp.dense_h_to_4h_2.bias" in key: + elif any_word_in_key(key, mlp_dense_2_keys): if convert_on_device: save_val(vals[0], saved_dir, trt_llm_key) else: @@ -415,7 +404,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t vals_i8 = generate_int8(val, act_range, multi_query_mode=multi_query_mode) write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor) - elif "attention.query_key_value.bias" in key or "attention.linear_qkv.bias" in key: + elif any_word_in_key(key, attention_qkv_bias_keys): qkv_hidden_dim = vals[0].shape[0] size_per_head = qkv_hidden_dim // (num_attention_heads + 2 * num_kv_heads) q_num = num_attention_heads // num_kv_heads @@ -446,7 +435,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t ] save_split(split_vals, saved_dir, trt_llm_key, tp_rank, split_factor) - elif "attention.query_key_value.weight" in key or "attention.linear_qkv.weight" in key: + elif any_word_in_key(key, attention_qkv_weight_keys): assert use_attention_nemo_shape, "Only support NEMO shape for QKV weights" hidden_dim = vals[0].shape[0] if size_per_head is None: @@ -509,17 +498,15 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t split_factor, kv_cache_only=int8_outputs == "kv_cache_only", ) - elif ( - "attention.query.weight" in key - or "attention.query.bias" in key - or "attention.key_value.weight" in key - or "attention.key_value.bias" in key - ): + + elif any_word_in_key(key, attention_not_mapped_keys): pass - elif "mlp.router.weight" in key: + + elif any_word_in_key(key, mlp_router_keys): val = np.concatenate(vals, axis=1) save_val(val, saved_dir, trt_llm_key) - elif "experts.linear_fc1.weight" in key: + + elif any_word_in_key(key, mlp_fc_expert_keys): cat_dim = -1 val = np.concatenate(vals, axis=cat_dim) w1, w3 = np.split(val, 2, axis=1) @@ -531,7 +518,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t split_vals = [np.concatenate(item, axis=1) for item in zip(split_w3s, split_w1s)] save_expert_split(split_vals, saved_dir, trt_llm_key, tp_rank, split_factor) - elif "experts.linear_fc2.weight" in key: + elif any_word_in_key(key, mlp_proj_experts_keys): cat_dim = -1 val = np.concatenate(vals, axis=cat_dim) split_vals = np.split(val, split_factor, axis=cat_dim) diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index da148f44b989..2357c8a57269 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -27,7 +27,7 @@ import zarr from tensorrt_llm._utils import np_bfloat16 from torch.distributed.checkpoint import FileSystemReader -from torch.distributed.checkpoint.metadata import TensorStorageMetadata, BytesStorageMetadata +from torch.distributed.checkpoint.metadata import BytesStorageMetadata, TensorStorageMetadata from torch.distributed.checkpoint.state_dict_loader import load_state_dict from transformers import AutoTokenizer, PreTrainedTokenizer @@ -72,17 +72,20 @@ def get_extra_state_key(state_dict): return key return False + def unpack_extra_state_key(key): basename = key.split('/')[0] size = int(key.split('/')[1].split('_')[-1]) return basename, size + def clear_loaded_extra_states(state_dict, basename): to_remove = [k for k in state_dict.keys() if basename + '/' in k] for key in to_remove: state_dict.pop(key) return state_dict + def load_scaling_factors(state_dict, basename, size): scales = [] for layer in range(size): @@ -98,6 +101,7 @@ def load_scaling_factors(state_dict, basename, size): all_scales = torch.stack(scales) return all_scales + def standarize_distributed_scaling_factors(state_dict): while key := get_extra_state_key(state_dict): basename, size = unpack_extra_state_key(key) @@ -105,7 +109,7 @@ def standarize_distributed_scaling_factors(state_dict): if scaling_factors != []: state_dict[basename + '.scale_fwd'] = scaling_factors state_dict = clear_loaded_extra_states(state_dict, basename) - + return state_dict @@ -153,9 +157,11 @@ def load_sharded_pickle_extra_state_scale(dir): all_scales = torch.stack(scales) return all_scales + def contains_extra_states(subdir): return list(subdir.glob('shard_0_*.pt')) != [] + def load_extra_state_from_pickle(sharded_state_dict, subdir): if scales := load_sharded_pickle_extra_state_scale(subdir): key = subdir.name + '.scale_fwd' @@ -163,6 +169,7 @@ def load_extra_state_from_pickle(sharded_state_dict, subdir): return sharded_state_dict + def load_sharded_metadata_zarr(checkpoint_dir: Union[Path, TarPath], torch_tensor=True): sharded_state_dict = {} for subdir in checkpoint_dir.iterdir(): @@ -182,6 +189,7 @@ def load_sharded_metadata_zarr(checkpoint_dir: Union[Path, TarPath], torch_tenso sharded_state_dict[key] = torch.from_numpy(arr[:].view(np.int16)).view(torch.bfloat16) else: from tensorrt_llm._utils import str_dtype_to_torch + sharded_state_dict[key] = torch.from_numpy(arr[:]).view(str_dtype_to_torch(arr.dtype.name)) else: sharded_state_dict[key] = arr[:]