-
Notifications
You must be signed in to change notification settings - Fork 3
/
dataset.py
58 lines (46 loc) · 1.9 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import torch
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import check_integrity, download_and_extract_archive
URL = "http://madm.dfki.de/files/sentinel/EuroSAT.zip"
MD5 = "c8fa014336c82ac7804f0398fcb19387"
SUBDIR = '2750'
def random_split(dataset, ratio=0.9, random_state=None):
if random_state is not None:
state = torch.random.get_rng_state()
torch.random.manual_seed(random_state)
n = int(len(dataset) * ratio)
split = torch.utils.data.random_split(dataset, [n, len(dataset) - n])
if random_state is not None:
torch.random.set_rng_state(state)
return split
class EuroSAT(ImageFolder):
def __init__(self, root='data', transform=None, target_transform=None):
self.download(root)
root = os.path.join(root, SUBDIR)
super().__init__(root, transform=transform, target_transform=target_transform)
@staticmethod
def download(root):
if not check_integrity(os.path.join(root, "EuroSAT.zip")):
download_and_extract_archive(URL, root, md5=MD5)
# Apparently torchvision doesn't have any loader for this so I made one
# Advantage compared to without loader: get "for free" transforms, DataLoader
# (workers), etc
class ImageFiles(Dataset):
"""
Generic data loader where all paths must be given
"""
def __init__(self, paths: [str], loader=default_loader, transform=None):
self.paths = paths
self.loader = loader
self.transform = transform
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
image = self.loader(self.paths[idx])
if self.transform is not None:
image = self.transform(image)
# WARNING -1 indicates no target, it's useful to keep the same interface as torchvision
return image, -1