-
Notifications
You must be signed in to change notification settings - Fork 6
/
data.py
149 lines (108 loc) · 4.03 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from pathlib import Path
from functools import partial, wraps
from beartype import beartype
from beartype.typing import Tuple, Union, Optional
from beartype.door import is_bearable
import torchaudio
from torchaudio.functional import resample
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
import random
from utils import curtail_to_multiple
from scipy.io.wavfile import read
from einops import rearrange
# helper functions
def exists(val):
return val is not None
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
# type
OptionalIntOrTupleInt = Optional[Union[int, Tuple[Optional[int], ...]]]
MAX_WAV_VALUE = 32768.0
# dataset functions
def load_wav(full_path):
sampling_rate, data = read(full_path)
return data, sampling_rate
@beartype
class SoundDataset(Dataset):
def __init__(
self,
# folder,
training_files,
split,
segment_size,
shuffle,
validate=False,
hop_length=80,
exts = ['flac', 'wav'],
max_length: OptionalIntOrTupleInt = None,
target_sample_hz: OptionalIntOrTupleInt = None,
seq_len_multiple_of: OptionalIntOrTupleInt = None
):
super().__init__()
# path = Path(folder)
# assert path.exists(), 'folder does not exist'
# files = [file for ext in exts for file in path.glob(f'**/*.{ext}')]
# assert len(files) > 0, 'no sound files found'
# self.files = files
self.files = training_files
self.split = split
self.validate = validate
self.hop_length = hop_length
self.shuffle = shuffle
self.segment_size = segment_size
self.target_sample_hz = cast_tuple(target_sample_hz)
num_outputs = len(self.target_sample_hz)
self.max_length = cast_tuple(max_length, num_outputs)
self.seq_len_multiple_of = cast_tuple(seq_len_multiple_of, num_outputs)
assert len(self.max_length) == len(self.target_sample_hz) == len(self.seq_len_multiple_of)
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
random.seed(1234)
if self.shuffle:
random.shuffle(self.files)
file = self.files[idx]
data, sample_hz = load_wav(file)
data = data / MAX_WAV_VALUE
data = torch.FloatTensor(data)
assert data.numel() > 0, f'one of your audio file ({file}) is empty. please remove it from your folder'
data = data.unsqueeze(0)
if self.split:
if data.size(1) >= self.segment_size:
max_start = data.size(1) - self.segment_size
data_start = random.randint(0, max_start)
data = data[:, data_start:data_start + self.segment_size]
else:
data = torch.nn.functional.pad(data, (0, self.segment_size - data.size(1)), 'constant')
return data.squeeze(0)
# dataloader functions
def collate_one_or_multiple_tensors(fn):
@wraps(fn)
def inner(data):
is_one_data = not isinstance(data[0], tuple)
if is_one_data:
data = torch.stack(data)
return (data,)
outputs = []
for datum in zip(*data):
if is_bearable(datum, Tuple[str, ...]):
output = list(datum)
else:
output = fn(datum)
outputs.append(output)
return tuple(outputs)
return inner
@collate_one_or_multiple_tensors
def curtail_to_shortest_collate(data):
min_len = min(*[datum.shape[0] for datum in data])
data = [datum[:min_len] for datum in data]
return torch.stack(data)
@collate_one_or_multiple_tensors
def pad_to_longest_fn(data):
return pad_sequence(data, batch_first = True)
def get_dataloader(ds, pad_to_longest = True, **kwargs):
collate_fn = pad_to_longest_fn if pad_to_longest else curtail_to_shortest_collate
return DataLoader(ds, collate_fn = collate_fn, **kwargs)