diff --git a/nobrainer/processing/base.py b/nobrainer/processing/base.py index 75db6b93..5f22be7c 100644 --- a/nobrainer/processing/base.py +++ b/nobrainer/processing/base.py @@ -20,7 +20,12 @@ class BaseEstimator: state_variables = [] model_ = None - def __init__(self, multi_gpu=False): + def __init__(self, checkpoint_filepath=None, multi_gpu=False): + self.checkpoint_tracker = None + if checkpoint_filepath: + from .checkpoint import CheckpointTracker + self.checkpoint_tracker = CheckpointTracker(self, checkpoint_filepath) + self.strategy = get_strategy(multi_gpu) @property @@ -38,7 +43,7 @@ def save(self, save_dir): # are stored as members, which doesn't leave room for # parameters that are specific to the runtime context. # (e.g. multi_gpu). - if key == "multi_gpu": + if key == "multi_gpu" or key == "checkpoint_filepath": continue model_info["__init__"][key] = getattr(self, key) for val in self.state_variables: @@ -49,7 +54,7 @@ def save(self, save_dir): @classmethod def load(cls, model_dir, multi_gpu=False, custom_objects=None, compile=False): - """Saves a trained model""" + """Loads a trained model from a save directory""" model_dir = Path(str(model_dir).rstrip(os.pathsep)) assert model_dir.exists() and model_dir.is_dir() model_file = model_dir / "model_params.pkl" @@ -70,6 +75,14 @@ def load(cls, model_dir, multi_gpu=False, custom_objects=None, compile=False): ) return klass + @classmethod + def load_latest(cls, checkpoint_filepath): + from .checkpoint import CheckpointTracker + checkpoint_tracker = CheckpointTracker(cls, checkpoint_filepath) + estimator = checkpoint_tracker.load() + estimator.checkpoint_tracker = checkpoint_tracker + return estimator + class TransformerMixin: """Mixin class for all transformers in scikit-learn.""" diff --git a/nobrainer/processing/checkpoint.py b/nobrainer/processing/checkpoint.py index b576037e..1787922b 100644 --- a/nobrainer/processing/checkpoint.py +++ b/nobrainer/processing/checkpoint.py @@ -40,8 +40,8 @@ def load(self): """Loads the most-recently created checkpoint from the checkpoint directory. """ - latest = max(glob(os.path.join(os.path.dirname(self.filepath), '*')), - key=os.path.getctime) + checkpoints = glob(os.path.join(os.path.dirname(self.filepath), '*/')) + latest = max(checkpoints, key=os.path.getctime) self.estimator = self.estimator.load(latest) logging.info(f"Loaded estimator from {latest}.") return self.estimator diff --git a/nobrainer/processing/segmentation.py b/nobrainer/processing/segmentation.py index 4f42c567..8a4c3ea1 100644 --- a/nobrainer/processing/segmentation.py +++ b/nobrainer/processing/segmentation.py @@ -5,7 +5,6 @@ import tensorflow as tf from .base import BaseEstimator -from .checkpoint import CheckpointTracker from .. import losses, metrics from ..dataset import get_steps_per_epoch @@ -18,8 +17,8 @@ class Segmentation(BaseEstimator): state_variables = ["block_shape_", "volume_shape_", "scalar_labels_"] - def __init__(self, base_model, model_args=None, multi_gpu=False): - super().__init__(multi_gpu=multi_gpu) + def __init__(self, base_model, model_args=None, checkpoint_filepath=None, multi_gpu=False): + super().__init__(checkpoint_filepath=checkpoint_filepath, multi_gpu=multi_gpu) if not isinstance(base_model, str): self.base_model = base_model.__name__ @@ -36,8 +35,6 @@ def fit( dataset_train, dataset_validate=None, epochs=1, - checkpoint_file_path=None, - warm_start=False, # TODO: figure out whether optimizer args should be flattened optimizer=None, opt_args=None, @@ -63,10 +60,6 @@ def fit( opt_args_tmp.update(**opt_args) opt_args = opt_args_tmp - checkpoint_tracker = None - if checkpoint_file_path: - checkpoint_tracker = CheckpointTracker(self, checkpoint_file_path) - def _create(base_model): # Instantiate and compile the model self.model_ = base_model( @@ -82,15 +75,7 @@ def _compile(): metrics=metrics, ) - if warm_start: - if checkpoint_tracker: - self = checkpoint_tracker.load() - - if self.model is None: - raise ValueError("warm_start requested, but model is undefined and no checkpoints were found") - with self.strategy.scope(): - _compile() - else: + if self.model is None: mod = importlib.import_module("..models", "nobrainer.processing") base_model = getattr(mod, self.base_model) if batch_size % self.strategy.num_replicas_in_sync: @@ -98,7 +83,8 @@ def _compile(): with self.strategy.scope(): _create(base_model) - _compile() + with self.strategy.scope(): + _compile() self.model_.summary() train_steps = get_steps_per_epoch( @@ -118,8 +104,8 @@ def _compile(): ) callbacks = [] - if checkpoint_tracker: - callbacks.append(checkpoint_tracker) + if self.checkpoint_tracker: + callbacks.append(self.checkpoint_tracker) self.model_.fit( dataset_train, diff --git a/nobrainer/tests/checkpoint_test.py b/nobrainer/tests/checkpoint_test.py index a3f2a506..197e55ca 100644 --- a/nobrainer/tests/checkpoint_test.py +++ b/nobrainer/tests/checkpoint_test.py @@ -1,15 +1,22 @@ """Tests for `nobrainer.processing.checkpoint`.""" -from nobrainer.processing.checkpoint import CheckpointTracker from nobrainer.processing.segmentation import Segmentation from nobrainer.models import meshnet import numpy as np -from numpy.testing import assert_array_equal +from numpy.testing import assert_allclose import os import pytest import tensorflow as tf +def _assert_model_weights_allclose(model1, model2): + for layer1, layer2 in zip(model1.model.layers, model2.model.layers): + weights1 = layer1.get_weights() + weights2 = layer2.get_weights() + assert len(weights1) == len(weights2) + for index in range(len(weights1)): + assert_allclose(weights1[index], weights2[index], rtol=1e-06, atol=1e-08) + def test_checkpoint(tmp_path): data_shape = (8, 8, 8, 8, 1) train = tf.data.Dataset.from_tensors( @@ -20,21 +27,19 @@ def test_checkpoint(tmp_path): train.n_volumes = data_shape[0] train.volume_shape = data_shape[1:4] - checkpoint_file_path = os.path.join(tmp_path, 'checkpoint-epoch_{epoch:03d}') - model1 = Segmentation(meshnet) + checkpoint_filepath = os.path.join(tmp_path, 'checkpoint-epoch_{epoch:03d}') + model1 = Segmentation(meshnet, checkpoint_filepath=checkpoint_filepath) model1.fit( dataset_train=train, - checkpoint_file_path=checkpoint_file_path, epochs=2, ) - model2 = Segmentation(meshnet) - checkpoint_tracker = CheckpointTracker(model2, checkpoint_file_path) - model2 = checkpoint_tracker.load() + model2 = Segmentation.load_latest(checkpoint_filepath=checkpoint_filepath) + _assert_model_weights_allclose(model1, model2) + model2.fit( + dataset_train=train, + epochs=3, + ) - for layer1, layer2 in zip(model1.model.layers, model2.model.layers): - weights1 = layer1.get_weights() - weights2 = layer2.get_weights() - assert len(weights1) == len(weights2) - for index in range(len(weights1)): - assert_array_equal(weights1[index], weights2[index]) + model3 = Segmentation.load_latest(checkpoint_filepath=checkpoint_filepath) + _assert_model_weights_allclose(model2, model3)