Skip to content

Commit

Permalink
fixed model checkpoint frequency (#3852)
Browse files Browse the repository at this point in the history
* fixed model checkpoint frequency

* fixed model checkpoint frequency

* fixed model checkpoint frequency

* fixed model checkpoint frequency

* merged
  • Loading branch information
williamFalcon authored and awaelchli committed Oct 5, 2020
1 parent dcb5b61 commit 49dd5a4
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 7 deletions.
7 changes: 4 additions & 3 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.epoch_last_check = None
self.last_global_step_saved = -1
self.prefix = prefix
self.best_k_models = {}
self.kth_best_model_path = ""
Expand Down Expand Up @@ -183,21 +183,22 @@ def save_checkpoint(self, trainer, pl_module):
to handle correct behaviour in distributed training, i.e., saving only on rank 0.
"""
epoch = trainer.current_epoch
global_step = trainer.global_step

if (
self.save_top_k == 0 # no models are saved
or self.period < 1 # no models are saved
or (epoch + 1) % self.period # skip epoch
or trainer.running_sanity_check # don't save anything during sanity check
or self.epoch_last_check == epoch # already saved
or self.last_global_step_saved == global_step # already saved at the last step
):
return

self._add_backward_monitor_support(trainer)
self._validate_monitor_key(trainer)

# track epoch when ckpt was last checked
self.epoch_last_check = trainer.current_epoch
self.last_global_step_saved = global_step

# what can be monitored
monitor_candidates = self._monitor_candidates(trainer)
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,11 @@ def on_train_end(self):

self._teardown_already_run = True

# maybe save checkpoint
# trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
# when a checkpoint was saved at the last step
self.trainer.global_step -= 1
self.check_checkpoint_callback(should_save=True, is_last=True)
self.trainer.global_step += 1

# hook
self.trainer.call_hook('on_train_end')
Expand Down Expand Up @@ -706,7 +709,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()}

# track all metrics for callbacks
# TODO: is this needed?
self.trainer.logger_connector.callback_metrics.update(batch_log_metrics)
self.trainer.logger_connector.callback_metrics.update(
{k: v for d in batch_callback_metrics for k, v in d.items() if v is not None}
)
Expand Down
Empty file added tests/checkpointing/__init__.py
Empty file.
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
from pytorch_lightning import Trainer, seed_everything
from tests.base import EvalModelTemplate
from pytorch_lightning import Trainer, seed_everything, callbacks
from tests.base import EvalModelTemplate, BoringModel
from unittest import mock
import pytest
import torch


def test_mc_called_on_fastdevrun(tmpdir):
Expand Down Expand Up @@ -60,3 +63,53 @@ def test_mc_called(tmpdir):
trainer = Trainer(max_epochs=3, checkpoint_callback=False)
trainer.fit(val_train_model)
assert len(trainer.dev_debugger.checkpoint_callback_history) == 0


@mock.patch('torch.save')
@pytest.mark.parametrize(['epochs', 'val_check_interval', 'expected'],
[(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 7)])
def test_default_checkpoint_freq(save_mock, tmpdir, epochs, val_check_interval, expected):

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=epochs,
weights_summary=None,
val_check_interval=val_check_interval
)
trainer.fit(model)

# make sure types are correct
assert save_mock.call_count == expected


@mock.patch('torch.save')
@pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'],
[(1, 1, 1.0, 1), (2, 2, 1.0, 2), (2, 1, 0.25, 4), (2, 2, 0.3, 7)])
def test_top_k(save_mock, tmpdir, k, epochs, val_check_interval, expected):

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.last_coeff = 10.0

def training_step(self, batch, batch_idx):
loss = self.step(torch.ones(32))
loss = loss / (loss + 0.0000001)
loss += self.last_coeff
self.log('my_loss', loss)
self.last_coeff *= 0.999
return loss

model = TestModel()
trainer = Trainer(
checkpoint_callback=callbacks.ModelCheckpoint(monitor='my_loss', save_top_k=k),
default_root_dir=tmpdir,
max_epochs=epochs,
weights_summary=None,
val_check_interval=val_check_interval
)
trainer.fit(model)

# make sure types are correct
assert save_mock.call_count == expected
File renamed without changes.
1 change: 1 addition & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def mock_save_function(filepath, *args):
# emulate callback's calls during the training
for i, loss in enumerate(losses):
trainer.current_epoch = i
trainer.global_step = i
trainer.logger_connector.callback_metrics = {'checkpoint_on': torch.tensor(loss)}
checkpoint_callback.on_validation_end(trainer, trainer.get_model())

Expand Down

0 comments on commit 49dd5a4

Please sign in to comment.