-
Notifications
You must be signed in to change notification settings - Fork 1
/
gtsrb_dataset.py
48 lines (37 loc) · 1.39 KB
/
gtsrb_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import os
import pandas as pd
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
class GTSRB(Dataset):
base_folder = 'GTSRB'
def __init__(self, root_dir, train=False, transform=None):
"""
Args:
train (bool): Load trainingset or test set.
root_dir (string): Directory containing GTSRB folder.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.root_dir = root_dir
self.sub_directory = 'trainingset' if train else 'testset'
self.csv_file_name = 'training.csv' if train else 'test.csv'
csv_file_path = os.path.join(
root_dir, self.base_folder, self.sub_directory, self.csv_file_name)
if train:
self.csv_data = pd.read_csv(csv_file_path, header=None)
else:
self.csv_data = pd.read_csv(csv_file_path)
self.transform = transform
def __len__(self):
return len(self.csv_data)
def __getitem__(self, idx):
idx = int(idx)
img_path = os.path.join(self.root_dir, self.base_folder, self.sub_directory,
self.csv_data.iloc[idx, 0])
img = Image.open(img_path)
classId = self.csv_data.iloc[idx, 1]
if self.transform is not None:
img = self.transform(img)
return img, classId