From dd0cd52627d8dae39bdb2252c39703f1b26e34ff Mon Sep 17 00:00:00 2001 From: Akira Date: Mon, 15 Jan 2024 23:00:34 +0900 Subject: [PATCH] save model --- pipelines/utils/early_stopping.py | 5 ++--- tests/pipelines/test_experimenter.py | 10 +--------- tests/pipelines/test_trainer.py | 14 ++------------ 3 files changed, 5 insertions(+), 24 deletions(-) diff --git a/pipelines/utils/early_stopping.py b/pipelines/utils/early_stopping.py index 6e811e2..9a3fdd5 100644 --- a/pipelines/utils/early_stopping.py +++ b/pipelines/utils/early_stopping.py @@ -1,10 +1,9 @@ from typing import Callable, Optional import numpy as np +import torch from torch import nn -from pipelines.utils.trainer_utils import save_seq2seq_model - class EarlyStopping: def __init__( @@ -60,5 +59,5 @@ def save_checkpoint(self, val_loss: float, model: nn.Module): f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ..." ) - save_seq2seq_model(model, self.model_save_path) + torch.save(model.state_dict(), self.model_save_path) self.val_loss_min = val_loss diff --git a/tests/pipelines/test_experimenter.py b/tests/pipelines/test_experimenter.py index 35b1d2b..6d297a2 100644 --- a/tests/pipelines/test_experimenter.py +++ b/tests/pipelines/test_experimenter.py @@ -1,8 +1,6 @@ import os import tempfile -from unittest.mock import patch -import torch from torch import nn from torch.optim import Adam @@ -13,13 +11,7 @@ from tests.utils import MockMovingMNISTDataLoaders -def mocked_save_model(model: nn.Module, save_path: str): - torch.save({"model_state_dict": model.state_dict()}, save_path) - - -@patch("pipelines.utils.early_stopping.save_seq2seq_model") -def test_run(mocked_save_seq2seq_model): - mocked_save_seq2seq_model.side_effect = mocked_save_model +def test_run(): with tempfile.TemporaryDirectory() as tempdirpath: model = TestModel() training_params: TrainingParams = { diff --git a/tests/pipelines/test_trainer.py b/tests/pipelines/test_trainer.py index 6da8e6e..abe8ce2 100644 --- a/tests/pipelines/test_trainer.py +++ b/tests/pipelines/test_trainer.py @@ -1,8 +1,6 @@ import os import tempfile -from unittest.mock import patch -import torch from torch import nn from torch.optim import Adam @@ -12,13 +10,7 @@ from tests.utils import mock_data_loader -def mocked_save_model(model: nn.Module, save_path: str): - torch.save({"model_state_dict": model.state_dict()}, save_path) - - -@patch("pipelines.utils.early_stopping.save_seq2seq_model") -def test_run(mocked_save_seq2seq_model): - mocked_save_seq2seq_model.side_effect = mocked_save_model +def test_run(): with tempfile.TemporaryDirectory() as tempdirpath: model = TestModel() epochs = 3 @@ -48,9 +40,7 @@ def test_run(mocked_save_seq2seq_model): assert len(metrics) == epochs -@patch("pipelines.utils.early_stopping.save_seq2seq_model") -def test_run_early_stopping(mocked_save_seq2seq_model): - mocked_save_seq2seq_model.side_effect = mocked_save_model +def test_run_early_stopping(): with tempfile.TemporaryDirectory() as tempdirpath: model = TestModel() epochs = 3