From f966d05479c9584bf357ccb29bb41ff4c5dcb44c Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Wed, 14 Aug 2024 14:34:31 -0700 Subject: [PATCH 01/21] initial commit Signed-off-by: Piotr Kaminski --- nemo/export/tarutils.py | 5 +- .../trt_llm/converter/model_converter.py | 7 +- .../converter/model_to_trt_llm_ckpt.py | 51 ++--- nemo/export/trt_llm/converter/utils.py | 197 ++++++++++++------ .../trt_llm/nemo_ckpt_loader/nemo_file.py | 110 ++++++++-- 5 files changed, 259 insertions(+), 111 deletions(-) diff --git a/nemo/export/tarutils.py b/nemo/export/tarutils.py index b93f65274120..b9af03e5bbb6 100644 --- a/nemo/export/tarutils.py +++ b/nemo/export/tarutils.py @@ -20,7 +20,7 @@ import zarr.storage -class TarPath: +class TarPath(os.PathLike): """ A class that represents a path inside a TAR archive and behaves like pathlib.Path. @@ -58,6 +58,9 @@ def __truediv__(self, key) -> 'TarPath': def __str__(self) -> str: return os.path.join(self._tar.name, self._relpath) + def __fspath__(self): + return os.path.join(self._tar.name, self._relpath) + @property def tarobject(self): return self._tar diff --git a/nemo/export/trt_llm/converter/model_converter.py b/nemo/export/trt_llm/converter/model_converter.py index 60d50316e9ed..a5b4b9af41a1 100755 --- a/nemo/export/trt_llm/converter/model_converter.py +++ b/nemo/export/trt_llm/converter/model_converter.py @@ -18,6 +18,7 @@ from typing import Dict, List, Tuple import numpy as np +import torch import tensorrt_llm from tensorrt_llm._utils import pad_vocab_size from tensorrt_llm.functional import non_gated_version @@ -159,7 +160,7 @@ def model_to_trtllm_ckpt( 'embedding_sharding_dim': 0, 'share_embedding_table': use_embedding_sharing, 'quantization': { - 'quant_algo': None, + 'quant_algo': "FP8" if nemo_model_config.get('fp8', False) else None, 'kv_cache_quant_algo': None, }, 'bias': nemo_model_config.get('bias'), @@ -261,9 +262,7 @@ def model_to_trtllm_ckpt( if mapping.is_last_pp_rank(): if has_lm_head: - weights_dict_local["lm_head.weight"] = np.ascontiguousarray( - split(lm_head_weight, mapping.tp_size, mapping.tp_rank) - ) + weights_dict_local["lm_head.weight"] = split(lm_head_weight, mapping.tp_size, mapping.tp_rank).contiguous() weights_dict_local["transformer.ln_f.weight"] = weights_dict["transformer.ln_f.weight"] ln_f_bias = weights_dict.get("transformer.ln_f.bias") diff --git a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py index 0345f979b8c2..e7557fc53675 100644 --- a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py +++ b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py @@ -148,6 +148,7 @@ def convert_model_to_trt_llm_ckpt( "use_attention_nemo_shape": True, "transpose_weights": True, "use_parallel_embedding": use_parallel_embedding, + "fp8": nemo_model_config.get('fp8', False), } # split_factor: in how many parts a TP training node is split @@ -158,7 +159,7 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): if tp_idx == 0 and pp_idx == 0: if has_position_embedding: val = model[get_layer_name("position_embedding", prefix)] - val = torch_to_numpy(val.to(storage_type).cpu()) + val = val.to(storage_type).cpu() model_level_weights["transformer.position_embedding.weight"].append(val) if pp_idx == 0: val = model.get("state_dict", model)[get_layer_name("word_embedding", prefix)] @@ -171,11 +172,11 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): pad_width = vocab_size_padded - vocab_size val = torch.nn.functional.pad(val, (0, 0, 0, pad_width), value=0) - val = torch_to_numpy(val.to(storage_type).cpu()) + val = val.to(storage_type).cpu() model_level_weights["transformer.vocab_embedding.weight"].append(val) if has_lm_head and pp_idx == training_pp_size - 1: val = model.get("state_dict", model)[get_layer_name("output_layer", prefix)] - val = torch_to_numpy(val.to(storage_type).cpu()) + val = val.to(storage_type).cpu() model_level_weights["lm_head.weight"].append(val) weights_dict = {} @@ -187,8 +188,24 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): starmap_args = [] for key, val in model.items(): - if "_extra_state" not in key: - if len(val.size()) == 1: + if len(val.size()) == 1: + starmap_args.append( + ( + tp_rank, + out_dir, + split_factor, + # Let's rename/map the key to the old layer name previously. You can try printing out + # the rename_key output of the old llama checkpoint and compare. + rename_key_dist_ckpt(key, 0), + # Since the state dict value has the full layers, let's select the ith layer weights/biases here. + [val], + storage_type, + None, + export_config, + ) + ) + else: + for i in range(num_layers): starmap_args.append( ( tp_rank, @@ -196,31 +213,14 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): split_factor, # Let's rename/map the key to the old layer name previously. You can try printing out # the rename_key output of the old llama checkpoint and compare. - rename_key_dist_ckpt(key, 0), + rename_key_dist_ckpt(key, i), # Since the state dict value has the full layers, let's select the ith layer weights/biases here. - [val], + [val[i]], storage_type, None, export_config, ) ) - else: - for i in range(num_layers): - starmap_args.append( - ( - tp_rank, - out_dir, - split_factor, - # Let's rename/map the key to the old layer name previously. You can try printing out - # the rename_key output of the old llama checkpoint and compare. - rename_key_dist_ckpt(key, i), - # Since the state dict value has the full layers, let's select the ith layer weights/biases here. - [val[i]], - storage_type, - None, - export_config, - ) - ) starmap_args = tqdm(starmap_args, desc="saving weights") @@ -236,7 +236,7 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): weights_dict.update(weights_dict_local) for key, values in model_level_weights.items(): - model_level_weights[key] = np.concatenate(values, axis=0) + model_level_weights[key] = torch.concatenate(values, axis=0) weights_dict[key] = model_level_weights[key] return weights_dict @@ -314,6 +314,7 @@ def dist_model_to_trt_llm_ckpt( "convert_on_device": True, "use_attention_nemo_shape": True, "transpose_weights": True, + "fp8": nemo_model_config.get('fp8', False), } starmap_config = { diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index eab17167cbd5..923632ef6847 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -174,6 +174,107 @@ def write_int8(vals, dir, base_key, split_dim, tp_rank, split_factor, kv_cache_o save_val(vals[save_key], dir, f"{base_key}.{save_key}") +def get_suffix(key): + return '.' + key.split('.')[-1] + +def get_layer_prefix(key): + layer_num = key.split(".")[1] + return f'transformer.layers.{layer_num}' + +def get_new_keyname(key): + layer_prefix = get_layer_prefix(key) + + if ("post_attention_layernorm.weight" in key + or "post_attention_layernorm.bias" in key + or "post_self_attn_layernorm.weight" in key): + return f'{layer_prefix}.post_layernorm' + get_suffix(key) + + if "mlp.linear_fc2.bias" in key or "mlp.dense_4h_to_h.bias" in key: + return f'{layer_prefix}.mlp.proj.bias' + + if "attention.linear_proj.bias" in key or "attention.dense.bias" in key: + return f'{layer_prefix}.attention.dense.bias' + + if "final_layernorm.weight" in key or "final_layernorm.bias" in key: + return key.replace("final_layernorm", "transformer.ln_f") + + if "input_layernorm.weight" in key or "input_layernorm.bias" in key: + return f'{layer_prefix}.input_layernorm' + get_suffix(key) + + if "pre_mlp_layernorm.weight" in key or "pre_mlp_layernorm.bias" in key: + return f'{layer_prefix}.post_layernorm' + get_suffix(key) + + if "attention.linear_proj.weight" in key or "attention.dense.weight" in key: + return f'{layer_prefix}.attention.dense.weight' + + if "mlp.linear_fc2.weight" in key or "mlp.dense_4h_to_h.weight" in key: + return f'{layer_prefix}.mlp.proj.weight' + + if ( + "mlp.dense_h_to_4h.weight" in key + or "mlp.dense_h_to_4h.bias" in key + or "mlp.linear_fc1.weight" in key + or "mlp.linear_fc1.bias" in key + ): + return f'{layer_prefix}.mlp.fc' + get_suffix(key) + + if "attention.query_key_value.bias" in key or "attention.linear_qkv.bias" in key: + return f'{layer_prefix}.attention.qkv.bias' + + if "attention.query_key_value.weight" in key or "attention.linear_qkv.weight" in key: + return f'{layer_prefix}.attention.qkv.weight' + + if "mlp.router.weight" in key: + return f'{layer_prefix}.mlp.router.weight' + + if "experts.linear_fc1.weight" in key: + return f'{layer_prefix}.mlp.fc.weight' + + if "experts.linear_fc2.weight" in key: + return f'{layer_prefix}.mlp.proj.weight' + + return key + + +def is_scaling_factor(key): + return "scale_fwd" in key + +def get_scaling_factor_keys(key): + base_key = '.'.join(key.split('.')[:-2]) + '.weight' + base_key = '.'.join(get_new_keyname(base_key).split('.')[:-1]) + weight_scale = base_key + '.weights_scaling_factor' + activation_scale = base_key + '.activation_scaling_factor' + return weight_scale, activation_scale + +def handle_scaling_factor(key, val, dir, split_gated_activation): + weights_key, activation_key = get_scaling_factor_keys(key) + weights_factors = 1 / val[1].view(1) + activation_factors = 1 / val[0].view(1) + save_val(torch_to_numpy(weights_factors), dir, weights_key) + save_val(torch_to_numpy(activation_factors), dir, activation_key) + + if split_gated_activation and (("mlp.dense_h_to_4h" in key) or ("mlp.linear_fc1" in key)): + layer_num = key.split(".")[1] + layer_prefix = f'transformer.layers.{layer_num}' + mapped_key = f'{layer_prefix}.mlp.gate' + save_val(torch_to_numpy(weights_factors), dir, mapped_key + '.weights_scaling_factor') + save_val(torch_to_numpy(activation_factors), dir, mapped_key + '.activation_scaling_factor') + + global weights_dict + return weights_dict + + +def cast_val_datatype(vals, key, storage_type, is_fp8_model): + if is_fp8_model: + fp8_storage_type = torch.float8_e4m3fn + quantized_keys = ['attention.dense', 'attention.linear', 'attention.query_key_value', 'attention.linear_qkv', 'mlp.linear', 'mlp.dense'] + for k in quantized_keys: + if k in key: + storage_type = fp8_storage_type + return [val.to(storage_type) for val in vals] + + return [val.to(storage_type) for val in vals] + # Note: in multi_query_mode, only query heads are split between multiple GPUs, while key/value head # are not split as there is only one head per key/value. @torch.no_grad() @@ -187,11 +288,13 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t num_kv_heads = config.get("num_kv_heads", num_attention_heads) size_per_head = config.get("kv_channels", None) convert_on_device = config.get("convert_on_device", False) + is_fp8_model = config.get("fp8", False) - save_int8 = int8_outputs == "all" or int8_outputs == "kv_cache_only" + if is_scaling_factor(key): + return handle_scaling_factor(key, vals[0], saved_dir, split_gated_activation) - layer_num = key.split(".")[1] - layer_prefix = f'transformer.layers.{layer_num}' + save_int8 = int8_outputs == "all" or int8_outputs == "kv_cache_only" + layer_prefix = get_layer_prefix(key) if not isinstance(vals, list): vals = [vals] @@ -201,13 +304,17 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t if "layernorm.weight" in key and config.get("apply_layernorm_1p", False): vals = [val.float() + 1.0 for val in vals] - vals = [val.to(storage_type) for val in vals] + print("key ", key, vals) + vals = cast_val_datatype(vals, key, storage_type, is_fp8_model) + print(vals) + # vals = [val.to(storage_type) for val in vals] if convert_on_device: assert len(vals) == 1 # Should only convert a single device param per call assert torch.is_tensor(vals[0]) elif torch.is_tensor(vals[0]): vals = [torch_to_numpy(val.cpu()) for val in vals] + trt_llm_key = get_new_keyname(key) if ( "input_layernorm.weight" in key or "input_layernorm.bias" in key @@ -224,29 +331,8 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t or "final_layernorm.bias" in key ): # shared weights, only need to convert the weights of rank 0 - if "post_self_attn_layernorm" in key or "post_attention_layernorm" in key: - if key.endswith('weight'): - key = f'{layer_prefix}.post_layernorm.weight' - else: - key = f'{layer_prefix}.post_layernorm.bias' - elif "mlp.linear_fc2.bias" in key or "mlp.dense_4h_to_h.bias" in key: - key = f'{layer_prefix}.mlp.proj.bias' - elif "attention.linear_proj.bias" in key or "attention.dense.bias" in key: - key = f'{layer_prefix}.attention.dense.bias' - elif "final_layernorm" in key: - key = key.replace("final_layernorm", "transformer.ln_f") - elif "input_layernorm" in key: - if key.endswith('weight'): - key = f'{layer_prefix}.input_layernorm.weight' - else: - key = f'{layer_prefix}.input_layernorm.bias' - elif "pre_mlp_layernorm" in key: - if key.endswith('weight'): - key = f'{layer_prefix}.post_layernorm.weight' - else: - key = f'{layer_prefix}.post_layernorm.bias' if tp_rank == 0 or convert_on_device: - save_val(vals[0], saved_dir, key) + save_val(vals[0], saved_dir, trt_llm_key) elif ( "attention.dense.weight" in key @@ -254,21 +340,16 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t or "attention.linear_proj.weight" in key or "mlp.linear_fc2.weight" in key ): - if "attention.linear_proj.weight" in key or "attention.dense.weight" in key: - key = f'{layer_prefix}.attention.dense.weight' - elif "mlp.linear_fc2.weight" in key or "mlp.dense_4h_to_h.weight" in key: - key = f'{layer_prefix}.mlp.proj.weight' - if convert_on_device: - save_val(vals[0], saved_dir, key) + save_val(vals[0], saved_dir, trt_llm_key) else: cat_dim = 0 val = np.concatenate(vals, axis=cat_dim) split_vals = np.split(val, split_factor, axis=cat_dim) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) + save_split(split_vals, saved_dir, trt_llm_key, tp_rank, split_factor) if act_range is not None and int8_outputs == "all": - base_key = key.replace(".weight", "") + base_key = trt_llm_key.replace(".weight", "") vals_i8 = generate_int8(val, act_range, multi_query_mode=multi_query_mode) write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor) @@ -278,11 +359,6 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t or "mlp.linear_fc1.weight" in key or "mlp.linear_fc1.bias" in key ): - if key.endswith("weight"): - key = f'{layer_prefix}.mlp.fc.weight' - else: - key = f'{layer_prefix}.mlp.fc.bias' - if split_gated_activation: if convert_on_device: vals, gates = [[n] for n in torch.chunk(vals[0], 2, axis=-1)] @@ -291,48 +367,43 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t vals, gates = list(zip(*splits)) if convert_on_device: - save_val(vals[0], saved_dir, key) + save_val(vals[0], saved_dir, trt_llm_key) else: cat_dim = -1 val = np.concatenate(vals, axis=cat_dim) split_vals = np.split(val, split_factor, axis=cat_dim) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) + save_split(split_vals, saved_dir, trt_llm_key, tp_rank, split_factor) if act_range is not None and int8_outputs == "all": - base_key = key.replace(".weight", "") + base_key = trt_llm_key.replace(".weight", "") vals_i8 = generate_int8(val, act_range, multi_query_mode=multi_query_mode) write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor) if split_gated_activation: assert not save_int8 - if key.endswith("weight"): - key = f'{layer_prefix}.mlp.gate.weight' - else: - key = f'{layer_prefix}.mlp.gate.bias' - + gate_key = f'{layer_prefix}.mlp.gate' + get_suffix(trt_llm_key) if convert_on_device: - save_val(gates[0], saved_dir, key) + save_val(gates[0], saved_dir, gate_key) else: gate = np.concatenate(gates, axis=cat_dim) split_vals = np.split(gate, split_factor, axis=cat_dim) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) + save_split(split_vals, saved_dir, gate_key, tp_rank, split_factor) elif "mlp.dense_h_to_4h_2.weight" in key or "mlp.dense_h_to_4h_2.bias" in key: if convert_on_device: - save_val(vals[0], saved_dir, key) + save_val(vals[0], saved_dir, trt_llm_key) else: cat_dim = -1 val = np.concatenate(vals, axis=cat_dim) split_vals = np.split(val, split_factor, axis=cat_dim) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) + save_split(split_vals, saved_dir, trt_llm_key, tp_rank, split_factor) if act_range is not None and int8_outputs == "all": - base_key = key.replace(".weight", "") + base_key = trt_llm_key.replace(".weight", "") vals_i8 = generate_int8(val, act_range, multi_query_mode=multi_query_mode) write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor) elif "attention.query_key_value.bias" in key or "attention.linear_qkv.bias" in key: - key = f'{layer_prefix}.attention.qkv.bias' qkv_hidden_dim = vals[0].shape[0] size_per_head = qkv_hidden_dim // (num_attention_heads + 2 * num_kv_heads) q_num = num_attention_heads // num_kv_heads @@ -349,7 +420,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t if convert_on_device: qkv = torch.split(val, [q_num, 1, 1], dim=1) split_vals = torch.concatenate([qkv[0].reshape(-1), qkv[1].reshape(-1), qkv[2].reshape(-1)], dim=1) - save_val(split_vals, saved_dir, key) + save_val(split_vals, saved_dir, trt_llm_key) else: qkv = np.split(val, [q_num, q_num + 1], axis=1) q_split = np.split(qkv[0], split_factor, axis=0) @@ -361,10 +432,9 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t np.concatenate([q_split[i].reshape(-1), k_split[i].reshape(-1), v_split[i].reshape(-1)], axis=0) for i in range(split_factor) ] - save_split(split_vals, saved_dir, key, tp_rank, split_factor) + save_split(split_vals, saved_dir, trt_llm_key, tp_rank, split_factor) elif "attention.query_key_value.weight" in key or "attention.linear_qkv.weight" in key: - key = f'{layer_prefix}.attention.qkv.weight' assert use_attention_nemo_shape, "Only support NEMO shape for QKV weights" hidden_dim = vals[0].shape[0] if size_per_head is None: @@ -380,7 +450,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t split_vals = torch.concatenate( [qkv[0].reshape(hidden_dim, -1), qkv[1].reshape(hidden_dim, -1), qkv[2].reshape(hidden_dim, -1)], dim=1 ) - save_val(split_vals, saved_dir, key) + save_val(split_vals, saved_dir, trt_llm_key) else: len_vals = len(vals) val = np.concatenate(vals, axis=1) @@ -414,10 +484,10 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t ) for i in range(split_factor) ] - save_split(split_vals, saved_dir, key, tp_rank, split_factor) + save_split(split_vals, saved_dir, trt_llm_key, tp_rank, split_factor) if save_int8: - base_key = key.replace(".weight", "") + base_key = trt_llm_key.replace(".weight", "") vals_i8 = generate_int8(val, act_range, is_qkv=True, multi_query_mode=multi_query_mode) write_int8( vals_i8, @@ -437,8 +507,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t pass elif "mlp.router.weight" in key: val = np.concatenate(vals, axis=1) - key = f'{layer_prefix}.mlp.router.weight' - save_val(val, saved_dir, key) + save_val(val, saved_dir, trt_llm_key) elif "experts.linear_fc1.weight" in key: cat_dim = -1 val = np.concatenate(vals, axis=cat_dim) @@ -449,15 +518,13 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t split_w3s = np.split(w3, split_factor, axis=1) split_vals = [np.concatenate(item, axis=1) for item in zip(split_w3s, split_w1s)] - key = f'{layer_prefix}.mlp.fc.weight' - save_expert_split(split_vals, saved_dir, key, tp_rank, split_factor) + save_expert_split(split_vals, saved_dir, trt_llm_key, tp_rank, split_factor) elif "experts.linear_fc2.weight" in key: cat_dim = -1 val = np.concatenate(vals, axis=cat_dim) split_vals = np.split(val, split_factor, axis=cat_dim) - key = f'{layer_prefix}.mlp.proj.weight' - save_expert_split(split_vals, saved_dir, key, tp_rank, split_factor) + save_expert_split(split_vals, saved_dir, trt_llm_key, tp_rank, split_factor) else: print(f"[WARNING] {key} not handled by converter") diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index 1d473f497f51..cb26fc3c52b7 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -25,7 +25,10 @@ import torch import yaml import zarr +from tensorrt_llm._utils import np_bfloat16 from torch.distributed.checkpoint import FileSystemReader +from torch.distributed.checkpoint.metadata import TensorStorageMetadata, BytesStorageMetadata +from torch.distributed.checkpoint.state_dict_loader import load_state_dict from transformers import AutoTokenizer, PreTrainedTokenizer from nemo.export.sentencepiece_tokenizer import SentencePieceTokenizer @@ -56,9 +59,55 @@ class TarFileSystemReader(FileSystemReader): """ def __init__(self, path: Union[Path, TarPath]) -> None: - """No call to super().__init__ because it expects pure Path.""" - self.path = path - self.storage_data = dict() + """Makes sure that super().__init__ gets a pure path as expected.""" + super_path = str(path) if isinstance(path, TarPath) else path + super().__init__(super_path) + if isinstance(path, TarPath): + self.path = path # overwrites path set in super().__init__ call + + +def _get_extra_state_key(state_dict): + for key in state_dict.keys(): + if '_extra_state/' in key: + return key + return False + +def _unpack_extra_state_key(key): + basename = key.split('/')[0] + size = int(key.split('/')[1].split('_')[-1]) + return basename, size + +def _clear_key_basename_from_state_dict(state_dict, basename): + # '/' is important, as scaling factors are saved to basename.scaling_fwd + to_remove = [k for k in state_dict.keys() if basename + '/' in k] + for key in to_remove: + state_dict.pop(key) + return state_dict + +def _load_scaling_factors(state_dict, basename, size): + scales = [] + for layer in range(size): + keyname = f'{basename}/shard_{layer}_{size}' + extra_state = state_dict[keyname][0] + extra_state.seek(0) + extra_state = torch.load(extra_state) + + if 'scale_fwd' not in extra_state.keys(): + return [] + scales.append(extra_state['scale_fwd'].cpu()) + + all_scales = torch.stack(scales) + return all_scales + +def standarize_distributed_scaling_factors(state_dict): + while key := _get_extra_state_key(state_dict): + basename, size = _unpack_extra_state_key(key) + scaling_factors = _load_scaling_factors(state_dict, basename, size) + if scaling_factors != []: + state_dict[basename + '.scale_fwd'] = scaling_factors + state_dict = _clear_key_basename_from_state_dict(state_dict, basename) + + return state_dict def load_sharded_metadata_torch_dist(checkpoint_dir: Union[Path, TarPath], torch_tensor=True): @@ -66,15 +115,17 @@ def load_sharded_metadata_torch_dist(checkpoint_dir: Union[Path, TarPath], torch metadata = fs_reader.read_metadata() state_dict = { - k: torch.empty(tp.size, dtype=tp.properties.dtype) + k: torch.empty(tp.size, dtype=tp.properties.dtype) if isinstance(tp, TensorStorageMetadata) else {} for k, tp in metadata.state_dict_metadata.items() - if isinstance(tp, TensorStorageMetadata) + if isinstance(tp, TensorStorageMetadata) or isinstance(tp, BytesStorageMetadata) } + load_state_dict( state_dict, storage_reader=fs_reader, no_dist=True, ) + state_dict = standarize_distributed_scaling_factors(state_dict) if not torch_tensor: for k, v in state_dict.items(): @@ -85,24 +136,51 @@ def load_sharded_metadata_torch_dist(checkpoint_dir: Union[Path, TarPath], torch return state_dict +def load_sharded_pickle_extra_state_scale(dir): + scales = [] + + i = 0 + while pt_file_list := list(dir.glob(f'shard_{i}_*.pt')): + pt_file = pt_file_list[0] + checkpoint = torch.load(pt_file) + checkpoint.seek(0) + state_dict = torch.load(checkpoint) + if not 'scale_fwd' in state_dict: + return [] + scale = state_dict['scale_fwd'].cpu() + scales.append(scale) + i += 1 + + all_scales = torch.stack(scales) + return all_scales + def load_sharded_metadata_zarr(checkpoint_dir: Union[Path, TarPath], torch_tensor=True): sharded_state_dict = {} for subdir in checkpoint_dir.iterdir(): - if not subdir.is_dir() or not (subdir / '.zarray').exists(): + if not subdir.is_dir(): continue - key = subdir.name - - zstore = ZarrPathStore(subdir) - arr = zarr.open(zstore, 'r') - if torch_tensor: - # sharded_state_dict[key] = torch.from_numpy(arr[:].astype("float32")).to(dtype=torch.bfloat16) - if arr.dtype.name == "bfloat16": - sharded_state_dict[key] = torch.from_numpy(arr[:].view(np.int16)).view(torch.bfloat16) + key = subdir.name + if list(subdir.glob('shard_0_*.pt')): + scales = load_sharded_pickle_extra_state_scale(subdir) + if scales != []: + key = key + '.scale_fwd' + sharded_state_dict[key] = scales + elif (subdir / '.zarray').exists(): + zstore = ZarrPathStore(subdir) + arr = zarr.open(zstore, 'r') + + if torch_tensor: + # sharded_state_dict[key] = torch.from_numpy(arr[:].astype("float32")).to(dtype=torch.bfloat16) + if arr.dtype.name == "bfloat16": + sharded_state_dict[key] = torch.from_numpy(arr[:].view(np.int16)).view(torch.bfloat16) + else: + from tensorrt_llm._utils import str_dtype_to_torch + sharded_state_dict[key] = torch.from_numpy(arr[:]).view(str_dtype_to_torch(arr.dtype.name)) else: - sharded_state_dict[key] = torch.from_numpy(arr[:]) + sharded_state_dict[key] = arr[:] else: - sharded_state_dict[key] = arr[:] + continue return sharded_state_dict From 50872684ff89959272bfe1996e3859cf8d6d3aaa Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Wed, 14 Aug 2024 14:34:55 -0700 Subject: [PATCH 02/21] PR draft Signed-off-by: Piotr Kaminski --- nemo/export/trt_llm/converter/utils.py | 26 ++++++++++--------- .../trt_llm/nemo_ckpt_loader/nemo_file.py | 16 ++++++------ 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index 923632ef6847..5f58c8c6a1d3 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -246,19 +246,24 @@ def get_scaling_factor_keys(key): activation_scale = base_key + '.activation_scaling_factor' return weight_scale, activation_scale + def handle_scaling_factor(key, val, dir, split_gated_activation): weights_key, activation_key = get_scaling_factor_keys(key) - weights_factors = 1 / val[1].view(1) - activation_factors = 1 / val[0].view(1) - save_val(torch_to_numpy(weights_factors), dir, weights_key) - save_val(torch_to_numpy(activation_factors), dir, activation_key) + + activation_factor = 1 / val[0].view(1) + weights_factor = 1 / val[1].view(1) + # weights_factor_2 = 1 / val[2].view(1) + + save_val(torch_to_numpy(activation_factor), dir, activation_key) + save_val(torch_to_numpy(weights_factor), dir, weights_key) + #save_val(torch_to_numpy(weights_factor_2), dir, weights_key + '_2') if split_gated_activation and (("mlp.dense_h_to_4h" in key) or ("mlp.linear_fc1" in key)): - layer_num = key.split(".")[1] - layer_prefix = f'transformer.layers.{layer_num}' + layer_prefix = get_layer_prefix(key) mapped_key = f'{layer_prefix}.mlp.gate' - save_val(torch_to_numpy(weights_factors), dir, mapped_key + '.weights_scaling_factor') - save_val(torch_to_numpy(activation_factors), dir, mapped_key + '.activation_scaling_factor') + save_val(torch_to_numpy(activation_factor), dir, mapped_key + '.activation_scaling_factor') + save_val(torch_to_numpy(weights_factor), dir, mapped_key + '.weights_scaling_factor') + #save_val(torch_to_numpy(weights_factor_2), dir, mapped_key + '.weights_scaling_factor_2') global weights_dict return weights_dict @@ -271,7 +276,7 @@ def cast_val_datatype(vals, key, storage_type, is_fp8_model): for k in quantized_keys: if k in key: storage_type = fp8_storage_type - return [val.to(storage_type) for val in vals] + break return [val.to(storage_type) for val in vals] @@ -304,10 +309,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t if "layernorm.weight" in key and config.get("apply_layernorm_1p", False): vals = [val.float() + 1.0 for val in vals] - print("key ", key, vals) vals = cast_val_datatype(vals, key, storage_type, is_fp8_model) - print(vals) - # vals = [val.to(storage_type) for val in vals] if convert_on_device: assert len(vals) == 1 # Should only convert a single device param per call assert torch.is_tensor(vals[0]) diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index cb26fc3c52b7..e6df3b72994f 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -66,25 +66,25 @@ def __init__(self, path: Union[Path, TarPath]) -> None: self.path = path # overwrites path set in super().__init__ call -def _get_extra_state_key(state_dict): +def get_extra_state_key(state_dict): for key in state_dict.keys(): if '_extra_state/' in key: return key return False -def _unpack_extra_state_key(key): +def unpack_extra_state_key(key): basename = key.split('/')[0] size = int(key.split('/')[1].split('_')[-1]) return basename, size -def _clear_key_basename_from_state_dict(state_dict, basename): +def clear_key_basename_from_state_dict(state_dict, basename): # '/' is important, as scaling factors are saved to basename.scaling_fwd to_remove = [k for k in state_dict.keys() if basename + '/' in k] for key in to_remove: state_dict.pop(key) return state_dict -def _load_scaling_factors(state_dict, basename, size): +def load_scaling_factors(state_dict, basename, size): scales = [] for layer in range(size): keyname = f'{basename}/shard_{layer}_{size}' @@ -100,12 +100,12 @@ def _load_scaling_factors(state_dict, basename, size): return all_scales def standarize_distributed_scaling_factors(state_dict): - while key := _get_extra_state_key(state_dict): - basename, size = _unpack_extra_state_key(key) - scaling_factors = _load_scaling_factors(state_dict, basename, size) + while key := get_extra_state_key(state_dict): + basename, size = unpack_extra_state_key(key) + scaling_factors = load_scaling_factors(state_dict, basename, size) if scaling_factors != []: state_dict[basename + '.scale_fwd'] = scaling_factors - state_dict = _clear_key_basename_from_state_dict(state_dict, basename) + state_dict = clear_key_basename_from_state_dict(state_dict, basename) return state_dict From 61d0f4743ba138109f43c518a3e736e356a74b86 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Wed, 14 Aug 2024 14:35:10 -0700 Subject: [PATCH 03/21] fixed scaling weights Signed-off-by: Piotr Kaminski --- .../trt_llm/converter/model_converter.py | 9 ++- .../converter/model_to_trt_llm_ckpt.py | 73 +++++++++++-------- nemo/export/trt_llm/converter/utils.py | 62 +++++++++------- .../trt_llm/nemo_ckpt_loader/nemo_file.py | 34 +++++---- 4 files changed, 104 insertions(+), 74 deletions(-) diff --git a/nemo/export/trt_llm/converter/model_converter.py b/nemo/export/trt_llm/converter/model_converter.py index a5b4b9af41a1..7600224ff373 100755 --- a/nemo/export/trt_llm/converter/model_converter.py +++ b/nemo/export/trt_llm/converter/model_converter.py @@ -131,8 +131,11 @@ def model_to_trtllm_ckpt( vocab_size_padded = pad_vocab_size(vocab_size, tensor_parallel_size) if has_lm_head else vocab_size if has_lm_head and vocab_size_padded != vocab_size: - pad_width = vocab_size_padded - vocab_size - lm_head_weight = np.pad(lm_head_weight, ((0, pad_width), (0, 0)), "constant", constant_values=0) + padding = (0, 0, 0, vocab_size_padded - vocab_size) + embedding_key = "transformer.vocab_embedding.weight" + lm_head_weight = torch.nn.functional.pad(lm_head_weight, padding, "constant", 0) + weights_dict[embedding_key] = torch.nn.functional.pad(weights_dict[embedding_key], padding, "constant", 0) + world_size = tensor_parallel_size * pipeline_parallel_size hidden_act = nemo_model_config.get('activation') @@ -161,7 +164,7 @@ def model_to_trtllm_ckpt( 'share_embedding_table': use_embedding_sharing, 'quantization': { 'quant_algo': "FP8" if nemo_model_config.get('fp8', False) else None, - 'kv_cache_quant_algo': None, + 'kv_cache_quant_algo': None, # TODO maybe "FP8", }, 'bias': nemo_model_config.get('bias'), 'apply_query_key_layer_scaling': False, diff --git a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py index e7557fc53675..462f741a8bb9 100644 --- a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py +++ b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py @@ -94,6 +94,33 @@ def rename_key_dist_ckpt(old_key: str, layer: int): return rename_key(new_key) +def load_scaling_factors(model, num_layers, tp_rank, out_dir, split_factor, storage_type, export_config): + starmap_args = [] + for key, val in model.items(): + if 'extra_state' not in key: + continue + + for i in range(num_layers): + starmap_args.append( + ( + tp_rank, + out_dir, + split_factor, + rename_key_dist_ckpt(key, i), + [val[i]], + storage_type, + None, + export_config, + {}, + ) + ) + + for starmap_arg in starmap_args: + scaling_factors = split_and_save_weight(*starmap_arg) + + return scaling_factors + + @torch.no_grad() def convert_model_to_trt_llm_ckpt( nemo_model_config, @@ -186,41 +213,24 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): handle_model_level_weights(model, 0, 0) model = extract_layers_with_prefix(model, transformer_layer_prefix) + scaling_factors = load_scaling_factors(model, num_layers, tp_rank, out_dir, split_factor, storage_type, export_config) + starmap_args = [] for key, val in model.items(): + if 'extra_state' in key: + continue + + # Let's rename/map the key to the old layer name previously. + # Since the state dict value has the full layers, let's select the ith layer weights/biases here. if len(val.size()) == 1: + key_vals = [(rename_key_dist_ckpt(key, 0), val)] + else: + key_vals = [(rename_key_dist_ckpt(key, i), val[i]) for i in range(num_layers)] + + for (k, v) in key_vals: starmap_args.append( - ( - tp_rank, - out_dir, - split_factor, - # Let's rename/map the key to the old layer name previously. You can try printing out - # the rename_key output of the old llama checkpoint and compare. - rename_key_dist_ckpt(key, 0), - # Since the state dict value has the full layers, let's select the ith layer weights/biases here. - [val], - storage_type, - None, - export_config, - ) + (tp_rank, out_dir, split_factor, k, [v], storage_type, None, export_config, scaling_factors) ) - else: - for i in range(num_layers): - starmap_args.append( - ( - tp_rank, - out_dir, - split_factor, - # Let's rename/map the key to the old layer name previously. You can try printing out - # the rename_key output of the old llama checkpoint and compare. - rename_key_dist_ckpt(key, i), - # Since the state dict value has the full layers, let's select the ith layer weights/biases here. - [val[i]], - storage_type, - None, - export_config, - ) - ) starmap_args = tqdm(starmap_args, desc="saving weights") @@ -239,6 +249,9 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): model_level_weights[key] = torch.concatenate(values, axis=0) weights_dict[key] = model_level_weights[key] + for key, value in scaling_factors.items(): + weights_dict[key] = value + return weights_dict diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index 5f58c8c6a1d3..9ce44559a30c 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -177,12 +177,12 @@ def write_int8(vals, dir, base_key, split_dim, tp_rank, split_factor, kv_cache_o def get_suffix(key): return '.' + key.split('.')[-1] -def get_layer_prefix(key): +def get_trt_llm_prefix(key): layer_num = key.split(".")[1] return f'transformer.layers.{layer_num}' def get_new_keyname(key): - layer_prefix = get_layer_prefix(key) + layer_prefix = get_trt_llm_prefix(key) if ("post_attention_layernorm.weight" in key or "post_attention_layernorm.bias" in key @@ -239,51 +239,62 @@ def get_new_keyname(key): def is_scaling_factor(key): return "scale_fwd" in key + def get_scaling_factor_keys(key): - base_key = '.'.join(key.split('.')[:-2]) + '.weight' - base_key = '.'.join(get_new_keyname(base_key).split('.')[:-1]) + weight_key = '.'.join(key.split('.')[:-2]) + '.weight' + base_key = '.'.join(get_new_keyname(weight_key).split('.')[:-1]) weight_scale = base_key + '.weights_scaling_factor' activation_scale = base_key + '.activation_scaling_factor' return weight_scale, activation_scale +first = True def handle_scaling_factor(key, val, dir, split_gated_activation): weights_key, activation_key = get_scaling_factor_keys(key) activation_factor = 1 / val[0].view(1) weights_factor = 1 / val[1].view(1) - # weights_factor_2 = 1 / val[2].view(1) + weights_factor_2 = 1 / val[2].view(1) save_val(torch_to_numpy(activation_factor), dir, activation_key) save_val(torch_to_numpy(weights_factor), dir, weights_key) - #save_val(torch_to_numpy(weights_factor_2), dir, weights_key + '_2') + # save_val(torch_to_numpy(weights_factor_2), dir, weights_key + '_2') + + # global first + # if first: + # first = False + # for i in range(32): + # save_val(torch_to_numpy(weights_factor_2), dir, f'transformer.layers.{i}.attention.kv_cache_scaling_factor') if split_gated_activation and (("mlp.dense_h_to_4h" in key) or ("mlp.linear_fc1" in key)): - layer_prefix = get_layer_prefix(key) + layer_prefix = get_trt_llm_prefix(key) mapped_key = f'{layer_prefix}.mlp.gate' save_val(torch_to_numpy(activation_factor), dir, mapped_key + '.activation_scaling_factor') save_val(torch_to_numpy(weights_factor), dir, mapped_key + '.weights_scaling_factor') - #save_val(torch_to_numpy(weights_factor_2), dir, mapped_key + '.weights_scaling_factor_2') + # save_val(torch_to_numpy(weights_factor_2), dir, mapped_key + '.weights_scaling_factor_2') global weights_dict return weights_dict -def cast_val_datatype(vals, key, storage_type, is_fp8_model): +def cast_val_datatype(vals, trt_llm_key, storage_type, is_fp8_model, scaling_factors): if is_fp8_model: fp8_storage_type = torch.float8_e4m3fn - quantized_keys = ['attention.dense', 'attention.linear', 'attention.query_key_value', 'attention.linear_qkv', 'mlp.linear', 'mlp.dense'] + quantized_keys = [ k.split('.weights_scaling_factor')[0] for k in scaling_factors.keys() if '.weights_scaling_factor' in k] for k in quantized_keys: - if k in key: + if k in trt_llm_key: storage_type = fp8_storage_type + s = scaling_factors[k + '.weights_scaling_factor'] + vals = [val.to(torch.float32) / s for val in vals] break return [val.to(storage_type) for val in vals] + # Note: in multi_query_mode, only query heads are split between multiple GPUs, while key/value head # are not split as there is only one head per key/value. @torch.no_grad() -def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_type, act_range, config): +def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_type, act_range, config, sf): use_attention_nemo_shape = config.get("use_attention_nemo_shape", False) split_gated_activation = config.get("split_gated_activation", False) num_attention_heads = config.get("num_attention_heads", 0) @@ -299,7 +310,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t return handle_scaling_factor(key, vals[0], saved_dir, split_gated_activation) save_int8 = int8_outputs == "all" or int8_outputs == "kv_cache_only" - layer_prefix = get_layer_prefix(key) + layer_prefix = get_trt_llm_prefix(key) if not isinstance(vals, list): vals = [vals] @@ -309,14 +320,13 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t if "layernorm.weight" in key and config.get("apply_layernorm_1p", False): vals = [val.float() + 1.0 for val in vals] - vals = cast_val_datatype(vals, key, storage_type, is_fp8_model) + trt_llm_key = get_new_keyname(key) + vals = cast_val_datatype(vals, trt_llm_key, storage_type, is_fp8_model, sf) if convert_on_device: assert len(vals) == 1 # Should only convert a single device param per call assert torch.is_tensor(vals[0]) elif torch.is_tensor(vals[0]): vals = [torch_to_numpy(val.cpu()) for val in vals] - - trt_llm_key = get_new_keyname(key) if ( "input_layernorm.weight" in key or "input_layernorm.bias" in key @@ -353,7 +363,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t if act_range is not None and int8_outputs == "all": base_key = trt_llm_key.replace(".weight", "") vals_i8 = generate_int8(val, act_range, multi_query_mode=multi_query_mode) - write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor) + write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor) # is cat dim always defined? elif ( "mlp.dense_h_to_4h.weight" in key @@ -462,13 +472,12 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t qkv = np.split(val, [q_num, q_num + 1], axis=2) query_groups_shape = qkv[0].shape - if len(query_groups_shape) > 1: - if (query_groups_shape[1] % split_factor) != 0: - raise Exception( - "Number of query groups of the models is {0}. Please select tensor parallelism size " - "that can split the number of query groups to equal number of query matrices in the " - "each GPU.".format(query_groups_shape[1]) - ) + if len(query_groups_shape) > 1 and ((query_groups_shape[1] % split_factor) != 0): + raise Exception( + "Number of query groups of the models is {0}. Please select tensor parallelism size " + "that can split the number of query groups to equal number of query matrices in the " + "each GPU.".format(query_groups_shape[1]) + ) q_split = np.split(qkv[0], split_factor, axis=1) k_split = np.split(qkv[1], split_factor, axis=1) @@ -538,10 +547,11 @@ def split(v, tp_size, idx, dim=0): """Splits the np tensor v on dim and return the idx's slice.""" if tp_size == 1: return v + if len(v.shape) == 1: return np.ascontiguousarray(np.split(v, tp_size)[idx]) - else: - return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx]) + + return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx]) def init_model_parallel_from_nemo(reshard_model): diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index e6df3b72994f..da148f44b989 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -77,8 +77,7 @@ def unpack_extra_state_key(key): size = int(key.split('/')[1].split('_')[-1]) return basename, size -def clear_key_basename_from_state_dict(state_dict, basename): - # '/' is important, as scaling factors are saved to basename.scaling_fwd +def clear_loaded_extra_states(state_dict, basename): to_remove = [k for k in state_dict.keys() if basename + '/' in k] for key in to_remove: state_dict.pop(key) @@ -105,7 +104,7 @@ def standarize_distributed_scaling_factors(state_dict): scaling_factors = load_scaling_factors(state_dict, basename, size) if scaling_factors != []: state_dict[basename + '.scale_fwd'] = scaling_factors - state_dict = clear_key_basename_from_state_dict(state_dict, basename) + state_dict = clear_loaded_extra_states(state_dict, basename) return state_dict @@ -138,35 +137,42 @@ def load_sharded_metadata_torch_dist(checkpoint_dir: Union[Path, TarPath], torch def load_sharded_pickle_extra_state_scale(dir): scales = [] + layer_number = 0 - i = 0 - while pt_file_list := list(dir.glob(f'shard_{i}_*.pt')): + while pt_file_list := list(dir.glob(f'shard_{layer_number}_*.pt')): pt_file = pt_file_list[0] checkpoint = torch.load(pt_file) checkpoint.seek(0) state_dict = torch.load(checkpoint) - if not 'scale_fwd' in state_dict: + if 'scale_fwd' not in state_dict: return [] scale = state_dict['scale_fwd'].cpu() scales.append(scale) - i += 1 + layer_number += 1 all_scales = torch.stack(scales) return all_scales +def contains_extra_states(subdir): + return list(subdir.glob('shard_0_*.pt')) != [] + +def load_extra_state_from_pickle(sharded_state_dict, subdir): + if scales := load_sharded_pickle_extra_state_scale(subdir): + key = subdir.name + '.scale_fwd' + sharded_state_dict[key] = scales + + return sharded_state_dict + def load_sharded_metadata_zarr(checkpoint_dir: Union[Path, TarPath], torch_tensor=True): sharded_state_dict = {} for subdir in checkpoint_dir.iterdir(): if not subdir.is_dir(): continue - key = subdir.name - if list(subdir.glob('shard_0_*.pt')): - scales = load_sharded_pickle_extra_state_scale(subdir) - if scales != []: - key = key + '.scale_fwd' - sharded_state_dict[key] = scales + if contains_extra_states(subdir): + sharded_state_dict = load_extra_state_from_pickle(sharded_state_dict, subdir) elif (subdir / '.zarray').exists(): + key = subdir.name zstore = ZarrPathStore(subdir) arr = zarr.open(zstore, 'r') @@ -179,8 +185,6 @@ def load_sharded_metadata_zarr(checkpoint_dir: Union[Path, TarPath], torch_tenso sharded_state_dict[key] = torch.from_numpy(arr[:]).view(str_dtype_to_torch(arr.dtype.name)) else: sharded_state_dict[key] = arr[:] - else: - continue return sharded_state_dict From 542d843a6dc730699972fbb72e6e8ebc9b783fd6 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Wed, 14 Aug 2024 14:46:40 -0700 Subject: [PATCH 04/21] Apply isort and black reformatting Signed-off-by: Piotr Kaminski --- .../trt_llm/converter/model_converter.py | 47 ++-- .../converter/model_to_trt_llm_ckpt.py | 34 +-- nemo/export/trt_llm/converter/utils.py | 217 ++++++++---------- .../trt_llm/nemo_ckpt_loader/nemo_file.py | 12 +- 4 files changed, 148 insertions(+), 162 deletions(-) diff --git a/nemo/export/trt_llm/converter/model_converter.py b/nemo/export/trt_llm/converter/model_converter.py index 7600224ff373..bb9b19173509 100755 --- a/nemo/export/trt_llm/converter/model_converter.py +++ b/nemo/export/trt_llm/converter/model_converter.py @@ -18,8 +18,8 @@ from typing import Dict, List, Tuple import numpy as np -import torch import tensorrt_llm +import torch from tensorrt_llm._utils import pad_vocab_size from tensorrt_llm.functional import non_gated_version from tensorrt_llm.layers import MoeConfig @@ -39,10 +39,11 @@ def get_config(decoder_type, config): if decoder_type == "llama": return LLaMAConfig(**config) - elif decoder_type == "gpt" or decoder_type == "gptnext": + + if decoder_type in ["gpt", "gptnext"]: return GPTConfig(**config) - else: - return PretrainedConfig(**config) + + return PretrainedConfig(**config) def prompt_convert(prompt_config, prompt_weights): @@ -93,8 +94,9 @@ def model_to_trtllm_ckpt( use_distributed_convert: bool = False, model_parallel_rank: int = None, vocab_size: int = None, + quantize_kv_cache: bool = False ) -> Tuple[List[Dict], List[PretrainedConfig]]: - + nemo_model_config['kv_cache'] = quantize_kv_cache if nemo_model_config.get("share_embeddings_and_output_weights", False) and not use_embedding_sharing: LOGGER.info( "Found share_embeddings_and_output_weights is True in NeMo config, set use_embedding_sharing = True" @@ -112,6 +114,8 @@ def model_to_trtllm_ckpt( ) vocab_size_padded = vocab_size else: + vocab_embedding_key = "transformer.vocab_embedding.weight" + weights_dict = convert_model_to_trt_llm_ckpt( model=model, nemo_model_config=nemo_model_config, @@ -127,15 +131,15 @@ def model_to_trtllm_ckpt( if has_lm_head: lm_head_weight = weights_dict["lm_head.weight"] if vocab_size is None: - vocab_size = weights_dict["transformer.vocab_embedding.weight"].shape[0] - vocab_size_padded = pad_vocab_size(vocab_size, tensor_parallel_size) if has_lm_head else vocab_size + vocab_size = weights_dict[vocab_embedding_key].shape[0] - if has_lm_head and vocab_size_padded != vocab_size: + vocab_size_padded = pad_vocab_size(vocab_size, tensor_parallel_size) if has_lm_head else vocab_size + if vocab_size_padded != vocab_size: padding = (0, 0, 0, vocab_size_padded - vocab_size) - embedding_key = "transformer.vocab_embedding.weight" - lm_head_weight = torch.nn.functional.pad(lm_head_weight, padding, "constant", 0) - weights_dict[embedding_key] = torch.nn.functional.pad(weights_dict[embedding_key], padding, "constant", 0) - + if has_lm_head: + lm_head_weight = torch.nn.functional.pad(lm_head_weight, padding, "constant", 0) + if vocab_embedding_key in weights_dict: + weights_dict[vocab_embedding_key] = torch.nn.functional.pad(weights_dict[vocab_embedding_key], padding, "constant", 0) world_size = tensor_parallel_size * pipeline_parallel_size hidden_act = nemo_model_config.get('activation') @@ -164,7 +168,7 @@ def model_to_trtllm_ckpt( 'share_embedding_table': use_embedding_sharing, 'quantization': { 'quant_algo': "FP8" if nemo_model_config.get('fp8', False) else None, - 'kv_cache_quant_algo': None, # TODO maybe "FP8", + 'kv_cache_quant_algo': "FP8" if quantize_kv_cache else None, }, 'bias': nemo_model_config.get('bias'), 'apply_query_key_layer_scaling': False, @@ -207,7 +211,7 @@ def model_to_trtllm_ckpt( return weights_dicts, model_configs pp_key = { - "transformer.vocab_embedding.weight", + vocab_embedding_key, "transformer.position_embedding.weight", "lm_head.weight", "transformer.ln_f.weight", @@ -232,10 +236,9 @@ def model_to_trtllm_ckpt( continue new_key = k if new_key.endswith(".bin"): # TP split - if new_key.endswith(f"{mapping.tp_rank}.bin"): - new_key = new_key.replace(f".{mapping.tp_rank}.bin", "") - else: + if not new_key.endswith(f"{mapping.tp_rank}.bin"): continue + new_key = new_key.replace(f".{mapping.tp_rank}.bin", "") if "layers" in new_key: # PP layer_num = int(new_key.split(".")[2]) if layer_num in layers_range: @@ -247,13 +250,13 @@ def model_to_trtllm_ckpt( if mapping.is_first_pp_rank(): embedding_weight = ( np.ascontiguousarray( - split(weights_dict["transformer.vocab_embedding.weight"], mapping.tp_size, mapping.tp_rank) + split(weights_dict[vocab_embedding_key], mapping.tp_size, mapping.tp_rank) ) if use_parallel_embedding - else weights_dict["transformer.vocab_embedding.weight"] + else weights_dict[vocab_embedding_key] ) - weights_dict_local["transformer.vocab_embedding.weight"] = embedding_weight + weights_dict_local[vocab_embedding_key] = embedding_weight pos_embedding_weight = weights_dict.get("transformer.position_embedding.weight") if pos_embedding_weight is not None: @@ -265,7 +268,9 @@ def model_to_trtllm_ckpt( if mapping.is_last_pp_rank(): if has_lm_head: - weights_dict_local["lm_head.weight"] = split(lm_head_weight, mapping.tp_size, mapping.tp_rank).contiguous() + weights_dict_local["lm_head.weight"] = split( + lm_head_weight, mapping.tp_size, mapping.tp_rank + ).contiguous() weights_dict_local["transformer.ln_f.weight"] = weights_dict["transformer.ln_f.weight"] ln_f_bias = weights_dict.get("transformer.ln_f.bias") diff --git a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py index 462f741a8bb9..6cb373159e17 100644 --- a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py +++ b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py @@ -25,7 +25,7 @@ from tqdm import tqdm from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision -from nemo.export.trt_llm.converter.utils import save_val, split_and_save_weight, weights_dict +from nemo.export.trt_llm.converter.utils import save_val, split_and_save_weight, load_scaling_factor, weights_dict LOGGER = logging.getLogger("NeMo") @@ -94,29 +94,18 @@ def rename_key_dist_ckpt(old_key: str, layer: int): return rename_key(new_key) -def load_scaling_factors(model, num_layers, tp_rank, out_dir, split_factor, storage_type, export_config): +def load_scaling_factors(model, num_layers, out_dir, export_config): starmap_args = [] for key, val in model.items(): if 'extra_state' not in key: continue - for i in range(num_layers): - starmap_args.append( - ( - tp_rank, - out_dir, - split_factor, - rename_key_dist_ckpt(key, i), - [val[i]], - storage_type, - None, - export_config, - {}, - ) - ) + for layer in range(num_layers): + args = (rename_key_dist_ckpt(key, layer), val[layer], out_dir, export_config) + starmap_args.append(args) for starmap_arg in starmap_args: - scaling_factors = split_and_save_weight(*starmap_arg) + scaling_factors = load_scaling_factor(*starmap_arg) return scaling_factors @@ -132,7 +121,6 @@ def convert_model_to_trt_llm_ckpt( use_parallel_embedding, processes, ): - # if checkpoints files could be found - start preparing output dir out_dir = create_export_dir(nemo_export_dir) storage_type = str_dtype_to_torch(storage_type) @@ -213,7 +201,7 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): handle_model_level_weights(model, 0, 0) model = extract_layers_with_prefix(model, transformer_layer_prefix) - scaling_factors = load_scaling_factors(model, num_layers, tp_rank, out_dir, split_factor, storage_type, export_config) + scaling_factors = load_scaling_factors(model, num_layers, out_dir, export_config) starmap_args = [] for key, val in model.items(): @@ -222,12 +210,10 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): # Let's rename/map the key to the old layer name previously. # Since the state dict value has the full layers, let's select the ith layer weights/biases here. - if len(val.size()) == 1: - key_vals = [(rename_key_dist_ckpt(key, 0), val)] - else: - key_vals = [(rename_key_dist_ckpt(key, i), val[i]) for i in range(num_layers)] + layer_vals = [(l, val[l]) for l in range(num_layers)] if len(val.size()) != 1 else [(0, val)] - for (k, v) in key_vals: + for (l, v) in layer_vals: + k = rename_key_dist_ckpt(key, l) starmap_args.append( (tp_rank, out_dir, split_factor, k, [v], storage_type, None, export_config, scaling_factors) ) diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index 9ce44559a30c..1a15e8874dc4 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -31,11 +31,29 @@ "falcon": 'FalconForCausalLM', } +post_layernorm_keys = ["post_attention_layernorm.weight", "post_attention_layernorm.bias", "post_self_attn_layernorm.weight"] +mlp_proj_bias_keys = ["mlp.linear_fc2.bias", "mlp.dense_4h_to_h.bias"] +attention_dense_bias_keys = ["attention.linear_proj.bias", "attention.dense.bias"] +input_layernorm_keys = ["input_layernorm.weight", "input_layernorm.bias"] +pre_layernorm_keys = ["pre_mlp_layernorm.weight", "pre_mlp_layernorm.bias"] +attention_dense_weight_keys = ["attention.linear_proj.weight", "attention.dense.weight"] +mlp_proj_weight_keys = ["mlp.linear_fc2.weight", "mlp.dense_4h_to_h.weight"] +mlp_fc_keys = ["mlp.dense_h_to_4h.weight", "mlp.dense_h_to_4h.bias", "mlp.linear_fc1.weight", "mlp.linear_fc1.bias"] +attention_qkv_bias_keys = ["attention.query_key_value.bias", "attention.linear_qkv.bias"] +attention_qkv_weight_keys = ["attention.query_key_value.weight", "attention.linear_qkv.weight"] +mlp_router_keys = ["mlp.router.weight"] +mlp_fc_expert_keys = ["experts.linear_fc1.weight"] +mlp_proj_experts_keys = ["experts.linear_fc2.weight"] +final_layernorm_keys = ["final_layernorm.weight", "final_layernorm.bias"] +mlp_dense_2_keys = ["mlp.dense_h_to_4h_2.weight", "mlp.dense_h_to_4h_2.bias"] +attention_not_mapped_keys = ["attention.query.weight", "attention.query.bias", "attention.key_value.weight", "attention.key_value.bias"] + def save_val(val, dir, key, tp_num=None): - suffix = "" if tp_num is None else f".{tp_num}.bin" - global weights_dict + if tp_num: + key += f".{tp_num}.bin" + global weights_dict # Transpose linear layer weights to the correct shape. if torch.is_tensor(val): val = val.detach().contiguous() @@ -43,14 +61,14 @@ def save_val(val, dir, key, tp_num=None): val = val.reshape(val.shape[0], -1) val = torch.transpose(val, 0, 1) if key not in weights_dict: - weights_dict[f"{key}{suffix}"] = torch.empty( + weights_dict[key] = torch.empty( val.size(), dtype=val.dtype, layout=val.layout, device="cpu", pin_memory=True ) - weights_dict[f"{key}{suffix}"].copy_(val, non_blocking=True) + weights_dict[key].copy_(val, non_blocking=True) else: if len(val.shape) >= 2: val = np.ascontiguousarray(np.transpose(val.reshape(val.shape[0], -1), [1, 0])) - weights_dict[f"{key}{suffix}"] = val + weights_dict[key] = val def save_split(split_vals, dir, key, i, split_factor): @@ -59,12 +77,13 @@ def save_split(split_vals, dir, key, i, split_factor): def save_expert_split(split_vals, dir, key, i, split_factor): + if tp_num: + key += f".{tp_num}.bin" + for j, val in enumerate(split_vals): tp_num = i * split_factor + j - suffix = "" if tp_num is None else f".{tp_num}.bin" - global weights_dict - weights_dict[f"{key}{suffix}"] = val + weights_dict[key] = val def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False): @@ -177,61 +196,47 @@ def write_int8(vals, dir, base_key, split_dim, tp_rank, split_factor, kv_cache_o def get_suffix(key): return '.' + key.split('.')[-1] + def get_trt_llm_prefix(key): layer_num = key.split(".")[1] return f'transformer.layers.{layer_num}' -def get_new_keyname(key): - layer_prefix = get_trt_llm_prefix(key) - if ("post_attention_layernorm.weight" in key - or "post_attention_layernorm.bias" in key - or "post_self_attn_layernorm.weight" in key): - return f'{layer_prefix}.post_layernorm' + get_suffix(key) +def any_word_in_key(key, words): + return any([word in key for word in words]) - if "mlp.linear_fc2.bias" in key or "mlp.dense_4h_to_h.bias" in key: - return f'{layer_prefix}.mlp.proj.bias' - if "attention.linear_proj.bias" in key or "attention.dense.bias" in key: - return f'{layer_prefix}.attention.dense.bias' - - if "final_layernorm.weight" in key or "final_layernorm.bias" in key: - return key.replace("final_layernorm", "transformer.ln_f") +def sequential_key_map(key, mapping): + for (keywords, mapped) in mapping: + if any_word_in_key(key, keywords): + return mapped - if "input_layernorm.weight" in key or "input_layernorm.bias" in key: - return f'{layer_prefix}.input_layernorm' + get_suffix(key) + return None - if "pre_mlp_layernorm.weight" in key or "pre_mlp_layernorm.bias" in key: - return f'{layer_prefix}.post_layernorm' + get_suffix(key) +def get_trt_llm_infix(key): + mapping = [ + (post_layernorm_keys, '.post_layernorm'), + (mlp_proj_bias_keys, '.mlp.proj'), + (attention_dense_bias_keys, '.attention.dense'), + (input_layernorm_keys, '.input_layernorm'), + (pre_layernorm_keys, '.post_layernorm'), + (attention_dense_weight_keys, '.attention.dense'), + (mlp_proj_weight_keys, '.mlp.proj'), + (mlp_fc_keys, '.mlp.fc'), + (attention_qkv_bias_keys + attention_qkv_weight_keys, '.attention.qkv'), + (mlp_router_keys, '.mlp.router'), + (mlp_fc_keys, '.mlp.fc'), + (mlp_proj_experts_keys, '.mlp.proj') + ] + return sequential_key_map(key, mapping) - if "attention.linear_proj.weight" in key or "attention.dense.weight" in key: - return f'{layer_prefix}.attention.dense.weight' - if "mlp.linear_fc2.weight" in key or "mlp.dense_4h_to_h.weight" in key: - return f'{layer_prefix}.mlp.proj.weight' - - if ( - "mlp.dense_h_to_4h.weight" in key - or "mlp.dense_h_to_4h.bias" in key - or "mlp.linear_fc1.weight" in key - or "mlp.linear_fc1.bias" in key - ): - return f'{layer_prefix}.mlp.fc' + get_suffix(key) - - if "attention.query_key_value.bias" in key or "attention.linear_qkv.bias" in key: - return f'{layer_prefix}.attention.qkv.bias' - - if "attention.query_key_value.weight" in key or "attention.linear_qkv.weight" in key: - return f'{layer_prefix}.attention.qkv.weight' - - if "mlp.router.weight" in key: - return f'{layer_prefix}.mlp.router.weight' - - if "experts.linear_fc1.weight" in key: - return f'{layer_prefix}.mlp.fc.weight' +def get_new_keyname(key): + if any_word_in_key(key, final_layernorm_keys): + return key.replace("final_layernorm", "transformer.ln_f") - if "experts.linear_fc2.weight" in key: - return f'{layer_prefix}.mlp.proj.weight' + if infix := get_trt_llm_infix(key): + return get_trt_llm_prefix(key) + infix + get_suffix(key) return key @@ -249,23 +254,28 @@ def get_scaling_factor_keys(key): first = True -def handle_scaling_factor(key, val, dir, split_gated_activation): - weights_key, activation_key = get_scaling_factor_keys(key) +def load_scaling_factor(key, val, dir, config): + global weights_dict + if not is_scaling_factor(key): + return weights_dict activation_factor = 1 / val[0].view(1) weights_factor = 1 / val[1].view(1) weights_factor_2 = 1 / val[2].view(1) + weights_key, activation_key = get_scaling_factor_keys(key) save_val(torch_to_numpy(activation_factor), dir, activation_key) save_val(torch_to_numpy(weights_factor), dir, weights_key) - # save_val(torch_to_numpy(weights_factor_2), dir, weights_key + '_2') + # TODO + # save_val(torch_to_numpy(weights_factor_2), dir, weights_key + '_2') # global first # if first: # first = False # for i in range(32): # save_val(torch_to_numpy(weights_factor_2), dir, f'transformer.layers.{i}.attention.kv_cache_scaling_factor') + split_gated_activation = config.get("split_gated_activation", False) if split_gated_activation and (("mlp.dense_h_to_4h" in key) or ("mlp.linear_fc1" in key)): layer_prefix = get_trt_llm_prefix(key) mapped_key = f'{layer_prefix}.mlp.gate' @@ -273,24 +283,33 @@ def handle_scaling_factor(key, val, dir, split_gated_activation): save_val(torch_to_numpy(weights_factor), dir, mapped_key + '.weights_scaling_factor') # save_val(torch_to_numpy(weights_factor_2), dir, mapped_key + '.weights_scaling_factor_2') - global weights_dict return weights_dict def cast_val_datatype(vals, trt_llm_key, storage_type, is_fp8_model, scaling_factors): - if is_fp8_model: - fp8_storage_type = torch.float8_e4m3fn - quantized_keys = [ k.split('.weights_scaling_factor')[0] for k in scaling_factors.keys() if '.weights_scaling_factor' in k] - for k in quantized_keys: - if k in trt_llm_key: - storage_type = fp8_storage_type - s = scaling_factors[k + '.weights_scaling_factor'] - vals = [val.to(torch.float32) / s for val in vals] - break + if not is_fp8_model: + return [val.to(storage_type) for val in vals] + + fp8_storage_type = torch.float8_e4m3fn + quantized_keys = [ k.split('.weights_scaling_factor')[0] for k in scaling_factors.keys() if '.weights_scaling_factor' in k ] + for k in quantized_keys: + if k in trt_llm_key: + storage_type = fp8_storage_type + scale = scaling_factors[k + '.weights_scaling_factor'] + vals = [val.to(torch.float32) / scale for val in vals] + break return [val.to(storage_type) for val in vals] +def split_val_gate(vals, convert_on_device): + if convert_on_device: + return [[n] for n in torch.chunk(vals[0], 2, axis=-1)] + + splits = [np.split(val, 2, axis=-1) for val in vals] + return list(zip(*splits)) + + # Note: in multi_query_mode, only query heads are split between multiple GPUs, while key/value head # are not split as there is only one head per key/value. @torch.no_grad() @@ -306,9 +325,6 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t convert_on_device = config.get("convert_on_device", False) is_fp8_model = config.get("fp8", False) - if is_scaling_factor(key): - return handle_scaling_factor(key, vals[0], saved_dir, split_gated_activation) - save_int8 = int8_outputs == "all" or int8_outputs == "kv_cache_only" layer_prefix = get_trt_llm_prefix(key) @@ -327,31 +343,13 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t assert torch.is_tensor(vals[0]) elif torch.is_tensor(vals[0]): vals = [torch_to_numpy(val.cpu()) for val in vals] - if ( - "input_layernorm.weight" in key - or "input_layernorm.bias" in key - or "pre_mlp_layernorm.weight" in key - or "pre_mlp_layernorm.bias" in key - or "attention.dense.bias" in key - or "attention.linear_proj.bias" in key - or "post_attention_layernorm.weight" in key - or "post_attention_layernorm.bias" in key - or "post_self_attn_layernorm.weight" in key - or "mlp.dense_4h_to_h.bias" in key - or "mlp.linear_fc2.bias" in key - or "final_layernorm.weight" in key - or "final_layernorm.bias" in key - ): + + if (any_word_in_key(key, input_layernorm_keys + pre_layernorm_keys + attention_dense_bias_keys + post_layernorm_keys + mlp_proj_bias_keys + final_layernorm_keys) + and (tp_rank == 0 or convert_on_device)): # shared weights, only need to convert the weights of rank 0 - if tp_rank == 0 or convert_on_device: - save_val(vals[0], saved_dir, trt_llm_key) + save_val(vals[0], saved_dir, trt_llm_key) - elif ( - "attention.dense.weight" in key - or "mlp.dense_4h_to_h.weight" in key - or "attention.linear_proj.weight" in key - or "mlp.linear_fc2.weight" in key - ): + elif any_word_in_key(key, attention_dense_weight_keys + mlp_proj_weight_keys): if convert_on_device: save_val(vals[0], saved_dir, trt_llm_key) else: @@ -363,20 +361,11 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t if act_range is not None and int8_outputs == "all": base_key = trt_llm_key.replace(".weight", "") vals_i8 = generate_int8(val, act_range, multi_query_mode=multi_query_mode) - write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor) # is cat dim always defined? - - elif ( - "mlp.dense_h_to_4h.weight" in key - or "mlp.dense_h_to_4h.bias" in key - or "mlp.linear_fc1.weight" in key - or "mlp.linear_fc1.bias" in key - ): + write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor) # TODO is cat dim always defined? + + elif any_word_in_key(key, mlp_fc_keys): if split_gated_activation: - if convert_on_device: - vals, gates = [[n] for n in torch.chunk(vals[0], 2, axis=-1)] - else: - splits = [np.split(val, 2, axis=-1) for val in vals] - vals, gates = list(zip(*splits)) + vals, gates = split_val_gate(vals, convert_on_device) if convert_on_device: save_val(vals[0], saved_dir, trt_llm_key) @@ -401,7 +390,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t split_vals = np.split(gate, split_factor, axis=cat_dim) save_split(split_vals, saved_dir, gate_key, tp_rank, split_factor) - elif "mlp.dense_h_to_4h_2.weight" in key or "mlp.dense_h_to_4h_2.bias" in key: + elif any_word_in_key(key, mlp_dense_2_keys): if convert_on_device: save_val(vals[0], saved_dir, trt_llm_key) else: @@ -415,7 +404,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t vals_i8 = generate_int8(val, act_range, multi_query_mode=multi_query_mode) write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor) - elif "attention.query_key_value.bias" in key or "attention.linear_qkv.bias" in key: + elif any_word_in_key(key, attention_qkv_bias_keys): qkv_hidden_dim = vals[0].shape[0] size_per_head = qkv_hidden_dim // (num_attention_heads + 2 * num_kv_heads) q_num = num_attention_heads // num_kv_heads @@ -446,7 +435,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t ] save_split(split_vals, saved_dir, trt_llm_key, tp_rank, split_factor) - elif "attention.query_key_value.weight" in key or "attention.linear_qkv.weight" in key: + elif any_word_in_key(key, attention_qkv_weight_keys): assert use_attention_nemo_shape, "Only support NEMO shape for QKV weights" hidden_dim = vals[0].shape[0] if size_per_head is None: @@ -509,17 +498,15 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t split_factor, kv_cache_only=int8_outputs == "kv_cache_only", ) - elif ( - "attention.query.weight" in key - or "attention.query.bias" in key - or "attention.key_value.weight" in key - or "attention.key_value.bias" in key - ): + + elif any_word_in_key(key, attention_not_mapped_keys): pass - elif "mlp.router.weight" in key: + + elif any_word_in_key(key, mlp_router_keys): val = np.concatenate(vals, axis=1) save_val(val, saved_dir, trt_llm_key) - elif "experts.linear_fc1.weight" in key: + + elif any_word_in_key(key, mlp_fc_expert_keys): cat_dim = -1 val = np.concatenate(vals, axis=cat_dim) w1, w3 = np.split(val, 2, axis=1) @@ -531,7 +518,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t split_vals = [np.concatenate(item, axis=1) for item in zip(split_w3s, split_w1s)] save_expert_split(split_vals, saved_dir, trt_llm_key, tp_rank, split_factor) - elif "experts.linear_fc2.weight" in key: + elif any_word_in_key(key, mlp_proj_experts_keys): cat_dim = -1 val = np.concatenate(vals, axis=cat_dim) split_vals = np.split(val, split_factor, axis=cat_dim) diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index da148f44b989..2357c8a57269 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -27,7 +27,7 @@ import zarr from tensorrt_llm._utils import np_bfloat16 from torch.distributed.checkpoint import FileSystemReader -from torch.distributed.checkpoint.metadata import TensorStorageMetadata, BytesStorageMetadata +from torch.distributed.checkpoint.metadata import BytesStorageMetadata, TensorStorageMetadata from torch.distributed.checkpoint.state_dict_loader import load_state_dict from transformers import AutoTokenizer, PreTrainedTokenizer @@ -72,17 +72,20 @@ def get_extra_state_key(state_dict): return key return False + def unpack_extra_state_key(key): basename = key.split('/')[0] size = int(key.split('/')[1].split('_')[-1]) return basename, size + def clear_loaded_extra_states(state_dict, basename): to_remove = [k for k in state_dict.keys() if basename + '/' in k] for key in to_remove: state_dict.pop(key) return state_dict + def load_scaling_factors(state_dict, basename, size): scales = [] for layer in range(size): @@ -98,6 +101,7 @@ def load_scaling_factors(state_dict, basename, size): all_scales = torch.stack(scales) return all_scales + def standarize_distributed_scaling_factors(state_dict): while key := get_extra_state_key(state_dict): basename, size = unpack_extra_state_key(key) @@ -105,7 +109,7 @@ def standarize_distributed_scaling_factors(state_dict): if scaling_factors != []: state_dict[basename + '.scale_fwd'] = scaling_factors state_dict = clear_loaded_extra_states(state_dict, basename) - + return state_dict @@ -153,9 +157,11 @@ def load_sharded_pickle_extra_state_scale(dir): all_scales = torch.stack(scales) return all_scales + def contains_extra_states(subdir): return list(subdir.glob('shard_0_*.pt')) != [] + def load_extra_state_from_pickle(sharded_state_dict, subdir): if scales := load_sharded_pickle_extra_state_scale(subdir): key = subdir.name + '.scale_fwd' @@ -163,6 +169,7 @@ def load_extra_state_from_pickle(sharded_state_dict, subdir): return sharded_state_dict + def load_sharded_metadata_zarr(checkpoint_dir: Union[Path, TarPath], torch_tensor=True): sharded_state_dict = {} for subdir in checkpoint_dir.iterdir(): @@ -182,6 +189,7 @@ def load_sharded_metadata_zarr(checkpoint_dir: Union[Path, TarPath], torch_tenso sharded_state_dict[key] = torch.from_numpy(arr[:].view(np.int16)).view(torch.bfloat16) else: from tensorrt_llm._utils import str_dtype_to_torch + sharded_state_dict[key] = torch.from_numpy(arr[:]).view(str_dtype_to_torch(arr.dtype.name)) else: sharded_state_dict[key] = arr[:] From 042d325c62eec21d547b9f0419ca6a13335aaac3 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Wed, 14 Aug 2024 14:47:09 -0700 Subject: [PATCH 05/21] Apply isort and black reformatting Signed-off-by: Piotr Kaminski --- .../trt_llm/converter/model_converter.py | 10 +++--- .../converter/model_to_trt_llm_ckpt.py | 4 +-- nemo/export/trt_llm/converter/utils.py | 35 +++++++++++++++---- 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/nemo/export/trt_llm/converter/model_converter.py b/nemo/export/trt_llm/converter/model_converter.py index bb9b19173509..6a7fe25ba824 100755 --- a/nemo/export/trt_llm/converter/model_converter.py +++ b/nemo/export/trt_llm/converter/model_converter.py @@ -94,7 +94,7 @@ def model_to_trtllm_ckpt( use_distributed_convert: bool = False, model_parallel_rank: int = None, vocab_size: int = None, - quantize_kv_cache: bool = False + quantize_kv_cache: bool = False, ) -> Tuple[List[Dict], List[PretrainedConfig]]: nemo_model_config['kv_cache'] = quantize_kv_cache if nemo_model_config.get("share_embeddings_and_output_weights", False) and not use_embedding_sharing: @@ -139,7 +139,9 @@ def model_to_trtllm_ckpt( if has_lm_head: lm_head_weight = torch.nn.functional.pad(lm_head_weight, padding, "constant", 0) if vocab_embedding_key in weights_dict: - weights_dict[vocab_embedding_key] = torch.nn.functional.pad(weights_dict[vocab_embedding_key], padding, "constant", 0) + weights_dict[vocab_embedding_key] = torch.nn.functional.pad( + weights_dict[vocab_embedding_key], padding, "constant", 0 + ) world_size = tensor_parallel_size * pipeline_parallel_size hidden_act = nemo_model_config.get('activation') @@ -249,9 +251,7 @@ def model_to_trtllm_ckpt( if mapping.is_first_pp_rank(): embedding_weight = ( - np.ascontiguousarray( - split(weights_dict[vocab_embedding_key], mapping.tp_size, mapping.tp_rank) - ) + np.ascontiguousarray(split(weights_dict[vocab_embedding_key], mapping.tp_size, mapping.tp_rank)) if use_parallel_embedding else weights_dict[vocab_embedding_key] ) diff --git a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py index 6cb373159e17..7886c2221566 100644 --- a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py +++ b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py @@ -25,7 +25,7 @@ from tqdm import tqdm from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision -from nemo.export.trt_llm.converter.utils import save_val, split_and_save_weight, load_scaling_factor, weights_dict +from nemo.export.trt_llm.converter.utils import load_scaling_factor, save_val, split_and_save_weight, weights_dict LOGGER = logging.getLogger("NeMo") @@ -212,7 +212,7 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): # Since the state dict value has the full layers, let's select the ith layer weights/biases here. layer_vals = [(l, val[l]) for l in range(num_layers)] if len(val.size()) != 1 else [(0, val)] - for (l, v) in layer_vals: + for l, v in layer_vals: k = rename_key_dist_ckpt(key, l) starmap_args.append( (tp_rank, out_dir, split_factor, k, [v], storage_type, None, export_config, scaling_factors) diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index 1a15e8874dc4..cfde6a359bdf 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -31,7 +31,11 @@ "falcon": 'FalconForCausalLM', } -post_layernorm_keys = ["post_attention_layernorm.weight", "post_attention_layernorm.bias", "post_self_attn_layernorm.weight"] +post_layernorm_keys = [ + "post_attention_layernorm.weight", + "post_attention_layernorm.bias", + "post_self_attn_layernorm.weight", +] mlp_proj_bias_keys = ["mlp.linear_fc2.bias", "mlp.dense_4h_to_h.bias"] attention_dense_bias_keys = ["attention.linear_proj.bias", "attention.dense.bias"] input_layernorm_keys = ["input_layernorm.weight", "input_layernorm.bias"] @@ -46,7 +50,12 @@ mlp_proj_experts_keys = ["experts.linear_fc2.weight"] final_layernorm_keys = ["final_layernorm.weight", "final_layernorm.bias"] mlp_dense_2_keys = ["mlp.dense_h_to_4h_2.weight", "mlp.dense_h_to_4h_2.bias"] -attention_not_mapped_keys = ["attention.query.weight", "attention.query.bias", "attention.key_value.weight", "attention.key_value.bias"] +attention_not_mapped_keys = [ + "attention.query.weight", + "attention.query.bias", + "attention.key_value.weight", + "attention.key_value.bias", +] def save_val(val, dir, key, tp_num=None): @@ -207,12 +216,13 @@ def any_word_in_key(key, words): def sequential_key_map(key, mapping): - for (keywords, mapped) in mapping: + for keywords, mapped in mapping: if any_word_in_key(key, keywords): return mapped return None + def get_trt_llm_infix(key): mapping = [ (post_layernorm_keys, '.post_layernorm'), @@ -226,7 +236,7 @@ def get_trt_llm_infix(key): (attention_qkv_bias_keys + attention_qkv_weight_keys, '.attention.qkv'), (mlp_router_keys, '.mlp.router'), (mlp_fc_keys, '.mlp.fc'), - (mlp_proj_experts_keys, '.mlp.proj') + (mlp_proj_experts_keys, '.mlp.proj'), ] return sequential_key_map(key, mapping) @@ -254,6 +264,8 @@ def get_scaling_factor_keys(key): first = True + + def load_scaling_factor(key, val, dir, config): global weights_dict if not is_scaling_factor(key): @@ -291,7 +303,9 @@ def cast_val_datatype(vals, trt_llm_key, storage_type, is_fp8_model, scaling_fac return [val.to(storage_type) for val in vals] fp8_storage_type = torch.float8_e4m3fn - quantized_keys = [ k.split('.weights_scaling_factor')[0] for k in scaling_factors.keys() if '.weights_scaling_factor' in k ] + quantized_keys = [ + k.split('.weights_scaling_factor')[0] for k in scaling_factors.keys() if '.weights_scaling_factor' in k + ] for k in quantized_keys: if k in trt_llm_key: storage_type = fp8_storage_type @@ -344,8 +358,15 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t elif torch.is_tensor(vals[0]): vals = [torch_to_numpy(val.cpu()) for val in vals] - if (any_word_in_key(key, input_layernorm_keys + pre_layernorm_keys + attention_dense_bias_keys + post_layernorm_keys + mlp_proj_bias_keys + final_layernorm_keys) - and (tp_rank == 0 or convert_on_device)): + if any_word_in_key( + key, + input_layernorm_keys + + pre_layernorm_keys + + attention_dense_bias_keys + + post_layernorm_keys + + mlp_proj_bias_keys + + final_layernorm_keys, + ) and (tp_rank == 0 or convert_on_device): # shared weights, only need to convert the weights of rank 0 save_val(vals[0], saved_dir, trt_llm_key) From 76535b4ef05c059a6ab0f94728f828cb487939bf Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Fri, 16 Aug 2024 09:14:10 -0700 Subject: [PATCH 06/21] fixed zarr loading, added flags, refactor Signed-off-by: Piotr Kaminski --- nemo/export/tarutils.py | 3 +- nemo/export/tensorrt_llm.py | 4 + .../trt_llm/converter/model_converter.py | 50 +++++--- .../converter/model_to_trt_llm_ckpt.py | 75 +++++------- nemo/export/trt_llm/converter/utils.py | 110 +++++++++--------- .../trt_llm/nemo_ckpt_loader/nemo_file.py | 41 ++++--- scripts/export/export_to_trt_llm.py | 16 +++ tests/export/nemo_export.py | 28 +++++ 8 files changed, 190 insertions(+), 137 deletions(-) diff --git a/nemo/export/tarutils.py b/nemo/export/tarutils.py index b9af03e5bbb6..55084af68110 100644 --- a/nemo/export/tarutils.py +++ b/nemo/export/tarutils.py @@ -58,8 +58,7 @@ def __truediv__(self, key) -> 'TarPath': def __str__(self) -> str: return os.path.join(self._tar.name, self._relpath) - def __fspath__(self): - return os.path.join(self._tar.name, self._relpath) + def __fspath__(self): return str(self) @property def tarobject(self): diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index e430acc9e1b8..731ad93d392e 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -173,6 +173,8 @@ def export( multiple_profiles: bool = False, gpt_attention_plugin: str = "auto", gemm_plugin: str = "auto", + fp8_quantized: bool = False, + fp8_kvcache: bool = False, ): """ Exports nemo checkpoints to TensorRT-LLM. @@ -322,6 +324,8 @@ def export( gpus_per_node=gpus_per_node, use_parallel_embedding=use_parallel_embedding, use_embedding_sharing=use_embedding_sharing, + fp8_quantized=fp8_quantized, + fp8_kvcache=fp8_kvcache ) for weight_dict, model_config in zip(weights_dicts, model_configs): diff --git a/nemo/export/trt_llm/converter/model_converter.py b/nemo/export/trt_llm/converter/model_converter.py index 6a7fe25ba824..875f0ef404f9 100755 --- a/nemo/export/trt_llm/converter/model_converter.py +++ b/nemo/export/trt_llm/converter/model_converter.py @@ -80,6 +80,21 @@ def prompt_convert(prompt_config, prompt_weights): return vtokens_embeddings +def create_common_export_config(nemo_model_config, decoder_type, fp8_quantized=False, fp8_kvcache=False): + is_mcore = nemo_model_config.get("mcore_gpt", False) + return { + "apply_layernorm_1p": nemo_model_config.get("normalization", "") == "layernorm1p", + "split_gated_activation": nemo_model_config.get("activation", "gelu") + in ["swiglu", "geglu", "fast-swiglu", "fast-geglu"] + and (decoder_type == "gptnext" or is_mcore), + "num_attention_heads": nemo_model_config["num_attention_heads"], + "use_attention_nemo_shape": True, + "transpose_weights": True, + "fp8_quantized": fp8_quantized, + "fp8_kvcache": fp8_kvcache, + } + + def model_to_trtllm_ckpt( model, nemo_model_config, @@ -93,16 +108,17 @@ def model_to_trtllm_ckpt( use_embedding_sharing: bool = False, use_distributed_convert: bool = False, model_parallel_rank: int = None, - vocab_size: int = None, - quantize_kv_cache: bool = False, + vocab_size: int | None = None, + fp8_quantized: bool = False, + fp8_kvcache: bool = False ) -> Tuple[List[Dict], List[PretrainedConfig]]: - nemo_model_config['kv_cache'] = quantize_kv_cache if nemo_model_config.get("share_embeddings_and_output_weights", False) and not use_embedding_sharing: LOGGER.info( "Found share_embeddings_and_output_weights is True in NeMo config, set use_embedding_sharing = True" ) use_embedding_sharing = True + export_config = create_common_export_config(nemo_model_config, decoder_type, fp8_quantized, fp8_kvcache) # If the model has been sharded with model parallelism, convert the model in a gpu-distributed manner if use_distributed_convert: weights_dict = dist_model_to_trt_llm_ckpt( @@ -111,6 +127,7 @@ def model_to_trtllm_ckpt( inference_tp_size=tensor_parallel_size, inference_pp_size=pipeline_parallel_size, tokenizer_vocab_size=vocab_size, + export_config=export_config, ) vocab_size_padded = vocab_size else: @@ -124,24 +141,23 @@ def model_to_trtllm_ckpt( processes=1, storage_type=dtype, use_parallel_embedding=use_parallel_embedding, - decoder_type=decoder_type, + export_config=export_config, ) - has_lm_head = "lm_head.weight" in weights_dict - if has_lm_head: - lm_head_weight = weights_dict["lm_head.weight"] if vocab_size is None: vocab_size = weights_dict[vocab_embedding_key].shape[0] + has_lm_head = "lm_head.weight" in weights_dict vocab_size_padded = pad_vocab_size(vocab_size, tensor_parallel_size) if has_lm_head else vocab_size - if vocab_size_padded != vocab_size: - padding = (0, 0, 0, vocab_size_padded - vocab_size) - if has_lm_head: - lm_head_weight = torch.nn.functional.pad(lm_head_weight, padding, "constant", 0) - if vocab_embedding_key in weights_dict: - weights_dict[vocab_embedding_key] = torch.nn.functional.pad( - weights_dict[vocab_embedding_key], padding, "constant", 0 - ) + padding = (0, 0, 0, vocab_size_padded - vocab_size) + if has_lm_head: + lm_head_weight = weights_dict["lm_head.weight"] + lm_head_weight = torch.nn.functional.pad(lm_head_weight, padding, "constant", 0) + + if vocab_embedding_key in weights_dict: + weights_dict[vocab_embedding_key] = torch.nn.functional.pad( + weights_dict[vocab_embedding_key], padding, "constant", 0 + ) world_size = tensor_parallel_size * pipeline_parallel_size hidden_act = nemo_model_config.get('activation') @@ -169,8 +185,8 @@ def model_to_trtllm_ckpt( 'embedding_sharding_dim': 0, 'share_embedding_table': use_embedding_sharing, 'quantization': { - 'quant_algo': "FP8" if nemo_model_config.get('fp8', False) else None, - 'kv_cache_quant_algo': "FP8" if quantize_kv_cache else None, + 'quant_algo': "FP8" if fp8_quantized else None, + 'kv_cache_quant_algo': "FP8" if fp8_kvcache else None, }, 'bias': nemo_model_config.get('bias'), 'apply_query_key_layer_scaling': False, diff --git a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py index 7886c2221566..e77efc3d3a6e 100644 --- a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py +++ b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py @@ -25,7 +25,7 @@ from tqdm import tqdm from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision -from nemo.export.trt_llm.converter.utils import load_scaling_factor, save_val, split_and_save_weight, weights_dict +from nemo.export.trt_llm.converter.utils import save_scaling_factor, save_val, split_and_save_weight, weights_dict LOGGER = logging.getLogger("NeMo") @@ -69,7 +69,7 @@ def get_layer_prefix(layer_names, is_mcore): return model_prefix, transformer_layer_prefix -def rename_key(new_key: str): +def rename_key(new_key: str) -> str: if "self_attention" in new_key: new_key = new_key.replace("self_attention", "attention") if "attention.linear_qkv.layer_norm_weight" in new_key: @@ -84,7 +84,7 @@ def rename_key(new_key: str): return new_key -def rename_key_dist_ckpt(old_key: str, layer: int): +def rename_key_dist_ckpt(old_key: str, layer: int) -> str: new_key = old_key if "layers." in old_key: split_key = old_key.split(".") @@ -94,18 +94,18 @@ def rename_key_dist_ckpt(old_key: str, layer: int): return rename_key(new_key) -def load_scaling_factors(model, num_layers, out_dir, export_config): - starmap_args = [] +def load_scaling_factors(model: dict, num_layers: int, export_config: dict) -> dict: + if not export_config.get('fp8_quantized', False): + return {} + + scaling_factors = {} for key, val in model.items(): if 'extra_state' not in key: continue for layer in range(num_layers): - args = (rename_key_dist_ckpt(key, layer), val[layer], out_dir, export_config) - starmap_args.append(args) - - for starmap_arg in starmap_args: - scaling_factors = load_scaling_factor(*starmap_arg) + renamed_key = rename_key_dist_ckpt(key, layer) + scaling_factors = save_scaling_factor(scaling_factors, renamed_key, val[layer], export_config) return scaling_factors @@ -117,9 +117,9 @@ def convert_model_to_trt_llm_ckpt( nemo_export_dir, storage_type, inference_tp_size, - decoder_type, use_parallel_embedding, processes, + export_config ): # if checkpoints files could be found - start preparing output dir out_dir = create_export_dir(nemo_export_dir) @@ -133,9 +133,6 @@ def convert_model_to_trt_llm_ckpt( has_position_embedding = get_layer_name("position_embedding", prefix) in model_state_dict has_lm_head = get_layer_name("output_layer", prefix) in model_state_dict - share_embeddings_and_output = nemo_model_config.get("share_embeddings_and_output_weights", False) - embedding_scaling = nemo_model_config.get("apply_embedding_scaling", False) - hidden_size = nemo_model_config["hidden_size"] num_layers = nemo_model_config["num_layers"] training_tp_size = 1 @@ -143,7 +140,6 @@ def convert_model_to_trt_llm_ckpt( num_kv_heads = nemo_model_config.get("num_query_groups", 0) multi_query_mode = nemo_model_config.get("multi_query_mode", False) num_attention_heads = nemo_model_config["num_attention_heads"] - kv_channels = nemo_model_config.get("kv_channels", None) if num_kv_heads == 0: if multi_query_mode: @@ -151,20 +147,14 @@ def convert_model_to_trt_llm_ckpt( else: num_kv_heads = num_attention_heads - export_config = { - "apply_layernorm_1p": nemo_model_config.get("normalization", "") == "layernorm1p", - "tp_size": training_tp_size, - "split_gated_activation": nemo_model_config.get("activation", "gelu") - in ["swiglu", "geglu", "fast-swiglu", "fast-geglu"] - and (decoder_type == "gptnext" or is_mcore), - "num_attention_heads": num_attention_heads, - "num_kv_heads": num_kv_heads, - "kv_channels": kv_channels, - "use_attention_nemo_shape": True, - "transpose_weights": True, - "use_parallel_embedding": use_parallel_embedding, - "fp8": nemo_model_config.get('fp8', False), - } + export_config.update( + { + "tp_size": training_tp_size, + "num_kv_heads": num_kv_heads, + "kv_channels": nemo_model_config.get("kv_channels", None), + "use_parallel_embedding": use_parallel_embedding, + } + ) # split_factor: in how many parts a TP training node is split split_factor = inference_tp_size @@ -200,8 +190,7 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): handle_model_level_weights(model, 0, 0) model = extract_layers_with_prefix(model, transformer_layer_prefix) - - scaling_factors = load_scaling_factors(model, num_layers, out_dir, export_config) + scaling_factors = load_scaling_factors(model, num_layers, export_config) starmap_args = [] for key, val in model.items(): @@ -235,9 +224,7 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): model_level_weights[key] = torch.concatenate(values, axis=0) weights_dict[key] = model_level_weights[key] - for key, value in scaling_factors.items(): - weights_dict[key] = value - + weights_dict.update(scaling_factors) return weights_dict @@ -268,6 +255,7 @@ def dist_model_to_trt_llm_ckpt( inference_tp_size, inference_pp_size, tokenizer_vocab_size, + export_config ): from megatron.core import parallel_state from megatron.core.tensor_parallel.utils import VocabUtility @@ -303,18 +291,13 @@ def dist_model_to_trt_llm_ckpt( prefix, transformer_layer_prefix = get_layer_prefix(sample_state_dict, is_mcore) assert is_mcore, "Only megatron-core inflight model conversion is supported" - export_config = { - "apply_layernorm_1p": nemo_model_config.get("normalization", "") == "layernorm1p", - "tp_size": tp_size, - "split_gated_activation": nemo_model_config.get("activation", "gelu") - in ["swiglu", "geglu", "fast-swiglu", "fast-geglu"], - "num_attention_heads": nemo_model_config["num_attention_heads"], - "num_kv_heads": nemo_model_config.get('num_query_groups', nemo_model_config['num_attention_heads']), - "convert_on_device": True, - "use_attention_nemo_shape": True, - "transpose_weights": True, - "fp8": nemo_model_config.get('fp8', False), - } + export_config.update( + { + "tp_size": tp_size, + "num_kv_heads": nemo_model_config.get('num_query_groups', nemo_model_config['num_attention_heads']), + "convert_on_device": True, + } + ) starmap_config = { "tp_rank": None, diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index cfde6a359bdf..a29f7804f582 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -13,6 +13,7 @@ # limitations under the License. +from typing import List, Optional, Tuple import numpy as np import tensorrt_llm import torch @@ -57,6 +58,8 @@ "attention.key_value.bias", ] +weight_scaling_suffix = '.weights_scaling_factor' +activation_scaling_suffix = '.activation_scaling_factor' def save_val(val, dir, key, tp_num=None): if tp_num: @@ -86,11 +89,10 @@ def save_split(split_vals, dir, key, i, split_factor): def save_expert_split(split_vals, dir, key, i, split_factor): - if tp_num: - key += f".{tp_num}.bin" - for j, val in enumerate(split_vals): tp_num = i * split_factor + j + if tp_num: + key += f".{tp_num}.bin" global weights_dict weights_dict[key] = val @@ -202,20 +204,20 @@ def write_int8(vals, dir, base_key, split_dim, tp_rank, split_factor, kv_cache_o save_val(vals[save_key], dir, f"{base_key}.{save_key}") -def get_suffix(key): +def get_suffix(key: str) -> str: return '.' + key.split('.')[-1] -def get_trt_llm_prefix(key): +def get_trt_llm_prefix(key: str) -> str: layer_num = key.split(".")[1] return f'transformer.layers.{layer_num}' -def any_word_in_key(key, words): +def any_word_in_key(key: str, words: List[str]) -> bool: return any([word in key for word in words]) -def sequential_key_map(key, mapping): +def sequential_key_map(key: str, mapping: List[Tuple[List[str], str]]) -> Optional[str]: for keywords, mapped in mapping: if any_word_in_key(key, keywords): return mapped @@ -223,7 +225,7 @@ def sequential_key_map(key, mapping): return None -def get_trt_llm_infix(key): +def get_trt_llm_infix(key: str) -> Optional[str]: mapping = [ (post_layernorm_keys, '.post_layernorm'), (mlp_proj_bias_keys, '.mlp.proj'), @@ -241,7 +243,7 @@ def get_trt_llm_infix(key): return sequential_key_map(key, mapping) -def get_new_keyname(key): +def get_trt_llm_keyname(key: str) -> str: if any_word_in_key(key, final_layernorm_keys): return key.replace("final_layernorm", "transformer.ln_f") @@ -251,51 +253,46 @@ def get_new_keyname(key): return key -def is_scaling_factor(key): +def is_scaling_factor(key: str) -> bool: return "scale_fwd" in key -def get_scaling_factor_keys(key): - weight_key = '.'.join(key.split('.')[:-2]) + '.weight' - base_key = '.'.join(get_new_keyname(weight_key).split('.')[:-1]) - weight_scale = base_key + '.weights_scaling_factor' - activation_scale = base_key + '.activation_scaling_factor' - return weight_scale, activation_scale +def get_scaling_factor_keys(key: str) -> Tuple[Tuple[str, str], Tuple[str, str]]: + # Reuses existing mapping of NeMo -> TRT LLM weights key via swapping suffixes + corresponding_weight_key = '.'.join(key.split('.')[:-2]) + '.weight' + corresponding_trt_llm_weight_key = get_trt_llm_keyname(corresponding_weight_key) + base_key = '.'.join(corresponding_trt_llm_weight_key.split('.')[:-1]) + weight_scale = base_key + weight_scaling_suffix + activation_scale = base_key + activation_scaling_suffix + keys = (weight_scale, activation_scale) -first = True + layer_prefix = get_trt_llm_prefix(key) + mapped_key = layer_prefix + '.mlp.gate' + gate_activation = mapped_key + activation_scaling_suffix + gate_weight = mapped_key + weight_scaling_suffix + gate_keys = (gate_activation, gate_weight) + return keys, gate_keys -def load_scaling_factor(key, val, dir, config): - global weights_dict +def save_scaling_factor(scaling_factors: dict, key: str, val: torch.Tensor, config: dict): if not is_scaling_factor(key): - return weights_dict - - activation_factor = 1 / val[0].view(1) - weights_factor = 1 / val[1].view(1) - weights_factor_2 = 1 / val[2].view(1) + return scaling_factors - weights_key, activation_key = get_scaling_factor_keys(key) - save_val(torch_to_numpy(activation_factor), dir, activation_key) - save_val(torch_to_numpy(weights_factor), dir, weights_key) + activation_factor = torch_to_numpy(1 / val[0].view(1)) + weights_factor = torch_to_numpy(1 / val[1].view(1)) - # TODO - # save_val(torch_to_numpy(weights_factor_2), dir, weights_key + '_2') - # global first - # if first: - # first = False - # for i in range(32): - # save_val(torch_to_numpy(weights_factor_2), dir, f'transformer.layers.{i}.attention.kv_cache_scaling_factor') + (weights_key, activation_key), gate_keys = get_scaling_factor_keys(key) + scaling_factors[activation_key] = activation_factor + scaling_factors[weights_key] = weights_factor split_gated_activation = config.get("split_gated_activation", False) - if split_gated_activation and (("mlp.dense_h_to_4h" in key) or ("mlp.linear_fc1" in key)): - layer_prefix = get_trt_llm_prefix(key) - mapped_key = f'{layer_prefix}.mlp.gate' - save_val(torch_to_numpy(activation_factor), dir, mapped_key + '.activation_scaling_factor') - save_val(torch_to_numpy(weights_factor), dir, mapped_key + '.weights_scaling_factor') - # save_val(torch_to_numpy(weights_factor_2), dir, mapped_key + '.weights_scaling_factor_2') + if split_gated_activation and any_word_in_key(key, ["mlp.dense_h_to_4h", "mlp.linear_fc1"]): + (gate_activation_key, gate_weight_key) = gate_keys + scaling_factors[gate_activation_key] = activation_factor + scaling_factors[gate_weight_key] = weights_factor - return weights_dict + return scaling_factors def cast_val_datatype(vals, trt_llm_key, storage_type, is_fp8_model, scaling_factors): @@ -304,26 +301,25 @@ def cast_val_datatype(vals, trt_llm_key, storage_type, is_fp8_model, scaling_fac fp8_storage_type = torch.float8_e4m3fn quantized_keys = [ - k.split('.weights_scaling_factor')[0] for k in scaling_factors.keys() if '.weights_scaling_factor' in k + k.split(weight_scaling_suffix)[0] for k in scaling_factors.keys() if k.endswith(weight_scaling_suffix) ] for k in quantized_keys: if k in trt_llm_key: storage_type = fp8_storage_type - scale = scaling_factors[k + '.weights_scaling_factor'] + scale = scaling_factors[k + weight_scaling_suffix] vals = [val.to(torch.float32) / scale for val in vals] break return [val.to(storage_type) for val in vals] -def split_val_gate(vals, convert_on_device): +def split_val_gate(vals: List[np.ndarray], convert_on_device: bool): if convert_on_device: return [[n] for n in torch.chunk(vals[0], 2, axis=-1)] splits = [np.split(val, 2, axis=-1) for val in vals] return list(zip(*splits)) - # Note: in multi_query_mode, only query heads are split between multiple GPUs, while key/value head # are not split as there is only one head per key/value. @torch.no_grad() @@ -337,11 +333,11 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t num_kv_heads = config.get("num_kv_heads", num_attention_heads) size_per_head = config.get("kv_channels", None) convert_on_device = config.get("convert_on_device", False) - is_fp8_model = config.get("fp8", False) - + is_fp8_model = config.get("fp8_quantized", False) + use_fp8_kv_cache = config.get("fp8_kvcache", False) save_int8 = int8_outputs == "all" or int8_outputs == "kv_cache_only" - layer_prefix = get_trt_llm_prefix(key) + trt_llm_key = get_trt_llm_keyname(key) if not isinstance(vals, list): vals = [vals] @@ -350,7 +346,6 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t if "layernorm.weight" in key and config.get("apply_layernorm_1p", False): vals = [val.float() + 1.0 for val in vals] - trt_llm_key = get_new_keyname(key) vals = cast_val_datatype(vals, trt_llm_key, storage_type, is_fp8_model, sf) if convert_on_device: assert len(vals) == 1 # Should only convert a single device param per call @@ -382,7 +377,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t if act_range is not None and int8_outputs == "all": base_key = trt_llm_key.replace(".weight", "") vals_i8 = generate_int8(val, act_range, multi_query_mode=multi_query_mode) - write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor) # TODO is cat dim always defined? + write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor) elif any_word_in_key(key, mlp_fc_keys): if split_gated_activation: @@ -403,7 +398,8 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t if split_gated_activation: assert not save_int8 - gate_key = f'{layer_prefix}.mlp.gate' + get_suffix(trt_llm_key) + layer_prefix = get_trt_llm_prefix(key) + gate_key = layer_prefix +'.mlp.gate' + get_suffix(trt_llm_key) if convert_on_device: save_val(gates[0], saved_dir, gate_key) else: @@ -520,6 +516,11 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t kv_cache_only=int8_outputs == "kv_cache_only", ) + if use_fp8_kv_cache: + base_key = trt_llm_key.replace('.qkv.weight', '') + scaling_factor = np.array([1.], dtype=np.float32) + save_val(scaling_factor, dir, base_key + '.kv_cache_scaling_factor') + elif any_word_in_key(key, attention_not_mapped_keys): pass @@ -551,13 +552,14 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t return weights_dict -def split(v, tp_size, idx, dim=0): +def split(v: np.ndarray | torch.Tensor, tp_size: int, idx: int, dim: int = 0): """Splits the np tensor v on dim and return the idx's slice.""" if tp_size == 1: return v - if len(v.shape) == 1: - return np.ascontiguousarray(np.split(v, tp_size)[idx]) + dim = dim if len(v.shape) != 1 else 0 + if torch.is_tensor(v): + return torch.split(v, v.size(dim) // tp_size, dim=dim)[idx].contiguous() return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx]) diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index 2357c8a57269..c1141b962f44 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -18,7 +18,7 @@ import logging import os from pathlib import Path -from typing import Dict, Union +from typing import Dict, Optional, Tuple, Union import numpy as np import tensorstore # This is important even though not used. Otherwise zarr raises error. @@ -66,27 +66,27 @@ def __init__(self, path: Union[Path, TarPath]) -> None: self.path = path # overwrites path set in super().__init__ call -def get_extra_state_key(state_dict): +def get_extra_state_key(state_dict: dict) -> Optional[str]: for key in state_dict.keys(): if '_extra_state/' in key: return key - return False + return None -def unpack_extra_state_key(key): +def unpack_extra_state_key(key: str) -> Tuple[str, int]: basename = key.split('/')[0] size = int(key.split('/')[1].split('_')[-1]) return basename, size -def clear_loaded_extra_states(state_dict, basename): +def clear_loaded_extra_states(state_dict: dict, basename: str): to_remove = [k for k in state_dict.keys() if basename + '/' in k] for key in to_remove: state_dict.pop(key) return state_dict -def load_scaling_factors(state_dict, basename, size): +def load_scaling_factors(state_dict: dict, basename: str, size: int): scales = [] for layer in range(size): keyname = f'{basename}/shard_{layer}_{size}' @@ -102,7 +102,7 @@ def load_scaling_factors(state_dict, basename, size): return all_scales -def standarize_distributed_scaling_factors(state_dict): +def standarize_distributed_scaling_factors(state_dict: dict): while key := get_extra_state_key(state_dict): basename, size = unpack_extra_state_key(key) scaling_factors = load_scaling_factors(state_dict, basename, size) @@ -139,17 +139,20 @@ def load_sharded_metadata_torch_dist(checkpoint_dir: Union[Path, TarPath], torch return state_dict -def load_sharded_pickle_extra_state_scale(dir): +def load_sharded_pickle_extra_state_scale(dir: Union[Path, TarPath]): scales = [] layer_number = 0 while pt_file_list := list(dir.glob(f'shard_{layer_number}_*.pt')): pt_file = pt_file_list[0] - checkpoint = torch.load(pt_file) - checkpoint.seek(0) - state_dict = torch.load(checkpoint) - if 'scale_fwd' not in state_dict: - return [] + with pt_file.open('rb') as checkpoint_file: + dictionary = torch.load(checkpoint_file) + + dictionary.seek(0) + state_dict = torch.load(dictionary) + if not state_dict or 'scale_fwd' not in state_dict: + return None + scale = state_dict['scale_fwd'].cpu() scales.append(scale) layer_number += 1 @@ -158,15 +161,17 @@ def load_sharded_pickle_extra_state_scale(dir): return all_scales -def contains_extra_states(subdir): +def contains_extra_states(subdir: Union[Path, TarPath]): return list(subdir.glob('shard_0_*.pt')) != [] -def load_extra_state_from_pickle(sharded_state_dict, subdir): - if scales := load_sharded_pickle_extra_state_scale(subdir): - key = subdir.name + '.scale_fwd' - sharded_state_dict[key] = scales +def load_extra_state_from_pickle(sharded_state_dict: dict, subdir: Union[Path, TarPath]): + scales = load_sharded_pickle_extra_state_scale(subdir) + if scales is None: + return sharded_state_dict + key = subdir.name + '.scale_fwd' + sharded_state_dict[key] = scales return sharded_state_dict diff --git a/scripts/export/export_to_trt_llm.py b/scripts/export/export_to_trt_llm.py index a9b9d92c172b..9fd6067ba48d 100644 --- a/scripts/export/export_to_trt_llm.py +++ b/scripts/export/export_to_trt_llm.py @@ -50,6 +50,20 @@ def get_args(argv): type=str, help="dtype of the model on TensorRT-LLM", ) + parser.add_argument( + "-fp8", + "--export_fp8_quantized", + default=False, + type=bool, + help="Enables exporting to a FP8-quantized TRT LLM checkpoint", + ) + parser.add_argument( + "-kv_fp8", + "--use_fp8_kv_cache", + default=False, + type=bool, + help="Enables exporting with FP8-quantizatized KV-cache", + ) parser.add_argument("-mil", "--max_input_len", default=256, type=int, help="Max input length of the model") parser.add_argument("-mol", "--max_output_len", default=256, type=int, help="Max output length of the model") parser.add_argument("-mbs", "--max_batch_size", default=8, type=int, help="Max batch size of the model") @@ -153,6 +167,8 @@ def nemo_export_trt_llm(argv): use_lora_plugin=args.use_lora_plugin, lora_target_modules=args.lora_target_modules, max_lora_rank=args.max_lora_rank, + fp8_quantized=args.export_fp8_quantized, + fp8_kvcache=args.use_fp8_kv_cache ) LOGGER.info("Export is successful.") diff --git a/tests/export/nemo_export.py b/tests/export/nemo_export.py index 6a296fdb92eb..a0333a8f8574 100644 --- a/tests/export/nemo_export.py +++ b/tests/export/nemo_export.py @@ -241,6 +241,8 @@ def run_inference( test_deployment=False, test_data_path=None, save_trt_engine=False, + fp8_quantized=False, + fp8_kvcache=False, ) -> Tuple[Optional[FunctionalResult], Optional[AccuracyResult]]: if Path(checkpoint_path).exists(): if tp_size > torch.cuda.device_count(): @@ -324,6 +326,8 @@ def run_inference( lora_target_modules=lora_target_modules, max_num_tokens=int(max_input_len * max_batch_size * 0.2), use_embedding_sharing=use_embedding_sharing, + fp8_quantized=fp8_quantized, + fp8_kvcache=fp8_kvcache ) if ptuning: @@ -451,6 +455,8 @@ def run_existing_checkpoints( test_data_path=None, save_trt_engine=False, in_framework=False, + fp8_quantized=False, + fp8_kvcache=False, ) -> Tuple[Optional[FunctionalResult], Optional[AccuracyResult]]: if tp_size > torch.cuda.device_count(): print("Skipping the test due to not enough number of GPUs") @@ -528,6 +534,8 @@ def run_existing_checkpoints( test_deployment=test_deployment, test_data_path=test_data_path, save_trt_engine=save_trt_engine, + fp8_quantized=fp8_quantized, + fp8_kvcache=fp8_kvcache, ) @@ -743,6 +751,20 @@ def get_args(): type=float, help="GPU memory utilization percentage for vLLM.", ) + parser.add_argument( + "-fp8", + "--export_fp8_quantized", + default="False", + type=str, + help="Enables exporting to a FP8-quantized TRT LLM checkpoint", + ) + parser.add_argument( + "-kv_fp8", + "--use_fp8_kv_cache", + default="False", + type=str, + help="Enables exporting with FP8-quantizatized KV-cache", + ) args = parser.parse_args() @@ -763,6 +785,8 @@ def str_to_bool(name: str, s: str) -> bool: args.use_vllm = str_to_bool("use_vllm", args.use_vllm) args.use_parallel_embedding = str_to_bool("use_parallel_embedding", args.use_parallel_embedding) args.in_framework = str_to_bool("in_framework", args.in_framework) + args.export_fp8_quantized = str_to_bool("export_fp8_quantized", args.export_fp8_quantized) + args.use_fp8_kv_cache = str_to_bool("use_fp8_kv_cache", args.use_fp8_kv_cache) return args @@ -816,6 +840,8 @@ def run_inference_tests(args): test_data_path=args.test_data_path, save_trt_engine=args.save_trt_engine, in_framework=args.in_framework, + fp8_quantized=args.export_fp8_quantized, + fp8_kvcache=args.use_fp8_kv_cache, ) tps = tps * 2 @@ -871,6 +897,8 @@ def run_inference_tests(args): test_cpp_runtime=args.test_cpp_runtime, test_data_path=args.test_data_path, save_trt_engine=args.save_trt_engine, + fp8_quantized=args.export_fp8_quantized, + fp8_kvcache=args.use_fp8_kv_cache, ) tps = tps * 2 From 7a1d042d54c244d3126c855a41e6a8db4ed1e557 Mon Sep 17 00:00:00 2001 From: Laplasjan107 Date: Fri, 16 Aug 2024 16:17:13 +0000 Subject: [PATCH 07/21] Apply isort and black reformatting Signed-off-by: Laplasjan107 --- nemo/export/tarutils.py | 3 ++- nemo/export/tensorrt_llm.py | 2 +- nemo/export/trt_llm/converter/model_converter.py | 2 +- nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py | 9 ++------- nemo/export/trt_llm/converter/utils.py | 7 +++++-- scripts/export/export_to_trt_llm.py | 2 +- tests/export/nemo_export.py | 2 +- 7 files changed, 13 insertions(+), 14 deletions(-) diff --git a/nemo/export/tarutils.py b/nemo/export/tarutils.py index 55084af68110..30ec0142f5c4 100644 --- a/nemo/export/tarutils.py +++ b/nemo/export/tarutils.py @@ -58,7 +58,8 @@ def __truediv__(self, key) -> 'TarPath': def __str__(self) -> str: return os.path.join(self._tar.name, self._relpath) - def __fspath__(self): return str(self) + def __fspath__(self): + return str(self) @property def tarobject(self): diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index c4624accf7b2..8d668a948018 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -327,7 +327,7 @@ def export( use_parallel_embedding=use_parallel_embedding, use_embedding_sharing=use_embedding_sharing, fp8_quantized=fp8_quantized, - fp8_kvcache=fp8_kvcache + fp8_kvcache=fp8_kvcache, ) for weight_dict, model_config in zip(weights_dicts, model_configs): diff --git a/nemo/export/trt_llm/converter/model_converter.py b/nemo/export/trt_llm/converter/model_converter.py index 875f0ef404f9..5a4d3a6a4067 100755 --- a/nemo/export/trt_llm/converter/model_converter.py +++ b/nemo/export/trt_llm/converter/model_converter.py @@ -110,7 +110,7 @@ def model_to_trtllm_ckpt( model_parallel_rank: int = None, vocab_size: int | None = None, fp8_quantized: bool = False, - fp8_kvcache: bool = False + fp8_kvcache: bool = False, ) -> Tuple[List[Dict], List[PretrainedConfig]]: if nemo_model_config.get("share_embeddings_and_output_weights", False) and not use_embedding_sharing: LOGGER.info( diff --git a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py index b0dec240bd2e..bfe47bee7905 100644 --- a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py +++ b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py @@ -119,7 +119,7 @@ def convert_model_to_trt_llm_ckpt( inference_tp_size, use_parallel_embedding, processes, - export_config + export_config, ): # if checkpoints files could be found - start preparing output dir out_dir = create_export_dir(nemo_export_dir) @@ -250,12 +250,7 @@ def get_layer_num(param_name): @torch.no_grad() def dist_model_to_trt_llm_ckpt( - model, - nemo_model_config, - inference_tp_size, - inference_pp_size, - tokenizer_vocab_size, - export_config + model, nemo_model_config, inference_tp_size, inference_pp_size, tokenizer_vocab_size, export_config ): from megatron.core import parallel_state from megatron.core.tensor_parallel.utils import VocabUtility diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index a29f7804f582..0916c997a2b1 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -61,6 +61,7 @@ weight_scaling_suffix = '.weights_scaling_factor' activation_scaling_suffix = '.activation_scaling_factor' + def save_val(val, dir, key, tp_num=None): if tp_num: key += f".{tp_num}.bin" @@ -275,6 +276,7 @@ def get_scaling_factor_keys(key: str) -> Tuple[Tuple[str, str], Tuple[str, str]] return keys, gate_keys + def save_scaling_factor(scaling_factors: dict, key: str, val: torch.Tensor, config: dict): if not is_scaling_factor(key): return scaling_factors @@ -320,6 +322,7 @@ def split_val_gate(vals: List[np.ndarray], convert_on_device: bool): splits = [np.split(val, 2, axis=-1) for val in vals] return list(zip(*splits)) + # Note: in multi_query_mode, only query heads are split between multiple GPUs, while key/value head # are not split as there is only one head per key/value. @torch.no_grad() @@ -399,7 +402,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t if split_gated_activation: assert not save_int8 layer_prefix = get_trt_llm_prefix(key) - gate_key = layer_prefix +'.mlp.gate' + get_suffix(trt_llm_key) + gate_key = layer_prefix + '.mlp.gate' + get_suffix(trt_llm_key) if convert_on_device: save_val(gates[0], saved_dir, gate_key) else: @@ -518,7 +521,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t if use_fp8_kv_cache: base_key = trt_llm_key.replace('.qkv.weight', '') - scaling_factor = np.array([1.], dtype=np.float32) + scaling_factor = np.array([1.0], dtype=np.float32) save_val(scaling_factor, dir, base_key + '.kv_cache_scaling_factor') elif any_word_in_key(key, attention_not_mapped_keys): diff --git a/scripts/export/export_to_trt_llm.py b/scripts/export/export_to_trt_llm.py index 9fd6067ba48d..7a240a6c4e6d 100644 --- a/scripts/export/export_to_trt_llm.py +++ b/scripts/export/export_to_trt_llm.py @@ -168,7 +168,7 @@ def nemo_export_trt_llm(argv): lora_target_modules=args.lora_target_modules, max_lora_rank=args.max_lora_rank, fp8_quantized=args.export_fp8_quantized, - fp8_kvcache=args.use_fp8_kv_cache + fp8_kvcache=args.use_fp8_kv_cache, ) LOGGER.info("Export is successful.") diff --git a/tests/export/nemo_export.py b/tests/export/nemo_export.py index a0333a8f8574..ca34bde01c68 100644 --- a/tests/export/nemo_export.py +++ b/tests/export/nemo_export.py @@ -327,7 +327,7 @@ def run_inference( max_num_tokens=int(max_input_len * max_batch_size * 0.2), use_embedding_sharing=use_embedding_sharing, fp8_quantized=fp8_quantized, - fp8_kvcache=fp8_kvcache + fp8_kvcache=fp8_kvcache, ) if ptuning: From 7d087dd07a9fee9e05c7857e7431e53f8552d2e3 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Fri, 16 Aug 2024 13:56:52 -0700 Subject: [PATCH 08/21] fix expert key mapping Signed-off-by: Piotr Kaminski --- nemo/export/trt_llm/converter/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index 0916c997a2b1..654b368ed288 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -238,7 +238,7 @@ def get_trt_llm_infix(key: str) -> Optional[str]: (mlp_fc_keys, '.mlp.fc'), (attention_qkv_bias_keys + attention_qkv_weight_keys, '.attention.qkv'), (mlp_router_keys, '.mlp.router'), - (mlp_fc_keys, '.mlp.fc'), + (mlp_fc_expert_keys, '.mlp.fc'), (mlp_proj_experts_keys, '.mlp.proj'), ] return sequential_key_map(key, mapping) From f5ff40ec2f1da0bf0a31c4b393f9cf5c4a68eeec Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Wed, 21 Aug 2024 03:45:48 -0700 Subject: [PATCH 09/21] refactor Signed-off-by: Piotr Kaminski --- nemo/export/tensorrt_llm.py | 2 + .../trt_llm/converter/model_converter.py | 4 +- .../converter/model_to_trt_llm_ckpt.py | 34 +++---- nemo/export/trt_llm/converter/utils.py | 8 +- .../trt_llm/nemo_ckpt_loader/nemo_file.py | 98 +++++++++++-------- 5 files changed, 80 insertions(+), 66 deletions(-) diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 8d668a948018..7b7c4e07e225 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -204,6 +204,8 @@ def export( multiple_profiles: (bool): enables multiple profiles feature of TRT-LLM. Default = False gpt_attention_plugin (str): enable the gpt attention plugin. Default = "auto" gemm_plugin (str): enable the gpt plugin. Default = "auto" + fp8_quantized (bool): enables exporting to FP8 TRT-LLM checkpoints + fp8_kvcache (bool): enables FP8 KV-cache quantization """ if n_gpus is not None: diff --git a/nemo/export/trt_llm/converter/model_converter.py b/nemo/export/trt_llm/converter/model_converter.py index 5a4d3a6a4067..9eac3acfa708 100755 --- a/nemo/export/trt_llm/converter/model_converter.py +++ b/nemo/export/trt_llm/converter/model_converter.py @@ -15,7 +15,7 @@ import csv import logging -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import numpy as np import tensorrt_llm @@ -108,7 +108,7 @@ def model_to_trtllm_ckpt( use_embedding_sharing: bool = False, use_distributed_convert: bool = False, model_parallel_rank: int = None, - vocab_size: int | None = None, + vocab_size: Optional[int] = None, fp8_quantized: bool = False, fp8_kvcache: bool = False, ) -> Tuple[List[Dict], List[PretrainedConfig]]: diff --git a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py index bfe47bee7905..cf88ea91bcb1 100644 --- a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py +++ b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py @@ -93,6 +93,8 @@ def rename_key_dist_ckpt(old_key: str, layer: int) -> str: return rename_key(new_key) +def is_scaling_factor(key: str) -> bool: + return "extra_state" in key def load_scaling_factors(model: dict, num_layers: int, export_config: dict) -> dict: if not export_config.get('fp8_quantized', False): @@ -100,12 +102,10 @@ def load_scaling_factors(model: dict, num_layers: int, export_config: dict) -> d scaling_factors = {} for key, val in model.items(): - if 'extra_state' not in key: - continue - - for layer in range(num_layers): - renamed_key = rename_key_dist_ckpt(key, layer) - scaling_factors = save_scaling_factor(scaling_factors, renamed_key, val[layer], export_config) + if is_scaling_factor(key): + for layer in range(num_layers): + renamed_key = rename_key_dist_ckpt(key, layer) + scaling_factors = save_scaling_factor(scaling_factors, renamed_key, val[layer], export_config) return scaling_factors @@ -194,18 +194,16 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): starmap_args = [] for key, val in model.items(): - if 'extra_state' in key: - continue - - # Let's rename/map the key to the old layer name previously. - # Since the state dict value has the full layers, let's select the ith layer weights/biases here. - layer_vals = [(l, val[l]) for l in range(num_layers)] if len(val.size()) != 1 else [(0, val)] - - for l, v in layer_vals: - k = rename_key_dist_ckpt(key, l) - starmap_args.append( - (tp_rank, out_dir, split_factor, k, [v], storage_type, None, export_config, scaling_factors) - ) + if not is_scaling_factor(key): + # Let's rename/map the key to the old layer name previously. + # Since the state dict value has the full layers, let's select the ith layer weights/biases here. + layer_vals = [(l, val[l]) for l in range(num_layers)] if len(val.size()) != 1 else [(0, val)] + + for l, v in layer_vals: + k = rename_key_dist_ckpt(key, l) + starmap_args.append( + (tp_rank, out_dir, split_factor, k, [v], storage_type, None, export_config, scaling_factors) + ) starmap_args = tqdm(starmap_args, desc="saving weights") diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index 654b368ed288..9ae86bde800e 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np import tensorrt_llm import torch @@ -326,7 +326,7 @@ def split_val_gate(vals: List[np.ndarray], convert_on_device: bool): # Note: in multi_query_mode, only query heads are split between multiple GPUs, while key/value head # are not split as there is only one head per key/value. @torch.no_grad() -def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_type, act_range, config, sf): +def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_type, act_range, config, scaling_factors): use_attention_nemo_shape = config.get("use_attention_nemo_shape", False) split_gated_activation = config.get("split_gated_activation", False) num_attention_heads = config.get("num_attention_heads", 0) @@ -349,7 +349,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t if "layernorm.weight" in key and config.get("apply_layernorm_1p", False): vals = [val.float() + 1.0 for val in vals] - vals = cast_val_datatype(vals, trt_llm_key, storage_type, is_fp8_model, sf) + vals = cast_val_datatype(vals, trt_llm_key, storage_type, is_fp8_model, scaling_factors) if convert_on_device: assert len(vals) == 1 # Should only convert a single device param per call assert torch.is_tensor(vals[0]) @@ -555,7 +555,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t return weights_dict -def split(v: np.ndarray | torch.Tensor, tp_size: int, idx: int, dim: int = 0): +def split(v: Union[np.ndarray, torch.Tensor], tp_size: int, idx: int, dim: int = 0): """Splits the np tensor v on dim and return the idx's slice.""" if tp_size == 1: return v diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index d08500ec18a6..3a2b0df65b79 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -18,14 +18,15 @@ import logging import os from pathlib import Path -from typing import Dict, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np import tensorstore # This is important even though not used. Otherwise zarr raises error. import torch import yaml import zarr -from tensorrt_llm._utils import np_bfloat16 +from io import BytesIO +from tensorrt_llm._utils import np_bfloat16, str_dtype_to_torch from torch.distributed.checkpoint import FileSystemReader from torch.distributed.checkpoint.metadata import BytesStorageMetadata, TensorStorageMetadata from torch.distributed.checkpoint.state_dict_loader import load_state_dict @@ -79,50 +80,68 @@ def unpack_extra_state_key(key: str) -> Tuple[str, int]: return basename, size -def clear_loaded_extra_states(state_dict: dict, basename: str): +def clear_loaded_extra_states(state_dict: dict, basename: str) -> dict: + """ The scaling factors are originally saved to state_dict under the keynames 'basename/*' + The standardized representation is saved to 'basename.*'. This function clears the former from the state. + """ to_remove = [k for k in state_dict.keys() if basename + '/' in k] for key in to_remove: state_dict.pop(key) return state_dict -def load_scaling_factors(state_dict: dict, basename: str, size: int): + +def retrieve_scale(bytes: BytesIO) -> Optional[torch.Tensor]: + bytes.seek(0) + extra_state = torch.load(bytes) + if not extra_state or 'scale_fwd' not in extra_state: + return None + return extra_state['scale_fwd'].cpu() + + +def load_scales_from_bytes(bytes_list: List[BytesIO]) -> Optional[torch.Tensor]: scales = [] - for layer in range(size): - keyname = f'{basename}/shard_{layer}_{size}' - extra_state = state_dict[keyname][0] - extra_state.seek(0) - extra_state = torch.load(extra_state) + for bytes in bytes_list: + scale = retrieve_scale(bytes) + if scale is None: + return None + scales.append(scale) + return torch.stack(scales) - if 'scale_fwd' not in extra_state.keys(): - return [] - scales.append(extra_state['scale_fwd'].cpu()) - all_scales = torch.stack(scales) - return all_scales +def load_scaling_factors(state_dict: dict, basename: str, size: int) -> Optional[torch.Tensor]: + keynames = [f'{basename}/shard_{layer}_{size}' for layer in range(size)] + bytes_list = [state_dict[keyname][0] for keyname in keynames] + return load_scales_from_bytes(bytes_list) -def standarize_distributed_scaling_factors(state_dict: dict): +def standarize_distributed_scaling_factors(state_dict: dict) -> dict: while key := get_extra_state_key(state_dict): basename, size = unpack_extra_state_key(key) scaling_factors = load_scaling_factors(state_dict, basename, size) - if scaling_factors != []: + if scaling_factors is not None: state_dict[basename + '.scale_fwd'] = scaling_factors state_dict = clear_loaded_extra_states(state_dict, basename) return state_dict -def load_sharded_metadata_torch_dist(checkpoint_dir: Union[Path, TarPath], torch_tensor=True): +def load_sharded_metadata_torch_dist(checkpoint_dir: Union[Path, TarPath], torch_tensor: bool = True): fs_reader = TarFileSystemReader(checkpoint_dir) metadata = fs_reader.read_metadata() state_dict = { - k: torch.empty(tp.size, dtype=tp.properties.dtype) if isinstance(tp, TensorStorageMetadata) else {} + k: torch.empty(tp.size, dtype=tp.properties.dtype) for k, tp in metadata.state_dict_metadata.items() - if isinstance(tp, TensorStorageMetadata) or isinstance(tp, BytesStorageMetadata) + if isinstance(tp, TensorStorageMetadata) } + state_dict.update({ + k: {} + for k, tp in metadata.state_dict_metadata.items() + if isinstance(tp, BytesStorageMetadata) + }) + load_state_dict( state_dict, storage_reader=fs_reader, @@ -139,26 +158,25 @@ def load_sharded_metadata_torch_dist(checkpoint_dir: Union[Path, TarPath], torch return state_dict -def load_sharded_pickle_extra_state_scale(dir: Union[Path, TarPath]): - scales = [] - layer_number = 0 +def get_sharded_file(dir: dict, layer_number: int) -> Optional[os.PathLike]: + pt_file_list = list(dir.glob(f'shard_{layer_number}_*.pt')) + if pt_file_list == []: + return None + return pt_file_list[0] - while pt_file_list := list(dir.glob(f'shard_{layer_number}_*.pt')): - pt_file = pt_file_list[0] - with pt_file.open('rb') as checkpoint_file: - dictionary = torch.load(checkpoint_file) - dictionary.seek(0) - state_dict = torch.load(dictionary) - if not state_dict or 'scale_fwd' not in state_dict: - return None +def load_sharded_pickle_extra_state_scale(dir: Union[Path, TarPath]): + def _get_layer_number(file): + basename = os.path.basename(str(file)) + return int(basename.split('_')[1]) - scale = state_dict['scale_fwd'].cpu() - scales.append(scale) - layer_number += 1 + pt_files = list(dir.glob('shard_*_*.pt')) + bytes_list = [] + for file in sorted(pt_files, key=_get_layer_number): + with file.open('rb') as opened_file: + bytes_list.append(torch.load(opened_file)) - all_scales = torch.stack(scales) - return all_scales + return load_scales_from_bytes(bytes_list) def contains_extra_states(subdir: Union[Path, TarPath]): @@ -167,14 +185,12 @@ def contains_extra_states(subdir: Union[Path, TarPath]): def load_extra_state_from_pickle(sharded_state_dict: dict, subdir: Union[Path, TarPath]): scales = load_sharded_pickle_extra_state_scale(subdir) - if scales is None: - return sharded_state_dict + if scales is not None: + key = subdir.name + '.scale_fwd' + sharded_state_dict[key] = scales - key = subdir.name + '.scale_fwd' - sharded_state_dict[key] = scales return sharded_state_dict - def load_sharded_metadata_zarr(checkpoint_dir: Union[Path, TarPath], torch_tensor=True): sharded_state_dict = {} for subdir in checkpoint_dir.iterdir(): @@ -193,8 +209,6 @@ def load_sharded_metadata_zarr(checkpoint_dir: Union[Path, TarPath], torch_tenso if arr.dtype.name == "bfloat16": sharded_state_dict[key] = torch.from_numpy(arr[:].view(np.int16)).view(torch.bfloat16) else: - from tensorrt_llm._utils import str_dtype_to_torch - sharded_state_dict[key] = torch.from_numpy(arr[:]).view(str_dtype_to_torch(arr.dtype.name)) else: sharded_state_dict[key] = arr[:] From a11bc2f13f902b8fff266af087f2cc40316148b4 Mon Sep 17 00:00:00 2001 From: Laplasjan107 Date: Wed, 21 Aug 2024 10:49:42 +0000 Subject: [PATCH 10/21] Apply isort and black reformatting Signed-off-by: Laplasjan107 --- .../trt_llm/converter/model_to_trt_llm_ckpt.py | 2 ++ nemo/export/trt_llm/converter/utils.py | 4 +++- nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py | 14 ++++++-------- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py index cf88ea91bcb1..55fb5b3ce948 100644 --- a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py +++ b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py @@ -93,9 +93,11 @@ def rename_key_dist_ckpt(old_key: str, layer: int) -> str: return rename_key(new_key) + def is_scaling_factor(key: str) -> bool: return "extra_state" in key + def load_scaling_factors(model: dict, num_layers: int, export_config: dict) -> dict: if not export_config.get('fp8_quantized', False): return {} diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index 9ae86bde800e..f1882bbea3a1 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -326,7 +326,9 @@ def split_val_gate(vals: List[np.ndarray], convert_on_device: bool): # Note: in multi_query_mode, only query heads are split between multiple GPUs, while key/value head # are not split as there is only one head per key/value. @torch.no_grad() -def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_type, act_range, config, scaling_factors): +def split_and_save_weight( + tp_rank, saved_dir, split_factor, key, vals, storage_type, act_range, config, scaling_factors +): use_attention_nemo_shape = config.get("use_attention_nemo_shape", False) split_gated_activation = config.get("split_gated_activation", False) num_attention_heads = config.get("num_attention_heads", 0) diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index 3a2b0df65b79..14f02b06b71b 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -17,6 +17,7 @@ import json import logging import os +from io import BytesIO from pathlib import Path from typing import Dict, List, Optional, Tuple, Union @@ -25,7 +26,6 @@ import torch import yaml import zarr -from io import BytesIO from tensorrt_llm._utils import np_bfloat16, str_dtype_to_torch from torch.distributed.checkpoint import FileSystemReader from torch.distributed.checkpoint.metadata import BytesStorageMetadata, TensorStorageMetadata @@ -81,7 +81,7 @@ def unpack_extra_state_key(key: str) -> Tuple[str, int]: def clear_loaded_extra_states(state_dict: dict, basename: str) -> dict: - """ The scaling factors are originally saved to state_dict under the keynames 'basename/*' + """The scaling factors are originally saved to state_dict under the keynames 'basename/*' The standardized representation is saved to 'basename.*'. This function clears the former from the state. """ to_remove = [k for k in state_dict.keys() if basename + '/' in k] @@ -90,7 +90,6 @@ def clear_loaded_extra_states(state_dict: dict, basename: str) -> dict: return state_dict - def retrieve_scale(bytes: BytesIO) -> Optional[torch.Tensor]: bytes.seek(0) extra_state = torch.load(bytes) @@ -136,11 +135,9 @@ def load_sharded_metadata_torch_dist(checkpoint_dir: Union[Path, TarPath], torch if isinstance(tp, TensorStorageMetadata) } - state_dict.update({ - k: {} - for k, tp in metadata.state_dict_metadata.items() - if isinstance(tp, BytesStorageMetadata) - }) + state_dict.update( + {k: {} for k, tp in metadata.state_dict_metadata.items() if isinstance(tp, BytesStorageMetadata)} + ) load_state_dict( state_dict, @@ -191,6 +188,7 @@ def load_extra_state_from_pickle(sharded_state_dict: dict, subdir: Union[Path, T return sharded_state_dict + def load_sharded_metadata_zarr(checkpoint_dir: Union[Path, TarPath], torch_tensor=True): sharded_state_dict = {} for subdir in checkpoint_dir.iterdir(): From ec14cb40280b48cb2075c5ee96092ed27f2daea6 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Wed, 21 Aug 2024 06:37:40 -0700 Subject: [PATCH 11/21] fix: failed test was finishing with exit code 0 Signed-off-by: Piotr Kaminski --- tests/export/nemo_export.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/export/nemo_export.py b/tests/export/nemo_export.py index 5ac766123f71..7fdfd73e232f 100644 --- a/tests/export/nemo_export.py +++ b/tests/export/nemo_export.py @@ -968,5 +968,7 @@ def optional_bool_to_pass_fail(b: Optional[bool]): run_inference_tests(args) except UsageError as e: LOGGER.error(f"{e}") + raise e except argparse.ArgumentError as e: LOGGER.error(f"{e}") + raise e From 73d926102a3ee91047fb6a9b21291668d2a35b05 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Wed, 21 Aug 2024 07:36:26 -0700 Subject: [PATCH 12/21] test commit -- rerun github checks Signed-off-by: Piotr Kaminski --- nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py index 55fb5b3ce948..04d7cc21db2b 100644 --- a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py +++ b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py @@ -187,7 +187,6 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): model_level_weights["lm_head.weight"].append(val) weights_dict = {} - tp_rank = 0 handle_model_level_weights(model, 0, 0) From 84a5e5e00b4e1c0e4ca326a970b4d7f2246439a4 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Wed, 21 Aug 2024 08:24:39 -0700 Subject: [PATCH 13/21] bugfix: naming Signed-off-by: Piotr Kaminski --- nemo/export/trt_llm/converter/utils.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index f1882bbea3a1..d7abe38c936a 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -63,8 +63,8 @@ def save_val(val, dir, key, tp_num=None): - if tp_num: - key += f".{tp_num}.bin" + suffix = f".{tp_num}.bin" if tp_num else '' + tp_key = key + suffix global weights_dict # Transpose linear layer weights to the correct shape. @@ -74,14 +74,14 @@ def save_val(val, dir, key, tp_num=None): val = val.reshape(val.shape[0], -1) val = torch.transpose(val, 0, 1) if key not in weights_dict: - weights_dict[key] = torch.empty( + weights_dict[tp_key] = torch.empty( val.size(), dtype=val.dtype, layout=val.layout, device="cpu", pin_memory=True ) - weights_dict[key].copy_(val, non_blocking=True) + weights_dict[tp_key].copy_(val, non_blocking=True) else: if len(val.shape) >= 2: val = np.ascontiguousarray(np.transpose(val.reshape(val.shape[0], -1), [1, 0])) - weights_dict[key] = val + weights_dict[tp_key] = val def save_split(split_vals, dir, key, i, split_factor): @@ -91,11 +91,10 @@ def save_split(split_vals, dir, key, i, split_factor): def save_expert_split(split_vals, dir, key, i, split_factor): for j, val in enumerate(split_vals): + suffix = f".{tp_num}.bin" if tp_num else '' tp_num = i * split_factor + j - if tp_num: - key += f".{tp_num}.bin" global weights_dict - weights_dict[key] = val + weights_dict[key + suffix] = val def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False): From 250525ec510874994ae046c27ced9d548a38f38b Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Wed, 21 Aug 2024 08:27:42 -0700 Subject: [PATCH 14/21] bugfix v2: naming Signed-off-by: Piotr Kaminski --- nemo/export/trt_llm/converter/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index d7abe38c936a..471086d2a333 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -91,8 +91,9 @@ def save_split(split_vals, dir, key, i, split_factor): def save_expert_split(split_vals, dir, key, i, split_factor): for j, val in enumerate(split_vals): - suffix = f".{tp_num}.bin" if tp_num else '' tp_num = i * split_factor + j + suffix = f".{tp_num}.bin" if tp_num else '' + global weights_dict weights_dict[key + suffix] = val From 69b4f69c0150735f9a2c617f0ae5f36dc203b40d Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Fri, 23 Aug 2024 02:51:05 -0700 Subject: [PATCH 15/21] apply code review changes Signed-off-by: Piotr Kaminski --- nemo/export/tensorrt_llm.py | 8 ++-- .../trt_llm/converter/model_converter.py | 22 ++++++--- scripts/export/export_to_trt_llm.py | 46 +++++++++++++------ tests/export/nemo_export.py | 17 ++++--- 4 files changed, 62 insertions(+), 31 deletions(-) diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 7b7c4e07e225..06a876c2b833 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -167,8 +167,8 @@ def export( multiple_profiles: bool = False, gpt_attention_plugin: str = "auto", gemm_plugin: str = "auto", - fp8_quantized: bool = False, - fp8_kvcache: bool = False, + fp8_quantized: Optional[bool] = None, + fp8_kvcache: Optional[bool] = None, ): """ Exports nemo checkpoints to TensorRT-LLM. @@ -204,8 +204,8 @@ def export( multiple_profiles: (bool): enables multiple profiles feature of TRT-LLM. Default = False gpt_attention_plugin (str): enable the gpt attention plugin. Default = "auto" gemm_plugin (str): enable the gpt plugin. Default = "auto" - fp8_quantized (bool): enables exporting to FP8 TRT-LLM checkpoints - fp8_kvcache (bool): enables FP8 KV-cache quantization + fp8_quantized (Optional[bool]): enables exporting to FP8 TRT-LLM checkpoints. If not set, autodetects the type. + fp8_kvcache (Optional[bool]): enables FP8 KV-cache quantization. If not set, autodetects the type. """ if n_gpus is not None: diff --git a/nemo/export/trt_llm/converter/model_converter.py b/nemo/export/trt_llm/converter/model_converter.py index 9eac3acfa708..b7c959d3b5b5 100755 --- a/nemo/export/trt_llm/converter/model_converter.py +++ b/nemo/export/trt_llm/converter/model_converter.py @@ -39,11 +39,10 @@ def get_config(decoder_type, config): if decoder_type == "llama": return LLaMAConfig(**config) - - if decoder_type in ["gpt", "gptnext"]: + elif decoder_type == "gpt" or decoder_type == "gptnext": return GPTConfig(**config) - - return PretrainedConfig(**config) + else: + return PretrainedConfig(**config) def prompt_convert(prompt_config, prompt_weights): @@ -95,6 +94,16 @@ def create_common_export_config(nemo_model_config, decoder_type, fp8_quantized=F } +def determine_quantization_settings(nemo_model_config, fp8_quantized: Optional[bool] = None, fp8_kvcache: Optional[bool] = None) -> Tuple[bool, bool]: + is_nemo_quantized = nemo_model_config.get('fp8', False) + if fp8_quantized is None: + fp8_quantized = is_nemo_quantized + if fp8_kvcache is None: + fp8_kvcache = is_nemo_quantized + + return fp8_quantized, fp8_kvcache + + def model_to_trtllm_ckpt( model, nemo_model_config, @@ -109,8 +118,8 @@ def model_to_trtllm_ckpt( use_distributed_convert: bool = False, model_parallel_rank: int = None, vocab_size: Optional[int] = None, - fp8_quantized: bool = False, - fp8_kvcache: bool = False, + fp8_quantized: Optional[bool] = None, + fp8_kvcache: Optional[bool] = None, ) -> Tuple[List[Dict], List[PretrainedConfig]]: if nemo_model_config.get("share_embeddings_and_output_weights", False) and not use_embedding_sharing: LOGGER.info( @@ -118,6 +127,7 @@ def model_to_trtllm_ckpt( ) use_embedding_sharing = True + fp8_quantized, fp8_kvcache = determine_quantization_settings(nemo_model_config, fp8_quantized, fp8_kvcache) export_config = create_common_export_config(nemo_model_config, decoder_type, fp8_quantized, fp8_kvcache) # If the model has been sharded with model parallelism, convert the model in a gpu-distributed manner if use_distributed_convert: diff --git a/scripts/export/export_to_trt_llm.py b/scripts/export/export_to_trt_llm.py index 7a240a6c4e6d..06193b06aee7 100644 --- a/scripts/export/export_to_trt_llm.py +++ b/scripts/export/export_to_trt_llm.py @@ -15,11 +15,14 @@ import argparse import logging import sys +from typing import Optional from nemo.export.tensorrt_llm import TensorRTLLM LOGGER = logging.getLogger("NeMo") +class UsageError(Exception): + pass def get_args(argv): parser = argparse.ArgumentParser( @@ -50,20 +53,6 @@ def get_args(argv): type=str, help="dtype of the model on TensorRT-LLM", ) - parser.add_argument( - "-fp8", - "--export_fp8_quantized", - default=False, - type=bool, - help="Enables exporting to a FP8-quantized TRT LLM checkpoint", - ) - parser.add_argument( - "-kv_fp8", - "--use_fp8_kv_cache", - default=False, - type=bool, - help="Enables exporting with FP8-quantizatized KV-cache", - ) parser.add_argument("-mil", "--max_input_len", default=256, type=int, help="Max input length of the model") parser.add_argument("-mol", "--max_output_len", default=256, type=int, help="Max output length of the model") parser.add_argument("-mbs", "--max_batch_size", default=8, type=int, help="Max batch size of the model") @@ -121,8 +110,37 @@ def get_args(argv): 'It is used to compute the workspace size of lora plugin.', ) parser.add_argument("-dm", "--debug_mode", default=False, action='store_true', help="Enable debug mode") + parser.add_argument( + "-fp8", + "--export_fp8_quantized", + default="auto", + type=str, + help="Enables exporting to a FP8-quantized TRT LLM checkpoint", + ) + parser.add_argument( + "-kv_fp8", + "--use_fp8_kv_cache", + default="auto", + type=str, + help="Enables exporting with FP8-quantizatized KV-cache", + ) args = parser.parse_args(argv) + + def str_to_bool(name: str, s: str, optional: bool = False) -> Optional[bool]: + s = s.lower() + true_strings = ["true", "1"] + false_strings = ["false", "0"] + if s in true_strings: + return True + if s in false_strings: + return False + if optional and s == 'auto': + return None + raise UsageError(f"Invalid boolean value for argument --{name}: '{s}'") + + args.export_fp8_quantized = str_to_bool("export_fp8_quantized", args.export_fp8_quantized, optional=True) + args.use_fp8_kv_cache = str_to_bool("use_fp8_kv_cache", args.use_fp8_kv_cache, optional=True) return args diff --git a/tests/export/nemo_export.py b/tests/export/nemo_export.py index 7fdfd73e232f..ecaf198a0c07 100644 --- a/tests/export/nemo_export.py +++ b/tests/export/nemo_export.py @@ -759,27 +759,30 @@ def get_args(): parser.add_argument( "-fp8", "--export_fp8_quantized", - default="False", + default="auto", type=str, help="Enables exporting to a FP8-quantized TRT LLM checkpoint", ) parser.add_argument( "-kv_fp8", "--use_fp8_kv_cache", - default="False", + default="auto", type=str, help="Enables exporting with FP8-quantizatized KV-cache", ) args = parser.parse_args() - def str_to_bool(name: str, s: str) -> bool: + def str_to_bool(name: str, s: str, optional: bool = False) -> Optional[bool]: + s = s.lower() true_strings = ["true", "1"] false_strings = ["false", "0"] - if s.lower() in true_strings: + if s in true_strings: return True - if s.lower() in false_strings: + if s in false_strings: return False + if optional and s == 'auto': + return None raise UsageError(f"Invalid boolean value for argument --{name}: '{s}'") args.test_cpp_runtime = str_to_bool("test_cpp_runtime", args.test_cpp_runtime) @@ -790,8 +793,8 @@ def str_to_bool(name: str, s: str) -> bool: args.use_vllm = str_to_bool("use_vllm", args.use_vllm) args.use_parallel_embedding = str_to_bool("use_parallel_embedding", args.use_parallel_embedding) args.in_framework = str_to_bool("in_framework", args.in_framework) - args.export_fp8_quantized = str_to_bool("export_fp8_quantized", args.export_fp8_quantized) - args.use_fp8_kv_cache = str_to_bool("use_fp8_kv_cache", args.use_fp8_kv_cache) + args.export_fp8_quantized = str_to_bool("export_fp8_quantized", args.export_fp8_quantized, optional=True) + args.use_fp8_kv_cache = str_to_bool("use_fp8_kv_cache", args.use_fp8_kv_cache, optional=True) return args From 487edd0dabbec94a1ccb05c40f6052f0fa9ac85c Mon Sep 17 00:00:00 2001 From: Laplasjan107 Date: Fri, 23 Aug 2024 09:54:11 +0000 Subject: [PATCH 16/21] Apply isort and black reformatting Signed-off-by: Laplasjan107 --- nemo/export/trt_llm/converter/model_converter.py | 4 +++- scripts/export/export_to_trt_llm.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/nemo/export/trt_llm/converter/model_converter.py b/nemo/export/trt_llm/converter/model_converter.py index b7c959d3b5b5..3078cc3ac4ae 100755 --- a/nemo/export/trt_llm/converter/model_converter.py +++ b/nemo/export/trt_llm/converter/model_converter.py @@ -94,7 +94,9 @@ def create_common_export_config(nemo_model_config, decoder_type, fp8_quantized=F } -def determine_quantization_settings(nemo_model_config, fp8_quantized: Optional[bool] = None, fp8_kvcache: Optional[bool] = None) -> Tuple[bool, bool]: +def determine_quantization_settings( + nemo_model_config, fp8_quantized: Optional[bool] = None, fp8_kvcache: Optional[bool] = None +) -> Tuple[bool, bool]: is_nemo_quantized = nemo_model_config.get('fp8', False) if fp8_quantized is None: fp8_quantized = is_nemo_quantized diff --git a/scripts/export/export_to_trt_llm.py b/scripts/export/export_to_trt_llm.py index 06193b06aee7..3f5924fde80c 100644 --- a/scripts/export/export_to_trt_llm.py +++ b/scripts/export/export_to_trt_llm.py @@ -21,9 +21,11 @@ LOGGER = logging.getLogger("NeMo") + class UsageError(Exception): pass + def get_args(argv): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, From e2a3139f759995034acc9283d209093d3cbd7425 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Tue, 27 Aug 2024 07:49:40 -0700 Subject: [PATCH 17/21] fix TensorRTLLM build (fp8 still not supported) Signed-off-by: Piotr Kaminski --- nemo/export/trt_llm/converter/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index 471086d2a333..d3c128aab2af 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -327,7 +327,7 @@ def split_val_gate(vals: List[np.ndarray], convert_on_device: bool): # are not split as there is only one head per key/value. @torch.no_grad() def split_and_save_weight( - tp_rank, saved_dir, split_factor, key, vals, storage_type, act_range, config, scaling_factors + tp_rank, saved_dir, split_factor, key, vals, storage_type, act_range, config, scaling_factors = {} ): use_attention_nemo_shape = config.get("use_attention_nemo_shape", False) split_gated_activation = config.get("split_gated_activation", False) From 19c866216ae496f9f35f91ca6b0042a39f0bfc32 Mon Sep 17 00:00:00 2001 From: Laplasjan107 Date: Tue, 27 Aug 2024 14:51:20 +0000 Subject: [PATCH 18/21] Apply isort and black reformatting Signed-off-by: Laplasjan107 --- nemo/export/trt_llm/converter/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index d3c128aab2af..a9b808f3d3a3 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -327,7 +327,7 @@ def split_val_gate(vals: List[np.ndarray], convert_on_device: bool): # are not split as there is only one head per key/value. @torch.no_grad() def split_and_save_weight( - tp_rank, saved_dir, split_factor, key, vals, storage_type, act_range, config, scaling_factors = {} + tp_rank, saved_dir, split_factor, key, vals, storage_type, act_range, config, scaling_factors={} ): use_attention_nemo_shape = config.get("use_attention_nemo_shape", False) split_gated_activation = config.get("split_gated_activation", False) From b01fdbadd2fa6257c9edc40a9afb7fe07420e467 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Tue, 27 Aug 2024 08:23:53 -0700 Subject: [PATCH 19/21] undo refactor Signed-off-by: Piotr Kaminski --- nemo/export/tarutils.py | 5 +- .../trt_llm/converter/model_converter.py | 31 ++++++------- .../converter/model_to_trt_llm_ckpt.py | 46 +++++++++++++++---- nemo/export/trt_llm/converter/utils.py | 28 +++++------ 4 files changed, 64 insertions(+), 46 deletions(-) diff --git a/nemo/export/tarutils.py b/nemo/export/tarutils.py index 30ec0142f5c4..b93f65274120 100644 --- a/nemo/export/tarutils.py +++ b/nemo/export/tarutils.py @@ -20,7 +20,7 @@ import zarr.storage -class TarPath(os.PathLike): +class TarPath: """ A class that represents a path inside a TAR archive and behaves like pathlib.Path. @@ -58,9 +58,6 @@ def __truediv__(self, key) -> 'TarPath': def __str__(self) -> str: return os.path.join(self._tar.name, self._relpath) - def __fspath__(self): - return str(self) - @property def tarobject(self): return self._tar diff --git a/nemo/export/trt_llm/converter/model_converter.py b/nemo/export/trt_llm/converter/model_converter.py index 3078cc3ac4ae..73a7566c275d 100755 --- a/nemo/export/trt_llm/converter/model_converter.py +++ b/nemo/export/trt_llm/converter/model_converter.py @@ -143,8 +143,6 @@ def model_to_trtllm_ckpt( ) vocab_size_padded = vocab_size else: - vocab_embedding_key = "transformer.vocab_embedding.weight" - weights_dict = convert_model_to_trt_llm_ckpt( model=model, nemo_model_config=nemo_model_config, @@ -156,20 +154,16 @@ def model_to_trtllm_ckpt( export_config=export_config, ) - if vocab_size is None: - vocab_size = weights_dict[vocab_embedding_key].shape[0] - has_lm_head = "lm_head.weight" in weights_dict - vocab_size_padded = pad_vocab_size(vocab_size, tensor_parallel_size) if has_lm_head else vocab_size - padding = (0, 0, 0, vocab_size_padded - vocab_size) if has_lm_head: lm_head_weight = weights_dict["lm_head.weight"] - lm_head_weight = torch.nn.functional.pad(lm_head_weight, padding, "constant", 0) + if vocab_size is None: + vocab_size = weights_dict["transformer.vocab_embedding.weight"].shape[0] + vocab_size_padded = pad_vocab_size(vocab_size, tensor_parallel_size) if has_lm_head else vocab_size - if vocab_embedding_key in weights_dict: - weights_dict[vocab_embedding_key] = torch.nn.functional.pad( - weights_dict[vocab_embedding_key], padding, "constant", 0 - ) + if has_lm_head and vocab_size_padded != vocab_size: + pad_width = vocab_size_padded - vocab_size + lm_head_weight = np.pad(lm_head_weight, ((0, pad_width), (0, 0)), "constant", constant_values=0) world_size = tensor_parallel_size * pipeline_parallel_size hidden_act = nemo_model_config.get('activation') @@ -241,7 +235,7 @@ def model_to_trtllm_ckpt( return weights_dicts, model_configs pp_key = { - vocab_embedding_key, + "transformer.vocab_embedding.weight", "transformer.position_embedding.weight", "lm_head.weight", "transformer.ln_f.weight", @@ -266,9 +260,10 @@ def model_to_trtllm_ckpt( continue new_key = k if new_key.endswith(".bin"): # TP split - if not new_key.endswith(f"{mapping.tp_rank}.bin"): + if new_key.endswith(f"{mapping.tp_rank}.bin"): + new_key = new_key.replace(f".{mapping.tp_rank}.bin", "") + else: continue - new_key = new_key.replace(f".{mapping.tp_rank}.bin", "") if "layers" in new_key: # PP layer_num = int(new_key.split(".")[2]) if layer_num in layers_range: @@ -279,12 +274,12 @@ def model_to_trtllm_ckpt( if mapping.is_first_pp_rank(): embedding_weight = ( - np.ascontiguousarray(split(weights_dict[vocab_embedding_key], mapping.tp_size, mapping.tp_rank)) + np.ascontiguousarray(split(weights_dict["transformer.vocab_embedding.weight"], mapping.tp_size, mapping.tp_rank)) if use_parallel_embedding - else weights_dict[vocab_embedding_key] + else weights_dict["transformer.vocab_embedding.weight"] ) - weights_dict_local[vocab_embedding_key] = embedding_weight + weights_dict_local["transformer.vocab_embedding.weight"] = embedding_weight pos_embedding_weight = weights_dict.get("transformer.position_embedding.weight") if pos_embedding_weight is not None: diff --git a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py index 04d7cc21db2b..b5420582ed99 100644 --- a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py +++ b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py @@ -69,7 +69,7 @@ def get_layer_prefix(layer_names, is_mcore): return model_prefix, transformer_layer_prefix -def rename_key(new_key: str) -> str: +def rename_key(new_key: str): if "self_attention" in new_key: new_key = new_key.replace("self_attention", "attention") if "attention.linear_qkv.layer_norm_weight" in new_key: @@ -84,7 +84,7 @@ def rename_key(new_key: str) -> str: return new_key -def rename_key_dist_ckpt(old_key: str, layer: int) -> str: +def rename_key_dist_ckpt(old_key: str, layer: int): new_key = old_key if "layers." in old_key: split_key = old_key.split(".") @@ -195,16 +195,42 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): starmap_args = [] for key, val in model.items(): - if not is_scaling_factor(key): - # Let's rename/map the key to the old layer name previously. - # Since the state dict value has the full layers, let's select the ith layer weights/biases here. - layer_vals = [(l, val[l]) for l in range(num_layers)] if len(val.size()) != 1 else [(0, val)] - - for l, v in layer_vals: - k = rename_key_dist_ckpt(key, l) + if "_extra_state" not in key: + if len(val.size()) == 1: starmap_args.append( - (tp_rank, out_dir, split_factor, k, [v], storage_type, None, export_config, scaling_factors) + ( + tp_rank, + out_dir, + split_factor, + # Let's rename/map the key to the old layer name previously. You can try printing out + # the rename_key output of the old llama checkpoint and compare. + rename_key_dist_ckpt(key, 0), + # Since the state dict value has the full layers, let's select the ith layer weights/biases here. + [val], + storage_type, + None, + export_config, + scaling_factors, + ) ) + else: + for i in range(num_layers): + starmap_args.append( + ( + tp_rank, + out_dir, + split_factor, + # Let's rename/map the key to the old layer name previously. You can try printing out + # the rename_key output of the old llama checkpoint and compare. + rename_key_dist_ckpt(key, i), + # Since the state dict value has the full layers, let's select the ith layer weights/biases here. + [val[i]], + storage_type, + None, + export_config, + scaling_factors, + ) + ) starmap_args = tqdm(starmap_args, desc="saving weights") diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index d3c128aab2af..3340b76616bd 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -63,10 +63,9 @@ def save_val(val, dir, key, tp_num=None): - suffix = f".{tp_num}.bin" if tp_num else '' - tp_key = key + suffix - + suffix = "" if tp_num is None else f".{tp_num}.bin" global weights_dict + # Transpose linear layer weights to the correct shape. if torch.is_tensor(val): val = val.detach().contiguous() @@ -74,14 +73,14 @@ def save_val(val, dir, key, tp_num=None): val = val.reshape(val.shape[0], -1) val = torch.transpose(val, 0, 1) if key not in weights_dict: - weights_dict[tp_key] = torch.empty( + weights_dict[f"{key}{suffix}"] = torch.empty( val.size(), dtype=val.dtype, layout=val.layout, device="cpu", pin_memory=True ) - weights_dict[tp_key].copy_(val, non_blocking=True) + weights_dict[f"{key}{suffix}"].copy_(val, non_blocking=True) else: if len(val.shape) >= 2: val = np.ascontiguousarray(np.transpose(val.reshape(val.shape[0], -1), [1, 0])) - weights_dict[tp_key] = val + weights_dict[f"{key}{suffix}"] = val def save_split(split_vals, dir, key, i, split_factor): @@ -92,10 +91,10 @@ def save_split(split_vals, dir, key, i, split_factor): def save_expert_split(split_vals, dir, key, i, split_factor): for j, val in enumerate(split_vals): tp_num = i * split_factor + j - suffix = f".{tp_num}.bin" if tp_num else '' + suffix = "" if tp_num is None else f".{tp_num}.bin" global weights_dict - weights_dict[key + suffix] = val + weights_dict[f"{key}{suffix}"] = val def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False): @@ -483,12 +482,13 @@ def split_and_save_weight( qkv = np.split(val, [q_num, q_num + 1], axis=2) query_groups_shape = qkv[0].shape - if len(query_groups_shape) > 1 and ((query_groups_shape[1] % split_factor) != 0): - raise Exception( - "Number of query groups of the models is {0}. Please select tensor parallelism size " - "that can split the number of query groups to equal number of query matrices in the " - "each GPU.".format(query_groups_shape[1]) - ) + if len(query_groups_shape) > 1: + if (query_groups_shape[1] % split_factor) != 0: + raise Exception( + "Number of query groups of the models is {0}. Please select tensor parallelism size " + "that can split the number of query groups to equal number of query matrices in the " + "each GPU.".format(query_groups_shape[1]) + ) q_split = np.split(qkv[0], split_factor, axis=1) k_split = np.split(qkv[1], split_factor, axis=1) From 0c922b7eabcaf09abb5cb6e729e898d8a0109820 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Wed, 28 Aug 2024 01:34:12 -0700 Subject: [PATCH 20/21] bugfix: arguments to dist_convert Signed-off-by: Piotr Kaminski --- .../trt_llm/converter/model_converter.py | 23 ++----- .../converter/model_to_trt_llm_ckpt.py | 60 +++++++++++++------ 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/nemo/export/trt_llm/converter/model_converter.py b/nemo/export/trt_llm/converter/model_converter.py index 73a7566c275d..c9a593caf5d3 100755 --- a/nemo/export/trt_llm/converter/model_converter.py +++ b/nemo/export/trt_llm/converter/model_converter.py @@ -79,21 +79,6 @@ def prompt_convert(prompt_config, prompt_weights): return vtokens_embeddings -def create_common_export_config(nemo_model_config, decoder_type, fp8_quantized=False, fp8_kvcache=False): - is_mcore = nemo_model_config.get("mcore_gpt", False) - return { - "apply_layernorm_1p": nemo_model_config.get("normalization", "") == "layernorm1p", - "split_gated_activation": nemo_model_config.get("activation", "gelu") - in ["swiglu", "geglu", "fast-swiglu", "fast-geglu"] - and (decoder_type == "gptnext" or is_mcore), - "num_attention_heads": nemo_model_config["num_attention_heads"], - "use_attention_nemo_shape": True, - "transpose_weights": True, - "fp8_quantized": fp8_quantized, - "fp8_kvcache": fp8_kvcache, - } - - def determine_quantization_settings( nemo_model_config, fp8_quantized: Optional[bool] = None, fp8_kvcache: Optional[bool] = None ) -> Tuple[bool, bool]: @@ -130,7 +115,6 @@ def model_to_trtllm_ckpt( use_embedding_sharing = True fp8_quantized, fp8_kvcache = determine_quantization_settings(nemo_model_config, fp8_quantized, fp8_kvcache) - export_config = create_common_export_config(nemo_model_config, decoder_type, fp8_quantized, fp8_kvcache) # If the model has been sharded with model parallelism, convert the model in a gpu-distributed manner if use_distributed_convert: weights_dict = dist_model_to_trt_llm_ckpt( @@ -139,7 +123,8 @@ def model_to_trtllm_ckpt( inference_tp_size=tensor_parallel_size, inference_pp_size=pipeline_parallel_size, tokenizer_vocab_size=vocab_size, - export_config=export_config, + fp8_quantized=fp8_quantized, + fp8_kvcache=fp8_kvcache ) vocab_size_padded = vocab_size else: @@ -151,7 +136,9 @@ def model_to_trtllm_ckpt( processes=1, storage_type=dtype, use_parallel_embedding=use_parallel_embedding, - export_config=export_config, + decoder_type=decoder_type, + fp8_quantized=fp8_quantized, + fp8_kvcache=fp8_kvcache ) has_lm_head = "lm_head.weight" in weights_dict diff --git a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py index b5420582ed99..07ac7b334f8e 100644 --- a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py +++ b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py @@ -119,10 +119,13 @@ def convert_model_to_trt_llm_ckpt( nemo_export_dir, storage_type, inference_tp_size, + decoder_type, use_parallel_embedding, processes, - export_config, + fp8_quantized=False, + fp8_kvcache=False, ): + # if checkpoints files could be found - start preparing output dir out_dir = create_export_dir(nemo_export_dir) storage_type = str_dtype_to_torch(storage_type) @@ -135,6 +138,9 @@ def convert_model_to_trt_llm_ckpt( has_position_embedding = get_layer_name("position_embedding", prefix) in model_state_dict has_lm_head = get_layer_name("output_layer", prefix) in model_state_dict + share_embeddings_and_output = nemo_model_config.get("share_embeddings_and_output_weights", False) + embedding_scaling = nemo_model_config.get("apply_embedding_scaling", False) + hidden_size = nemo_model_config["hidden_size"] num_layers = nemo_model_config["num_layers"] training_tp_size = 1 @@ -142,6 +148,7 @@ def convert_model_to_trt_llm_ckpt( num_kv_heads = nemo_model_config.get("num_query_groups", 0) multi_query_mode = nemo_model_config.get("multi_query_mode", False) num_attention_heads = nemo_model_config["num_attention_heads"] + kv_channels = nemo_model_config.get("kv_channels", None) if num_kv_heads == 0: if multi_query_mode: @@ -149,14 +156,21 @@ def convert_model_to_trt_llm_ckpt( else: num_kv_heads = num_attention_heads - export_config.update( - { - "tp_size": training_tp_size, - "num_kv_heads": num_kv_heads, - "kv_channels": nemo_model_config.get("kv_channels", None), - "use_parallel_embedding": use_parallel_embedding, - } - ) + export_config = { + "apply_layernorm_1p": nemo_model_config.get("normalization", "") == "layernorm1p", + "tp_size": training_tp_size, + "split_gated_activation": nemo_model_config.get("activation", "gelu") + in ["swiglu", "geglu", "fast-swiglu", "fast-geglu"] + and (decoder_type == "gptnext" or is_mcore), + "num_attention_heads": num_attention_heads, + "num_kv_heads": num_kv_heads, + "kv_channels": kv_channels, + "use_attention_nemo_shape": True, + "transpose_weights": True, + "use_parallel_embedding": use_parallel_embedding, + "fp8_quantized": fp8_quantized, + "fp8_kvcache": fp8_kvcache, + } # split_factor: in how many parts a TP training node is split split_factor = inference_tp_size @@ -275,7 +289,13 @@ def get_layer_num(param_name): @torch.no_grad() def dist_model_to_trt_llm_ckpt( - model, nemo_model_config, inference_tp_size, inference_pp_size, tokenizer_vocab_size, export_config + model, + nemo_model_config, + inference_tp_size, + inference_pp_size, + tokenizer_vocab_size, + fp8_quantized=False, + fp8_kvcache=False, ): from megatron.core import parallel_state from megatron.core.tensor_parallel.utils import VocabUtility @@ -311,13 +331,19 @@ def dist_model_to_trt_llm_ckpt( prefix, transformer_layer_prefix = get_layer_prefix(sample_state_dict, is_mcore) assert is_mcore, "Only megatron-core inflight model conversion is supported" - export_config.update( - { - "tp_size": tp_size, - "num_kv_heads": nemo_model_config.get('num_query_groups', nemo_model_config['num_attention_heads']), - "convert_on_device": True, - } - ) + export_config = { + "apply_layernorm_1p": nemo_model_config.get("normalization", "") == "layernorm1p", + "tp_size": tp_size, + "split_gated_activation": nemo_model_config.get("activation", "gelu") + in ["swiglu", "geglu", "fast-swiglu", "fast-geglu"], + "num_attention_heads": nemo_model_config["num_attention_heads"], + "num_kv_heads": nemo_model_config.get('num_query_groups', nemo_model_config['num_attention_heads']), + "convert_on_device": True, + "use_attention_nemo_shape": True, + "transpose_weights": True, + "fp8_quantized": fp8_quantized, + "fp8_kvcache": fp8_kvcache, + } starmap_config = { "tp_rank": None, From bcf85e40ac9dc7d4ad901dd7fca16dde5927a638 Mon Sep 17 00:00:00 2001 From: Laplasjan107 Date: Wed, 28 Aug 2024 08:36:18 +0000 Subject: [PATCH 21/21] Apply isort and black reformatting Signed-off-by: Laplasjan107 --- nemo/export/trt_llm/converter/model_converter.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/nemo/export/trt_llm/converter/model_converter.py b/nemo/export/trt_llm/converter/model_converter.py index c9a593caf5d3..6748346a10d0 100755 --- a/nemo/export/trt_llm/converter/model_converter.py +++ b/nemo/export/trt_llm/converter/model_converter.py @@ -124,7 +124,7 @@ def model_to_trtllm_ckpt( inference_pp_size=pipeline_parallel_size, tokenizer_vocab_size=vocab_size, fp8_quantized=fp8_quantized, - fp8_kvcache=fp8_kvcache + fp8_kvcache=fp8_kvcache, ) vocab_size_padded = vocab_size else: @@ -138,7 +138,7 @@ def model_to_trtllm_ckpt( use_parallel_embedding=use_parallel_embedding, decoder_type=decoder_type, fp8_quantized=fp8_quantized, - fp8_kvcache=fp8_kvcache + fp8_kvcache=fp8_kvcache, ) has_lm_head = "lm_head.weight" in weights_dict @@ -261,7 +261,9 @@ def model_to_trtllm_ckpt( if mapping.is_first_pp_rank(): embedding_weight = ( - np.ascontiguousarray(split(weights_dict["transformer.vocab_embedding.weight"], mapping.tp_size, mapping.tp_rank)) + np.ascontiguousarray( + split(weights_dict["transformer.vocab_embedding.weight"], mapping.tp_size, mapping.tp_rank) + ) if use_parallel_embedding else weights_dict["transformer.vocab_embedding.weight"] )