From f56726bffc388f57b2e0553603cbe1fa52ad3935 Mon Sep 17 00:00:00 2001 From: Mddct Date: Sat, 2 Mar 2024 23:52:07 +0800 Subject: [PATCH] refactor --- wenet/dataset/dataset.py | 3 --- wenet/dataset/processor.py | 28 +++++++++++----------------- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/wenet/dataset/dataset.py b/wenet/dataset/dataset.py index 30e552977..4853e5135 100644 --- a/wenet/dataset/dataset.py +++ b/wenet/dataset/dataset.py @@ -47,9 +47,6 @@ def Dataset(data_type, dataset = WenetTarShardDatasetSource(data_list_file, partition=partition) dataset = dataset.map_ignore_error(processor.decode_wav) - wav_conf = conf.get('wav_conf', None) - if wav_conf is not None and wav_conf.get('pad_or_trim', False): - dataset = dataset.map(processor.pad_or_trim_wav) speaker_conf = conf.get('speaker_conf', None) if speaker_conf is not None: diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index 72abb3020..4f90ac78f 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -107,22 +107,6 @@ def detect_task(sample): return sample -def pad_or_trim_wav(sample, max_duration=30): - assert "wav" in sample - assert 'sample_rate' in sample - assert sample['sample_rate'] == 16000 - length = max_duration * 16000 - wav = sample['wav'] - assert isinstance(wav, torch.Tensor) - if wav.size(1) >= length: - sample['wav'] = wav[:, :length] - else: - pad_zeros = [[0.0] * (length - wav.size(1))] - pad_zeros = torch.tensor(pad_zeros).detach() - sample['wav'] = torch.cat([wav, pad_zeros], dim=-1) - return sample - - def decode_wav(sample): """ Parse key/wav/txt from json line @@ -303,13 +287,16 @@ def compute_log_mel_spectrogram(sample, n_fft=400, hop_length=160, num_mel_bins=80, - padding=0): + padding=0, + pad_or_trim: bool = False, + max_duration: int = 30): """ Extract log mel spectrogram, modified from openai-whisper, see: - https://github.com/openai/whisper/blob/main/whisper/audio.py - https://github.com/wenet-e2e/wenet/pull/2141#issuecomment-1811765040 Args: sample: {key, wav, sample_rate, ...} + max_duration: valid when pad_or_trim is True (orign whisper style) Returns: {key, feat, wav, sample_rate, ...} @@ -321,6 +308,13 @@ def compute_log_mel_spectrogram(sample, waveform = sample['wav'].squeeze(0) # (channel=1, sample) -> (sample,) if padding > 0: waveform = F.pad(waveform, (0, padding)) + if pad_or_trim: + length = max_duration * sample_rate + if waveform.size(1) >= length: + waveform = waveform[:, :length] + else: + waveform = F.pad(waveform, (0, length - waveform.size(1))) + window = torch.hann_window(n_fft) stft = torch.stft(waveform, n_fft,