diff --git a/examples/whisper/README.md b/examples/whisper/README.md index 6766607f0..8f3f97a2e 100755 --- a/examples/whisper/README.md +++ b/examples/whisper/README.md @@ -87,7 +87,7 @@ trtllm-build --checkpoint_dir ${checkpoint_dir}/decoder \ --gemm_plugin ${INFERENCE_PRECISION} \ --bert_attention_plugin ${INFERENCE_PRECISION} \ --gpt_attention_plugin ${INFERENCE_PRECISION} \ - --remove_input_padding disable + --remove_input_padding enable ``` ### Run @@ -121,13 +121,14 @@ WEIGHT_ONLY_PRECISION=int8 MAX_BEAM_WIDTH=4 MAX_BATCH_SIZE=8 checkpoint_dir=distil_whisper_medium_en_weights_${WEIGHT_ONLY_PRECISION} -output_dir=distil_whisper_medium_en${WEIGHT_ONLY_PRECISION} +output_dir=distil_whisper_medium_en_${WEIGHT_ONLY_PRECISION} python3 convert_checkpoint.py \ --use_weight_only \ --weight_only_precision $WEIGHT_ONLY_PRECISION \ --output_dir $checkpoint_dir \ - --model_name distil-medium.en + --model_name distil-medium.en \ + --chunk_length 15 ```
Now, we can build and run the model like before:

@@ -160,7 +161,7 @@ trtllm-build --checkpoint_dir ${checkpoint_dir}/decoder \ --gemm_plugin ${INFERENCE_PRECISION} \ --bert_attention_plugin ${INFERENCE_PRECISION} \ --gpt_attention_plugin ${INFERENCE_PRECISION} \ - --remove_input_padding disable + --remove_input_padding enable python3 run.py --engine_dir $output_dir --dataset hf-internal-testing/librispeech_asr_dummy --name librispeech_dummy_${output_dir} ``` diff --git a/examples/whisper/convert_checkpoint.py b/examples/whisper/convert_checkpoint.py index 31ae15cf9..4bd5bfe6b 100644 --- a/examples/whisper/convert_checkpoint.py +++ b/examples/whisper/convert_checkpoint.py @@ -25,6 +25,7 @@ from tensorrt_llm.functional import LayerNormPositionType, LayerNormType from tensorrt_llm.models.convert_utils import weight_only_quantize_dict from tensorrt_llm.quantization import QuantAlgo +from whisper_utils import SAMPLE_RATE, HOP_LENGTH def parse_arguments(): @@ -50,6 +51,10 @@ def parse_arguments(): "distil-medium.en", "distil-small.en", ]) + parser.add_argument("--chunk_length", + type=int, + default=30, + help="Chunk length in seconds for encoder input") parser.add_argument('--dtype', type=str, default='float16', @@ -83,7 +88,7 @@ def parse_arguments(): return args -def get_encoder_config(model_metadata: dict, dtype: str, +def get_encoder_config(model_metadata: dict, dtype: str, chunk_length: int, quant_algo: QuantAlgo) -> dict: model_is_multilingual = (model_metadata['n_vocab'] >= 51865) num_languages = model_metadata['n_vocab'] - 51765 - int( @@ -96,6 +101,7 @@ def get_encoder_config(model_metadata: dict, dtype: str, 'hidden_size': model_metadata['n_audio_state'], 'n_mels': model_metadata['n_mels'], 'n_audio_ctx': model_metadata['n_audio_ctx'], + 'chunk_length': int(chunk_length * SAMPLE_RATE / HOP_LENGTH), 'vocab_size': model_metadata['n_vocab'], 'hidden_act': "gelu", 'num_languages': num_languages, @@ -397,7 +403,9 @@ def convert_openai_whisper_decoder(model_metadata: dict, def convert_and_save(component: str = "encoder"): # call get_encoder_config or get_decoder_config according to component if component == "encoder": - config = get_encoder_config(model_metadata, args.dtype, quant_algo) + config = get_encoder_config( + model_metadata, args.dtype, args.chunk_length, quant_algo + ) else: config = get_decoder_config(model_metadata, args.dtype, args.logits_dtype, quant_algo) diff --git a/examples/whisper/distil_whisper/convert_from_distil_whisper.py b/examples/whisper/distil_whisper/convert_from_distil_whisper.py index 479961d13..d96cccf51 100644 --- a/examples/whisper/distil_whisper/convert_from_distil_whisper.py +++ b/examples/whisper/distil_whisper/convert_from_distil_whisper.py @@ -53,10 +53,10 @@ def main(): print("Trying to load the model from the cache") model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir, - use_safetensors=True) + use_safetensors=True).half() else: print("Downloading the model:") - model = AutoModel.from_pretrained(model_name, use_safetensors=True) + model = AutoModel.from_pretrained(model_name, use_safetensors=True).half() config = model.config model_dims = { diff --git a/examples/whisper/run.py b/examples/whisper/run.py index bcc990b1e..dfc0f2c68 100644 --- a/examples/whisper/run.py +++ b/examples/whisper/run.py @@ -163,42 +163,50 @@ def get_session(self, engine_dir, runtime_mapping, debug_mode=False): return decoder_generation_session - def generate(self, - decoder_input_ids, - encoder_outputs, - eot_id, - max_new_tokens=40, - num_beams=1): + def generate( + self, decoder_input_ids, encoder_outputs, eot_id, max_new_tokens=40, num_beams=1 + ): + + if torch.is_tensor(encoder_outputs): + encoder_outputs = list(encoder_outputs) + if torch.is_tensor(decoder_input_ids): + decoder_input_ids = list(decoder_input_ids) + encoder_input_lengths = torch.tensor( - [encoder_outputs.shape[1] for x in range(encoder_outputs.shape[0])], + [encoder_output.shape[0] for encoder_output in encoder_outputs], dtype=torch.int32, - device='cuda') - decoder_input_lengths = torch.tensor([ - decoder_input_ids.shape[-1] - for _ in range(decoder_input_ids.shape[0]) - ], - dtype=torch.int32, - device='cuda') + device="cuda", + ) + decoder_input_lengths = torch.tensor( + [decoder_input_id.shape[-1] for decoder_input_id in decoder_input_ids], + dtype=torch.int32, + device="cuda", + ) decoder_max_input_length = torch.max(decoder_input_lengths).item() - cross_attention_mask = torch.ones( - [encoder_outputs.shape[0], 1, - encoder_outputs.shape[1]]).int().cuda() + cross_attention_mask = ( + torch.ones([len(encoder_outputs), 1, encoder_input_lengths.sum().item()]) + .int() + .cuda() + ) + + encoder_outputs = torch.cat(encoder_outputs).half().cuda() + decoder_input_ids = torch.cat(decoder_input_ids).int().cuda() # generation config - sampling_config = SamplingConfig(end_id=eot_id, - pad_id=eot_id, - num_beams=num_beams) + sampling_config = SamplingConfig( + end_id=eot_id, pad_id=eot_id, num_beams=num_beams + ) self.decoder_generation_session.setup( decoder_input_lengths.size(0), decoder_max_input_length, max_new_tokens, beam_width=num_beams, - encoder_max_input_length=encoder_outputs.shape[1]) + encoder_max_input_length=encoder_input_lengths.max().item(), + ) torch.cuda.synchronize() - decoder_input_ids = decoder_input_ids.type(torch.int32).cuda() output_ids = self.decoder_generation_session.decode( decoder_input_ids, decoder_input_lengths, @@ -233,7 +241,7 @@ def __init__(self, engine_dir, debug_mode=False, assets_dir=None): assert (Path(assets_dir) / "multilingual.tiktoken").exists( ), "multilingual.tiktoken file is not existed in assets_dir" else: - tokenizer_name == "gpt2" + tokenizer_name = "gpt2" assert (Path(assets_dir) / "gpt2.tiktoken").exists( ), "gpt2.tiktoken file is not existed in assets_dir" self.tokenizer = get_tokenizer(name=tokenizer_name, diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index 06b62c9ee..d136932e6 100644 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -3393,10 +3393,8 @@ def conv1d(input: Tensor, and bias.producer.type == trt.LayerType.CONSTANT) bias = bias.producer.weights if is_bias_constant else trt.Weights() - input_shuffle_layer = default_trtnet().add_shuffle(input.trt_tensor) - input_shuffle_layer.reshape_dims = trt.Dims([*(input.size()), 1]) - input_shuffled = _create_tensor(input_shuffle_layer.get_output(0), - input_shuffle_layer) + # equivalent to expand_dims(input, -1) but supports multiple dynamic axes + input_shuffled = stack([input], dim=input.ndim()) kernel_size = trt.Dims([kernel_size, 1]) @@ -3414,13 +3412,7 @@ def conv1d(input: Tensor, layer.set_input(2, bias.trt_tensor) output_2d = _create_tensor(layer.get_output(0), layer) - output_2d_shuffle_layer = default_trtnet().add_shuffle(output_2d.trt_tensor) - output_2d_shuffle_layer.reshape_dims = trt.Dims( - [output_2d.size()[0], - output_2d.size()[1], - output_2d.size()[2]]) - output_1d = _create_tensor(output_2d_shuffle_layer.get_output(0), - output_2d_shuffle_layer) + output_1d = squeeze(output_2d, dim=-1) return output_1d diff --git a/tensorrt_llm/models/enc_dec/model.py b/tensorrt_llm/models/enc_dec/model.py index 8aee9977c..47b5f15ac 100644 --- a/tensorrt_llm/models/enc_dec/model.py +++ b/tensorrt_llm/models/enc_dec/model.py @@ -25,7 +25,7 @@ MLPType, PositionEmbeddingType, Tensor, assertion, cast, gather_last_token_logits, gelu, maximum, minimum, recv, send, shape, - transpose) + transpose, slice) from tensorrt_llm.layers import (MLP, Attention, AttentionMaskType, AttentionParams, BertAttention, ColumnLinear, Conv1d, Embedding, FusedGatedMLP, GatedMLP, @@ -1885,7 +1885,11 @@ def forward(self, x: Tensor, input_lengths=None): x = cast(x, x_type) x = gelu(x) x = transpose(x, 2, 1) - x = x + cast(self.positional_embedding.value, x_type) + x = x + cast(slice(input=self.positional_embedding.value, + starts=[0,0], + sizes=[x.shape[1], self.positional_embedding.shape[1]], + strides=[1,1]), + x.dtype) hidden_states = x for encoder_layer in self.encoder_layers: @@ -1903,11 +1907,11 @@ def prepare_inputs(self, max_batch_size=16): x = Tensor(name="x", dtype=self._dtype, - shape=[-1, self.config.n_mels, 3000], + shape=[-1, self.config.n_mels, self.config.chunk_length], dim_range=OrderedDict([ ("batch_size", [bs_range]), ("feature_dim", [self.config.n_mels]), - ("feature_len_range", [3000]), + ("feature_len_range", [self.config.chunk_length]), ])) input_lengths = Tensor( name="input_lengths",