diff --git a/whisper/audio.py b/whisper/audio.py index a19b7ab0d..513ab7c9d 100644 --- a/whisper/audio.py +++ b/whisper/audio.py @@ -1,6 +1,6 @@ import os from functools import lru_cache -from typing import Union +from typing import Optional, Union import ffmpeg import numpy as np @@ -15,10 +15,8 @@ N_MELS = 80 HOP_LENGTH = 160 CHUNK_LENGTH = 30 -N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk -N_FRAMES = exact_div( - N_SAMPLES, HOP_LENGTH -) # 3000: number of frames in a mel spectrogram input +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk +N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame @@ -100,7 +98,10 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: def log_mel_spectrogram( - audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS + audio: Union[str, np.ndarray, torch.Tensor], + n_mels: int = N_MELS, + padding: int = 0, + device: Optional[Union[str, torch.device]] = None, ): """ Compute the log-Mel spectrogram of @@ -113,6 +114,12 @@ def log_mel_spectrogram( n_mels: int The number of Mel-frequency filters, only 80 is supported + padding: int + Number of zero samples to pad to the right + + device: Optional[Union[str, torch.device]] + If given, the audio tensor is moved to this device before STFT + Returns ------- torch.Tensor, shape = (80, n_frames) @@ -123,6 +130,10 @@ def log_mel_spectrogram( audio = load_audio(audio) audio = torch.from_numpy(audio) + if device is not None: + audio = audio.to(device) + if padding > 0: + audio = F.pad(audio, (0, padding)) window = torch.hann_window(N_FFT).to(audio.device) stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) magnitudes = stft[..., :-1].abs() ** 2 diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 20f01477e..773e6365e 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -11,6 +11,7 @@ FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, + N_SAMPLES, SAMPLE_RATE, log_mel_spectrogram, pad_or_trim, @@ -116,7 +117,9 @@ def transcribe( if dtype == torch.float32: decode_options["fp16"] = False - mel = log_mel_spectrogram(audio) + # Pad 30-seconds of silence to the input audio, for slicing + mel = log_mel_spectrogram(audio, padding=N_SAMPLES) + content_frames = mel.shape[-1] - N_FRAMES if decode_options.get("language", None) is None: if not model.is_multilingual: @@ -212,14 +215,13 @@ def new_segment( } # show the progress bar when verbose is False (if True, transcribed text will be printed) - num_frames = mel.shape[-1] with tqdm.tqdm( - total=num_frames, unit="frames", disable=verbose is not False + total=content_frames, unit="frames", disable=verbose is not False ) as pbar: - while seek < num_frames: + while seek < content_frames: time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) - mel_segment = mel[:, seek:] - segment_size = min(mel_segment.shape[-1], N_FRAMES) + mel_segment = mel[:, seek : seek + N_FRAMES] + segment_size = min(N_FRAMES, content_frames - seek) segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) @@ -246,20 +248,18 @@ def new_segment( current_tokens = [] timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) - consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[ - 0 - ].add_(1) - if ( - len(consecutive) > 0 - ): # if the output contains two consecutive timestamp tokens - if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [ - False, - True, - ]: - consecutive = consecutive.tolist() + [len(tokens)] + single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] + + consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + consecutive.add_(1) + if len(consecutive) > 0: + # if the output contains two consecutive timestamp tokens + slices = consecutive.tolist() + if single_timestamp_ending: + slices.append(len(tokens)) last_slice = 0 - for current_slice in consecutive: + for current_slice in slices: sliced_tokens = tokens[last_slice:current_slice] start_timestamp_pos = ( sliced_tokens[0].item() - tokenizer.timestamp_begin @@ -278,7 +278,7 @@ def new_segment( current_tokens.append(sliced_tokens.tolist()) last_slice = current_slice - if ended_with_single_timestamp: + if single_timestamp_ending: # single timestamp at the end means no speech after the last timestamp. seek += segment_size else: @@ -329,7 +329,7 @@ def new_segment( word_end_timestamps = [ w["end"] for s in current_segments for w in s["words"] ] - if len(consecutive) > 0 and len(word_end_timestamps) > 0: + if not single_timestamp_ending and len(word_end_timestamps) > 0: seek_shift = round( (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND ) @@ -356,7 +356,7 @@ def new_segment( ) # update progress bar - pbar.update(min(num_frames, seek) - previous_seek) + pbar.update(min(content_frames, seek) - previous_seek) return dict( text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),