This repository has been archived by the owner on Jan 20, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 47
/
data_loader.py
55 lines (45 loc) · 1.82 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torchvision
import torchvision.transforms as transforms
NUM_WORKERS = 2
def get_cifar(num_classes=100, dataset_dir='./data', batch_size=128, crop=False):
"""
:param num_classes: 10 for cifar10, 100 for cifar100
:param dataset_dir: location of datasets, default is a directory named 'data'
:param batch_size: batchsize, default to 128
:param crop: whether or not use randomized horizontal crop, default to False
:return:
"""
normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
simple_transform = transforms.Compose([transforms.ToTensor(), normalize])
if crop is True:
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
])
else:
train_transform = simple_transform
if num_classes == 100:
trainset = torchvision.datasets.CIFAR100(root=dataset_dir, train=True,
download=True, transform=train_transform)
testset = torchvision.datasets.CIFAR100(root=dataset_dir, train=False,
download=True, transform=simple_transform)
else:
trainset = torchvision.datasets.CIFAR10(root=dataset_dir, train=True,
download=True, transform=train_transform)
testset = torchvision.datasets.CIFAR10(root=dataset_dir, train=False,
download=True, transform=simple_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, num_workers=NUM_WORKERS,
pin_memory=True, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, num_workers=NUM_WORKERS,
pin_memory=True, shuffle=False)
return trainloader, testloader
if __name__ == "__main__":
print("CIFAR10")
print(get_cifar(10))
print("---"*20)
print("---"*20)
print("CIFAR100")
print(get_cifar(100))