Skip to content

Commit

Permalink
Implement resuming behaviour (#25)
Browse files Browse the repository at this point in the history
* Remove unnecessary on save checkpoint hook injection

* Test the run path is stored in the checkpoint

* Implement resume logic

* Test resume behaviour

* Refactor resume with semantic options
  • Loading branch information
lucmos authored Jan 10, 2022
1 parent 560314f commit 78b3d6c
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 22 deletions.
6 changes: 5 additions & 1 deletion conf/train/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@ trainer:
gpus: 1
precision: 32
max_epochs: 3
max_steps: 1000
max_steps: 10000
accumulate_grad_batches: 1
num_sanity_val_steps: 2
gradient_clip_val: 10.0
val_check_interval: 1.0
deterministic: ${train.deterministic}

restore:
ckpt_or_run_path: null
mode: null # null, continue, hotstart

monitor:
metric: 'loss/val'
mode: 'min'
Expand Down
69 changes: 59 additions & 10 deletions src/nn_template/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
# See https://github.com/PyTorchLightning/pytorch-lightning/issues/1503
#
# Force the execution of __init__.py if this file is executed directly.
import nn_template # isort:skip # noqa

import logging
from typing import List
from operator import xor
from typing import List, Optional, Tuple

import hydra
import omegaconf
Expand All @@ -17,12 +16,31 @@
from pytorch_lightning.loggers.base import DummyLogger

from nn_core.common import PROJECT_ROOT
from nn_core.hooks import OnSaveCheckpointInjection
from nn_core.model_logging import NNLogger
from nn_core.resume import resolve_ckpt, resolve_run_path, resolve_run_version

import nn_template # isort:skip # noqa


pylogger = logging.getLogger(__name__)


RESUME_MODES = {
"continue": {
"restore_model": True,
"restore_run": True,
},
"hotstart": {
"restore_model": True,
"restore_run": False,
},
None: {
"restore_model": False,
"restore_run": False,
},
}


def build_callbacks(cfg: DictConfig) -> List[Callback]:
callbacks: List[Callback] = []

Expand All @@ -33,6 +51,38 @@ def build_callbacks(cfg: DictConfig) -> List[Callback]:
return callbacks


def parse_restore(restore_cfg: DictConfig) -> Tuple[Optional[str], Optional[str]]:
ckpt_or_run_path = restore_cfg.ckpt_or_run_path
resume_mode = restore_cfg.mode

resume_ckpt_path = None
resume_run_version = None

if xor(bool(ckpt_or_run_path), bool(resume_mode)):
pylogger.warning(f"Inconsistent resume modality {resume_mode} and checkpoint path '{ckpt_or_run_path}'")

if resume_mode not in RESUME_MODES:
message = f"Unsupported resume mode {resume_mode}. Available resume modes are: {RESUME_MODES}"
pylogger.error(message)
raise ValueError(message)

flags = RESUME_MODES[resume_mode]
restore_model = flags["restore_model"]
restore_run = flags["restore_run"]

if ckpt_or_run_path is not None:
if restore_model:
resume_ckpt_path = resolve_ckpt(ckpt_or_run_path)
pylogger.info(f"Resume training from: '{resume_ckpt_path}'")

if restore_run:
run_path = resolve_run_path(ckpt_or_run_path)
resume_run_version = resolve_run_version(run_path=run_path)
pylogger.info(f"Resume logging to: '{run_path}'")

return resume_ckpt_path, resume_run_version


def run(cfg: DictConfig) -> str:
"""Generic train loop.
Expand All @@ -50,26 +100,25 @@ def run(cfg: DictConfig) -> str:
cfg.nn.data.num_workers.val = 0
cfg.nn.data.num_workers.test = 0

resume_ckpt_path, resume_run_version = parse_restore(cfg.train.restore)

# Instantiate datamodule
pylogger.info(f"Instantiating <{cfg.nn.data['_target_']}>")
datamodule: pl.LightningDataModule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False)

# Instantiate model
pylogger.info(f"Instantiating <{cfg.nn.module['_target_']}>")
model: pl.LightningModule = hydra.utils.instantiate(cfg.nn.module, _recursive_=False)
model.on_save_checkpoint = OnSaveCheckpointInjection(cfg=cfg, on_save_checkpoint=model.on_save_checkpoint)

# Instantiate the callbacks
callbacks: List[Callback] = build_callbacks(cfg=cfg.train.callbacks)

storage_dir: str = cfg.core.storage_dir

# The logger attribute will be filled by the NNLoggerConfiguration callback.
logger: NNLogger = NNLogger(logger=DummyLogger(), storage_dir=storage_dir, cfg=cfg)

pylogger.info("Instantiating the Trainer")
logger: NNLogger = NNLogger(logger=DummyLogger(), storage_dir=storage_dir, cfg=cfg, resume_id=resume_run_version)

# The Lightning core, the Trainer
pylogger.info("Instantiating the <Trainer>")
trainer = pl.Trainer(
default_root_dir=storage_dir,
logger=logger,
Expand All @@ -78,7 +127,7 @@ def run(cfg: DictConfig) -> str:
)

pylogger.info("Starting training!")
trainer.fit(model=model, datamodule=datamodule)
trainer.fit(model=model, datamodule=datamodule, ckpt_path=resume_ckpt_path)

if fast_dev_run:
pylogger.info("Skipping testing in 'fast_dev_run' mode!")
Expand Down
23 changes: 23 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
import shutil
from pathlib import Path
from typing import Dict, Union

import pytest
import torch
from hydra import compose, initialize
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
Expand Down Expand Up @@ -59,6 +62,10 @@ def cfg_simple_train(cfg: DictConfig) -> DictConfig:
cfg.train.trainer.max_steps = TRAIN_MAX_NSTEPS
cfg.train.trainer.val_check_interval = TRAIN_MAX_NSTEPS

# Ensure the resuming is disabled
cfg.train.restore.ckpt_or_run_path = None
cfg.train.restore.mode = None

return cfg


Expand Down Expand Up @@ -110,3 +117,19 @@ def run_trainings_not_dry(cfg_all_not_dry: DictConfig) -> str:
)
def run_trainings(cfg_all: DictConfig) -> str:
yield run(cfg=cfg_all)


#
# Utility functions
#
def get_checkpoint_path(storagedir: Union[str, Path]) -> Path:
ckpts_path = Path(storagedir) / "checkpoints"
checkpoint_path = next(ckpts_path.glob("*"))
assert checkpoint_path
return checkpoint_path


def load_checkpoint(storagedir: Union[str, Path]) -> Dict:
checkpoint = torch.load(get_checkpoint_path(storagedir))
assert checkpoint
return checkpoint
29 changes: 18 additions & 11 deletions tests/test_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from importlib import import_module
from pathlib import Path
from typing import Any, Dict, Union
from typing import Any, Dict

import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import LightningModule

from tests.conftest import load_checkpoint

from nn_template.pl_modules.pl_module import MyLightningModule
from nn_template.run import run

Expand All @@ -25,20 +26,23 @@ def test_load_checkpoint(run_trainings_not_dry: str, cfg_all_not_dry: DictConfig
assert sum(p.numel() for p in module.parameters())


def _check_cfg_in_checkpoint(storagedir: Union[str, Path], _cfg: DictConfig) -> Dict:
ckpts_path = Path(storagedir) / "checkpoints"
checkpoint_path = next(ckpts_path.glob("*"))
assert checkpoint_path

checkpoint = torch.load(checkpoint_path)
def _check_cfg_in_checkpoint(checkpoint: Dict, _cfg: DictConfig) -> Dict:
assert "cfg" in checkpoint
assert checkpoint["cfg"] == _cfg

return checkpoint

def _check_run_path_in_checkpoint(checkpoint: Dict) -> Dict:
assert "run_path" in checkpoint
assert checkpoint["run_path"]
checkpoint["run_path"]: str
assert checkpoint["run_path"].startswith("//")


def test_cfg_in_checkpoint(run_trainings_not_dry: str, cfg_all_not_dry: DictConfig) -> None:
_check_cfg_in_checkpoint(run_trainings_not_dry, cfg_all_not_dry)
checkpoint = load_checkpoint(run_trainings_not_dry)

_check_cfg_in_checkpoint(checkpoint, cfg_all_not_dry)
_check_run_path_in_checkpoint(checkpoint)


class ModuleWithCustomCheckpoint(MyLightningModule):
Expand All @@ -51,7 +55,10 @@ def test_on_save_checkpoint_hook(cfg_all_not_dry: DictConfig) -> None:
cfg.nn.module._target_ = "tests.test_checkpoint.ModuleWithCustomCheckpoint"
output_path = Path(run(cfg))

checkpoint = _check_cfg_in_checkpoint(output_path, cfg)
checkpoint = load_checkpoint(output_path)

_check_cfg_in_checkpoint(checkpoint, cfg)
_check_run_path_in_checkpoint(checkpoint)

assert "test_key" in checkpoint
assert checkpoint["test_key"] == "test_value"
29 changes: 29 additions & 0 deletions tests/test_resume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
from omegaconf import DictConfig, OmegaConf
from pytest import TempPathFactory

from tests.conftest import TRAIN_MAX_NSTEPS, get_checkpoint_path, load_checkpoint

from nn_template.run import run


def test_resume(run_trainings_not_dry: str, cfg_all_not_dry: DictConfig, tmp_path_factory: TempPathFactory) -> None:
old_checkpoint_path = get_checkpoint_path(run_trainings_not_dry)

new_cfg = OmegaConf.create(cfg_all_not_dry)
new_storage_dir = tmp_path_factory.mktemp("resumed_training")

new_cfg.core.storage_dir = str(new_storage_dir)
new_cfg.train.trainer.max_steps = 2 * TRAIN_MAX_NSTEPS

new_cfg.train.restore.ckpt_or_run_path = str(old_checkpoint_path)
new_cfg.train.restore.mode = "hotstart"

new_training_dir = run(new_cfg)

old_checkpoint = torch.load(old_checkpoint_path)
new_checkpoint = load_checkpoint(new_training_dir)

assert old_checkpoint["run_path"] != new_checkpoint["run_path"]
assert old_checkpoint["global_step"] * 2 == new_checkpoint["global_step"]
assert new_checkpoint["epoch"] == 2

0 comments on commit 78b3d6c

Please sign in to comment.