-
Notifications
You must be signed in to change notification settings - Fork 0
/
4_get_metrics.py
104 lines (82 loc) · 2.81 KB
/
4_get_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 5 17:29:35 2021
@author: Joana Rocha
"""
from utils_segmentation_models import *
import numpy as np
import shutil
import cv2
import os
#%% DEFINE PATH TO SAVE
saveto_path = 'C:/Users/SofiaPereira/Documents/Projects/COVID_segmentation/get_metrics/'
if os.path.isdir(saveto_path)==True:
shutil.rmtree(saveto_path)
try:
os.mkdir(saveto_path)
except OSError:
print ("Creation of the directory %s failed" % saveto_path)
#%% CHECK ALL IMAGES IN FOLDER
test_gt_path = "C:/Users/SofiaPereira/Documents/Projects/COVID_segmentation/preprocessed_data/test/masks/masks/"
test_pred_path = 'C:/Users/SofiaPereira/Documents/Projects/COVID_segmentation/14-09-2021_09h09/predictions_14-09-2021_09h09/'
dim = (512,512)
img_list = [file for file in os.listdir(test_pred_path) if file.endswith('.png')]
# Double check common files =======================================================
# gt_list = [file for file in os.listdir(test_gt_path) if file.endswith('.png')]
# print(len(gt_list))
# print(len(img_list))
# common = list(set(img_list).intersection(gt_list))
# print(len(common))
# ==================================================================================
#%% GET EVALUATION METRICS
accs = []
recs = []
precs = []
acc_per_class0 = []
acc_per_class1 = []
dices = []
jaccs = []
for i in img_list:
print(i)
gt_mask = (cv2.imread(test_gt_path + i,0))/255
gt_mask = cv2.resize(gt_mask, dim, interpolation = cv2.INTER_NEAREST)
pred_mask = (cv2.imread(test_pred_path + i,0))/255
if pred_mask.sum() > 0:
acc, rec, prec, acc_per_class, dice, jacc = get_metrics(gt_mask.ravel(),pred_mask.ravel(),print=False)
accs.append(acc)
recs.append(rec)
precs.append(prec)
acc_per_class0.append(acc_per_class[0])
acc_per_class1.append(acc_per_class[1])
dices.append(dice)
jaccs.append(jacc)
final_acc = np.mean(accs)
final_rec = np.mean(recs)
final_precs = np.mean(precs)
final_acc0 = np.mean(acc_per_class0)
final_acc1 = np.mean(acc_per_class1)
final_dice = np.mean(dices)
final_jacc = np.mean(jaccs)
print("Acc: ", final_acc)
print("Acc per class: " + str(final_acc0) + "/" + str(final_acc1))
print("Recall: ", final_rec)
print("Precision: ", final_precs)
print("Dice: ", final_dice)
print("Jaccard: ", final_jacc)
#%% SAVE SCORES TO TXT
file = open(saveto_path + 'FINALscores.txt','w')
file.write('\nAcc: ')
file.write(str(final_acc))
file.write('\nAcc per class: ')
file.write(str(final_acc0) + "/" + str(final_acc1))
file.write('\nRecall: ')
file.write(str(final_rec))
file.write('\nPrecision: ')
file.write(str(final_precs))
file.write('\nDice: ')
file.write(str(final_dice))
file.write('\nJaccard: ')
file.write(str(final_jacc))
file.write('\n')
file.write(str(test_pred_path))
file.close()