-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
73 lines (54 loc) · 2.16 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import sys
import cv2
import torch
import numpy as np
import skimage.io
def save_checkpoint(state, save_path: str):
torch.save(state, save_path)
def load_checkpoint(ckpt_path):
ckpt = torch.load(ckpt_path)
return ckpt
def progress_bar(count, total, prefix='', suffix=''):
bar_len = 60
filled_len = int(round(bar_len * count / float(total)))
# percents = round(100.0 * count / float(total), 1)
bar = '=' * filled_len + '-' * (bar_len - filled_len)
sys.stdout.write(prefix + '[%s]-Step [%s/%s]-%s\r' % (bar, count, total, suffix))
sys.stdout.flush()
# if count == total:
# print("\n")
def experiment_record(*args):
with open("ckpt/log.txt", 'a') as f:
print("""=======================================================
UUID: {}
Time: {}
Batch size: {}
Lr: {}
Result:
Epoch: {}
Valid IoU: {}
=======================================================""".format(*args), file=f)
def save_overlap_image(mask_filenames, pred):
"""
Saving original image as .jpg and save prediction with ground truth
:param mask_filenames:
:param pred:
:return:
"""
masks_rgb = np.empty((len(pred), 256, 256, 3))
for i, p in enumerate(pred):
masks_rgb[i, p == 1] = [255, 255, 255] # (White: 111) tumor
masks_rgb[i, p == 0] = [0, 0, 0] # (Black: 000) Not tumor
masks_rgb = masks_rgb.astype(np.uint8)
for i, mask_fn in enumerate(mask_filenames):
ground_truth = cv2.imread(mask_fn, 0).astype("uint8")
original_img = cv2.imread(mask_fn.replace("_mask", ""))
_, thresh_gt = cv2.threshold(ground_truth, 127, 255, 0)
contours_gt, _ = cv2.findContours(thresh_gt, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
# contours_p, _ = cv2.findContours(pred[i, :, :], cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cv2.imwrite('./test/'+mask_fn.split('\\')[-1].replace('_mask.tif', '.jpg'),
original_img)
overlap_mask_gt = cv2.drawContours(masks_rgb[i], contours_gt, 0, (0, 255, 0), 1)
cv2.imwrite('./test/' + mask_fn.split('\\')[-1].replace('.tif', '_gt.jpg'),
overlap_mask_gt)