From fbbc4c2ebafc11a933a54ff492b068cf2895f998 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Aug 2023 23:29:19 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nobrainer/processing/checkpoint.py | 7 +++++-- nobrainer/processing/segmentation.py | 5 +++-- nobrainer/tests/checkpoint_test.py | 15 ++++++++------- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/nobrainer/processing/checkpoint.py b/nobrainer/processing/checkpoint.py index b576037e..98a7fa0b 100644 --- a/nobrainer/processing/checkpoint.py +++ b/nobrainer/processing/checkpoint.py @@ -3,6 +3,7 @@ from glob import glob import logging import os + import tensorflow as tf from .base import BaseEstimator @@ -40,8 +41,10 @@ def load(self): """Loads the most-recently created checkpoint from the checkpoint directory. """ - latest = max(glob(os.path.join(os.path.dirname(self.filepath), '*')), - key=os.path.getctime) + latest = max( + glob(os.path.join(os.path.dirname(self.filepath), "*")), + key=os.path.getctime, + ) self.estimator = self.estimator.load(latest) logging.info(f"Loaded estimator from {latest}.") return self.estimator diff --git a/nobrainer/processing/segmentation.py b/nobrainer/processing/segmentation.py index 4f42c567..176ef79f 100644 --- a/nobrainer/processing/segmentation.py +++ b/nobrainer/processing/segmentation.py @@ -9,7 +9,6 @@ from .. import losses, metrics from ..dataset import get_steps_per_epoch - logging.getLogger().setLevel(logging.INFO) @@ -87,7 +86,9 @@ def _compile(): self = checkpoint_tracker.load() if self.model is None: - raise ValueError("warm_start requested, but model is undefined and no checkpoints were found") + raise ValueError( + "warm_start requested, but model is undefined and no checkpoints were found" + ) with self.strategy.scope(): _compile() else: diff --git a/nobrainer/tests/checkpoint_test.py b/nobrainer/tests/checkpoint_test.py index a3f2a506..d6935cf1 100644 --- a/nobrainer/tests/checkpoint_test.py +++ b/nobrainer/tests/checkpoint_test.py @@ -1,26 +1,27 @@ """Tests for `nobrainer.processing.checkpoint`.""" -from nobrainer.processing.checkpoint import CheckpointTracker -from nobrainer.processing.segmentation import Segmentation -from nobrainer.models import meshnet +import os + import numpy as np from numpy.testing import assert_array_equal -import os import pytest import tensorflow as tf +from nobrainer.models import meshnet +from nobrainer.processing.checkpoint import CheckpointTracker +from nobrainer.processing.segmentation import Segmentation + def test_checkpoint(tmp_path): data_shape = (8, 8, 8, 8, 1) train = tf.data.Dataset.from_tensors( - (np.random.rand(*data_shape), - np.random.randint(0, 1, data_shape)) + (np.random.rand(*data_shape), np.random.randint(0, 1, data_shape)) ) train.scalar_labels = False train.n_volumes = data_shape[0] train.volume_shape = data_shape[1:4] - checkpoint_file_path = os.path.join(tmp_path, 'checkpoint-epoch_{epoch:03d}') + checkpoint_file_path = os.path.join(tmp_path, "checkpoint-epoch_{epoch:03d}") model1 = Segmentation(meshnet) model1.fit( dataset_train=train,