diff --git a/examples/multimodal/speech_llm/export/README.md b/examples/multimodal/speech_llm/export/README.md new file mode 100644 index 000000000000..05e44d112cce --- /dev/null +++ b/examples/multimodal/speech_llm/export/README.md @@ -0,0 +1,83 @@ +## Setup +In this part, we are going to export SALM model into TRTLLM. +First, let's download the [SALM nemo model](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/speechllm_fc_llama2_7b/) from NVIDIA ngc. + +```bash +wget --content-disposition 'https://api.ngc.nvidia.com/v2/models/org/nvidia/team/nemo/speechllm_fc_llama2_7b/1.23.1/files?redirect=true&path=speechllm_fc_llama2_7b.nemo' -O speechllm_fc_llama2_7b.nemo +``` + +Then, we need to extract the different parts of SALM. +```bash +output=$PWD/output +python3 extract_salm_weights.py --model_file_path=speechllm_fc_llama2_7b.nemo --output_dir=$output +``` +It takes a while to run the above command. + +Under the `output` dir, you'll see: +``` +output + |___speechllm_fc_llama2_7b_lora.nemo + |___speechllm_fc_llama2_7b_perception + | |____model_config.yaml + | |____model_weights.ckpt + |___speechllm_fc_llama2_7b_llm.nemo + |___ xxx.tokenizer.model +``` + +After we get the lora nemo model and llm nemo model, we can merge the lora part into the llm by: +```bash +python /opt/NeMo/scripts/nlp_language_modeling/merge_lora_weights/merge.py \ + trainer.accelerator=gpu \ + tensor_model_parallel_size=1 \ + pipeline_model_parallel_size=1 \ + gpt_model_file=output/speechllm_fc_llama2_7b_llm.nemo \ + lora_model_path=output/speechllm_fc_llama2_7b_lora.nemo \ + merged_model_path=speechllm_fc_llama2_7b_llm_merged.nemo +``` + +Now we are able to export the engine by: +```bash +python3 export_salm.py \ + model.perception_model_path=output/speechllm_fc_llama2_7b_perception \ + model.llm_model_path=output/speechllm_fc_llama2_7b_llm_merged.nemo +``` + +You should be able to get the generated engines under `./salm` folder. To run the engines, you may run: +```python +from nemo.export.tensorrt_mm_exporter import TensorRTMMExporter + +output_dir = "/ws/salm" # the engine directory +trt_llm_exporter = TensorRTMMExporter(model_dir=output_dir, load_model=True, modality='audio') +input_text = "Q: what's the transcription of the audio? A:" +input_media = '/ws/data/test_audio.wav' +print(trt_llm_exporter.forward(input_text, input_media)) + +``` + +## Deploy +If you want to generate the engines and deploy them with Triton Inference Server, you may also run: + +```bash +python3 NeMo/scripts/deploy/multimodal/deploy_triton.py \ + --modality="audio" \ + --visual_checkpoint=NeMo/examples/multimodal/speech_llm/export/output/speechllm_fc_llama2_7b_perception \ + --llm_checkpoint=NeMo/examples/multimodal/speech_llm/export/output/speechllm_fc_llama2_7b_llm_merged.nemo \ + --llm_model_type="llama" \ + --model_type="salm" \ + --triton_model_name="salm" \ + --max_input_len=4096 \ + --max_output_len=256 \ + --max_multimodal_len=3072 \ + --triton_model_repository=/tmp/trt_model_dir/ +``` + +And on client side, you may run: +```bash +python3 NeMo/scripts/deploy/multimodal/query.py \ + --model_name="salm" \ + --model_type="salm" \ + --input_text="Q: what's the transcription of the audio? A:" \ + --input_media=/ws/data/test_audio.wav +``` + +For more details, please check `NeMo/scripts/deploy/multimodal/deploy_triton.py` and ` NeMo/scripts/deploy/multimodal/query.py`. \ No newline at end of file diff --git a/examples/multimodal/speech_llm/export/conf/salm_export.yaml b/examples/multimodal/speech_llm/export/conf/salm_export.yaml new file mode 100644 index 000000000000..54ab6e9180c5 --- /dev/null +++ b/examples/multimodal/speech_llm/export/conf/salm_export.yaml @@ -0,0 +1,16 @@ +name: speechllm_salm +infer: + output_dir: ./salm + max_batch_size: 1 + tensor_parallelism: 1 + max_input_len: 4096 + max_output_len: 256 + max_multimodal_len: 3072 + perception_max_batch_size: 1 + +model: + type: salm + precision: float16 + perception_model_path: /path/to/speechllm_llama2_7b_perception + llm_model_path: /path/to/speechllm_llama2_7b_llm.nemo + llm_model_type: llama diff --git a/examples/multimodal/speech_llm/export/export_salm.py b/examples/multimodal/speech_llm/export/export_salm.py new file mode 100644 index 000000000000..00500bf46f50 --- /dev/null +++ b/examples/multimodal/speech_llm/export/export_salm.py @@ -0,0 +1,39 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.core.config import hydra_runner +from nemo.export.tensorrt_mm_exporter import TensorRTMMExporter + + +@hydra_runner(config_path='conf', config_name='salm_export') +def main(cfg): + exporter = TensorRTMMExporter(model_dir=cfg.infer.output_dir, load_model=False, modality='audio') + exporter.export( + visual_checkpoint_path=cfg.model.perception_model_path, + llm_checkpoint_path=cfg.model.llm_model_path, + model_type=cfg.model.type, + llm_model_type=cfg.model.llm_model_type, + tensor_parallel_size=cfg.infer.tensor_parallelism, + max_input_len=cfg.infer.max_input_len, + max_output_len=cfg.infer.max_output_len, + vision_max_batch_size=cfg.infer.perception_max_batch_size, + max_batch_size=cfg.infer.max_batch_size, + max_multimodal_len=cfg.infer.max_multimodal_len, + dtype=cfg.model.precision, + load_model=False, + ) + + +if __name__ == '__main__': + main() diff --git a/examples/multimodal/speech_llm/export/extract_salm_weights.py b/examples/multimodal/speech_llm/export/extract_salm_weights.py new file mode 100644 index 000000000000..0698a411110e --- /dev/null +++ b/examples/multimodal/speech_llm/export/extract_salm_weights.py @@ -0,0 +1,204 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os +import tempfile + +import torch +from megatron.core import dist_checkpointing +from omegaconf import OmegaConf +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.multimodal.speech_llm.modules.perception_modules import AudioPerceptionModule +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector +from nemo.collections.nlp.parts.utils_funcs import load_state_dict_helper +from nemo.utils import logging +from nemo.utils.model_utils import inject_model_parallel_rank + + +def get_config_and_state_dict_from_nemo(filepath, map_location, output_dir, sharded_state_dict=None): + cwd = os.getcwd() + save_restore_connector = NLPSaveRestoreConnector() + + with tempfile.TemporaryDirectory() as tmpdir: + try: + if os.path.isfile(filepath): + save_restore_connector._unpack_nemo_file(path2file=filepath, out_folder=tmpdir) + else: + tmpdir = filepath + + os.chdir(tmpdir) + config_yaml = "model_config.yaml" + model_weights_ckpt = "model_weights.ckpt" + + # find file in tmpdir that endswith "tokenizer.model" + tokenizer = None + for file in os.listdir(tmpdir): + if file.endswith("tokenizer.model"): + tokenizer = file + break + if tokenizer is None: + raise ValueError(f"Tokenizer not found in {tmpdir}") + tokenizer_path = os.path.join(tmpdir, tokenizer) + # copy tokenizer_path to current directory + os.system(f"cp {tokenizer_path} {output_dir}") + tokenizer_path = os.path.join(output_dir, tokenizer) + + # load conf + with open(config_yaml) as f: + conf = OmegaConf.load(f) + + os.chdir(cwd) + model_weights = os.path.join(tmpdir, model_weights_ckpt) + model_weights = inject_model_parallel_rank(model_weights) + state_dict = save_restore_connector._load_state_dict_from_disk(model_weights, map_location=map_location) + + # distributed checkpointing + if state_dict is None and sharded_state_dict is not None: + checkpoint = dict(state_dict=sharded_state_dict) + tmp_model_weights_ckpt = os.path.join(tmpdir, save_restore_connector.model_weights_ckpt) + tmp_model_weights_dir = os.path.splitext(tmp_model_weights_ckpt)[0] + assert os.path.isdir(tmp_model_weights_dir), f'Expected {tmp_model_weights_dir} to be a directory.' + checkpoint = dist_checkpointing.load( + sharded_state_dict=checkpoint, + checkpoint_dir=tmp_model_weights_dir, + ) + state_dict = checkpoint["state_dict"] + + conf.tokenizer.model = tokenizer_path + return conf, state_dict + finally: + os.chdir(cwd) + + +def get_llm_model_state_dict(state_dict, lora_model_state_dict): + llm_model_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("model."): + if key not in lora_model_state_dict and value != None: + llm_model_state_dict[key] = value + return llm_model_state_dict + + +def get_lora_state_dict(state_dict): + lora_model_state_dict = {} + for key, value in state_dict.items(): + if "adapter_layer.lora" in key and value != None: + lora_model_state_dict[key] = value + return lora_model_state_dict + + +def get_perception_state_dict(state_dict): + perception_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("perception."): + key = key.replace("perception.", "", 1) + perception_state_dict[key] = value + return perception_state_dict + + +def save_llm_model(state_dict, nemo_config, output_path): + if nemo_config.get('megatron_amp_O2', False): + keys = list(state_dict.keys()) + for key in keys: + state_dict[key.replace('model.', 'model.module.', 1)] = state_dict['state_dict'].pop(key) + + trainer = Trainer(accelerator='cpu', strategy=NLPDDPStrategy()) + model = load_state_dict_helper(MegatronGPTModel, nemo_config, trainer, state_dict) + model._save_restore_connector = NLPSaveRestoreConnector() + model.cfg.use_cpu_initialization = False + + model.save_to(output_path) + logging.info(f'llm model saved to: {output_path}') + + +def save_nemo_weights(state_dict, output_dir, config, save_nemo_model=True): + if not os.path.exists(output_dir): + os.mkdir(output_dir) + weight_file = os.path.join(output_dir, "model_weights.ckpt") + torch.save(state_dict, weight_file) + # convert config to yaml + config_file = os.path.join(output_dir, "model_config.yaml") + with open(config_file, "w") as f: + f.write(OmegaConf.to_yaml(config)) + + if save_nemo_model: + # create nemo file + nemo_model_name = f"{output_dir}.nemo" + nemo_path = os.path.join(output_dir, nemo_model_name) + # tar model_config.yaml and model_weights.ckpt + os.system(f"tar -C {output_dir} -cvf {nemo_path} model_config.yaml model_weights.ckpt") + # remove model_config.yaml and model_weights.ckpt + os.system(f"rm {config_file} {weight_file}") + # remove the empty directory + os.system(f"rmdir {output_dir}") + + +def separate_speechllm_model(model_file_path, output_dir, map_location="cuda:0"): + if not os.path.exists(output_dir): + os.mkdir(output_dir) + output_dir = os.path.abspath(output_dir) + + logging.info(f"Separating {model_file_path} into perception, lora, and llm model") + filepath = model_file_path + conf, state_dict = get_config_and_state_dict_from_nemo(filepath, map_location, output_dir) + + base_model_name = os.path.basename(filepath).split(".")[0] + + perception_state_dict = get_perception_state_dict(state_dict) + perception_model_dir = None + if perception_state_dict: + perception_model_dir = f"{base_model_name}_perception" + perception_model_dir = os.path.join(output_dir, perception_model_dir) + save_nemo_weights(perception_state_dict, perception_model_dir, conf.perception, save_nemo_model=False) + + # verify if the exported perception model is correct + perception = AudioPerceptionModule(cfg=conf.perception) + perception.load_state_dict(perception_state_dict) + perception.eval() + print(perception) + print(perception(input_signal=torch.randn(1, 1000), input_signal_length=torch.tensor([1000]))) + # absolute path of perception model + logging.info(f"Perception model saved to: {perception_model_dir}") + + lora_model_weights = get_lora_state_dict(state_dict) + lora_model_dir = None + if lora_model_weights: + lora_model_dir = f"{base_model_name}_lora" + lora_model_dir = os.path.join(output_dir, lora_model_dir) + save_nemo_weights(lora_model_weights, lora_model_dir, conf) + logging.info(f"Lora model saved to: {lora_model_dir}.nemo") + # hard code the target model for now + llm_model_weights = get_llm_model_state_dict(state_dict, lora_model_weights) + if llm_model_weights: + llm_model = f"{base_model_name}_llm.nemo" + llm_model = os.path.join(output_dir, llm_model) + conf.target = "nemo.collections.nlp.models.language_modeling.megatron_gpt_model.MegatronGPTModel" + save_llm_model(llm_model_weights, conf, llm_model) + logging.info(f"LLM model saved to: {llm_model}") + + +# filepath = "/ws/speechllm_fc_llama2_7b.nemo" +# output_dir = "/ws/speechllm_fc_llama2_7b_separated" +# perception_model_dir, lora_model, llm_model = separate_speechllm_model(filepath, output_dir) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Separate speechllm model') + parser.add_argument('--model_file_path', type=str, help='Path to the speechllm model') + parser.add_argument('--output_dir', type=str, help='Output directory to save the separated models') + args = parser.parse_args() + separate_speechllm_model(args.model_file_path, args.output_dir) diff --git a/nemo/deploy/multimodal/query_multimodal.py b/nemo/deploy/multimodal/query_multimodal.py index 1c01c6861048..63e6a3e8c3a6 100644 --- a/nemo/deploy/multimodal/query_multimodal.py +++ b/nemo/deploy/multimodal/query_multimodal.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np +import soundfile as sf from PIL import Image from nemo.deploy.utils import str_list2numpy @@ -71,6 +72,11 @@ def setup_media(self, input_media): elif self.model_type == "neva" or self.model_type == "vila": media = Image.open(input_media).convert('RGB') return np.expand_dims(np.array(media), axis=0) + elif self.model_type == "salm": + waveform, sample_rate = sf.read(input_media, dtype=np.float32) + input_signal = np.array([waveform], dtype=np.float32) + input_signal_length = np.array([[len(waveform)]], dtype=np.int32) + return {"input_signal": input_signal, "input_signal_length": input_signal_length} else: raise RuntimeError(f"Invalid model type {self.model_type}") @@ -105,8 +111,10 @@ def query( inputs = {"input_text": prompts} media = self.setup_media(input_media) - - inputs["input_media"] = np.repeat(media[np.newaxis, :, :, :, :], prompts.shape[0], axis=0) + if isinstance(media, dict): + inputs.update(media) + else: + inputs["input_media"] = np.repeat(media[np.newaxis, :, :, :, :], prompts.shape[0], axis=0) if batch_size is not None: inputs["batch_size"] = np.full(prompts.shape, batch_size, dtype=np.int_) diff --git a/nemo/export/multimodal/build.py b/nemo/export/multimodal/build.py index 8ee3fa1c05e7..53c598be47c6 100644 --- a/nemo/export/multimodal/build.py +++ b/nemo/export/multimodal/build.py @@ -23,9 +23,12 @@ import tensorrt as trt import torch import yaml +from omegaconf import OmegaConf from tensorrt_llm.builder import Builder from transformers import AutoModel +from nemo.collections.multimodal.speech_llm.modules.perception_modules import AudioPerceptionModule +from nemo.core.classes.common import typecheck from nemo.export.tensorrt_llm import TensorRTLLM from nemo.export.trt_llm.nemo_ckpt_loader.nemo_file import load_nemo_model @@ -76,6 +79,32 @@ def export_visual_wrapper_onnx( ) +def export_perception_wrapper_onnx( + perception_wrapper, + input, + output_dir, + input_names=['processed_signal', 'processed_signal_length'], + output_names=['encoded', 'encoded_length'], + dynamic_axes={ + 'processed_signal': {0: 'batch', 2: 'time'}, + 'processed_signal_length': {0: 'batch'}, + 'encoded': {0: 'batch', 1: 'time'}, + 'encoded_length': {0: 'batch'}, + }, +): + logger.log(trt.Logger.INFO, "Exporting onnx") + os.makedirs(f'{output_dir}/onnx', exist_ok=True) + torch.onnx.export( + perception_wrapper, + input, + f'{output_dir}/onnx/perception_encoder.onnx', + opset_version=17, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + ) + + def build_trt_engine( model_type, input_sizes, @@ -85,8 +114,8 @@ def build_trt_engine( image_size=None, num_frames=None, nemo_config=None, + part_name='visual_encoder', ): - part_name = 'visual_encoder' onnx_file = '%s/onnx/%s.onnx' % (output_dir, part_name) engine_file = '%s/%s.engine' % (output_dir, part_name) config_file = '%s/%s' % (output_dir, "config.json") @@ -131,6 +160,10 @@ def build_trt_engine( # input sizes can be a list of ints (e.g., [3, H, W]) when inputs are images, # or a list of three int lists (e.g., [[1, 1, 2700], [1, 500, 2700], [1, 4096, 2700]]). + # or a list of three list of lists + # (e.g., [{input1: min_shape, input2: min_shape, }, \ + # {input1: opt_shape, input2: opt_shape}, \ + # {input1: max_shape, input2: max_shape}] ) assert isinstance(input_sizes, list), "input_sizes must be a list" if isinstance(input_sizes[0], int): logger.log(trt.Logger.INFO, f"Processed input sizes {input_sizes}") @@ -139,10 +172,23 @@ def build_trt_engine( elif len(input_sizes) == 3 and isinstance(input_sizes[0], list): min_size, opt_size, max_size = input_sizes logger.log(trt.Logger.INFO, f"Processed min/opt/max input sizes {min_size}/{opt_size}/{max_size}") + elif len(input_sizes) == 3 and isinstance(input_sizes[0], dict): + logger.log(trt.Logger.INFO, f"Processed min/opt/max input sizes {input_sizes}") else: raise ValueError(f"invalid input sizes: {input_sizes}") - profile.set_shape(inputT.name, [nMinBS, *min_size], [nOptBS, *opt_size], [nMaxBS, *max_size]) + if isinstance(input_sizes[0], dict): + for i in range(network.num_inputs): + inputT = network.get_input(i) + input_name = inputT.name + min_size = input_sizes[0][input_name] + opt_size = input_sizes[1][input_name] + max_size = input_sizes[2][input_name] + logger.log(trt.Logger.INFO, f"{input_name} min/opt/max input sizes {min_size}/{opt_size}/{max_size}") + profile.set_shape(input_name, min_size, opt_size, max_size) + else: + profile.set_shape(inputT.name, [nMinBS, *min_size], [nOptBS, *opt_size], [nMaxBS, *max_size]) + config.add_optimization_profile(profile) t0 = time() @@ -367,6 +413,76 @@ def forward(self, images): ) +def build_perception_engine( + model_dir: str, + perception_checkpoint_path: str, + model_type: str = "salm", + max_batch_size: int = 1, +): + assert model_type == "salm", f"Invalid model type {model_type}" + + def load_perception_model(perception_checkpoint_path): + weights = "model_weights.ckpt" + perception_state_dict = torch.load(os.path.join(perception_checkpoint_path, weights)) + config = "model_config.yaml" + config = OmegaConf.load(os.path.join(perception_checkpoint_path, config)) + perception = AudioPerceptionModule(cfg=config) + perception.load_state_dict(perception_state_dict) + perception.eval() + return perception + + if not os.path.exists(model_dir): + os.makedirs(model_dir) + # load perception model + perception_model = load_perception_model(perception_checkpoint_path) + feature_extractor = perception_model.preprocessor + input_signal = torch.randn(1, 1000, dtype=torch.float32) + input_signal_length = torch.tensor([1000], dtype=torch.int32) + + processed_signal, processed_signal_length = feature_extractor( + input_signal=input_signal, length=input_signal_length + ) + processed_signal_length = processed_signal_length.to(torch.int32) + dump_path = model_dir + "/feature_extractor.ts" # dump the feature extractor as torchscript + feature_extractor.export(dump_path, (input_signal, input_signal_length)) + + class PerceptionWrapper(torch.nn.Module): + def __init__(self, encoder, modality_adapter, proj): + super().__init__() + self.encoder = encoder + self.modality_adapter = modality_adapter + self.proj = proj + + @typecheck.disable_checks() + def forward(self, processed_signal, processed_signal_length): + encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length) + encoded, encoded_len = self.modality_adapter(audio_signal=encoded, length=encoded_len) + # b, c, t -> b, t, c + encoded = self.proj(encoded.transpose(1, 2)) + encoded_len = encoded_len.to(torch.int32) + return encoded, encoded_len + + perception = PerceptionWrapper(perception_model.encoder, perception_model.modality_adapter, perception_model.proj) + export_perception_wrapper_onnx(perception, (processed_signal, processed_signal_length), model_dir) + # export the onnx perception model to tensorrt engine + # 512 -> 5.12 sec, 3072 -> 30.72 sec + opt_batch_size = max(1, max_batch_size // 2) + shapes = [ + {"processed_signal": [1, 80, 64], "processed_signal_length": [1]}, + {"processed_signal": [opt_batch_size, 80, 512], "processed_signal_length": [opt_batch_size]}, + {"processed_signal": [max_batch_size, 80, 3072], "processed_signal_length": [max_batch_size]}, + ] + build_trt_engine( + model_type, + shapes, + model_dir, + max_batch_size, + dtype=torch.float16, + nemo_config=None, + part_name='perception_encoder', + ) + + def build_visual_engine( model_dir: str, visual_checkpoint_path: str, diff --git a/nemo/export/multimodal/run.py b/nemo/export/multimodal/run.py index 149df995c77a..2cde46ca41fa 100644 --- a/nemo/export/multimodal/run.py +++ b/nemo/export/multimodal/run.py @@ -25,6 +25,7 @@ import einops import numpy as np +import soundfile as sf import tensorrt as trt import tensorrt_llm import tensorrt_llm.profiler as profiler @@ -32,7 +33,7 @@ import yaml from PIL import Image from tensorrt_llm import logger -from tensorrt_llm._utils import str_dtype_to_trt +from tensorrt_llm._utils import str_dtype_to_trt, torch_dtype_to_trt from tensorrt_llm.runtime import ModelRunner, Session, TensorInfo from torch.nn import functional as F from torchvision import transforms @@ -54,7 +55,8 @@ def trt_dtype_to_torch(dtype): class MultimodalModelRunner: - def __init__(self, visual_engine_dir, llm_engine_dir): + def __init__(self, visual_engine_dir, llm_engine_dir, modality='vision'): + self.modality = modality self.runtime_rank = tensorrt_llm.mpi_rank() device_id = self.runtime_rank % torch.cuda.device_count() torch.cuda.set_device(device_id) @@ -68,13 +70,15 @@ def __init__(self, visual_engine_dir, llm_engine_dir): config = json.load(f) self.model_type = config['builder_config']['model_type'] self.vision_precision = config['builder_config']['precision'] + self.modality_precision = config['builder_config']['precision'] self.num_frames = config['builder_config'].get('num_frames', None) self.image_size = config['builder_config'].get('image_size', None) self.profiling_iterations = 20 - self.init_image_encoder(visual_engine_dir) + if modality == 'vision': + self.init_image_encoder(visual_engine_dir) self.init_tokenizer(llm_engine_dir) self.init_llm(llm_engine_dir) if self.model_type == 'lita' or self.model_type == 'vila' or self.model_type == 'vita': @@ -242,10 +246,10 @@ def insert_tokens_by_index(self, input_ids, num_frames): def preprocess(self, warmup, pre_prompt, post_prompt, image, attention_mask, batch_size): if not warmup: - profiler.start("Vision") + profiler.start(self.modality.capitalize()) if not warmup: - profiler.stop("Vision") + profiler.stop(self.modality.capitalize()) if self.model_type == 'vila': visual_features, visual_atts = self.get_visual_features(image, attention_mask) @@ -848,7 +852,7 @@ def print_result(self, input_text, output_text, batch_size, num_beams, run_profi if run_profiling: msec_per_batch = lambda name: 1000 * profiler.elapsed_time_in_sec(name) / self.profiling_iterations logger.info('Latencies per batch (msec)') - logger.info('TRT vision encoder: %.1f' % (msec_per_batch('Vision'))) + logger.info(f'TRT {self.modality} encoder: %.1f' % (msec_per_batch(self.modality.capitalize()))) logger.info('TRTLLM LLM generate: %.1f' % (msec_per_batch('LLM'))) logger.info('Multimodal generate: %.1f' % (msec_per_batch('Generate'))) @@ -864,3 +868,278 @@ def load_test_media(self, input_media): raise RuntimeError(f"Invalid model type {self.model_type}") return media + + +class SpeechllmModelRunner(MultimodalModelRunner): + def __init__(self, perception_engine_dir, llm_engine_dir, modality): + """ + perception_engine_dir: path to the perception engine directory + it should contain: + config.json nemo_config.yaml + perception_encoder.engine : tensorrt engine + feature_extractor.ts : torchscript model + llm_engine_dir: path to the LLM engine directory + """ + super().__init__(perception_engine_dir, llm_engine_dir, modality) + assert self.model_type == 'salm' + # init preprocessor + feature_extractor_path = os.path.join(perception_engine_dir, 'feature_extractor.ts') + self.feature_extractor = self.init_speech_preprocessor(feature_extractor_path) + self.init_modality_encoder(perception_engine_dir) + + def init_modality_encoder(self, engine_dir): + """ + Initialize the modality encoder session from the prebuilt engine directory + Args: + engine_dir: str, path to the engine directory + """ + # find file with .engine extension + engine_file = None + for file in os.listdir(engine_dir): + if file.endswith('.engine'): + engine_file = file + break + assert engine_file is not None, f"Engine file not found in {engine_dir}" + encoder_path = os.path.join(engine_dir, engine_file) + logger.info(f'Loading engine from {encoder_path}') + with open(encoder_path, 'rb') as f: + engine_buffer = f.read() + logger.info(f'Creating session from engine {encoder_path}') + self.modality_encoder_session = Session.from_serialized_engine(engine_buffer) + + def init_speech_preprocessor(self, feature_extractor_path): + feature_extractor = torch.jit.load(feature_extractor_path) + feature_extractor.eval() + return feature_extractor + + def process_audio(self, input_signal, input_signal_length): + """ + Args: + input_signal: audio signal in numpy array + input_signal_length: length of the audio signal in numpy array + + Returns: + processed_signal: torch.tensor [B, 80, T] + processed_signal_length [B] + """ + input_signal = torch.tensor(input_signal, dtype=torch.float32) + input_signal_length = torch.tensor(input_signal_length, dtype=torch.int32) + processed_signal, processed_signal_length = self.feature_extractor(input_signal, input_signal_length) + return processed_signal, processed_signal_length + + def setup_inputs(self, input_text, input_media, batch_size): + """ + Args: + input_text: str or List[str] or None + input_media: Tuple[np.array, np.array] + input_signal: audio signal in numpy array [b, -1] + input_signal_length: length of the audio signal in numpy array [b] + batch_size: int + + """ + input_signal, input_signal_length = input_media + processed_signal, processed_signal_length = self.process_audio(input_signal, input_signal_length) + processed_signal = processed_signal.to(self.device) + processed_signal_length = processed_signal_length.to(self.device) + if input_text is None: + input_text = "Q: what's the transcription of the audio? A:" + + if isinstance(input_text, str): + input_text = [input_text] * batch_size + + assert len(input_text) == batch_size + pre_prompt = [''] * batch_size + post_prompt = input_text + decoder_input_ids = None + attention_mask = None + return ( + input_text, + pre_prompt, + post_prompt, + processed_signal, + processed_signal_length, + decoder_input_ids, + attention_mask, + ) + + def load_test_media(self, input_media_path): + """ + Args: + input_media_path: str, path to the audio file + Returns: + input_signal: np.array [1, -1] + input_signal_length: np.array [1] + """ + waveform, sample_rate = sf.read(input_media_path, dtype=np.float32) + input_signal = np.array([waveform], dtype=np.float32) + input_signal_length = np.array([len(waveform)], dtype=np.int32) + return input_signal, input_signal_length + + def get_modality_encoder_features(self, modality_features, attention_mask): + """ + Do inference on the modality encoder engine + Args: + modality_features: dict {'input1': torch.tensor, 'input2': torch.tensor, ..} + attention_mask: None + Returns: + """ + + if attention_mask is not None: + modality_features['attention_mask'] = attention_mask + + tensor_info = [] + for key, tensor in modality_features.items(): + tensor_info.append(TensorInfo(key, torch_dtype_to_trt(tensor.dtype), tensor.shape)) + + output_info = self.modality_encoder_session.infer_shapes(tensor_info) + + outputs = { + t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device=self.device) + for t in output_info + } + + ok = self.modality_encoder_session.run(modality_features, outputs, self.stream.cuda_stream) + assert ok, "Runtime execution failed for vision encoder session" + self.stream.synchronize() + + return outputs + + def preprocess(self, warmup, pre_prompt, post_prompt, processed_features, attention_mask, batch_size): + """ + Args: + warmup: bool + pre_prompt: List[str] + post_prompt: List[str] + processed_features: Tuple[torch.tensor, torch.tensor] + processed_signal: torch.tensor [B, 80, T] + processed_signal_length: torch.tensor [B] + attention_mask: None + batch_size: int + Returns: + input_ids: torch.tensor [B, L] + input_lengths: torch.tensor [B] + ptuning_args: List[torch.tensor] + encoded_features: torch.tensor [B, L, D] + """ + if not warmup: + profiler.start(self.modality.capitalize()) + + if not warmup: + profiler.stop(self.modality.capitalize()) + + assert self.model_type == 'salm', f"Invalid model type {self.model_type}" + + processed_features = { + "processed_signal": processed_features[0], + "processed_signal_length": processed_features[1].to(torch.int32), + } + encoded_outputs = self.get_modality_encoder_features(processed_features, attention_mask) + encoded_features, encoded_length = encoded_outputs['encoded'], encoded_outputs['encoded_length'] + pre_input_ids = self.tokenizer(pre_prompt).input_ids + post_input_ids = self.tokenizer(post_prompt).input_ids + input_lengths = [] + input_ids = [] + encoded_length = encoded_length.cpu().numpy() + fake_id_start = self.model.vocab_size + for i in range(batch_size): + feat_len = encoded_length[i] + feat_fake_ids = np.arange(fake_id_start, fake_id_start + feat_len) + cur_input_ids = np.concatenate([pre_input_ids[i], feat_fake_ids, post_input_ids[i]]) + fake_id_start += feat_len + input_lengths.append(len(cur_input_ids)) + input_ids.append(cur_input_ids) + + max_length = max(input_lengths) + # convert input_ids to torch tensor with padding + input_ids = [ + np.pad(ids, (0, max_length - len(ids)), 'constant', constant_values=self.tokenizer.pad_token_id) + for ids in input_ids + ] + input_ids = torch.tensor(input_ids, dtype=torch.int32) + input_lengths = torch.tensor(input_lengths, dtype=torch.int32) + ptuning_args = self.ptuning_setup(encoded_features, input_ids, input_lengths) + + return input_ids, input_lengths, ptuning_args, encoded_features + + def run( + self, + input_text, + input_media=None, + max_new_tokens: int = 30, + batch_size: int = 1, + top_k: int = 1, + top_p: float = 0.0, + temperature: float = 1.0, + repetition_penalty: float = 1.0, + num_beams: int = 1, + run_profiling=False, + check_accuracy=False, + input_signal=None, + input_signal_length=None, + ): + """ + Args: + input_text: str or List[str] or None + input_media: Tuple[np.array, np.array] or None + input_signal: audio signal in numpy array [b, -1] + input_signal_length: length of the audio signal in numpy array [b] + max_new_tokens: int + batch_size: int + top_k: int + top_p: float + temperature: float + repetition_penalty: float + num_beams: int + run_profiling: bool + check_accuracy: bool + """ + if input_media is None: + assert input_signal is not None and input_signal_length is not None + input_media = (input_signal, input_signal_length) + + ( + input_text, + pre_prompt, + post_prompt, + processed_signal, + processed_signal_length, + decoder_input_ids, + attention_mask, + ) = self.setup_inputs(input_text, input_media, batch_size) + processed_media = (processed_signal, processed_signal_length) + + self.generate( + pre_prompt, + post_prompt, + processed_media, + decoder_input_ids, + max_new_tokens, + attention_mask=attention_mask, + warmup=True, + batch_size=batch_size, + top_k=top_k, + top_p=top_p, + temperature=temperature, + repetition_penalty=repetition_penalty, + num_beams=num_beams, + ) + num_iters = self.profiling_iterations if run_profiling else 1 + for _ in range(num_iters): + output_text = self.generate( + pre_prompt, + post_prompt, + processed_media, + decoder_input_ids, + max_new_tokens, + attention_mask=attention_mask, + warmup=False, + batch_size=batch_size, + top_k=top_k, + top_p=top_p, + temperature=temperature, + repetition_penalty=repetition_penalty, + num_beams=num_beams, + ) + if self.runtime_rank == 0: + self.print_result(input_text, output_text, batch_size, num_beams, run_profiling, check_accuracy) + return output_text diff --git a/nemo/export/tensorrt_mm_exporter.py b/nemo/export/tensorrt_mm_exporter.py index b0536a55f95f..d4da0ac34b1c 100644 --- a/nemo/export/tensorrt_mm_exporter.py +++ b/nemo/export/tensorrt_mm_exporter.py @@ -21,8 +21,8 @@ import wrapt from nemo.deploy import ITritonDeployable -from nemo.export.multimodal.build import build_trtllm_engine, build_visual_engine -from nemo.export.multimodal.run import MultimodalModelRunner +from nemo.export.multimodal.build import build_perception_engine, build_trtllm_engine, build_visual_engine +from nemo.export.multimodal.run import MultimodalModelRunner, SpeechllmModelRunner use_deploy = True try: @@ -74,9 +74,13 @@ def __init__( self, model_dir: str, load_model: bool = True, + modality: str = "vision", ): self.model_dir = model_dir self.runner = None + # vision modality is for image and video + assert modality in ["vision", "audio"] + self.modality = modality if load_model: self._load() @@ -128,8 +132,12 @@ def export( dtype=dtype, ) - visual_dir = os.path.join(self.model_dir, "visual_engine") - build_visual_engine(visual_dir, visual_checkpoint_path, model_type, vision_max_batch_size) + if model_type == "salm": + perception_dir = os.path.join(self.model_dir, "perception_engine") + build_perception_engine(perception_dir, visual_checkpoint_path, model_type, vision_max_batch_size) + else: + visual_dir = os.path.join(self.model_dir, "visual_engine") + build_visual_engine(visual_dir, visual_checkpoint_path, model_type, vision_max_batch_size) if load_model: self._load() @@ -164,19 +172,32 @@ def forward( num_beams, ) + def get_input_media_tensors(self): + if self.modality == "vision": + return [Tensor(name="input_media", shape=(-1, -1, -1, 3), dtype=np.uint8)] + elif self.modality == "audio": + return [ + Tensor(name="input_signal", shape=(-1,), dtype=np.single), + Tensor(name="input_signal_length", shape=(1,), dtype=np.intc), + ] + return [] + @property def get_triton_input(self): inputs = ( - Tensor(name="input_text", shape=(-1,), dtype=bytes), - Tensor(name="input_media", shape=(-1, -1, -1, 3), dtype=np.uint8), - Tensor(name="batch_size", shape=(-1,), dtype=np.int_, optional=True), - Tensor(name="max_output_len", shape=(-1,), dtype=np.int_, optional=True), - Tensor(name="top_k", shape=(-1,), dtype=np.int_, optional=True), - Tensor(name="top_p", shape=(-1,), dtype=np.single, optional=True), - Tensor(name="temperature", shape=(-1,), dtype=np.single, optional=True), - Tensor(name="repetition_penalty", shape=(-1,), dtype=np.single, optional=True), - Tensor(name="num_beams", shape=(-1,), dtype=np.int_, optional=True), + [Tensor(name="input_text", shape=(-1,), dtype=bytes)] + + self.get_input_media_tensors() + + [ + Tensor(name="batch_size", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="max_output_len", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="top_k", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="top_p", shape=(-1,), dtype=np.single, optional=True), + Tensor(name="temperature", shape=(-1,), dtype=np.single, optional=True), + Tensor(name="repetition_penalty", shape=(-1,), dtype=np.single, optional=True), + Tensor(name="num_beams", shape=(-1,), dtype=np.int_, optional=True), + ] ) + inputs = tuple(inputs) return inputs @property @@ -198,6 +219,9 @@ def triton_infer_fn(self, **inputs: np.ndarray): infer_input["input_image"] = ndarray2img(inputs.pop("input_media")[0])[0] elif self.runner.model_type in video_model_list: infer_input["input_image"] = inputs.pop("input_media")[0] + elif self.runner.model_type == "salm": + infer_input["input_signal"] = inputs.pop("input_signal") + infer_input["input_signal_length"] = inputs.pop("input_signal_length")[:, 0] if "batch_size" in inputs: infer_input["batch_size"] = inputs.pop("batch_size")[0][0] if "max_output_len" in inputs: @@ -223,5 +247,9 @@ def triton_infer_fn(self, **inputs: np.ndarray): def _load(self): llm_dir = os.path.join(self.model_dir, "llm_engine") - visual_dir = os.path.join(self.model_dir, "visual_engine") - self.runner = MultimodalModelRunner(visual_dir, llm_dir) + if self.modality == "vision": + visual_dir = os.path.join(self.model_dir, "visual_engine") + self.runner = MultimodalModelRunner(visual_dir, llm_dir, self.modality) + elif self.modality == "audio": + perception_dir = os.path.join(self.model_dir, "perception_engine") + self.runner = SpeechllmModelRunner(perception_dir, llm_dir, self.modality) diff --git a/scripts/deploy/multimodal/deploy_triton.py b/scripts/deploy/multimodal/deploy_triton.py index d0bf8f10548a..18463a3fc24a 100755 --- a/scripts/deploy/multimodal/deploy_triton.py +++ b/scripts/deploy/multimodal/deploy_triton.py @@ -35,6 +35,16 @@ def get_args(argv): formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=f"Deploy nemo models to Triton", ) + # default modality is vision, can be changed to audio + parser.add_argument( + "-mod", + "--modality", + type=str, + required=False, + default="vision", + choices=["vision", "audio"], + help="Modality of the model", + ) parser.add_argument("-vc", "--visual_checkpoint", type=str, help="Source .nemo file for visual model") parser.add_argument( "-lc", @@ -48,7 +58,7 @@ def get_args(argv): "--model_type", type=str, required=True, - choices=["neva", "video-neva", "lita", "vila", "vita"], + choices=["neva", "video-neva", "lita", "vila", "vita", "salm"], help="Type of the model that is supported.", ) parser.add_argument( @@ -123,8 +133,7 @@ def get_trt_deployable(args): raise ValueError("Model type is required to be defined if a nemo checkpoint is provided.") exporter = TensorRTMMExporter( - model_dir=trt_path, - load_model=(args.visual_checkpoint is None), + model_dir=trt_path, load_model=(args.visual_checkpoint is None), modality=args.modality ) if args.visual_checkpoint is not None: