Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Mar 2, 2024
1 parent b1e80f6 commit f56726b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 20 deletions.
3 changes: 0 additions & 3 deletions wenet/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ def Dataset(data_type,
dataset = WenetTarShardDatasetSource(data_list_file,
partition=partition)
dataset = dataset.map_ignore_error(processor.decode_wav)
wav_conf = conf.get('wav_conf', None)
if wav_conf is not None and wav_conf.get('pad_or_trim', False):
dataset = dataset.map(processor.pad_or_trim_wav)

speaker_conf = conf.get('speaker_conf', None)
if speaker_conf is not None:
Expand Down
28 changes: 11 additions & 17 deletions wenet/dataset/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,22 +107,6 @@ def detect_task(sample):
return sample


def pad_or_trim_wav(sample, max_duration=30):
assert "wav" in sample
assert 'sample_rate' in sample
assert sample['sample_rate'] == 16000
length = max_duration * 16000
wav = sample['wav']
assert isinstance(wav, torch.Tensor)
if wav.size(1) >= length:
sample['wav'] = wav[:, :length]
else:
pad_zeros = [[0.0] * (length - wav.size(1))]
pad_zeros = torch.tensor(pad_zeros).detach()
sample['wav'] = torch.cat([wav, pad_zeros], dim=-1)
return sample


def decode_wav(sample):
""" Parse key/wav/txt from json line
Expand Down Expand Up @@ -303,13 +287,16 @@ def compute_log_mel_spectrogram(sample,
n_fft=400,
hop_length=160,
num_mel_bins=80,
padding=0):
padding=0,
pad_or_trim: bool = False,
max_duration: int = 30):
""" Extract log mel spectrogram, modified from openai-whisper, see:
- https://github.com/openai/whisper/blob/main/whisper/audio.py
- https://github.com/wenet-e2e/wenet/pull/2141#issuecomment-1811765040
Args:
sample: {key, wav, sample_rate, ...}
max_duration: valid when pad_or_trim is True (orign whisper style)
Returns:
{key, feat, wav, sample_rate, ...}
Expand All @@ -321,6 +308,13 @@ def compute_log_mel_spectrogram(sample,
waveform = sample['wav'].squeeze(0) # (channel=1, sample) -> (sample,)
if padding > 0:
waveform = F.pad(waveform, (0, padding))
if pad_or_trim:
length = max_duration * sample_rate
if waveform.size(1) >= length:
waveform = waveform[:, :length]
else:
waveform = F.pad(waveform, (0, length - waveform.size(1)))

window = torch.hann_window(n_fft)
stft = torch.stft(waveform,
n_fft,
Expand Down

0 comments on commit f56726b

Please sign in to comment.