-
Notifications
You must be signed in to change notification settings - Fork 15
/
load_db.py
69 lines (53 loc) · 1.84 KB
/
load_db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import json
def load_dataset(fp_data='./data/youtube_val.json'):
"""Load the YouTube dataset.
Args:
fp_data: Filepath to the json file.
Returns:
Hand mesh dataset.
"""
with open(fp_data, "r") as file:
data = json.load(file)
return data
def retrieve_sample(data, ann_index):
"""Retrieve an annotation-image pair from the dataset.
Args:
data: Hand mesh dataset.
ann_index: Annotation index.
Returns:
A sample from the hand mesh dataset.
"""
ann = data['annotations'][ann_index]
images = data['images']
img_idxs = [im['id'] for im in images]
img = images[img_idxs.index(ann['image_id'])]
return ann, img
def viz_sample(data, ann_index, faces=None, db_root='./data/'):
"""Visualize a sample from the dataset.
Args:
data: Hand mesh dataset.
ann_index: Annotation index.
faces: MANO faces.
db_root: Filepath to the youtube parent directory.
"""
import imageio
import matplotlib.pyplot as plt
import numpy as np
from os.path import join
ann, img = retrieve_sample(data, ann_index)
image = imageio.imread(join(db_root, img['name']))
vertices = np.array(ann['vertices'])
plt.figure(figsize=(10, 10))
plt.imshow(image)
if faces is None:
plt.plot(vertices[:, 0], vertices[:, 1], 'o', color='green', markersize=1)
else:
plt.triplot(vertices[:, 0], vertices[:, 1], faces, lw=0.2)
plt.show()
if __name__ == "__main__":
data = load_dataset()
print("Data keys:", [k for k in data.keys()])
print("Image keys:", [k for k in data['images'][0].keys()])
print("Annotations keys:", [k for k in data['annotations'][0].keys()])
print("The number of images:", len(data['images']))
print("The number of annotations:", len(data['annotations']))