Skip to content

Commit

Permalink
Test resume behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
lucmos committed Jan 9, 2022
1 parent ca286dc commit ba02321
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 14 deletions.
24 changes: 24 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
import shutil
from pathlib import Path
from typing import Dict, Union

import pytest
import torch
from hydra import compose, initialize
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
Expand Down Expand Up @@ -59,6 +62,11 @@ def cfg_simple_train(cfg: DictConfig) -> DictConfig:
cfg.train.trainer.max_steps = TRAIN_MAX_NSTEPS
cfg.train.trainer.val_check_interval = TRAIN_MAX_NSTEPS

# Ensure the resuming is disabled
cfg.train.resume.ckpt_or_run_path = None
cfg.train.resume.training = None
cfg.train.resume.logging = None

return cfg


Expand Down Expand Up @@ -110,3 +118,19 @@ def run_trainings_not_dry(cfg_all_not_dry: DictConfig) -> str:
)
def run_trainings(cfg_all: DictConfig) -> str:
yield run(cfg=cfg_all)


#
# Utility functions
#
def get_checkpoint_path(storagedir: Union[str, Path]) -> Path:
ckpts_path = Path(storagedir) / "checkpoints"
checkpoint_path = next(ckpts_path.glob("*"))
assert checkpoint_path
return checkpoint_path


def load_checkpoint(storagedir: Union[str, Path]) -> Dict:
checkpoint = torch.load(get_checkpoint_path(storagedir))
assert checkpoint
return checkpoint
19 changes: 5 additions & 14 deletions tests/test_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from importlib import import_module
from pathlib import Path
from typing import Any, Dict, Union
from typing import Any, Dict

import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import LightningModule

from tests.conftest import load_checkpoint

from nn_template.pl_modules.pl_module import MyLightningModule
from nn_template.run import run

Expand All @@ -25,16 +26,6 @@ def test_load_checkpoint(run_trainings_not_dry: str, cfg_all_not_dry: DictConfig
assert sum(p.numel() for p in module.parameters())


def _load_checkpoint(storagedir: Union[str, Path]) -> Dict:
ckpts_path = Path(storagedir) / "checkpoints"
checkpoint_path = next(ckpts_path.glob("*"))
assert checkpoint_path

checkpoint = torch.load(checkpoint_path)
assert checkpoint
return checkpoint


def _check_cfg_in_checkpoint(checkpoint: Dict, _cfg: DictConfig) -> Dict:
assert "cfg" in checkpoint
assert checkpoint["cfg"] == _cfg
Expand All @@ -48,7 +39,7 @@ def _check_run_path_in_checkpoint(checkpoint: Dict) -> Dict:


def test_cfg_in_checkpoint(run_trainings_not_dry: str, cfg_all_not_dry: DictConfig) -> None:
checkpoint = _load_checkpoint(run_trainings_not_dry)
checkpoint = load_checkpoint(run_trainings_not_dry)

_check_cfg_in_checkpoint(checkpoint, cfg_all_not_dry)
_check_run_path_in_checkpoint(checkpoint)
Expand All @@ -64,7 +55,7 @@ def test_on_save_checkpoint_hook(cfg_all_not_dry: DictConfig) -> None:
cfg.nn.module._target_ = "tests.test_checkpoint.ModuleWithCustomCheckpoint"
output_path = Path(run(cfg))

checkpoint = _load_checkpoint(output_path)
checkpoint = load_checkpoint(output_path)

_check_cfg_in_checkpoint(checkpoint, cfg)
_check_run_path_in_checkpoint(checkpoint)
Expand Down
30 changes: 30 additions & 0 deletions tests/test_resume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
from omegaconf import DictConfig, OmegaConf
from pytest import TempPathFactory

from tests.conftest import TRAIN_MAX_NSTEPS, get_checkpoint_path, load_checkpoint

from nn_template.run import run


def test_resume(run_trainings_not_dry: str, cfg_all_not_dry: DictConfig, tmp_path_factory: TempPathFactory) -> None:
old_checkpoint_path = get_checkpoint_path(run_trainings_not_dry)

new_cfg = OmegaConf.create(cfg_all_not_dry)
new_storage_dir = tmp_path_factory.mktemp("resumed_training")

new_cfg.core.storage_dir = str(new_storage_dir)
new_cfg.train.trainer.max_steps = 2 * TRAIN_MAX_NSTEPS

new_cfg.train.resume.ckpt_or_run_path = str(old_checkpoint_path)
new_cfg.train.resume.training = True
new_cfg.train.resume.logging = False

new_training_dir = run(new_cfg)

old_checkpoint = torch.load(old_checkpoint_path)
new_checkpoint = load_checkpoint(new_training_dir)

assert old_checkpoint["run_path"] != new_checkpoint["run_path"]
assert old_checkpoint["global_step"] * 2 == new_checkpoint["global_step"]
assert new_checkpoint["epoch"] == 2

0 comments on commit ba02321

Please sign in to comment.