Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add chunk_length parameter to Whisper #1909

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions examples/whisper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ trtllm-build --checkpoint_dir ${checkpoint_dir}/decoder \
--gemm_plugin ${INFERENCE_PRECISION} \
--bert_attention_plugin ${INFERENCE_PRECISION} \
--gpt_attention_plugin ${INFERENCE_PRECISION} \
--remove_input_padding disable
--remove_input_padding enable
```

### Run
Expand Down Expand Up @@ -121,13 +121,14 @@ WEIGHT_ONLY_PRECISION=int8
MAX_BEAM_WIDTH=4
MAX_BATCH_SIZE=8
checkpoint_dir=distil_whisper_medium_en_weights_${WEIGHT_ONLY_PRECISION}
output_dir=distil_whisper_medium_en${WEIGHT_ONLY_PRECISION}
output_dir=distil_whisper_medium_en_${WEIGHT_ONLY_PRECISION}

python3 convert_checkpoint.py \
--use_weight_only \
--weight_only_precision $WEIGHT_ONLY_PRECISION \
--output_dir $checkpoint_dir \
--model_name distil-medium.en
--model_name distil-medium.en \
--chunk_length 15
```

<details><summary> Now, we can build and run the model like before: </summary><p>
Expand Down Expand Up @@ -160,7 +161,7 @@ trtllm-build --checkpoint_dir ${checkpoint_dir}/decoder \
--gemm_plugin ${INFERENCE_PRECISION} \
--bert_attention_plugin ${INFERENCE_PRECISION} \
--gpt_attention_plugin ${INFERENCE_PRECISION} \
--remove_input_padding disable
--remove_input_padding enable

python3 run.py --engine_dir $output_dir --dataset hf-internal-testing/librispeech_asr_dummy --name librispeech_dummy_${output_dir}
```
Expand Down
12 changes: 10 additions & 2 deletions examples/whisper/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tensorrt_llm.functional import LayerNormPositionType, LayerNormType
from tensorrt_llm.models.convert_utils import weight_only_quantize_dict
from tensorrt_llm.quantization import QuantAlgo
from whisper_utils import SAMPLE_RATE, HOP_LENGTH


def parse_arguments():
Expand All @@ -50,6 +51,10 @@ def parse_arguments():
"distil-medium.en",
"distil-small.en",
])
parser.add_argument("--chunk_length",
type=int,
default=30,
help="Chunk length in seconds for encoder input")
parser.add_argument('--dtype',
type=str,
default='float16',
Expand Down Expand Up @@ -83,7 +88,7 @@ def parse_arguments():
return args


def get_encoder_config(model_metadata: dict, dtype: str,
def get_encoder_config(model_metadata: dict, dtype: str, chunk_length: int,
quant_algo: QuantAlgo) -> dict:
model_is_multilingual = (model_metadata['n_vocab'] >= 51865)
num_languages = model_metadata['n_vocab'] - 51765 - int(
Expand All @@ -96,6 +101,7 @@ def get_encoder_config(model_metadata: dict, dtype: str,
'hidden_size': model_metadata['n_audio_state'],
'n_mels': model_metadata['n_mels'],
'n_audio_ctx': model_metadata['n_audio_ctx'],
'chunk_length': int(chunk_length * SAMPLE_RATE / HOP_LENGTH),
'vocab_size': model_metadata['n_vocab'],
'hidden_act': "gelu",
'num_languages': num_languages,
Expand Down Expand Up @@ -397,7 +403,9 @@ def convert_openai_whisper_decoder(model_metadata: dict,
def convert_and_save(component: str = "encoder"):
# call get_encoder_config or get_decoder_config according to component
if component == "encoder":
config = get_encoder_config(model_metadata, args.dtype, quant_algo)
config = get_encoder_config(
model_metadata, args.dtype, args.chunk_length, quant_algo
)
else:
config = get_decoder_config(model_metadata, args.dtype,
args.logits_dtype, quant_algo)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def main():
print("Trying to load the model from the cache")
model = AutoModel.from_pretrained(model_name,
cache_dir=cache_dir,
use_safetensors=True)
use_safetensors=True).half()
else:
print("Downloading the model:")
model = AutoModel.from_pretrained(model_name, use_safetensors=True)
model = AutoModel.from_pretrained(model_name, use_safetensors=True).half()

config = model.config
model_dims = {
Expand Down
54 changes: 31 additions & 23 deletions examples/whisper/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,42 +163,50 @@ def get_session(self, engine_dir, runtime_mapping, debug_mode=False):

return decoder_generation_session

def generate(self,
decoder_input_ids,
encoder_outputs,
eot_id,
max_new_tokens=40,
num_beams=1):
def generate(
self, decoder_input_ids, encoder_outputs, eot_id, max_new_tokens=40, num_beams=1
):

if torch.is_tensor(encoder_outputs):
encoder_outputs = list(encoder_outputs)
if torch.is_tensor(decoder_input_ids):
decoder_input_ids = list(decoder_input_ids)

encoder_input_lengths = torch.tensor(
[encoder_outputs.shape[1] for x in range(encoder_outputs.shape[0])],
[encoder_output.shape[0] for encoder_output in encoder_outputs],
dtype=torch.int32,
device='cuda')
decoder_input_lengths = torch.tensor([
decoder_input_ids.shape[-1]
for _ in range(decoder_input_ids.shape[0])
],
dtype=torch.int32,
device='cuda')
device="cuda",
)
decoder_input_lengths = torch.tensor(
[decoder_input_id.shape[-1] for decoder_input_id in decoder_input_ids],
dtype=torch.int32,
device="cuda",
)
decoder_max_input_length = torch.max(decoder_input_lengths).item()

cross_attention_mask = torch.ones(
[encoder_outputs.shape[0], 1,
encoder_outputs.shape[1]]).int().cuda()
cross_attention_mask = (
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @MahmoudAshraf97 if I understand correctly, you are making this change because distil-whisper can work on dynamic chunk sizes, unlike whisper which must use fixed 30 second chunks. Am I understanding correctly? Thank you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @galv , this PR contains 2 main changes

  1. Encoder is no longer restricted to 30s inputs, this helps in case of distil-whisper as you mentioned
  2. Decoder now supports remove_input_padding and accepts packed inputs to save memory

torch.ones([len(encoder_outputs), 1, encoder_input_lengths.sum().item()])
.int()
.cuda()
)

encoder_outputs = torch.cat(encoder_outputs).half().cuda()
decoder_input_ids = torch.cat(decoder_input_ids).int().cuda()

# generation config
sampling_config = SamplingConfig(end_id=eot_id,
pad_id=eot_id,
num_beams=num_beams)
sampling_config = SamplingConfig(
end_id=eot_id, pad_id=eot_id, num_beams=num_beams
)
self.decoder_generation_session.setup(
decoder_input_lengths.size(0),
decoder_max_input_length,
max_new_tokens,
beam_width=num_beams,
encoder_max_input_length=encoder_outputs.shape[1])
encoder_max_input_length=encoder_input_lengths.max().item(),
)

torch.cuda.synchronize()

decoder_input_ids = decoder_input_ids.type(torch.int32).cuda()
output_ids = self.decoder_generation_session.decode(
decoder_input_ids,
decoder_input_lengths,
Expand Down Expand Up @@ -233,7 +241,7 @@ def __init__(self, engine_dir, debug_mode=False, assets_dir=None):
assert (Path(assets_dir) / "multilingual.tiktoken").exists(
), "multilingual.tiktoken file is not existed in assets_dir"
else:
tokenizer_name == "gpt2"
tokenizer_name = "gpt2"
assert (Path(assets_dir) / "gpt2.tiktoken").exists(
), "gpt2.tiktoken file is not existed in assets_dir"
self.tokenizer = get_tokenizer(name=tokenizer_name,
Expand Down
14 changes: 3 additions & 11 deletions tensorrt_llm/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3393,10 +3393,8 @@ def conv1d(input: Tensor,
and bias.producer.type == trt.LayerType.CONSTANT)
bias = bias.producer.weights if is_bias_constant else trt.Weights()

input_shuffle_layer = default_trtnet().add_shuffle(input.trt_tensor)
input_shuffle_layer.reshape_dims = trt.Dims([*(input.size()), 1])
input_shuffled = _create_tensor(input_shuffle_layer.get_output(0),
input_shuffle_layer)
# equivalent to expand_dims(input, -1) but supports multiple dynamic axes
input_shuffled = stack([input], dim=input.ndim())

kernel_size = trt.Dims([kernel_size, 1])

Expand All @@ -3414,13 +3412,7 @@ def conv1d(input: Tensor,
layer.set_input(2, bias.trt_tensor)

output_2d = _create_tensor(layer.get_output(0), layer)
output_2d_shuffle_layer = default_trtnet().add_shuffle(output_2d.trt_tensor)
output_2d_shuffle_layer.reshape_dims = trt.Dims(
[output_2d.size()[0],
output_2d.size()[1],
output_2d.size()[2]])
output_1d = _create_tensor(output_2d_shuffle_layer.get_output(0),
output_2d_shuffle_layer)
output_1d = squeeze(output_2d, dim=-1)

return output_1d

Expand Down
12 changes: 8 additions & 4 deletions tensorrt_llm/models/enc_dec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
MLPType, PositionEmbeddingType, Tensor,
assertion, cast, gather_last_token_logits,
gelu, maximum, minimum, recv, send, shape,
transpose)
transpose, slice)
from tensorrt_llm.layers import (MLP, Attention, AttentionMaskType,
AttentionParams, BertAttention, ColumnLinear,
Conv1d, Embedding, FusedGatedMLP, GatedMLP,
Expand Down Expand Up @@ -1885,7 +1885,11 @@ def forward(self, x: Tensor, input_lengths=None):
x = cast(x, x_type)
x = gelu(x)
x = transpose(x, 2, 1)
x = x + cast(self.positional_embedding.value, x_type)
x = x + cast(slice(input=self.positional_embedding.value,
starts=[0,0],
sizes=[x.shape[1], self.positional_embedding.shape[1]],
strides=[1,1]),
x.dtype)

hidden_states = x
for encoder_layer in self.encoder_layers:
Expand All @@ -1903,11 +1907,11 @@ def prepare_inputs(self, max_batch_size=16):

x = Tensor(name="x",
dtype=self._dtype,
shape=[-1, self.config.n_mels, 3000],
shape=[-1, self.config.n_mels, self.config.chunk_length],
dim_range=OrderedDict([
("batch_size", [bs_range]),
("feature_dim", [self.config.n_mels]),
("feature_len_range", [3000]),
("feature_len_range", [self.config.chunk_length]),
]))
input_lengths = Tensor(
name="input_lengths",
Expand Down