-
Notifications
You must be signed in to change notification settings - Fork 1
/
data.py
94 lines (84 loc) · 5.46 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from datasets.dataset import *
from torch.utils.data import Dataset, DataLoader
import utils
def get_dataset(args):
if args.dataset == 'mnist':
data = MultiDigitSplits(dataset=args.dataset, num_compare=args.num_compare)
data_loader_train = data.get_train_loader(batch_size=args.batch_size, drop_last=True)
data_loader_valid = data.get_valid_loader(batch_size=args.batch_size, drop_last=True)
image_size = (28, 112)
channels = 1
image_patch_size = 7
if args.dataset == 'svhn':
data = MultiDigitSplits(dataset=args.dataset, num_compare=args.num_compare)
data_loader_train = data.get_train_loader(batch_size=args.batch_size, drop_last=True)
data_loader_valid = data.get_test_loader(batch_size=args.batch_size, drop_last=True)
image_size = (54, 54)
channels = 3
image_patch_size = 6
if args.dataset == 'rds':
train_dataset = RandomDotsDataset(args.num_compare, args.reversible, 100000)
val_dataset = RandomDotsDataset(args.num_compare, args.reversible, 1000)
data_loader_train = utils.FastDataLoader(train_dataset, batch_size=args.batch_size, drop_last=True, num_workers=2, pin_memory=True, shuffle=True)
data_loader_valid = utils.FastDataLoader(val_dataset, batch_size=args.batch_size, drop_last=True, num_workers=2, pin_memory=True, shuffle=True)
image_size = 42
channels = 1
image_patch_size = 7
if args.dataset == 'clocks_cropped':
train_dataset = TimelapseClocks(args.num_compare, 100000,True)
val_dataset = TimelapseClocks(args.num_compare, 1000,False)
data_loader_train = utils.FastDataLoader(train_dataset, batch_size=args.batch_size, drop_last=True, num_workers=8, pin_memory=True, shuffle=True)
data_loader_valid = utils.FastDataLoader(val_dataset, batch_size=args.batch_size, drop_last=True, num_workers=8, pin_memory=True, shuffle=True)
image_size = 196
channels = 3
image_patch_size = 14
if args.dataset == 'clocks_full':
train_dataset = TimelaspseFull(train=True, t=args.num_compare, dt=5)
val_dataset = TimelaspseFull(train=False, t=args.num_compare, dt=5)
data_loader_train = utils.FastDataLoader(train_dataset, batch_size=args.batch_size, drop_last=True, num_workers=8, pin_memory=True, shuffle=True)
data_loader_valid = utils.FastDataLoader(val_dataset, batch_size=args.batch_size, drop_last=True, num_workers=8, pin_memory=True, shuffle=True)
image_size = (160*2, 240*2)
channels = 3
image_patch_size = 20
if args.dataset == 'scenes':
train_dataset = SkyLapseDataset(num_compare=args.num_compare, train=True)
val_dataset = SkyLapseDataset(num_compare=args.num_compare, train=False)
data_loader_train = utils.FastDataLoader(train_dataset, batch_size=args.batch_size, drop_last=True, num_workers=8, pin_memory=True, shuffle=True)
data_loader_valid = utils.FastDataLoader(val_dataset, batch_size=args.batch_size, drop_last=True, num_workers=8, pin_memory=True, shuffle=True)
image_size = (336, 336)
channels = 3
image_patch_size = 21
if args.dataset == 'moca':
train_dataset = MoCADataset(num_compare=args.num_compare, train=True)
val_dataset = MoCADataset(num_compare=args.num_compare, train=False)
data_loader_train = utils.FastDataLoader(train_dataset, batch_size=args.batch_size, drop_last=True, num_workers=8, pin_memory=True, shuffle=True)
data_loader_valid = utils.FastDataLoader(val_dataset, batch_size=args.batch_size, drop_last=True, num_workers=1, pin_memory=False, shuffle=True)
image_size = (336, 336)
channels = 3
image_patch_size = 21
if args.dataset == 'muds':
train_dataset = SpaceDataset(train=True, t=args.num_compare)
val_dataset = SpaceMonoDataset(train=False, t=args.num_compare)
data_loader_train = utils.FastDataLoader(train_dataset, batch_size=args.batch_size, drop_last=True, num_workers=8, pin_memory=True, shuffle=True)
data_loader_valid = utils.FastDataLoader(val_dataset, batch_size=60, drop_last=False, num_workers=8, pin_memory=True, shuffle=False)
image_size = (196,196)
channels = 3
image_patch_size = 7
if args.dataset == 'calfire':
train_dataset = CalFireDataset(train=True, t=args.num_compare)
val_dataset = CalFireDataset(train=False, t=args.num_compare)
data_loader_train = utils.FastDataLoader(train_dataset, batch_size=args.batch_size, drop_last=True, num_workers=8, pin_memory=True, shuffle=True)
data_loader_valid = utils.FastDataLoader(val_dataset, batch_size=args.batch_size, drop_last=True, num_workers=8, pin_memory=True, shuffle=True)
image_size = (196,196)
channels = 3
image_patch_size = 14
if args.dataset == 'mri3':
train_dataset = MRI3Dataset(train=True, t=args.num_compare)
val_dataset = MRI3Dataset(train=False, t=args.num_compare)
data_loader_train = utils.FastDataLoader(train_dataset, batch_size=args.batch_size, drop_last=True, num_workers=8, pin_memory=True, shuffle=True)
data_loader_valid = utils.FastDataLoader(val_dataset, batch_size=args.batch_size, drop_last=True, num_workers=8, pin_memory=True, shuffle=True)
image_size = (224, 154)
channels = 1
image_patch_size = 14
image_size = utils.pair(image_size)
return data_loader_train, data_loader_valid, image_size, channels, image_patch_size