Skip to content

Commit

Permalink
save model
Browse files Browse the repository at this point in the history
  • Loading branch information
tsugumi-sys committed Jan 15, 2024
1 parent c2d533a commit dd0cd52
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 24 deletions.
5 changes: 2 additions & 3 deletions pipelines/utils/early_stopping.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Callable, Optional

import numpy as np
import torch
from torch import nn

from pipelines.utils.trainer_utils import save_seq2seq_model


class EarlyStopping:
def __init__(
Expand Down Expand Up @@ -60,5 +59,5 @@ def save_checkpoint(self, val_loss: float, model: nn.Module):
f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ..."
)

save_seq2seq_model(model, self.model_save_path)
torch.save(model.state_dict(), self.model_save_path)
self.val_loss_min = val_loss
10 changes: 1 addition & 9 deletions tests/pipelines/test_experimenter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
import tempfile
from unittest.mock import patch

import torch
from torch import nn
from torch.optim import Adam

Expand All @@ -13,13 +11,7 @@
from tests.utils import MockMovingMNISTDataLoaders


def mocked_save_model(model: nn.Module, save_path: str):
torch.save({"model_state_dict": model.state_dict()}, save_path)


@patch("pipelines.utils.early_stopping.save_seq2seq_model")
def test_run(mocked_save_seq2seq_model):
mocked_save_seq2seq_model.side_effect = mocked_save_model
def test_run():
with tempfile.TemporaryDirectory() as tempdirpath:
model = TestModel()
training_params: TrainingParams = {
Expand Down
14 changes: 2 additions & 12 deletions tests/pipelines/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
import tempfile
from unittest.mock import patch

import torch
from torch import nn
from torch.optim import Adam

Expand All @@ -12,13 +10,7 @@
from tests.utils import mock_data_loader


def mocked_save_model(model: nn.Module, save_path: str):
torch.save({"model_state_dict": model.state_dict()}, save_path)


@patch("pipelines.utils.early_stopping.save_seq2seq_model")
def test_run(mocked_save_seq2seq_model):
mocked_save_seq2seq_model.side_effect = mocked_save_model
def test_run():
with tempfile.TemporaryDirectory() as tempdirpath:
model = TestModel()
epochs = 3
Expand Down Expand Up @@ -48,9 +40,7 @@ def test_run(mocked_save_seq2seq_model):
assert len(metrics) == epochs


@patch("pipelines.utils.early_stopping.save_seq2seq_model")
def test_run_early_stopping(mocked_save_seq2seq_model):
mocked_save_seq2seq_model.side_effect = mocked_save_model
def test_run_early_stopping():
with tempfile.TemporaryDirectory() as tempdirpath:
model = TestModel()
epochs = 3
Expand Down

0 comments on commit dd0cd52

Please sign in to comment.