-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdata_loader.py
41 lines (32 loc) · 1.23 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import h5py
import numpy as np
from pathlib import Path
import torch
import random
from torch.utils import data
class HDF5Dataset(data.Dataset):
def __init__(self, h5_path):
self.h5_path = h5_path
self.h_file = h5py.File(h5_path, 'r')
self.length = len(self.h_file['dataset']['id'])
def __getitem__(self, index):
data = self.h_file['dataset']['data'][index]
label = self.h_file['dataset']['label'][index]
sound_id = self.h_file['dataset']['id'][index]
return data, label, sound_id
def __len__(self):
return self.length
class InMemoryDataset(data.Dataset):
def __init__(self, h5_path):
self.h5_path = h5_path
with h5py.File(h5_path, 'r') as h_file:
self.data = h_file['dataset']['data'][:]
self.cf_data = h_file['dataset']['cf_data'][:]
self.label = h_file['dataset']['label'][:]
self.id = h_file['dataset']['id'][:]
self.length = len(self.id)
def __getitem__(self, index):
rnd_index = np.random.randint(0,1300-256)
return self.data[index][rnd_index:(rnd_index+256)], self.label[index], self.cf_data[index], self.id[index]
def __len__(self):
return self.length