-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
29 lines (22 loc) · 789 Bytes
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import yaml
import torch
from torch.utils.data import Dataset, DataLoader
class ClassificationDataset(Dataset):
def __init__(self, data):
"""
:param data: dictionary of sentences, lengths and labels
"""
self.sentences = data['sentences']
self.labels = data['labels']
self.lengths = data['lengths']
assert len(self.sentences) == len(self.labels) == len(self.lengths)
def __len__(self):
return len(self.sentences)
def __getitem__(self, idx):
element = {'sentences': self.sentences[idx],
'labels': self.labels[idx],
'lengths': self.lengths[idx]
}
return element
if __name__ == "__main__":
config = yaml.safe_load(open("config.yml"))