diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 3c73da1c0731..06a876c2b833 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -167,6 +167,8 @@ def export( multiple_profiles: bool = False, gpt_attention_plugin: str = "auto", gemm_plugin: str = "auto", + fp8_quantized: Optional[bool] = None, + fp8_kvcache: Optional[bool] = None, ): """ Exports nemo checkpoints to TensorRT-LLM. @@ -202,6 +204,8 @@ def export( multiple_profiles: (bool): enables multiple profiles feature of TRT-LLM. Default = False gpt_attention_plugin (str): enable the gpt attention plugin. Default = "auto" gemm_plugin (str): enable the gpt plugin. Default = "auto" + fp8_quantized (Optional[bool]): enables exporting to FP8 TRT-LLM checkpoints. If not set, autodetects the type. + fp8_kvcache (Optional[bool]): enables FP8 KV-cache quantization. If not set, autodetects the type. """ if n_gpus is not None: @@ -324,6 +328,8 @@ def export( gpus_per_node=gpus_per_node, use_parallel_embedding=use_parallel_embedding, use_embedding_sharing=use_embedding_sharing, + fp8_quantized=fp8_quantized, + fp8_kvcache=fp8_kvcache, ) for weight_dict, model_config in zip(weights_dicts, model_configs): diff --git a/nemo/export/trt_llm/converter/model_converter.py b/nemo/export/trt_llm/converter/model_converter.py index 60d50316e9ed..6748346a10d0 100755 --- a/nemo/export/trt_llm/converter/model_converter.py +++ b/nemo/export/trt_llm/converter/model_converter.py @@ -15,10 +15,11 @@ import csv import logging -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import numpy as np import tensorrt_llm +import torch from tensorrt_llm._utils import pad_vocab_size from tensorrt_llm.functional import non_gated_version from tensorrt_llm.layers import MoeConfig @@ -78,6 +79,18 @@ def prompt_convert(prompt_config, prompt_weights): return vtokens_embeddings +def determine_quantization_settings( + nemo_model_config, fp8_quantized: Optional[bool] = None, fp8_kvcache: Optional[bool] = None +) -> Tuple[bool, bool]: + is_nemo_quantized = nemo_model_config.get('fp8', False) + if fp8_quantized is None: + fp8_quantized = is_nemo_quantized + if fp8_kvcache is None: + fp8_kvcache = is_nemo_quantized + + return fp8_quantized, fp8_kvcache + + def model_to_trtllm_ckpt( model, nemo_model_config, @@ -91,15 +104,17 @@ def model_to_trtllm_ckpt( use_embedding_sharing: bool = False, use_distributed_convert: bool = False, model_parallel_rank: int = None, - vocab_size: int = None, + vocab_size: Optional[int] = None, + fp8_quantized: Optional[bool] = None, + fp8_kvcache: Optional[bool] = None, ) -> Tuple[List[Dict], List[PretrainedConfig]]: - if nemo_model_config.get("share_embeddings_and_output_weights", False) and not use_embedding_sharing: LOGGER.info( "Found share_embeddings_and_output_weights is True in NeMo config, set use_embedding_sharing = True" ) use_embedding_sharing = True + fp8_quantized, fp8_kvcache = determine_quantization_settings(nemo_model_config, fp8_quantized, fp8_kvcache) # If the model has been sharded with model parallelism, convert the model in a gpu-distributed manner if use_distributed_convert: weights_dict = dist_model_to_trt_llm_ckpt( @@ -108,6 +123,8 @@ def model_to_trtllm_ckpt( inference_tp_size=tensor_parallel_size, inference_pp_size=pipeline_parallel_size, tokenizer_vocab_size=vocab_size, + fp8_quantized=fp8_quantized, + fp8_kvcache=fp8_kvcache, ) vocab_size_padded = vocab_size else: @@ -120,6 +137,8 @@ def model_to_trtllm_ckpt( storage_type=dtype, use_parallel_embedding=use_parallel_embedding, decoder_type=decoder_type, + fp8_quantized=fp8_quantized, + fp8_kvcache=fp8_kvcache, ) has_lm_head = "lm_head.weight" in weights_dict @@ -159,8 +178,8 @@ def model_to_trtllm_ckpt( 'embedding_sharding_dim': 0, 'share_embedding_table': use_embedding_sharing, 'quantization': { - 'quant_algo': None, - 'kv_cache_quant_algo': None, + 'quant_algo': "FP8" if fp8_quantized else None, + 'kv_cache_quant_algo': "FP8" if fp8_kvcache else None, }, 'bias': nemo_model_config.get('bias'), 'apply_query_key_layer_scaling': False, @@ -261,9 +280,9 @@ def model_to_trtllm_ckpt( if mapping.is_last_pp_rank(): if has_lm_head: - weights_dict_local["lm_head.weight"] = np.ascontiguousarray( - split(lm_head_weight, mapping.tp_size, mapping.tp_rank) - ) + weights_dict_local["lm_head.weight"] = split( + lm_head_weight, mapping.tp_size, mapping.tp_rank + ).contiguous() weights_dict_local["transformer.ln_f.weight"] = weights_dict["transformer.ln_f.weight"] ln_f_bias = weights_dict.get("transformer.ln_f.bias") diff --git a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py index db8a66308047..07ac7b334f8e 100644 --- a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py +++ b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py @@ -25,7 +25,7 @@ from tqdm import tqdm from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision -from nemo.export.trt_llm.converter.utils import save_val, split_and_save_weight, weights_dict +from nemo.export.trt_llm.converter.utils import save_scaling_factor, save_val, split_and_save_weight, weights_dict LOGGER = logging.getLogger("NeMo") @@ -94,6 +94,24 @@ def rename_key_dist_ckpt(old_key: str, layer: int): return rename_key(new_key) +def is_scaling_factor(key: str) -> bool: + return "extra_state" in key + + +def load_scaling_factors(model: dict, num_layers: int, export_config: dict) -> dict: + if not export_config.get('fp8_quantized', False): + return {} + + scaling_factors = {} + for key, val in model.items(): + if is_scaling_factor(key): + for layer in range(num_layers): + renamed_key = rename_key_dist_ckpt(key, layer) + scaling_factors = save_scaling_factor(scaling_factors, renamed_key, val[layer], export_config) + + return scaling_factors + + @torch.no_grad() def convert_model_to_trt_llm_ckpt( nemo_model_config, @@ -104,6 +122,8 @@ def convert_model_to_trt_llm_ckpt( decoder_type, use_parallel_embedding, processes, + fp8_quantized=False, + fp8_kvcache=False, ): # if checkpoints files could be found - start preparing output dir @@ -148,6 +168,8 @@ def convert_model_to_trt_llm_ckpt( "use_attention_nemo_shape": True, "transpose_weights": True, "use_parallel_embedding": use_parallel_embedding, + "fp8_quantized": fp8_quantized, + "fp8_kvcache": fp8_kvcache, } # split_factor: in how many parts a TP training node is split @@ -158,7 +180,7 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): if tp_idx == 0 and pp_idx == 0: if has_position_embedding: val = model[get_layer_name("position_embedding", prefix)] - val = torch_to_numpy(val.to(storage_type).cpu()) + val = val.to(storage_type).cpu() model_level_weights["transformer.position_embedding.weight"].append(val) if pp_idx == 0: val = model.get("state_dict", model)[get_layer_name("word_embedding", prefix)] @@ -171,19 +193,19 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): pad_width = vocab_size_padded - vocab_size val = torch.nn.functional.pad(val, (0, 0, 0, pad_width), value=0) - val = torch_to_numpy(val.to(storage_type).cpu()) + val = val.to(storage_type).cpu() model_level_weights["transformer.vocab_embedding.weight"].append(val) if has_lm_head and pp_idx == training_pp_size - 1: val = model.get("state_dict", model)[get_layer_name("output_layer", prefix)] - val = torch_to_numpy(val.to(storage_type).cpu()) + val = val.to(storage_type).cpu() model_level_weights["lm_head.weight"].append(val) weights_dict = {} - tp_rank = 0 handle_model_level_weights(model, 0, 0) model = extract_layers_with_prefix(model, transformer_layer_prefix) + scaling_factors = load_scaling_factors(model, num_layers, export_config) starmap_args = [] for key, val in model.items(): @@ -202,6 +224,7 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): storage_type, None, export_config, + scaling_factors, ) ) else: @@ -219,6 +242,7 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): storage_type, None, export_config, + scaling_factors, ) ) @@ -236,9 +260,10 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): weights_dict.update(weights_dict_local) for key, values in model_level_weights.items(): - model_level_weights[key] = np.concatenate(values, axis=0) + model_level_weights[key] = torch.concatenate(values, axis=0) weights_dict[key] = model_level_weights[key] + weights_dict.update(scaling_factors) return weights_dict @@ -269,6 +294,8 @@ def dist_model_to_trt_llm_ckpt( inference_tp_size, inference_pp_size, tokenizer_vocab_size, + fp8_quantized=False, + fp8_kvcache=False, ): from megatron.core import parallel_state from megatron.core.tensor_parallel.utils import VocabUtility @@ -314,6 +341,8 @@ def dist_model_to_trt_llm_ckpt( "convert_on_device": True, "use_attention_nemo_shape": True, "transpose_weights": True, + "fp8_quantized": fp8_quantized, + "fp8_kvcache": fp8_kvcache, } starmap_config = { diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index eab17167cbd5..3f9f2a31a307 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -13,6 +13,7 @@ # limitations under the License. +from typing import List, Optional, Tuple, Union import numpy as np import tensorrt_llm import torch @@ -31,6 +32,35 @@ "falcon": 'FalconForCausalLM', } +post_layernorm_keys = [ + "post_attention_layernorm.weight", + "post_attention_layernorm.bias", + "post_self_attn_layernorm.weight", +] +mlp_proj_bias_keys = ["mlp.linear_fc2.bias", "mlp.dense_4h_to_h.bias"] +attention_dense_bias_keys = ["attention.linear_proj.bias", "attention.dense.bias"] +input_layernorm_keys = ["input_layernorm.weight", "input_layernorm.bias"] +pre_layernorm_keys = ["pre_mlp_layernorm.weight", "pre_mlp_layernorm.bias"] +attention_dense_weight_keys = ["attention.linear_proj.weight", "attention.dense.weight"] +mlp_proj_weight_keys = ["mlp.linear_fc2.weight", "mlp.dense_4h_to_h.weight"] +mlp_fc_keys = ["mlp.dense_h_to_4h.weight", "mlp.dense_h_to_4h.bias", "mlp.linear_fc1.weight", "mlp.linear_fc1.bias"] +attention_qkv_bias_keys = ["attention.query_key_value.bias", "attention.linear_qkv.bias"] +attention_qkv_weight_keys = ["attention.query_key_value.weight", "attention.linear_qkv.weight"] +mlp_router_keys = ["mlp.router.weight"] +mlp_fc_expert_keys = ["experts.linear_fc1.weight"] +mlp_proj_experts_keys = ["experts.linear_fc2.weight"] +final_layernorm_keys = ["final_layernorm.weight", "final_layernorm.bias"] +mlp_dense_2_keys = ["mlp.dense_h_to_4h_2.weight", "mlp.dense_h_to_4h_2.bias"] +attention_not_mapped_keys = [ + "attention.query.weight", + "attention.query.bias", + "attention.key_value.weight", + "attention.key_value.bias", +] + +weight_scaling_suffix = '.weights_scaling_factor' +activation_scaling_suffix = '.activation_scaling_factor' + def save_val(val, dir, key, tp_num=None): suffix = "" if tp_num is None else f".{tp_num}.bin" @@ -174,10 +204,130 @@ def write_int8(vals, dir, base_key, split_dim, tp_rank, split_factor, kv_cache_o save_val(vals[save_key], dir, f"{base_key}.{save_key}") +def get_suffix(key: str) -> str: + return '.' + key.split('.')[-1] + + +def get_trt_llm_prefix(key: str) -> str: + layer_num = key.split(".")[1] + return f'transformer.layers.{layer_num}' + + +def any_word_in_key(key: str, words: List[str]) -> bool: + return any([word in key for word in words]) + + +def sequential_key_map(key: str, mapping: List[Tuple[List[str], str]]) -> Optional[str]: + for keywords, mapped in mapping: + if any_word_in_key(key, keywords): + return mapped + + return None + + +def get_trt_llm_infix(key: str) -> Optional[str]: + mapping = [ + (post_layernorm_keys, '.post_layernorm'), + (mlp_proj_bias_keys, '.mlp.proj'), + (attention_dense_bias_keys, '.attention.dense'), + (input_layernorm_keys, '.input_layernorm'), + (pre_layernorm_keys, '.post_layernorm'), + (attention_dense_weight_keys, '.attention.dense'), + (mlp_proj_weight_keys, '.mlp.proj'), + (mlp_fc_keys, '.mlp.fc'), + (attention_qkv_bias_keys + attention_qkv_weight_keys, '.attention.qkv'), + (mlp_router_keys, '.mlp.router'), + (mlp_fc_expert_keys, '.mlp.fc'), + (mlp_proj_experts_keys, '.mlp.proj'), + ] + return sequential_key_map(key, mapping) + + +def get_trt_llm_keyname(key: str) -> str: + if any_word_in_key(key, final_layernorm_keys): + return key.replace("final_layernorm", "transformer.ln_f") + + if infix := get_trt_llm_infix(key): + return get_trt_llm_prefix(key) + infix + get_suffix(key) + + return key + + +def is_scaling_factor(key: str) -> bool: + return "scale_fwd" in key + + +def get_scaling_factor_keys(key: str) -> Tuple[Tuple[str, str], Tuple[str, str]]: + # Reuses existing mapping of NeMo -> TRT LLM weights key via swapping suffixes + corresponding_weight_key = '.'.join(key.split('.')[:-2]) + '.weight' + corresponding_trt_llm_weight_key = get_trt_llm_keyname(corresponding_weight_key) + base_key = '.'.join(corresponding_trt_llm_weight_key.split('.')[:-1]) + + weight_scale = base_key + weight_scaling_suffix + activation_scale = base_key + activation_scaling_suffix + keys = (weight_scale, activation_scale) + + layer_prefix = get_trt_llm_prefix(key) + mapped_key = layer_prefix + '.mlp.gate' + gate_activation = mapped_key + activation_scaling_suffix + gate_weight = mapped_key + weight_scaling_suffix + gate_keys = (gate_activation, gate_weight) + + return keys, gate_keys + + +def save_scaling_factor(scaling_factors: dict, key: str, val: torch.Tensor, config: dict): + if not is_scaling_factor(key): + return scaling_factors + + activation_factor = torch_to_numpy(1 / val[0].view(1)) + weights_factor = torch_to_numpy(1 / val[1].view(1)) + + (weights_key, activation_key), gate_keys = get_scaling_factor_keys(key) + scaling_factors[activation_key] = activation_factor + scaling_factors[weights_key] = weights_factor + + split_gated_activation = config.get("split_gated_activation", False) + if split_gated_activation and any_word_in_key(key, ["mlp.dense_h_to_4h", "mlp.linear_fc1"]): + (gate_activation_key, gate_weight_key) = gate_keys + scaling_factors[gate_activation_key] = activation_factor + scaling_factors[gate_weight_key] = weights_factor + + return scaling_factors + + +def cast_val_datatype(vals, trt_llm_key, storage_type, is_fp8_model, scaling_factors): + if not is_fp8_model: + return [val.to(storage_type) for val in vals] + + fp8_storage_type = torch.float8_e4m3fn + quantized_keys = [ + k.split(weight_scaling_suffix)[0] for k in scaling_factors.keys() if k.endswith(weight_scaling_suffix) + ] + for k in quantized_keys: + if k in trt_llm_key: + storage_type = fp8_storage_type + scale = scaling_factors[k + weight_scaling_suffix] + vals = [val.to(torch.float32) / scale for val in vals] + break + + return [val.to(storage_type) for val in vals] + + +def split_val_gate(vals: List[np.ndarray], convert_on_device: bool): + if convert_on_device: + return [[n] for n in torch.chunk(vals[0], 2, axis=-1)] + + splits = [np.split(val, 2, axis=-1) for val in vals] + return list(zip(*splits)) + + # Note: in multi_query_mode, only query heads are split between multiple GPUs, while key/value head # are not split as there is only one head per key/value. @torch.no_grad() -def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_type, act_range, config): +def split_and_save_weight( + tp_rank, saved_dir, split_factor, key, vals, storage_type, act_range, config, scaling_factors={} +): use_attention_nemo_shape = config.get("use_attention_nemo_shape", False) split_gated_activation = config.get("split_gated_activation", False) num_attention_heads = config.get("num_attention_heads", 0) @@ -187,12 +337,11 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t num_kv_heads = config.get("num_kv_heads", num_attention_heads) size_per_head = config.get("kv_channels", None) convert_on_device = config.get("convert_on_device", False) - + is_fp8_model = config.get("fp8_quantized", False) + use_fp8_kv_cache = config.get("fp8_kvcache", False) save_int8 = int8_outputs == "all" or int8_outputs == "kv_cache_only" - layer_num = key.split(".")[1] - layer_prefix = f'transformer.layers.{layer_num}' - + trt_llm_key = get_trt_llm_keyname(key) if not isinstance(vals, list): vals = [vals] @@ -201,138 +350,82 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t if "layernorm.weight" in key and config.get("apply_layernorm_1p", False): vals = [val.float() + 1.0 for val in vals] - vals = [val.to(storage_type) for val in vals] + vals = cast_val_datatype(vals, trt_llm_key, storage_type, is_fp8_model, scaling_factors) if convert_on_device: assert len(vals) == 1 # Should only convert a single device param per call assert torch.is_tensor(vals[0]) elif torch.is_tensor(vals[0]): vals = [torch_to_numpy(val.cpu()) for val in vals] - if ( - "input_layernorm.weight" in key - or "input_layernorm.bias" in key - or "pre_mlp_layernorm.weight" in key - or "pre_mlp_layernorm.bias" in key - or "attention.dense.bias" in key - or "attention.linear_proj.bias" in key - or "post_attention_layernorm.weight" in key - or "post_attention_layernorm.bias" in key - or "post_self_attn_layernorm.weight" in key - or "mlp.dense_4h_to_h.bias" in key - or "mlp.linear_fc2.bias" in key - or "final_layernorm.weight" in key - or "final_layernorm.bias" in key - ): + if any_word_in_key( + key, + input_layernorm_keys + + pre_layernorm_keys + + attention_dense_bias_keys + + post_layernorm_keys + + mlp_proj_bias_keys + + final_layernorm_keys, + ) and (tp_rank == 0 or convert_on_device): # shared weights, only need to convert the weights of rank 0 - if "post_self_attn_layernorm" in key or "post_attention_layernorm" in key: - if key.endswith('weight'): - key = f'{layer_prefix}.post_layernorm.weight' - else: - key = f'{layer_prefix}.post_layernorm.bias' - elif "mlp.linear_fc2.bias" in key or "mlp.dense_4h_to_h.bias" in key: - key = f'{layer_prefix}.mlp.proj.bias' - elif "attention.linear_proj.bias" in key or "attention.dense.bias" in key: - key = f'{layer_prefix}.attention.dense.bias' - elif "final_layernorm" in key: - key = key.replace("final_layernorm", "transformer.ln_f") - elif "input_layernorm" in key: - if key.endswith('weight'): - key = f'{layer_prefix}.input_layernorm.weight' - else: - key = f'{layer_prefix}.input_layernorm.bias' - elif "pre_mlp_layernorm" in key: - if key.endswith('weight'): - key = f'{layer_prefix}.post_layernorm.weight' - else: - key = f'{layer_prefix}.post_layernorm.bias' - if tp_rank == 0 or convert_on_device: - save_val(vals[0], saved_dir, key) - - elif ( - "attention.dense.weight" in key - or "mlp.dense_4h_to_h.weight" in key - or "attention.linear_proj.weight" in key - or "mlp.linear_fc2.weight" in key - ): - if "attention.linear_proj.weight" in key or "attention.dense.weight" in key: - key = f'{layer_prefix}.attention.dense.weight' - elif "mlp.linear_fc2.weight" in key or "mlp.dense_4h_to_h.weight" in key: - key = f'{layer_prefix}.mlp.proj.weight' + save_val(vals[0], saved_dir, trt_llm_key) + elif any_word_in_key(key, attention_dense_weight_keys + mlp_proj_weight_keys): if convert_on_device: - save_val(vals[0], saved_dir, key) + save_val(vals[0], saved_dir, trt_llm_key) else: cat_dim = 0 val = np.concatenate(vals, axis=cat_dim) split_vals = np.split(val, split_factor, axis=cat_dim) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) + save_split(split_vals, saved_dir, trt_llm_key, tp_rank, split_factor) if act_range is not None and int8_outputs == "all": - base_key = key.replace(".weight", "") + base_key = trt_llm_key.replace(".weight", "") vals_i8 = generate_int8(val, act_range, multi_query_mode=multi_query_mode) write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor) - elif ( - "mlp.dense_h_to_4h.weight" in key - or "mlp.dense_h_to_4h.bias" in key - or "mlp.linear_fc1.weight" in key - or "mlp.linear_fc1.bias" in key - ): - if key.endswith("weight"): - key = f'{layer_prefix}.mlp.fc.weight' - else: - key = f'{layer_prefix}.mlp.fc.bias' - + elif any_word_in_key(key, mlp_fc_keys): if split_gated_activation: - if convert_on_device: - vals, gates = [[n] for n in torch.chunk(vals[0], 2, axis=-1)] - else: - splits = [np.split(val, 2, axis=-1) for val in vals] - vals, gates = list(zip(*splits)) + vals, gates = split_val_gate(vals, convert_on_device) if convert_on_device: - save_val(vals[0], saved_dir, key) + save_val(vals[0], saved_dir, trt_llm_key) else: cat_dim = -1 val = np.concatenate(vals, axis=cat_dim) split_vals = np.split(val, split_factor, axis=cat_dim) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) + save_split(split_vals, saved_dir, trt_llm_key, tp_rank, split_factor) if act_range is not None and int8_outputs == "all": - base_key = key.replace(".weight", "") + base_key = trt_llm_key.replace(".weight", "") vals_i8 = generate_int8(val, act_range, multi_query_mode=multi_query_mode) write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor) if split_gated_activation: assert not save_int8 - if key.endswith("weight"): - key = f'{layer_prefix}.mlp.gate.weight' - else: - key = f'{layer_prefix}.mlp.gate.bias' - + layer_prefix = get_trt_llm_prefix(key) + gate_key = layer_prefix + '.mlp.gate' + get_suffix(trt_llm_key) if convert_on_device: - save_val(gates[0], saved_dir, key) + save_val(gates[0], saved_dir, gate_key) else: gate = np.concatenate(gates, axis=cat_dim) split_vals = np.split(gate, split_factor, axis=cat_dim) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) + save_split(split_vals, saved_dir, gate_key, tp_rank, split_factor) - elif "mlp.dense_h_to_4h_2.weight" in key or "mlp.dense_h_to_4h_2.bias" in key: + elif any_word_in_key(key, mlp_dense_2_keys): if convert_on_device: - save_val(vals[0], saved_dir, key) + save_val(vals[0], saved_dir, trt_llm_key) else: cat_dim = -1 val = np.concatenate(vals, axis=cat_dim) split_vals = np.split(val, split_factor, axis=cat_dim) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) + save_split(split_vals, saved_dir, trt_llm_key, tp_rank, split_factor) if act_range is not None and int8_outputs == "all": - base_key = key.replace(".weight", "") + base_key = trt_llm_key.replace(".weight", "") vals_i8 = generate_int8(val, act_range, multi_query_mode=multi_query_mode) write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor) - elif "attention.query_key_value.bias" in key or "attention.linear_qkv.bias" in key: - key = f'{layer_prefix}.attention.qkv.bias' + elif any_word_in_key(key, attention_qkv_bias_keys): qkv_hidden_dim = vals[0].shape[0] size_per_head = qkv_hidden_dim // (num_attention_heads + 2 * num_kv_heads) q_num = num_attention_heads // num_kv_heads @@ -349,7 +442,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t if convert_on_device: qkv = torch.split(val, [q_num, 1, 1], dim=1) split_vals = torch.concatenate([qkv[0].reshape(-1), qkv[1].reshape(-1), qkv[2].reshape(-1)], dim=1) - save_val(split_vals, saved_dir, key) + save_val(split_vals, saved_dir, trt_llm_key) else: qkv = np.split(val, [q_num, q_num + 1], axis=1) q_split = np.split(qkv[0], split_factor, axis=0) @@ -361,10 +454,9 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t np.concatenate([q_split[i].reshape(-1), k_split[i].reshape(-1), v_split[i].reshape(-1)], axis=0) for i in range(split_factor) ] - save_split(split_vals, saved_dir, key, tp_rank, split_factor) + save_split(split_vals, saved_dir, trt_llm_key, tp_rank, split_factor) - elif "attention.query_key_value.weight" in key or "attention.linear_qkv.weight" in key: - key = f'{layer_prefix}.attention.qkv.weight' + elif any_word_in_key(key, attention_qkv_weight_keys): assert use_attention_nemo_shape, "Only support NEMO shape for QKV weights" hidden_dim = vals[0].shape[0] if size_per_head is None: @@ -380,7 +472,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t split_vals = torch.concatenate( [qkv[0].reshape(hidden_dim, -1), qkv[1].reshape(hidden_dim, -1), qkv[2].reshape(hidden_dim, -1)], dim=1 ) - save_val(split_vals, saved_dir, key) + save_val(split_vals, saved_dir, trt_llm_key) else: len_vals = len(vals) val = np.concatenate(vals, axis=1) @@ -414,10 +506,10 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t ) for i in range(split_factor) ] - save_split(split_vals, saved_dir, key, tp_rank, split_factor) + save_split(split_vals, saved_dir, trt_llm_key, tp_rank, split_factor) if save_int8: - base_key = key.replace(".weight", "") + base_key = trt_llm_key.replace(".weight", "") vals_i8 = generate_int8(val, act_range, is_qkv=True, multi_query_mode=multi_query_mode) write_int8( vals_i8, @@ -428,18 +520,20 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t split_factor, kv_cache_only=int8_outputs == "kv_cache_only", ) - elif ( - "attention.query.weight" in key - or "attention.query.bias" in key - or "attention.key_value.weight" in key - or "attention.key_value.bias" in key - ): + + if use_fp8_kv_cache: + base_key = trt_llm_key.replace('.qkv.weight', '') + scaling_factor = np.array([1.0], dtype=np.float32) + save_val(scaling_factor, dir, base_key + '.kv_cache_scaling_factor') + + elif any_word_in_key(key, attention_not_mapped_keys): pass - elif "mlp.router.weight" in key: + + elif any_word_in_key(key, mlp_router_keys): val = np.concatenate(vals, axis=1) - key = f'{layer_prefix}.mlp.router.weight' - save_val(val, saved_dir, key) - elif "experts.linear_fc1.weight" in key: + save_val(val, saved_dir, trt_llm_key) + + elif any_word_in_key(key, mlp_fc_expert_keys): cat_dim = -1 val = np.concatenate(vals, axis=cat_dim) w1, w3 = np.split(val, 2, axis=1) @@ -449,15 +543,13 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t split_w3s = np.split(w3, split_factor, axis=1) split_vals = [np.concatenate(item, axis=1) for item in zip(split_w3s, split_w1s)] - key = f'{layer_prefix}.mlp.fc.weight' - save_expert_split(split_vals, saved_dir, key, tp_rank, split_factor) + save_expert_split(split_vals, saved_dir, trt_llm_key, tp_rank, split_factor) - elif "experts.linear_fc2.weight" in key: + elif any_word_in_key(key, mlp_proj_experts_keys): cat_dim = -1 val = np.concatenate(vals, axis=cat_dim) split_vals = np.split(val, split_factor, axis=cat_dim) - key = f'{layer_prefix}.mlp.proj.weight' - save_expert_split(split_vals, saved_dir, key, tp_rank, split_factor) + save_expert_split(split_vals, saved_dir, trt_llm_key, tp_rank, split_factor) else: print(f"[WARNING] {key} not handled by converter") @@ -465,14 +557,16 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t return weights_dict -def split(v, tp_size, idx, dim=0): +def split(v: Union[np.ndarray, torch.Tensor], tp_size: int, idx: int, dim: int = 0): """Splits the np tensor v on dim and return the idx's slice.""" if tp_size == 1: return v - if len(v.shape) == 1: - return np.ascontiguousarray(np.split(v, tp_size)[idx]) - else: - return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx]) + + dim = dim if len(v.shape) != 1 else 0 + if torch.is_tensor(v): + return torch.split(v, v.size(dim) // tp_size, dim=dim)[idx].contiguous() + + return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx]) def init_model_parallel_from_nemo(reshard_model): diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index 1b711b5edbf3..14f02b06b71b 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -17,16 +17,18 @@ import json import logging import os +from io import BytesIO from pathlib import Path -from typing import Dict, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np import tensorstore # This is important even though not used. Otherwise zarr raises error. import torch import yaml import zarr -from tensorrt_llm._utils import np_bfloat16 -from torch.distributed.checkpoint import FileSystemReader, TensorStorageMetadata +from tensorrt_llm._utils import np_bfloat16, str_dtype_to_torch +from torch.distributed.checkpoint import FileSystemReader +from torch.distributed.checkpoint.metadata import BytesStorageMetadata, TensorStorageMetadata from torch.distributed.checkpoint.state_dict_loader import load_state_dict from transformers import AutoTokenizer, PreTrainedTokenizer @@ -65,7 +67,65 @@ def __init__(self, path: Union[Path, TarPath]) -> None: self.path = path # overwrites path set in super().__init__ call -def load_sharded_metadata_torch_dist(checkpoint_dir: Union[Path, TarPath], torch_tensor=True): +def get_extra_state_key(state_dict: dict) -> Optional[str]: + for key in state_dict.keys(): + if '_extra_state/' in key: + return key + return None + + +def unpack_extra_state_key(key: str) -> Tuple[str, int]: + basename = key.split('/')[0] + size = int(key.split('/')[1].split('_')[-1]) + return basename, size + + +def clear_loaded_extra_states(state_dict: dict, basename: str) -> dict: + """The scaling factors are originally saved to state_dict under the keynames 'basename/*' + The standardized representation is saved to 'basename.*'. This function clears the former from the state. + """ + to_remove = [k for k in state_dict.keys() if basename + '/' in k] + for key in to_remove: + state_dict.pop(key) + return state_dict + + +def retrieve_scale(bytes: BytesIO) -> Optional[torch.Tensor]: + bytes.seek(0) + extra_state = torch.load(bytes) + if not extra_state or 'scale_fwd' not in extra_state: + return None + return extra_state['scale_fwd'].cpu() + + +def load_scales_from_bytes(bytes_list: List[BytesIO]) -> Optional[torch.Tensor]: + scales = [] + for bytes in bytes_list: + scale = retrieve_scale(bytes) + if scale is None: + return None + scales.append(scale) + return torch.stack(scales) + + +def load_scaling_factors(state_dict: dict, basename: str, size: int) -> Optional[torch.Tensor]: + keynames = [f'{basename}/shard_{layer}_{size}' for layer in range(size)] + bytes_list = [state_dict[keyname][0] for keyname in keynames] + return load_scales_from_bytes(bytes_list) + + +def standarize_distributed_scaling_factors(state_dict: dict) -> dict: + while key := get_extra_state_key(state_dict): + basename, size = unpack_extra_state_key(key) + scaling_factors = load_scaling_factors(state_dict, basename, size) + if scaling_factors is not None: + state_dict[basename + '.scale_fwd'] = scaling_factors + state_dict = clear_loaded_extra_states(state_dict, basename) + + return state_dict + + +def load_sharded_metadata_torch_dist(checkpoint_dir: Union[Path, TarPath], torch_tensor: bool = True): fs_reader = TarFileSystemReader(checkpoint_dir) metadata = fs_reader.read_metadata() @@ -74,11 +134,17 @@ def load_sharded_metadata_torch_dist(checkpoint_dir: Union[Path, TarPath], torch for k, tp in metadata.state_dict_metadata.items() if isinstance(tp, TensorStorageMetadata) } + + state_dict.update( + {k: {} for k, tp in metadata.state_dict_metadata.items() if isinstance(tp, BytesStorageMetadata)} + ) + load_state_dict( state_dict, storage_reader=fs_reader, no_dist=True, ) + state_dict = standarize_distributed_scaling_factors(state_dict) if not torch_tensor: for k, v in state_dict.items(): @@ -89,24 +155,61 @@ def load_sharded_metadata_torch_dist(checkpoint_dir: Union[Path, TarPath], torch return state_dict +def get_sharded_file(dir: dict, layer_number: int) -> Optional[os.PathLike]: + pt_file_list = list(dir.glob(f'shard_{layer_number}_*.pt')) + if pt_file_list == []: + return None + return pt_file_list[0] + + +def load_sharded_pickle_extra_state_scale(dir: Union[Path, TarPath]): + def _get_layer_number(file): + basename = os.path.basename(str(file)) + return int(basename.split('_')[1]) + + pt_files = list(dir.glob('shard_*_*.pt')) + bytes_list = [] + for file in sorted(pt_files, key=_get_layer_number): + with file.open('rb') as opened_file: + bytes_list.append(torch.load(opened_file)) + + return load_scales_from_bytes(bytes_list) + + +def contains_extra_states(subdir: Union[Path, TarPath]): + return list(subdir.glob('shard_0_*.pt')) != [] + + +def load_extra_state_from_pickle(sharded_state_dict: dict, subdir: Union[Path, TarPath]): + scales = load_sharded_pickle_extra_state_scale(subdir) + if scales is not None: + key = subdir.name + '.scale_fwd' + sharded_state_dict[key] = scales + + return sharded_state_dict + + def load_sharded_metadata_zarr(checkpoint_dir: Union[Path, TarPath], torch_tensor=True): sharded_state_dict = {} for subdir in checkpoint_dir.iterdir(): - if not subdir.is_dir() or not (subdir / '.zarray').exists(): + if not subdir.is_dir(): continue - key = subdir.name - - zstore = ZarrPathStore(subdir) - arr = zarr.open(zstore, 'r') - if torch_tensor: - # sharded_state_dict[key] = torch.from_numpy(arr[:].astype("float32")).to(dtype=torch.bfloat16) - if arr.dtype.name == "bfloat16": - sharded_state_dict[key] = torch.from_numpy(arr[:].view(np.int16)).view(torch.bfloat16) + if contains_extra_states(subdir): + sharded_state_dict = load_extra_state_from_pickle(sharded_state_dict, subdir) + elif (subdir / '.zarray').exists(): + key = subdir.name + zstore = ZarrPathStore(subdir) + arr = zarr.open(zstore, 'r') + + if torch_tensor: + # sharded_state_dict[key] = torch.from_numpy(arr[:].astype("float32")).to(dtype=torch.bfloat16) + if arr.dtype.name == "bfloat16": + sharded_state_dict[key] = torch.from_numpy(arr[:].view(np.int16)).view(torch.bfloat16) + else: + sharded_state_dict[key] = torch.from_numpy(arr[:]).view(str_dtype_to_torch(arr.dtype.name)) else: - sharded_state_dict[key] = torch.from_numpy(arr[:]) - else: - sharded_state_dict[key] = arr[:] + sharded_state_dict[key] = arr[:] return sharded_state_dict diff --git a/scripts/export/export_to_trt_llm.py b/scripts/export/export_to_trt_llm.py index a9b9d92c172b..3f5924fde80c 100644 --- a/scripts/export/export_to_trt_llm.py +++ b/scripts/export/export_to_trt_llm.py @@ -15,12 +15,17 @@ import argparse import logging import sys +from typing import Optional from nemo.export.tensorrt_llm import TensorRTLLM LOGGER = logging.getLogger("NeMo") +class UsageError(Exception): + pass + + def get_args(argv): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -107,8 +112,37 @@ def get_args(argv): 'It is used to compute the workspace size of lora plugin.', ) parser.add_argument("-dm", "--debug_mode", default=False, action='store_true', help="Enable debug mode") + parser.add_argument( + "-fp8", + "--export_fp8_quantized", + default="auto", + type=str, + help="Enables exporting to a FP8-quantized TRT LLM checkpoint", + ) + parser.add_argument( + "-kv_fp8", + "--use_fp8_kv_cache", + default="auto", + type=str, + help="Enables exporting with FP8-quantizatized KV-cache", + ) args = parser.parse_args(argv) + + def str_to_bool(name: str, s: str, optional: bool = False) -> Optional[bool]: + s = s.lower() + true_strings = ["true", "1"] + false_strings = ["false", "0"] + if s in true_strings: + return True + if s in false_strings: + return False + if optional and s == 'auto': + return None + raise UsageError(f"Invalid boolean value for argument --{name}: '{s}'") + + args.export_fp8_quantized = str_to_bool("export_fp8_quantized", args.export_fp8_quantized, optional=True) + args.use_fp8_kv_cache = str_to_bool("use_fp8_kv_cache", args.use_fp8_kv_cache, optional=True) return args @@ -153,6 +187,8 @@ def nemo_export_trt_llm(argv): use_lora_plugin=args.use_lora_plugin, lora_target_modules=args.lora_target_modules, max_lora_rank=args.max_lora_rank, + fp8_quantized=args.export_fp8_quantized, + fp8_kvcache=args.use_fp8_kv_cache, ) LOGGER.info("Export is successful.") diff --git a/tests/export/nemo_export.py b/tests/export/nemo_export.py index 557d6c07613d..ecaf198a0c07 100644 --- a/tests/export/nemo_export.py +++ b/tests/export/nemo_export.py @@ -242,6 +242,8 @@ def run_inference( test_deployment=False, test_data_path=None, save_trt_engine=False, + fp8_quantized=False, + fp8_kvcache=False, ) -> Tuple[Optional[FunctionalResult], Optional[AccuracyResult]]: if Path(checkpoint_path).exists(): if tp_size > torch.cuda.device_count(): @@ -325,6 +327,8 @@ def run_inference( lora_target_modules=lora_target_modules, max_num_tokens=max_num_tokens, use_embedding_sharing=use_embedding_sharing, + fp8_quantized=fp8_quantized, + fp8_kvcache=fp8_kvcache, ) if ptuning: @@ -452,6 +456,8 @@ def run_existing_checkpoints( test_data_path=None, save_trt_engine=False, in_framework=False, + fp8_quantized=False, + fp8_kvcache=False, ) -> Tuple[Optional[FunctionalResult], Optional[AccuracyResult]]: if tp_size > torch.cuda.device_count(): print("Skipping the test due to not enough number of GPUs") @@ -530,6 +536,8 @@ def run_existing_checkpoints( test_deployment=test_deployment, test_data_path=test_data_path, save_trt_engine=save_trt_engine, + fp8_quantized=fp8_quantized, + fp8_kvcache=fp8_kvcache, ) @@ -748,16 +756,33 @@ def get_args(): type=float, help="GPU memory utilization percentage for vLLM.", ) + parser.add_argument( + "-fp8", + "--export_fp8_quantized", + default="auto", + type=str, + help="Enables exporting to a FP8-quantized TRT LLM checkpoint", + ) + parser.add_argument( + "-kv_fp8", + "--use_fp8_kv_cache", + default="auto", + type=str, + help="Enables exporting with FP8-quantizatized KV-cache", + ) args = parser.parse_args() - def str_to_bool(name: str, s: str) -> bool: + def str_to_bool(name: str, s: str, optional: bool = False) -> Optional[bool]: + s = s.lower() true_strings = ["true", "1"] false_strings = ["false", "0"] - if s.lower() in true_strings: + if s in true_strings: return True - if s.lower() in false_strings: + if s in false_strings: return False + if optional and s == 'auto': + return None raise UsageError(f"Invalid boolean value for argument --{name}: '{s}'") args.test_cpp_runtime = str_to_bool("test_cpp_runtime", args.test_cpp_runtime) @@ -768,6 +793,8 @@ def str_to_bool(name: str, s: str) -> bool: args.use_vllm = str_to_bool("use_vllm", args.use_vllm) args.use_parallel_embedding = str_to_bool("use_parallel_embedding", args.use_parallel_embedding) args.in_framework = str_to_bool("in_framework", args.in_framework) + args.export_fp8_quantized = str_to_bool("export_fp8_quantized", args.export_fp8_quantized, optional=True) + args.use_fp8_kv_cache = str_to_bool("use_fp8_kv_cache", args.use_fp8_kv_cache, optional=True) return args @@ -821,6 +848,8 @@ def run_inference_tests(args): test_data_path=args.test_data_path, save_trt_engine=args.save_trt_engine, in_framework=args.in_framework, + fp8_quantized=args.export_fp8_quantized, + fp8_kvcache=args.use_fp8_kv_cache, ) tps = tps * 2 @@ -877,6 +906,8 @@ def run_inference_tests(args): test_cpp_runtime=args.test_cpp_runtime, test_data_path=args.test_data_path, save_trt_engine=args.save_trt_engine, + fp8_quantized=args.export_fp8_quantized, + fp8_kvcache=args.use_fp8_kv_cache, ) tps = tps * 2 @@ -940,5 +971,7 @@ def optional_bool_to_pass_fail(b: Optional[bool]): run_inference_tests(args) except UsageError as e: LOGGER.error(f"{e}") + raise e except argparse.ArgumentError as e: LOGGER.error(f"{e}") + raise e