-
Notifications
You must be signed in to change notification settings - Fork 3
/
data.py
60 lines (49 loc) · 1.54 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import librosa
import os
import numpy as np
import torch
import torch.utils.data as data
path = 'your path'
CLASSES = []
for _,dir,_ in os.walk(path):
CLASSES = dir
break
print(CLASSES)
class SolarDataset(data.Dataset):
def __init__(self, mode='train', root=path):
super(SolarDataset, self).__init__()
self.root = os.path.join(root,mode)
self.data = list()
self.prep_dataset()
def prep_dataset(self):
for root, dir, files in os.walk(self.root):
for file in files:
f_path, cmd = os.path.join(root, file), root.split('/')[-1]
self.data.append((f_path, cmd))
def __getitem__(self, idx):
f_path, cmd = self.data[idx]
x = self.transform(f_path)
y = CLASSES.index(cmd)
return x, y
def __len__(self):
return len(self.data)
def transform(self, path, sr=16000):
sig, sr = librosa.load(path, sr)
spec = librosa.feature.mfcc(sig, sr=sr, n_mfcc=40)
x = np.array(spec, np.float32, copy=False)
x = torch.from_numpy(x)
return x
def _collate_fn(batch):
inputs = [s[0] for s in batch]
targets = [s[1] for s in batch]
B = len(batch)
F,T = inputs[0].shape
max_len = 0
for input in inputs:
max_len = max(max_len, len(input[0]))
temp = torch.zeros(B,F,max_len)
for x in range(B):
temp[x,:,:inputs[x].size(1)] = inputs[x]
inputs = temp.unsqueeze(1)
targets = torch.LongTensor(targets)
return inputs, targets