diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 4aae83efd6f6..f64506d42d5c 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -176,6 +176,7 @@ def export( multiple_profiles: bool = False, gpt_attention_plugin: str = "auto", gemm_plugin: str = "auto", + use_mcore_path: bool = False, reduce_fusion: bool = True, fp8_quantized: Optional[bool] = None, fp8_kvcache: Optional[bool] = None, @@ -213,11 +214,11 @@ def export( multiple_profiles: (bool): enables multiple profiles feature of TRT-LLM. Default = False gpt_attention_plugin (str): enable the gpt attention plugin. Default = "auto" gemm_plugin (str): enable the gpt plugin. Default = "auto" + use_mcore_path (bool) : Use the more recent mcore path for export reduce_fusion (bool): enables fusing extra kernels after custom TRT-LLM allReduce fp8_quantized (Optional[bool]): enables exporting to FP8 TRT-LLM checkpoints. If not set, autodetects the type. fp8_kvcache (Optional[bool]): enables FP8 KV-cache quantization. If not set, autodetects the type. """ - if n_gpus is not None: warnings.warn( "Parameter n_gpus is deprecated and will be removed in the next release. " @@ -326,53 +327,169 @@ def export( "Supported model types are: {1}.".format(model_type, self.get_supported_models_list) ) - if model_type == "gpt" or model_type == "starcoder": - model_type = "gptnext" + model, model_configs, self.tokenizer = load_nemo_model(nemo_checkpoint_path, nemo_export_dir) + if use_mcore_path: + from megatron.core.export.data_type import DataType + from megatron.core.export.export_config import ExportConfig + from megatron.core.export.model_type import ModelType + from megatron.core.export.trtllm.model_to_trllm_mapping.default_conversion_dict import ( + DEFAULT_CONVERSION_DICT, + ) + from megatron.core.export.trtllm.trtllm_helper import TRTLLMHelper + from megatron.core.transformer.transformer_config import TransformerConfig + from tensorrt_llm.layers import MoeConfig + + def get_transformer_config(nemo_model_config): + normalization = nemo_model_config.get('normalization', 'layernorm') + transformer_config_normalization = 'LayerNorm' + layernorm_zero_centered_gamma = False + if normalization == 'layernorm1p': + layernorm_zero_centered_gamma = True + elif normalization == 'rmsnorm': + transformer_config_normalization = 'RMSNorm' + + conf = TransformerConfig( + num_layers=nemo_model_config.get('num_layers'), + moe_router_topk=nemo_model_config.get('moe_router_topk', 0), + num_attention_heads=nemo_model_config.get('num_attention_heads'), + num_query_groups=nemo_model_config.get( + 'num_query_groups', nemo_model_config['num_attention_heads'] + ), + kv_channels=nemo_model_config.get("kv_channels", None), + hidden_size=nemo_model_config.get('hidden_size'), + ffn_hidden_size=nemo_model_config.get('ffn_hidden_size'), + layernorm_epsilon=nemo_model_config.get('layernorm_epsilon'), + add_bias_linear=nemo_model_config.get('bias'), + num_moe_experts=nemo_model_config.get('num_moe_experts', None), + normalization=transformer_config_normalization, + layernorm_zero_centered_gamma=layernorm_zero_centered_gamma, + ) - if model_type == "mixtral": - model_type = "llama" + return conf + + # We build the transformer config using the nemo model config. + transformer_config = get_transformer_config(model_configs) + input_model_type = getattr(ModelType, model_type) + + # MCore export supports some default conversion dictionaries + mcore_model_conversion_dict = DEFAULT_CONVERSION_DICT[input_model_type] + # All Mcore conversion dicts start with "decoder.layers.4.blah.blah" , while nemo models start with "model.decoder.layers.4.blahblah". so we append model. to the keys + nemo_model_conversion_dict = { + f'model.{key}': value for key, value in mcore_model_conversion_dict.items() + } + + trtllm_helper = TRTLLMHelper( + transformer_config=transformer_config, + model_type=input_model_type, + trtllm_conversion_dict=nemo_model_conversion_dict, + position_embedding_type=model_configs.get('position_embedding_type'), + max_position_embeddings=model_configs.get('max_position_embeddings'), + rotary_percentage=model_configs.get('rotary_percentage', 1.0), + rotary_base=model_configs.get('rotary_base', 10000), + moe_tp_mode=model_configs.get('moe_tp_mode', 2), + multi_query_mode=model_configs.get("multi_query_mode", False), + activation=model_configs.get('activation', "gelu"), + seq_len_interpolation_factor=model_configs.get("seq_len_interpolation_factor"), + moe_renorm_mode=model_configs.get( + 'moe_renorm_mode', MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE + ), + share_embeddings_and_output_weights=model_configs.get( + "share_embeddings_and_output_weights", False + ), + ) - model, model_configs, self.tokenizer = load_nemo_model(nemo_checkpoint_path, nemo_export_dir) - weights_dicts, model_configs = model_to_trtllm_ckpt( - model=model, - nemo_model_config=model_configs, - nemo_export_dir=nemo_export_dir, - decoder_type=model_type, - dtype=dtype, - tensor_parallel_size=tensor_parallelism_size, - pipeline_parallel_size=pipeline_parallelism_size, - gpus_per_node=gpus_per_node, - use_parallel_embedding=use_parallel_embedding, - use_embedding_sharing=use_embedding_sharing, - fp8_quantized=fp8_quantized, - fp8_kvcache=fp8_kvcache, - ) + input_dtype = getattr(DataType, dtype) + export_config = ExportConfig( + tensor_parallelism_size, + pipeline_parallelism_size, + use_parallel_embedding, + use_embedding_sharing, + ) - for weight_dict, model_config in zip(weights_dicts, model_configs): - build_and_save_engine( - max_input_len=max_input_len, - max_output_len=max_output_len, - max_batch_size=max_batch_size, - model_config=model_config, - model_weights=weight_dict, - model_dir=self.model_dir, - model_type=model_type, - lora_ckpt_list=self.lora_ckpt_list, - use_lora_plugin=use_lora_plugin, - max_lora_rank=max_lora_rank, - lora_target_modules=lora_target_modules, - max_prompt_embedding_table_size=max_prompt_embedding_table_size, - paged_kv_cache=paged_kv_cache, - remove_input_padding=remove_input_padding, - paged_context_fmha=paged_context_fmha, - max_num_tokens=max_num_tokens, - opt_num_tokens=opt_num_tokens, - max_seq_len=max_seq_len, - multiple_profiles=multiple_profiles, - gpt_attention_plugin=gpt_attention_plugin, - gemm_plugin=gemm_plugin, + trtllm_model_weights_list, trtllm_model_config_list = ( + trtllm_helper.get_trtllm_pretrained_config_and_model_weights( + model_state_dict=model, + export_config=export_config, + dtype=input_dtype, + state_dict_split_by_layer_numbers=False, + ) + ) + + for trtllm_model_weights, trtllm_model_config in zip( + trtllm_model_weights_list, trtllm_model_config_list + ): + trtllm_helper.build_and_save_engine( + max_input_len=max_input_len, + max_output_len=max_output_len, + max_batch_size=max_batch_size, + engine_dir=self.model_dir, + trtllm_model_weights=trtllm_model_weights, + trtllm_model_config=trtllm_model_config, + lora_ckpt_list=self.lora_ckpt_list, + use_lora_plugin=use_lora_plugin, + max_lora_rank=max_lora_rank, + lora_target_modules=lora_target_modules, + max_prompt_embedding_table_size=max_prompt_embedding_table_size, + paged_kv_cache=paged_kv_cache, + remove_input_padding=remove_input_padding, + paged_context_fmha=paged_context_fmha, + use_refit=False, + max_num_tokens=max_num_tokens, + max_seq_len=max_seq_len, + opt_num_tokens=opt_num_tokens, + max_beam_width=1, + tokens_per_block=128, + multiple_profiles=multiple_profiles, + gpt_attention_plugin=gpt_attention_plugin, + gemm_plugin=gemm_plugin, + ) + else: + if model_type == "gpt" or model_type == "starcoder": + model_type = "gptnext" + + if model_type == "mixtral": + model_type = "llama" + + weights_dicts, model_configs = model_to_trtllm_ckpt( + model=model, + nemo_model_config=model_configs, + nemo_export_dir=nemo_export_dir, + decoder_type=model_type, + dtype=dtype, + tensor_parallel_size=tensor_parallelism_size, + pipeline_parallel_size=pipeline_parallelism_size, + gpus_per_node=gpus_per_node, + use_parallel_embedding=use_parallel_embedding, + use_embedding_sharing=use_embedding_sharing, + fp8_quantized=fp8_quantized, + fp8_kvcache=fp8_kvcache, ) + for weight_dict, model_config in zip(weights_dicts, model_configs): + build_and_save_engine( + max_input_len=max_input_len, + max_output_len=max_output_len, + max_batch_size=max_batch_size, + model_config=model_config, + model_weights=weight_dict, + model_dir=self.model_dir, + model_type=model_type, + lora_ckpt_list=self.lora_ckpt_list, + use_lora_plugin=use_lora_plugin, + max_lora_rank=max_lora_rank, + lora_target_modules=lora_target_modules, + max_prompt_embedding_table_size=max_prompt_embedding_table_size, + paged_kv_cache=paged_kv_cache, + remove_input_padding=remove_input_padding, + paged_context_fmha=paged_context_fmha, + max_num_tokens=max_num_tokens, + opt_num_tokens=opt_num_tokens, + max_seq_len=max_seq_len, + multiple_profiles=multiple_profiles, + gpt_attention_plugin=gpt_attention_plugin, + gemm_plugin=gemm_plugin, + ) + tokenizer_path = os.path.join(nemo_export_dir, "tokenizer.model") if os.path.exists(tokenizer_path): shutil.copy(tokenizer_path, self.model_dir) @@ -451,7 +568,6 @@ def convert_to_safe_tensors( weight_dict[k] = numpy_to_torch(v) safetensors.torch.save_file(weight_dict, os.path.join(self.model_dir, f'rank{rank}.safetensors')) - model_configs[0].to_json_file(os.path.join(self.model_dir, 'config.json')) tokenizer_path = os.path.join(nemo_export_dir, "tokenizer.model")