Skip to content

Commit

Permalink
test(Evaluation): added edge case support of non-complete data for co…
Browse files Browse the repository at this point in the history
…nfusion matrix
  • Loading branch information
muellerdo committed Jul 27, 2022
1 parent cbe64f9 commit ed10587
Showing 1 changed file with 39 additions and 13 deletions.
52 changes: 39 additions & 13 deletions tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def setUpClass(self):
class_index = np.random.randint(0, 4)
self.labels_ohe[i][class_index] = 1
# Create predictions
for i in range(0, 50):
self.preds = np.random.rand(50, 4)
self.preds = np.random.rand(50, 4)

# Create imaging data indices
self.sample_list = []
Expand Down Expand Up @@ -100,17 +99,17 @@ def test_evaluate_fitting_standard_tl(self):
self.assertTrue(os.path.exists(os.path.join(self.tmp_plot.name,
"plot.fitting_course.standard.png")))

# def test_evaluate_fitting_standard_tl_noft(self):
# # Create artificial history data - transfer learning
# hist_standard_tl = {"tl_loss": [], "tl_val_loss": []}
# for i in range(0, 150):
# hist_standard_tl["tl_loss"].append(random.uniform(0, 1))
# hist_standard_tl["tl_val_loss"].append(random.uniform(0, 1))
# # Apply fitting evaluation
# evaluate_fitting(hist_standard_tl, out_path=self.tmp_plot.name,
# monitor=["loss"], suffix="standard")
# self.assertTrue(os.path.exists(os.path.join(self.tmp_plot.name,
# "plot.fitting_course.standard.png")))
def test_evaluate_fitting_standard_tl_noft(self):
# Create artificial history data - transfer learning
hist_standard_tl = {"tl_loss": [], "tl_val_loss": []}
for i in range(0, 150):
hist_standard_tl["tl_loss"].append(random.uniform(0, 1))
hist_standard_tl["tl_val_loss"].append(random.uniform(0, 1))
# Apply fitting evaluation
evaluate_fitting(hist_standard_tl, out_path=self.tmp_plot.name,
monitor=["loss"], suffix="standard")
self.assertTrue(os.path.exists(os.path.join(self.tmp_plot.name,
"plot.fitting_course.standard.png")))

def test_evaluate_fitting_advanced(self):
# Create artificial history data - advanced
Expand Down Expand Up @@ -304,6 +303,33 @@ def test_evaluate_performance_confusionmatrix(self):
"plot.performance.confusion_matrix.test.png")
self.assertTrue(os.path.exists(path_plot))

def test_evaluate_performance_confusionmatrix_edgecases(self):
# Create classification labels
labels_ohe = np.zeros((50, 4), dtype=np.uint8)
for i in range(0, 50):
class_index = np.random.randint(0, 3)
labels_ohe[i][class_index] = 1
# Confusion Mat
evaluate_performance(self.preds, labels_ohe,
out_path=self.tmp_plot.name,
multi_label=False,
store_csv=False,
plot_barplot=False,
plot_confusion_matrix=True,
plot_roc_curve=False)
# Create predictions
preds = np.random.rand(50, 4)
for i in range(0, 50):
preds[i][3] = 0.0
# Confusion Mat
evaluate_performance(preds, self.labels_ohe,
out_path=self.tmp_plot.name,
multi_label=False,
store_csv=False,
plot_barplot=False,
plot_confusion_matrix=True,
plot_roc_curve=False)

def test_evaluate_performance_roc(self):
evaluate_performance(self.preds, self.labels_ohe,
out_path=self.tmp_plot.name,
Expand Down

0 comments on commit ed10587

Please sign in to comment.