Skip to content

Commit

Permalink
feat(AutoML): added evaluation plot generation in block function trai…
Browse files Browse the repository at this point in the history
…ning
  • Loading branch information
muellerdo committed Jun 15, 2022
1 parent 5b4c095 commit 47aff85
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
14 changes: 9 additions & 5 deletions aucmedi/automl/block_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from aucmedi.data_processing.subfunctions import *
from aucmedi.neural_network.loss_functions import *
from aucmedi.ensemble import *
from aucmedi.evaluation import evaluate_fitting

#-----------------------------------------------------#
# Building Blocks for Training #
Expand Down Expand Up @@ -212,7 +213,7 @@ def block_train(config):
**paras_datagen)

# Start model training
model.train(training_generator=train_gen, **paras_train)
hist = model.train(training_generator=train_gen, **paras_train)
# Store model
path_model = os.path.join(config["output"], "model.last.hdf5")
model.dump(path_model)
Expand Down Expand Up @@ -243,9 +244,9 @@ def block_train(config):
**paras_datagen)

# Start model training
model.train(training_generator=train_gen,
validation_generator=val_gen,
**paras_train)
hist = model.train(training_generator=train_gen,
validation_generator=val_gen,
**paras_train)
# Store model
path_model = os.path.join(config["output"], "model.last.hdf5")
model.dump(path_model)
Expand All @@ -272,6 +273,9 @@ def block_train(config):
standardize_mode=None,
**paras_datagen)
# Start model training
el.train(training_generator=train_gen, **paras_train)
hist = el.train(training_generator=train_gen, **paras_train)
# Store model directory
el.dump(config["output"])

# Plot fitting history
evaluate_fitting(train_history=hist, out_path=config["output"])
10 changes: 8 additions & 2 deletions tests/test_automl_block_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def test_minimal(self):
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "model.last.hdf5")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "logs.training.csv")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "meta.training.json")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "plot.fitting_course.png")))

def test_minimal_multilabel(self):
# Initialize temporary directory
Expand Down Expand Up @@ -146,6 +147,7 @@ def test_minimal_multilabel(self):
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "model.last.hdf5")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "logs.training.csv")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "meta.training.json")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "plot.fitting_course.png")))

def test_minimal_3D(self):
# Initialize temporary directory
Expand Down Expand Up @@ -173,6 +175,7 @@ def test_minimal_3D(self):
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "model.last.hdf5")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "logs.training.csv")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "meta.training.json")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "plot.fitting_course.png")))

#-------------------------------------------------#
# Analysis: Standard #
Expand Down Expand Up @@ -203,6 +206,7 @@ def test_standard(self):
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "model.best_loss.hdf5")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "logs.training.csv")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "meta.training.json")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "plot.fitting_course.png")))

def test_standard_multilabel(self):
# Initialize temporary directory
Expand Down Expand Up @@ -230,6 +234,7 @@ def test_standard_multilabel(self):
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "model.best_loss.hdf5")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "logs.training.csv")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "meta.training.json")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "plot.fitting_course.png")))

def test_standard_3D(self):
# Initialize temporary directory
Expand Down Expand Up @@ -258,6 +263,7 @@ def test_standard_3D(self):
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "model.best_loss.hdf5")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "logs.training.csv")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "meta.training.json")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "plot.fitting_course.png")))

#-------------------------------------------------#
# Analysis: Composite #
Expand All @@ -284,7 +290,7 @@ def test_composite(self):
# Run AutoML training block
block_train(config)

self.assertTrue(len(os.listdir(output_dir.name))==6)
self.assertTrue(len(os.listdir(output_dir.name))==7)

def test_composite_multilabel(self):
# Initialize temporary directory
Expand Down Expand Up @@ -333,4 +339,4 @@ def test_composite_3D(self):
# Run AutoML training block
block_train(config)

self.assertTrue(len(os.listdir(output_dir.name))==6)
self.assertTrue(len(os.listdir(output_dir.name))==7)

0 comments on commit 47aff85

Please sign in to comment.