Skip to content

Commit

Permalink
[dataset] support pad or trim for whisper decoding (#2378)
Browse files Browse the repository at this point in the history
* [dataset] support pad or trim for whisper decoding

* refactor

* refactor
  • Loading branch information
Mddct authored Mar 2, 2024
1 parent ad663fd commit 77d951b
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion wenet/dataset/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,16 @@ def compute_log_mel_spectrogram(sample,
n_fft=400,
hop_length=160,
num_mel_bins=80,
padding=0):
padding=0,
pad_or_trim: bool = False,
max_duration: int = 30):
""" Extract log mel spectrogram, modified from openai-whisper, see:
- https://github.com/openai/whisper/blob/main/whisper/audio.py
- https://github.com/wenet-e2e/wenet/pull/2141#issuecomment-1811765040
Args:
sample: {key, wav, sample_rate, ...}
max_duration: valid when pad_or_trim is True (orign whisper style)
Returns:
{key, feat, wav, sample_rate, ...}
Expand All @@ -305,6 +308,13 @@ def compute_log_mel_spectrogram(sample,
waveform = sample['wav'].squeeze(0) # (channel=1, sample) -> (sample,)
if padding > 0:
waveform = F.pad(waveform, (0, padding))
if pad_or_trim:
length = max_duration * sample_rate
if waveform.size(0) >= length:
waveform = waveform[:length]
else:
waveform = F.pad(waveform, (0, length - waveform.size(0)))

window = torch.hann_window(n_fft)
stft = torch.stft(waveform,
n_fft,
Expand Down

0 comments on commit 77d951b

Please sign in to comment.