-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
116 lines (103 loc) · 4.24 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from PIL import Image
import numpy as np
import torch
import pickle
from torch.utils.data import Dataset
from torchvision import transforms
import os
import lmdb
import json
FFHQ_DATA_DIR = 'VQVAE/data/ffhq_images'
FFHQ_LABELS_DIR = 'VQVAE/data/ffhq-features-dataset-master'
CAT_DATA_DIR = 'VQVAE/data/cat_faces/cats'
class FFHQDataset(Dataset):
def __init__(self, data_dir = FFHQ_DATA_DIR, labels_dir = FFHQ_LABELS_DIR):
# helper function
def flatten_dict(dic, pfx=""):
add_dict = {}
for k, v in dic.items():
if isinstance(v, dict):
add_dict.update(flatten_dict(v, k))
else:
k = k if pfx == "" else pfx+"_"+k
add_dict[k] = v
return add_dict
self.img_files = []
print("create memory efficient dataset from data ", data_dir)
self.labels = np.load(f'{labels_dir}/all_labels.npy',allow_pickle=True).item()
print("all labels loaded.. label count:", len(self.labels))
all_cnt = 0
invalid_files = []
processed = 0
for path_dir in os.listdir(data_dir):
if not os.path.isdir(os.path.join(data_dir, path_dir)):continue
if "resized" not in path_dir: continue
path_dir = os.path.join(data_dir, path_dir)
print("processing ", path_dir)
for file in os.listdir(path_dir):
processed += 1
if file.split('.')[0] not in self.labels:
invalid_files.append(file.split('.')[0])
continue #check label exists
#if all_cnt >= 10000: break
self.img_files.append(os.path.join(path_dir, file))
all_cnt += 1
print("\ttotal image cnt: ", len(self.img_files))
print(f"{len(invalid_files)} invalid files out of {processed}...")
self.transforms = transforms.Compose([
transforms.ToTensor(), #DO NOT NORMALIZE DATA
])
def __len__(self):
return len(self.img_files)
def __getitem__(self, ind):
data = Image.open(self.img_files[ind])
data = self.transforms(data) #is it best to apply transform here?
file_id = self.img_files[ind].split('/')[-1].split('.')[0]
return data, self.labels[file_id]
class CatsDataset(Dataset):
def __init__(self, data_dir = CAT_DATA_DIR):
self.img_files = []
self.data_dir = data_dir
print("create memory efficient dataset from data ", data_dir)
for path_dir in os.listdir(data_dir):
if os.path.isdir(os.path.join(data_dir, path_dir)):continue
self.img_files.append(path_dir)
print("\ttotal image cnt: ", len(self.img_files))
def __len__(self):
return len(self.img_files)
def __getitem__(self, ind):
path_dir = os.path.join(self.data_dir, self.img_files[ind])
data = Image.open(path_dir)
data = transforms.ToTensor()(data)
return data, self.img_files[ind]
class LmdbDataset(Dataset):
def __init__(self, data_dir, labels_dir=None, keys =["code"]):
self.data_dir = data_dir
print("create lmdb dataset from data ", data_dir)
self.db = lmdb.open(
data_dir,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False
)
self.keys = keys
# get labels
if labels_dir is not None:
self.labels = np.load(f'{labels_dir}/all_labels.npy',allow_pickle=True).item()
else: self.labels = None
print("all labels loaded.. label count:", len(self.labels))
# get length
with self.db.begin(write=False) as txn:
self.length = int(txn.get("length".encode('utf-8')).decode('utf-8'))
print(f"\t contains {self.length}rows.")
def __len__(self):
return self.length
def __getitem__(self, ind):
with self.db.begin(write=False) as txn:
data = pickle.loads(txn.get(str(ind).encode('utf-8')))
ret = [torch.from_numpy(data[k]) for k in self.keys]
file_id = data["filename"]
label = self.labels[file_id] if self.labels else None
return ret, label, file_id