Skip to content

Commit

Permalink
Improve some tests (#5049)
Browse files Browse the repository at this point in the history
* Improve some tests

* Add TrainerState asserts

Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
  • Loading branch information
carmocca and s-rog committed Dec 13, 2020
1 parent a49291d commit 398f122
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 259 deletions.
142 changes: 30 additions & 112 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
import pickle
import platform
import re
from argparse import Namespace
from distutils.version import LooseVersion
from pathlib import Path
from unittest import mock
from unittest.mock import MagicMock, Mock
from unittest.mock import Mock

import cloudpickle
import pytest
Expand Down Expand Up @@ -641,20 +639,17 @@ def validation_epoch_end(self, outputs):
@pytest.mark.parametrize("enable_pl_optimizer", [False, True])
def test_checkpoint_repeated_strategy(enable_pl_optimizer, tmpdir):
"""
This test validates that the checkpoint can be called when provided to callacks list
This test validates that the checkpoint can be called when provided to callbacks list
"""

checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath=tmpdir, filename="{epoch:02d}")

class ExtendedBoringModel(BoringModel):

def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"val_loss": loss}

model = ExtendedBoringModel()
model.validation_step_end = None
model.validation_epoch_end = None
trainer = Trainer(
max_epochs=1,
Expand All @@ -663,92 +658,30 @@ def validation_step(self, batch, batch_idx):
limit_test_batches=2,
callbacks=[checkpoint_callback],
enable_pl_optimizer=enable_pl_optimizer,
weights_summary=None,
progress_bar_refresh_rate=0,
)

trainer.fit(model)
assert os.listdir(tmpdir) == ['epoch=00.ckpt']

def get_last_checkpoint():
ckpts = os.listdir(tmpdir)
ckpts_map = {int(x.split("=")[1].split('.')[0]): osp.join(tmpdir, x) for x in ckpts if "epoch" in x}
num_ckpts = len(ckpts_map) - 1
return ckpts_map[num_ckpts]

for idx in range(1, 5):
for idx in range(4):
# load from checkpoint
chk = get_last_checkpoint()
model = BoringModel.load_from_checkpoint(chk)
trainer = pl.Trainer(
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
resume_from_checkpoint=chk,
enable_pl_optimizer=enable_pl_optimizer)
trainer.fit(model)
trainer.test(model)

assert str(os.listdir(tmpdir)) == "['epoch=00.ckpt']"


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.parametrize("enable_pl_optimizer", [False, True])
def test_checkpoint_repeated_strategy_tmpdir(enable_pl_optimizer, tmpdir):
"""
This test validates that the checkpoint can be called when provided to callacks list
"""

checkpoint_callback = ModelCheckpoint(monitor='val_loss', filepath=os.path.join(tmpdir, "{epoch:02d}"))

class ExtendedBoringModel(BoringModel):

def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"val_loss": loss}

model = ExtendedBoringModel()
model.validation_step_end = None
model.validation_epoch_end = None
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
callbacks=[checkpoint_callback],
enable_pl_optimizer=enable_pl_optimizer,
)

trainer.fit(model)
assert sorted(os.listdir(tmpdir)) == sorted(['epoch=00.ckpt', 'lightning_logs'])
path_to_lightning_logs = osp.join(tmpdir, 'lightning_logs')
assert sorted(os.listdir(path_to_lightning_logs)) == sorted(['version_0'])

def get_last_checkpoint():
ckpts = os.listdir(tmpdir)
ckpts_map = {int(x.split("=")[1].split('.')[0]): osp.join(tmpdir, x) for x in ckpts if "epoch" in x}
num_ckpts = len(ckpts_map) - 1
return ckpts_map[num_ckpts]

for idx in range(1, 5):

# load from checkpoint
chk = get_last_checkpoint()
model = LogInTwoMethods.load_from_checkpoint(chk)
model = LogInTwoMethods.load_from_checkpoint(checkpoint_callback.best_model_path)
trainer = pl.Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
resume_from_checkpoint=chk,
enable_pl_optimizer=enable_pl_optimizer)

resume_from_checkpoint=checkpoint_callback.best_model_path,
enable_pl_optimizer=enable_pl_optimizer,
weights_summary=None,
progress_bar_refresh_rate=0,
)
trainer.fit(model)
trainer.test(model)
assert sorted(os.listdir(tmpdir)) == sorted(['epoch=00.ckpt', 'lightning_logs'])
assert sorted(os.listdir(path_to_lightning_logs)) == sorted([f'version_{i}' for i in range(idx + 1)])
trainer.test(model, verbose=False)
assert set(os.listdir(tmpdir)) == {'epoch=00.ckpt', 'lightning_logs'}
assert set(os.listdir(tmpdir.join("lightning_logs"))) == {f'version_{i}' for i in range(4)}


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
Expand All @@ -760,86 +693,71 @@ def test_checkpoint_repeated_strategy_extended(enable_pl_optimizer, tmpdir):
"""

class ExtendedBoringModel(BoringModel):

def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"val_loss": loss}

def validation_epoch_end(self, *_):
...

def assert_trainer_init(trainer):
assert not trainer.checkpoint_connector.has_trained
assert trainer.global_step == 0
assert trainer.current_epoch == 0

def get_last_checkpoint(ckpt_dir):
ckpts = os.listdir(ckpt_dir)
ckpts.sort()
return osp.join(ckpt_dir, ckpts[-1])
last = ckpt_dir.listdir(sort=True)[-1]
return str(last)

def assert_checkpoint_content(ckpt_dir):
chk = pl_load(get_last_checkpoint(ckpt_dir))
assert chk["epoch"] == epochs
assert chk["global_step"] == 4

def assert_checkpoint_log_dir(idx):
lightning_logs_path = osp.join(tmpdir, 'lightning_logs')
assert sorted(os.listdir(lightning_logs_path)) == [f'version_{i}' for i in range(idx + 1)]
assert len(os.listdir(ckpt_dir)) == epochs

def get_model():
model = ExtendedBoringModel()
model.validation_step_end = None
model.validation_epoch_end = None
return model
lightning_logs = tmpdir / 'lightning_logs'
actual = [d.basename for d in lightning_logs.listdir(sort=True)]
assert actual == [f'version_{i}' for i in range(idx + 1)]
assert len(ckpt_dir.listdir()) == epochs

ckpt_dir = osp.join(tmpdir, 'checkpoints')
ckpt_dir = tmpdir / 'checkpoints'
checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)
epochs = 2
limit_train_batches = 2

model = get_model()

trainer_config = dict(
default_root_dir=tmpdir,
max_epochs=epochs,
limit_train_batches=limit_train_batches,
limit_val_batches=3,
limit_test_batches=4,
enable_pl_optimizer=enable_pl_optimizer,
)

trainer = pl.Trainer(
**trainer_config,
callbacks=[checkpoint_cb],
)
trainer = pl.Trainer(**trainer_config)
assert_trainer_init(trainer)

model = ExtendedBoringModel()
trainer.fit(model)
assert trainer.checkpoint_connector.has_trained
assert trainer.global_step == epochs * limit_train_batches
assert trainer.current_epoch == epochs - 1
assert_checkpoint_log_dir(0)
assert_checkpoint_content(ckpt_dir)

trainer.test(model)
assert trainer.current_epoch == epochs - 1

assert_checkpoint_content(ckpt_dir)

for idx in range(1, 5):
chk = get_last_checkpoint(ckpt_dir)
assert_checkpoint_content(ckpt_dir)

checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)
model = get_model()

# load from checkpoint
trainer = pl.Trainer(
**trainer_config,
resume_from_checkpoint=chk,
callbacks=[checkpoint_cb],
)
trainer_config["callbacks"] = [ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)]
trainer = pl.Trainer(**trainer_config, resume_from_checkpoint=chk)
assert_trainer_init(trainer)

model = ExtendedBoringModel()
trainer.test(model)
assert not trainer.checkpoint_connector.has_trained
assert trainer.global_step == epochs * limit_train_batches
Expand Down
Loading

0 comments on commit 398f122

Please sign in to comment.