Skip to content

Commit

Permalink
Rework checkpoint loading to statically load.
Browse files Browse the repository at this point in the history
  • Loading branch information
ohinds committed Aug 22, 2023
1 parent 61ccc9f commit cded43c
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 40 deletions.
19 changes: 16 additions & 3 deletions nobrainer/processing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ class BaseEstimator:
state_variables = []
model_ = None

def __init__(self, multi_gpu=False):
def __init__(self, checkpoint_filepath=None, multi_gpu=False):
self.checkpoint_tracker = None
if checkpoint_filepath:
from .checkpoint import CheckpointTracker
self.checkpoint_tracker = CheckpointTracker(self, checkpoint_filepath)

self.strategy = get_strategy(multi_gpu)

@property
Expand All @@ -38,7 +43,7 @@ def save(self, save_dir):
# are stored as members, which doesn't leave room for
# parameters that are specific to the runtime context.
# (e.g. multi_gpu).
if key == "multi_gpu":
if key == "multi_gpu" or key == "checkpoint_filepath":
continue
model_info["__init__"][key] = getattr(self, key)
for val in self.state_variables:
Expand All @@ -49,7 +54,7 @@ def save(self, save_dir):

@classmethod
def load(cls, model_dir, multi_gpu=False, custom_objects=None, compile=False):
"""Saves a trained model"""
"""Loads a trained model from a save directory"""
model_dir = Path(str(model_dir).rstrip(os.pathsep))
assert model_dir.exists() and model_dir.is_dir()
model_file = model_dir / "model_params.pkl"
Expand All @@ -70,6 +75,14 @@ def load(cls, model_dir, multi_gpu=False, custom_objects=None, compile=False):
)
return klass

@classmethod
def load_latest(cls, checkpoint_filepath):
from .checkpoint import CheckpointTracker
checkpoint_tracker = CheckpointTracker(cls, checkpoint_filepath)
estimator = checkpoint_tracker.load()
estimator.checkpoint_tracker = checkpoint_tracker
return estimator


class TransformerMixin:
"""Mixin class for all transformers in scikit-learn."""
Expand Down
4 changes: 2 additions & 2 deletions nobrainer/processing/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def load(self):
"""Loads the most-recently created checkpoint from the
checkpoint directory.
"""
latest = max(glob(os.path.join(os.path.dirname(self.filepath), '*')),
key=os.path.getctime)
checkpoints = glob(os.path.join(os.path.dirname(self.filepath), '*/'))
latest = max(checkpoints, key=os.path.getctime)
self.estimator = self.estimator.load(latest)
logging.info(f"Loaded estimator from {latest}.")
return self.estimator
28 changes: 7 additions & 21 deletions nobrainer/processing/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import tensorflow as tf

from .base import BaseEstimator
from .checkpoint import CheckpointTracker
from .. import losses, metrics
from ..dataset import get_steps_per_epoch

Expand All @@ -18,8 +17,8 @@ class Segmentation(BaseEstimator):

state_variables = ["block_shape_", "volume_shape_", "scalar_labels_"]

def __init__(self, base_model, model_args=None, multi_gpu=False):
super().__init__(multi_gpu=multi_gpu)
def __init__(self, base_model, model_args=None, checkpoint_filepath=None, multi_gpu=False):
super().__init__(checkpoint_filepath=checkpoint_filepath, multi_gpu=multi_gpu)

if not isinstance(base_model, str):
self.base_model = base_model.__name__
Expand All @@ -36,8 +35,6 @@ def fit(
dataset_train,
dataset_validate=None,
epochs=1,
checkpoint_file_path=None,
warm_start=False,
# TODO: figure out whether optimizer args should be flattened
optimizer=None,
opt_args=None,
Expand All @@ -63,10 +60,6 @@ def fit(
opt_args_tmp.update(**opt_args)
opt_args = opt_args_tmp

checkpoint_tracker = None
if checkpoint_file_path:
checkpoint_tracker = CheckpointTracker(self, checkpoint_file_path)

def _create(base_model):
# Instantiate and compile the model
self.model_ = base_model(
Expand All @@ -82,23 +75,16 @@ def _compile():
metrics=metrics,
)

if warm_start:
if checkpoint_tracker:
self = checkpoint_tracker.load()

if self.model is None:
raise ValueError("warm_start requested, but model is undefined and no checkpoints were found")
with self.strategy.scope():
_compile()
else:
if self.model is None:
mod = importlib.import_module("..models", "nobrainer.processing")
base_model = getattr(mod, self.base_model)
if batch_size % self.strategy.num_replicas_in_sync:
raise ValueError("batch size must be a multiple of the number of GPUs")

with self.strategy.scope():
_create(base_model)
_compile()
with self.strategy.scope():
_compile()
self.model_.summary()

train_steps = get_steps_per_epoch(
Expand All @@ -118,8 +104,8 @@ def _compile():
)

callbacks = []
if checkpoint_tracker:
callbacks.append(checkpoint_tracker)
if self.checkpoint_tracker:
callbacks.append(self.checkpoint_tracker)

self.model_.fit(
dataset_train,
Expand Down
33 changes: 19 additions & 14 deletions nobrainer/tests/checkpoint_test.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
"""Tests for `nobrainer.processing.checkpoint`."""

from nobrainer.processing.checkpoint import CheckpointTracker
from nobrainer.processing.segmentation import Segmentation
from nobrainer.models import meshnet
import numpy as np
from numpy.testing import assert_array_equal
from numpy.testing import assert_allclose
import os
import pytest
import tensorflow as tf


def _assert_model_weights_allclose(model1, model2):
for layer1, layer2 in zip(model1.model.layers, model2.model.layers):
weights1 = layer1.get_weights()
weights2 = layer2.get_weights()
assert len(weights1) == len(weights2)
for index in range(len(weights1)):
assert_allclose(weights1[index], weights2[index], rtol=1e-06, atol=1e-08)

def test_checkpoint(tmp_path):
data_shape = (8, 8, 8, 8, 1)
train = tf.data.Dataset.from_tensors(
Expand All @@ -20,21 +27,19 @@ def test_checkpoint(tmp_path):
train.n_volumes = data_shape[0]
train.volume_shape = data_shape[1:4]

checkpoint_file_path = os.path.join(tmp_path, 'checkpoint-epoch_{epoch:03d}')
model1 = Segmentation(meshnet)
checkpoint_filepath = os.path.join(tmp_path, 'checkpoint-epoch_{epoch:03d}')
model1 = Segmentation(meshnet, checkpoint_filepath=checkpoint_filepath)
model1.fit(
dataset_train=train,
checkpoint_file_path=checkpoint_file_path,
epochs=2,
)

model2 = Segmentation(meshnet)
checkpoint_tracker = CheckpointTracker(model2, checkpoint_file_path)
model2 = checkpoint_tracker.load()
model2 = Segmentation.load_latest(checkpoint_filepath=checkpoint_filepath)
_assert_model_weights_allclose(model1, model2)
model2.fit(
dataset_train=train,
epochs=3,
)

for layer1, layer2 in zip(model1.model.layers, model2.model.layers):
weights1 = layer1.get_weights()
weights2 = layer2.get_weights()
assert len(weights1) == len(weights2)
for index in range(len(weights1)):
assert_array_equal(weights1[index], weights2[index])
model3 = Segmentation.load_latest(checkpoint_filepath=checkpoint_filepath)
_assert_model_weights_allclose(model2, model3)

0 comments on commit cded43c

Please sign in to comment.