Skip to content

Commit

Permalink
feat(AutoML): added metadata storing in training block
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerdo committed Jun 13, 2022
1 parent 173d235 commit 74fcfbd
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
7 changes: 7 additions & 0 deletions aucmedi/automl/block_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#-----------------------------------------------------#
# External libraries
import os
import json
from tensorflow.keras.metrics import AUC
from tensorflow_addons.metrics import F1Score
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, \
Expand Down Expand Up @@ -69,6 +70,12 @@ def block_train(config):
# Create output directory
if not os.path.exists(config["output"]) : os.mkdir(config["output"])

# Store meta information
config["class_names"] = class_names
path_meta = os.path.join(config["output"], "meta.training.json")
with open(path_meta, "w") as json_io:
json.dump(config, json_io)

# Define Callbacks
callbacks = []
if config["analysis"] == "standard":
Expand Down
12 changes: 9 additions & 3 deletions tests/test_automl_block_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def test_minimal(self):

self.assertTrue(os.path.exists(os.path.join(output_dir.name, "model.last.hdf5")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "logs.training.csv")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "meta.training.json")))

def test_minimal_multilabel(self):
# Initialize temporary directory
Expand Down Expand Up @@ -146,6 +147,7 @@ def test_minimal_multilabel(self):

self.assertTrue(os.path.exists(os.path.join(output_dir.name, "model.last.hdf5")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "logs.training.csv")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "meta.training.json")))

def test_minimal_3D(self):
# Initialize temporary directory
Expand Down Expand Up @@ -174,6 +176,7 @@ def test_minimal_3D(self):

self.assertTrue(os.path.exists(os.path.join(output_dir.name, "model.last.hdf5")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "logs.training.csv")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "meta.training.json")))

#-------------------------------------------------#
# Analysis: Standard #
Expand Down Expand Up @@ -205,6 +208,7 @@ def test_standard(self):
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "model.last.hdf5")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "model.best_loss.hdf5")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "logs.training.csv")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "meta.training.json")))

def test_standard_multilabel(self):
# Initialize temporary directory
Expand Down Expand Up @@ -233,6 +237,7 @@ def test_standard_multilabel(self):
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "model.last.hdf5")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "model.best_loss.hdf5")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "logs.training.csv")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "meta.training.json")))

def test_standard_3D(self):
# Initialize temporary directory
Expand Down Expand Up @@ -262,6 +267,7 @@ def test_standard_3D(self):
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "model.last.hdf5")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "model.best_loss.hdf5")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "logs.training.csv")))
self.assertTrue(os.path.exists(os.path.join(output_dir.name, "meta.training.json")))

#-------------------------------------------------#
# Analysis: Composite #
Expand Down Expand Up @@ -290,7 +296,7 @@ def test_composite(self):
# Run AutoML training block
block_train(config)

self.assertTrue(len(os.listdir(output_dir.name))==5)
self.assertTrue(len(os.listdir(output_dir.name))==6)

def test_composite_multilabel(self):
# Initialize temporary directory
Expand All @@ -316,7 +322,7 @@ def test_composite_multilabel(self):
# Run AutoML training block
block_train(config)

self.assertTrue(len(os.listdir(output_dir.name))==5)
self.assertTrue(len(os.listdir(output_dir.name))==6)

def test_composite_3D(self):
# Initialize temporary directory
Expand All @@ -343,4 +349,4 @@ def test_composite_3D(self):
# Run AutoML training block
block_train(config)

self.assertTrue(len(os.listdir(output_dir.name))==5)
self.assertTrue(len(os.listdir(output_dir.name))==6)

0 comments on commit 74fcfbd

Please sign in to comment.