-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
119 lines (95 loc) · 4.24 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Markus Enzweiler - markus.enzweiler@hs-esslingen.de
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import utils
def get_loaders(dataset_name, img_size, batch_size, root="./data"):
load_fn = None
num_img_channels = 0
if dataset_name in "mnist":
load_fn = torchvision.datasets.MNIST
num_img_channels = 1
elif dataset_name == "fashion-mnist":
load_fn = torchvision.datasets.FashionMNIST
num_img_channels = 1
elif dataset_name == "cifar-10":
load_fn = torchvision.datasets.CIFAR10
num_img_channels = 3
elif dataset_name == "cifar-100":
load_fn = torchvision.datasets.CIFAR10
num_img_channels = 3
elif dataset_name == "celeb-a":
load_fn = torchvision.datasets.CelebA
num_img_channels = 3
else:
raise ValueError(f"Unknown dataset {dataset_name}")
train_loader, test_loader, classes_list = torchvision_load(
dataset_name, batch_size, load_fn, img_size, root
)
return train_loader, test_loader, classes_list, num_img_channels
def torchvision_load(
dataset_name, batch_size, load_fn, img_size=(64, 64), root="./data"
):
transform_gray = transforms.Compose(
[
transforms.Resize(img_size), # resize the image to img_size pixels
transforms.ToTensor(), # convert to tensor. This will also normalize pixels to 0-1
transforms.Normalize((0.5,), (0.5,)), # normalize to -1 to 1 range
]
)
transform_rgb = transforms.Compose(
[
transforms.Resize(img_size), # resize the image to img_size pixels
transforms.ToTensor(), # convert to tensor. This will also normalize pixels to 0-1
transforms.Normalize(
(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
), # normalize to -1 to 1 range
]
)
# load train and test sets using torchvision
if dataset_name == "celeb-a":
tr = load_fn(root=root, split="train", download=True, transform=transform_rgb)
test = load_fn(root=root, split="test", download=True, transform=transform_rgb)
classes_list = None # could use "identity" attribute of the dataset
elif dataset_name in ["cifar-100", "cifar-10"]:
tr = load_fn(root=root, train=True, download=True, transform=transform_rgb)
test = load_fn(root=root, train=False, download=True, transform=transform_rgb)
classes_list = tr.classes
elif dataset_name in ["mnist", "fashion-mnist"]:
tr = load_fn(root=root, train=True, download=True, transform=transform_gray)
test = load_fn(root=root, train=False, download=True, transform=transform_gray)
classes_list = tr.classes
else:
raise ValueError(f"Unknown dataset {dataset_name}")
# Data loaders
train_loader = torch.utils.data.DataLoader(
tr, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=2
)
test_loader = torch.utils.data.DataLoader(
test, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=2
)
return train_loader, test_loader, classes_list
if __name__ == "__main__":
batch_size = 32
img_size = (28, 28)
tr_loader, test_loader, classes_list, num_img_channels = get_loaders(
"fashion-mnist", img_size=img_size, batch_size=batch_size
)
B, C, H, W = batch_size, num_img_channels, img_size[0], img_size[1]
print(f"Batch size: {B}, Channels: {C}, Height: {H}, Width: {W}")
images, labels = next(iter(tr_loader))
assert images.shape == (B, C, H, W), "Wrong training set size"
assert labels.shape == (B,), "Wrong training set size"
images, labels = next(iter(test_loader))
assert images.shape == (B, C, H, W), "Wrong training set size"
assert labels.shape == (B,), "Wrong training set size"
print(f"Classes : {classes_list}")
# print min / max images
print(f"Images: min {torch.min(images):6.5f} | max {torch.max(images):6.5f}")
# Save an image as a sanity check
# Convert PyTorch tensor to numpy array and scale to 0-255
img_data = (images[0].detach().cpu().numpy() * 255).astype(np.uint8)
# Save the image using Pillow
utils.save_image(img_data, "/tmp/trainTmp.png")
print("Dataset prepared successfully!")