-
Notifications
You must be signed in to change notification settings - Fork 225
/
Copy pathdataloader.py
203 lines (175 loc) · 7.8 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# -*- coding: utf-8 -*-
# @Time : 6/19/21 12:23 AM
# @Author : Yuan Gong
# @Affiliation : Massachusetts Institute of Technology
# @Email : yuangong@mit.edu
# @File : dataloader.py.py
# modified from:
# Author: David Harwath
# with some functions borrowed from https://github.com/SeanNaren/deepspeech.pytorch
import csv
import json
import torchaudio
import numpy as np
import torch
import torch.nn.functional
from torch.utils.data import Dataset
import random
def make_index_dict(label_csv):
index_lookup = {}
with open(label_csv, 'r') as f:
csv_reader = csv.DictReader(f)
line_count = 0
for row in csv_reader:
index_lookup[row['mid']] = row['index']
line_count += 1
return index_lookup
def make_name_dict(label_csv):
name_lookup = {}
with open(label_csv, 'r') as f:
csv_reader = csv.DictReader(f)
line_count = 0
for row in csv_reader:
name_lookup[row['index']] = row['display_name']
line_count += 1
return name_lookup
def lookup_list(index_list, label_csv):
label_list = []
table = make_name_dict(label_csv)
for item in index_list:
label_list.append(table[item])
return label_list
def preemphasis(signal,coeff=0.97):
"""perform preemphasis on the input signal.
:param signal: The signal to filter.
:param coeff: The preemphasis coefficient. 0 is none, default 0.97.
:returns: the filtered signal.
"""
return np.append(signal[0],signal[1:]-coeff*signal[:-1])
class AudiosetDataset(Dataset):
def __init__(self, dataset_json_file, audio_conf, label_csv=None):
"""
Dataset that manages audio recordings
:param audio_conf: Dictionary containing the audio loading and preprocessing settings
:param dataset_json_file
"""
self.datapath = dataset_json_file
with open(dataset_json_file, 'r') as fp:
data_json = json.load(fp)
self.data = data_json['data']
self.audio_conf = audio_conf
print('---------------the {:s} dataloader---------------'.format(self.audio_conf.get('mode')))
self.melbins = self.audio_conf.get('num_mel_bins')
self.freqm = self.audio_conf.get('freqm')
self.timem = self.audio_conf.get('timem')
print('now using following mask: {:d} freq, {:d} time'.format(self.audio_conf.get('freqm'), self.audio_conf.get('timem')))
self.mixup = self.audio_conf.get('mixup')
print('now using mix-up with rate {:f}'.format(self.mixup))
self.dataset = self.audio_conf.get('dataset')
print('now process ' + self.dataset)
# dataset spectrogram mean and std, used to normalize the input
self.norm_mean = self.audio_conf.get('mean')
self.norm_std = self.audio_conf.get('std')
print('use dataset mean {:.3f} and std {:.3f} to normalize the input'.format(self.norm_mean, self.norm_std))
# if add noise for data augmentation
self.noise = self.audio_conf.get('noise')
if self.noise == True:
print('now use noise augmentation')
self.index_dict = make_index_dict(label_csv)
self.label_num = len(self.index_dict)
print('number of classes is {:d}'.format(self.label_num))
def _wav2fbank(self, filename, filename2=None):
# mixup
if filename2 == None:
waveform, sr = torchaudio.load(filename)
waveform = waveform - waveform.mean()
# mixup
else:
waveform1, sr = torchaudio.load(filename)
waveform2, _ = torchaudio.load(filename2)
waveform1 = waveform1 - waveform1.mean()
waveform2 = waveform2 - waveform2.mean()
if waveform1.shape[1] != waveform2.shape[1]:
if waveform1.shape[1] > waveform2.shape[1]:
# padding
temp_wav = torch.zeros(1, waveform1.shape[1])
temp_wav[0, 0:waveform2.shape[1]] = waveform2
waveform2 = temp_wav
else:
# cutting
waveform2 = waveform2[0, 0:waveform1.shape[1]]
# sample lambda from uniform distribution
#mix_lambda = random.random()
# sample lambda from beta distribtion
mix_lambda = np.random.beta(10, 10)
mix_waveform = mix_lambda * waveform1 + (1 - mix_lambda) * waveform2
waveform = mix_waveform - mix_waveform.mean()
fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sr, use_energy=False,
window_type='hanning', num_mel_bins=self.melbins, dither=0.0, frame_shift=10)
target_length = self.audio_conf.get('target_length')
n_frames = fbank.shape[0]
p = target_length - n_frames
# cut and pad
if p > 0:
m = torch.nn.ZeroPad2d((0, 0, 0, p))
fbank = m(fbank)
elif p < 0:
fbank = fbank[0:target_length, :]
if filename2 == None:
return fbank, 0
else:
return fbank, mix_lambda
def __getitem__(self, index):
"""
returns: image, audio, nframes
where image is a FloatTensor of size (3, H, W)
audio is a FloatTensor of size (N_freq, N_frames) for spectrogram, or (N_frames) for waveform
nframes is an integer
"""
# do mix-up for this sample (controlled by the given mixup rate)
if random.random() < self.mixup:
datum = self.data[index]
# find another sample to mix, also do balance sampling
# sample the other sample from the multinomial distribution, will make the performance worse
# mix_sample_idx = np.random.choice(len(self.data), p=self.sample_weight_file)
# sample the other sample from the uniform distribution
mix_sample_idx = random.randint(0, len(self.data)-1)
mix_datum = self.data[mix_sample_idx]
# get the mixed fbank
fbank, mix_lambda = self._wav2fbank(datum['wav'], mix_datum['wav'])
# initialize the label
label_indices = np.zeros(self.label_num)
# add sample 1 labels
for label_str in datum['labels'].split(','):
label_indices[int(self.index_dict[label_str])] += mix_lambda
# add sample 2 labels
for label_str in mix_datum['labels'].split(','):
label_indices[int(self.index_dict[label_str])] += 1.0-mix_lambda
label_indices = torch.FloatTensor(label_indices)
# if not do mixup
else:
datum = self.data[index]
label_indices = np.zeros(self.label_num)
fbank, mix_lambda = self._wav2fbank(datum['wav'])
for label_str in datum['labels'].split(','):
label_indices[int(self.index_dict[label_str])] = 1.0
label_indices = torch.FloatTensor(label_indices)
# SpecAug, not do for eval set
freqm = torchaudio.transforms.FrequencyMasking(self.freqm)
timem = torchaudio.transforms.TimeMasking(self.timem)
fbank = torch.transpose(fbank, 0, 1)
if self.freqm != 0:
fbank = freqm(fbank)
if self.timem != 0:
fbank = timem(fbank)
fbank = torch.transpose(fbank, 0, 1)
# normalize the input
fbank = (fbank - self.norm_mean) / (self.norm_std * 2)
if self.noise == True:
fbank = fbank + torch.rand(fbank.shape[0], fbank.shape[1]) * np.random.rand() / 10
fbank = torch.roll(fbank, np.random.randint(-10, 10), 0)
mix_ratio = min(mix_lambda, 1-mix_lambda) / max(mix_lambda, 1-mix_lambda)
# the output fbank shape is [time_frame_num, frequency_bins], e.g., [1024, 128]
return fbank, label_indices
def __len__(self):
return len(self.data)