-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
42 lines (33 loc) · 1.88 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Import librariesnn.Relu
import torch
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
def prepare_data(batch_size=4, num_workers=2, train_sample_size=None, test_sample_size=None):
train_transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Resize((32, 32)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomResizedCrop((32, 32), scale=(0.8, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=2),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=train_transform)
if train_sample_size is not None:
# Randomly sample a subset of the training set
indices = torch.randperm(len(trainset))[:train_sample_size]
trainset = torch.utils.data.Subset(trainset, indices)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=num_workers)
test_transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Resize((32, 32)),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=test_transform)
if test_sample_size is not None:
# Randomly sample a subset of the test set
indices = torch.randperm(len(testset))[:test_sample_size]
testset = torch.utils.data.Subset(testset, indices)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
classes = ('deer', 'car', 'frog', 'horse', 'ship', 'truck', 'cat', 'bird', 'plane', 'dog')
return trainloader, testloader, classes