forked from lRomul/argus-tgs-salt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
make_inpaint_images.py
66 lines (50 loc) · 2.09 KB
/
make_inpaint_images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
from os.path import join
import cv2
import numpy as np
from skimage.restoration import inpaint_biharmonic
import multiprocessing as mp
from multiprocessing import Pool
from src.utils import make_dir
from src.config import TRAIN_DIR
from src.config import TEST_DIR
ORIG_SIZE = (101, 101)
SAVE_SIZE = (148, 148)
SAVE_NAME = '148'
TARGET_THRESHOLD = 0.7
N_WORKERS = mp.cpu_count()
make_dir(join(TRAIN_DIR, 'images'+SAVE_NAME))
make_dir(join(TRAIN_DIR, 'masks'+SAVE_NAME))
make_dir(join(TEST_DIR, 'images'+SAVE_NAME))
diff = (SAVE_SIZE - np.array(ORIG_SIZE))
pad_left = diff // 2
pad_right = diff - pad_left
PAD_WIDTH = ((pad_left[0], pad_right[0]), (pad_left[1], pad_right[1]))
print('Pad width:', PAD_WIDTH)
MASK_INPAINT = np.zeros(ORIG_SIZE, dtype=np.uint8)
MASK_INPAINT = np.pad(MASK_INPAINT, PAD_WIDTH, mode='constant', constant_values=255)
def inpaint_train(img_file):
img = cv2.imread(join(TRAIN_DIR, 'images', img_file), cv2.IMREAD_GRAYSCALE)
trg = cv2.imread(join(TRAIN_DIR, 'masks', img_file), cv2.IMREAD_GRAYSCALE)
img = np.pad(img, PAD_WIDTH, mode='constant')
trg = np.pad(trg, PAD_WIDTH, mode='constant')
img = (inpaint_biharmonic(img, MASK_INPAINT)*255).astype(np.uint8)
trg = inpaint_biharmonic(trg, MASK_INPAINT)
trg = np.where(trg > TARGET_THRESHOLD, 255, 0)
cv2.imwrite(join(TRAIN_DIR, 'images'+SAVE_NAME, img_file), img)
cv2.imwrite(join(TRAIN_DIR, 'masks'+SAVE_NAME, img_file), trg)
def inpaint_test(img_file):
img = cv2.imread(join(TEST_DIR, 'images', img_file), cv2.IMREAD_GRAYSCALE)
img = np.pad(img, PAD_WIDTH, mode='constant')
img = (inpaint_biharmonic(img, MASK_INPAINT)*255).astype(np.uint8)
cv2.imwrite(join(TEST_DIR, 'images'+SAVE_NAME, img_file), img)
if __name__ == '__main__':
# Train
print('Start train inpaint')
with Pool(processes=N_WORKERS) as pool:
pool.map(inpaint_train, os.listdir(join(TRAIN_DIR, 'images')))
# Test
print('Start test inpaint')
with Pool(processes=N_WORKERS) as pool:
pool.map(inpaint_test, os.listdir(join(TEST_DIR, 'images')))
print('Inpaint complete')