-
Notifications
You must be signed in to change notification settings - Fork 42
/
Copy pathcreate_lmdb.py
87 lines (79 loc) · 2.76 KB
/
create_lmdb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import os.path as osp
import sys
import glob
import pickle
import lmdb
import cv2
try:
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
from utils import ProgressBar
except ImportError:
pass
# configurations
# img_folder = '/data/DIV2K/DIV2K_train/LR/x4/*' # glob matching pattern
img_folder = "/data/rolf48/blurry/*"
# lmdb_save_path = '/data/DIV2K/DIV2K_train_LR_sub.lmdb'
lmdb_save_path = "/data/rolf48/blurry.lmdb"
meta_info = {"name": "rolf48"}
mode = (
2 # 1 for reading all the images to memory and then writing to lmdb (more memory);
)
# 2 for reading several images and then writing to lmdb, loop over (less memory)
batch = 1000 # Used in mode 2. After batch images, lmdb commits.
###########################################
if not lmdb_save_path.endswith(".lmdb"):
raise ValueError("lmdb_save_path must end with 'lmdb'.")
#### whether the lmdb file exist
if osp.exists(lmdb_save_path):
print("Folder [{:s}] already exists. Exit...".format(lmdb_save_path))
sys.exit(1)
img_list = sorted(glob.glob(img_folder))
if mode == 1:
print("Read images...")
dataset = [cv2.imread(v, cv2.IMREAD_UNCHANGED) for v in img_list]
data_size = sum([img.nbytes for img in dataset])
elif mode == 2:
print("Calculating the total size of images...")
data_size = sum(os.stat(v).st_size for v in img_list)
else:
raise ValueError("mode should be 1 or 2")
key_l = []
resolution_l = []
pbar = ProgressBar(len(img_list))
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
txn = env.begin(write=True) # txn is a Transaction object
for i, v in enumerate(img_list):
pbar.update("Write {}".format(v))
base_name = osp.splitext(osp.basename(v))[0]
key = base_name.encode("ascii")
data = dataset[i] if mode == 1 else cv2.imread(v, cv2.IMREAD_UNCHANGED)
if data.ndim == 2:
H, W = data.shape
C = 1
else:
H, W, C = data.shape
txn.put(key, data)
key_l.append(base_name)
resolution_l.append("{:d}_{:d}_{:d}".format(C, H, W))
# commit in mode 2
if mode == 2 and i % batch == 1:
txn.commit()
txn = env.begin(write=True)
txn.commit()
env.close()
print("Finish writing lmdb.")
#### create meta information
# check whether all the images are the same size
same_resolution = len(set(resolution_l)) <= 1
if same_resolution:
meta_info["resolution"] = [resolution_l[0]]
meta_info["keys"] = key_l
print("All images have the same resolution. Simplify the meta info...")
else:
meta_info["resolution"] = resolution_l
meta_info["keys"] = key_l
print("Not all images have the same resolution. Save meta info for each image...")
#### pickle dump
pickle.dump(meta_info, open(osp.join(lmdb_save_path, "meta_info.pkl"), "wb"))
print("Finish creating lmdb meta info.")