-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprepare.py
69 lines (53 loc) · 2.75 KB
/
prepare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import json
from tqdm.notebook import tqdm
from PIL import Image
for file in tqdm(os.listdir("data/hiu_dmtl_data/to_cv_community")):
if file.endswith('.jpg'):
name = os.path.basename(file).split('.')[0]
img = Image.open(os.path.join('data', 'hiu_dmtl_data', 'to_cv_community', file))
img.save(os.path.join('data', 'hiu_dmtl_data', 'to_cv_community', name + '.png'))
os.remove(os.path.join('data', 'hiu_dmtl_data', 'to_cv_community', file))
if not os.path.exists("data/hiu_dmtl_data/train"):
os.mkdir("data/hiu_dmtl_data/train")
usable = []
for file in tqdm(os.listdir("data/hiu_dmtl_data/to_cv_community")):
if file.endswith('.json'):
id = os.path.basename(file).split('.')[0]
with open(os.path.join('data', 'hiu_dmtl_data', 'to_cv_community', file), 'r') as jsonfile:
data = json.load(jsonfile)
if sum(data['pts3d_w_2hand']) != 0:
usable.append(id)
for id in tqdm(usable):
image_path = os.path.join('data', 'hiu_dmtl_data', 'to_cv_community', id + '.png')
label_path = os.path.join('data', 'hiu_dmtl_data', 'to_cv_community', id + '.json')
os.rename(image_path, os.path.join('data', 'hiu_dmtl_data', 'train', id + '.png'))
os.rename(label_path, os.path.join('data', 'hiu_dmtl_data', 'train', id + '.json'))
if not os.path.exists("data/hiu_dmtl_data/test"):
os.mkdir("data/hiu_dmtl_data/test")
if not os.path.exists("data/hiu_dmtl_data/valid"):
os.mkdir("data/hiu_dmtl_data/valid")
train_ids = []
test_ids = []
valid_ids = []
for idx, id in tqdm(enumerate(usable)):
if idx % 10 == 0:
image_path = os.path.join('data', 'hiu_dmtl_data', 'train', id + '.png')
label_path = os.path.join('data', 'hiu_dmtl_data', 'train', id + '.json')
os.rename(image_path, os.path.join('data', 'hiu_dmtl_data', 'test', id + '.png'))
os.rename(label_path, os.path.join('data', 'hiu_dmtl_data', 'test', id + '.json'))
test_ids.append(id)
elif idx % 10 == 3:
image_path = os.path.join('data', 'hiu_dmtl_data', 'train', id + '.png')
label_path = os.path.join('data', 'hiu_dmtl_data', 'train', id + '.json')
os.rename(image_path, os.path.join('data', 'hiu_dmtl_data', 'valid', id + '.png'))
os.rename(label_path, os.path.join('data', 'hiu_dmtl_data', 'valid', id + '.json'))
valid_ids.append(id)
else:
train_ids.append(id)
with open(os.path.join('data', 'hiu_dmtl_data', 'train','ids.json'), 'w') as jsonfile:
json.dump(train_ids, jsonfile)
with open(os.path.join('data', 'hiu_dmtl_data', 'test','ids.json'), 'w') as jsonfile:
json.dump(test_ids, jsonfile)
with open(os.path.join('data', 'hiu_dmtl_data', 'valid','ids.json'), 'w') as jsonfile:
json.dump(valid_ids, jsonfile)