Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor load in checkpoint connector #4593

Merged
merged 23 commits into from
Dec 13, 2020
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 51 additions & 58 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,20 @@

import io
import os
from pathlib import Path
import re
import signal
from abc import ABC
from subprocess import call
from typing import Union

import torch
import torch.distributed as torch_distrib

import pytorch_lightning
from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities import AMPType, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
Expand Down Expand Up @@ -63,16 +65,17 @@ def restore_weights(self, model: LightningModule):
if self.trainer.on_gpu:
torch.cuda.empty_cache()

# if script called from hpc resubmit, load weights
did_restore_hpc_weights = self.restore_hpc_weights_if_needed(model)
# 1. Attempt to restore states from HPC checkpoint
dir_path_hpc = str(self.trainer.weights_save_path)
max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_")
if max_suffix is not None:
checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt'
self.hpc_load(checkpoint_path, self.trainer.on_gpu)
rank_zero_info(f'restored hpc model from: {checkpoint_path}')

# clear cache after restore
if self.trainer.on_gpu:
torch.cuda.empty_cache()

if not did_restore_hpc_weights:
if self.trainer.resume_from_checkpoint is not None:
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu)
# 2. Attempt to restore states from `resume_from_checkpoint` file
elif self.trainer.resume_from_checkpoint is not None:
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu)

# wait for all to catch up
self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights')
Expand All @@ -83,24 +86,14 @@ def restore_weights(self, model: LightningModule):

def restore(self, checkpoint_path: str, on_gpu: bool):
"""
Load model/training states from the checkpoint file through file-read and state-restore.
Also restores all training state like:
- epoch
- callbacks
- schedulers
- optimizer
In detail, check return value description of `dump_checkpoint`
Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
All restored states are listed in return value description of `dump_checkpoint`.
"""

# if on_gpu:
# checkpoint = torch.load(checkpoint_path)
# else:
# load on CPU first
# read a checkpoint dictionary object from the checkpoint file at `checkpoint_path`
# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

# restore states from the checkpoint dictionary object
# load model state
# acquire the model
model = self.trainer.get_model()

# restore model and datamodule state
Expand All @@ -117,14 +110,14 @@ def restore_model_state(self, model: LightningModule, checkpoint) -> None:
Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object
"""

# give the datamodule a chance to load something
# restore datamodule states
if self.trainer.datamodule is not None:
self.trainer.datamodule.on_load_checkpoint(checkpoint)

# give model a chance to restore something
# hook for arbitrary processing before model restore
tarepan marked this conversation as resolved.
Show resolved Hide resolved
model.on_load_checkpoint(checkpoint)

# restore the state_dict on the model
# restore model state_dict
model.load_state_dict(checkpoint['state_dict'])

def restore_training_state(self, checkpoint):
Expand Down Expand Up @@ -198,23 +191,6 @@ def restore_training_state(self, checkpoint):
for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers):
scheduler['scheduler'].load_state_dict(lrs_state)

def restore_hpc_weights_if_needed(self, model: LightningModule):
"""If there is a set of hpc weights, use as signal to restore model."""
did_restore = False

# look for hpc weights
folderpath = str(self.trainer.weights_save_path)
fs = get_filesystem(folderpath)
if fs.exists(folderpath):
files = [os.path.basename(f['name']) for f in fs.listdir(folderpath)]
hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]

# if hpc weights exist restore model
if len(hpc_weight_paths) > 0:
self.hpc_load(folderpath, self.trainer.on_gpu)
did_restore = True
return did_restore

# ----------------------------------
# PRIVATE OPS
# ----------------------------------
Expand All @@ -227,7 +203,8 @@ def hpc_save(self, folderpath: str, logger):
# save logger to make sure we get all the metrics
logger.save()

ckpt_number = self.max_ckpt_in_folder(folderpath) + 1
max_suffix = self.max_ckpt_in_folder(folderpath)
ckpt_number = (max_suffix if max_suffix is not None else 0) + 1

fs.makedirs(folderpath, exist_ok=True)
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')
Expand Down Expand Up @@ -340,36 +317,52 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:

return checkpoint

def hpc_load(self, folderpath, on_gpu):
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, self.max_ckpt_in_folder(folderpath))
def hpc_load(self, checkpoint_path: str, on_gpu: bool):
"""
Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc.
All restored states are listed in return value description of `dump_checkpoint`.
"""

# load on CPU first
checkpoint = pl_load(filepath, map_location=lambda storage, loc: storage)
# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

# load model state
# acquire the model
model = self.trainer.get_model()

# restore states from 'PyTorch-Lightning checkpoint' dictionary object
# restore model and datamodule state
self.restore_model_state(model, checkpoint)

if self.trainer.root_gpu is not None:
model.cuda(self.trainer.root_gpu)

# load training state (affects trainer only)
# restore training state
self.restore_training_state(checkpoint)

# call model hook
# call hpc specific hook
model.on_hpc_load(checkpoint)

log.info(f'restored hpc model from: {filepath}')
def max_ckpt_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Union[None, int]:
tarepan marked this conversation as resolved.
Show resolved Hide resolved
"""List up files in `dir_path` with name_key, then yield maximum suffix number.

Args:
dir_path: path of directory which may contain files which name include `name_key`
tarepan marked this conversation as resolved.
Show resolved Hide resolved

Returns:
None if no-corresponding-file else maximum suffix number
"""

# check directory existence
fs = get_filesystem(dir_path)
if not fs.exists(dir_path):
return None

def max_ckpt_in_folder(self, path, name_key='ckpt_'):
fs = get_filesystem(path)
files = [os.path.basename(f["name"]) for f in fs.listdir(path)]
# check corresponding file existence
files = [os.path.basename(f["name"]) for f in fs.listdir(dir_path)]
files = [x for x in files if name_key in x]
if len(files) == 0:
return 0
return None

# extract suffix number
ckpt_vs = []
for name in files:
name = name.split(name_key)[-1]
Expand Down
8 changes: 6 additions & 2 deletions tests/base/develop_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,13 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, wi
trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = \
trainer.init_optimizers(pretrained_model)

# test HPC loading / saving
# test HPC saving
trainer.checkpoint_connector.hpc_save(save_dir, logger)
trainer.checkpoint_connector.hpc_load(save_dir, on_gpu=on_gpu)
# test HPC loading
max_suffix = trainer.checkpoint_connector.max_ckpt_in_folder(save_dir)
tarepan marked this conversation as resolved.
Show resolved Hide resolved
ckpt_number = max_suffix if max_suffix is not None else 0
checkpoint_path = f'{save_dir}/hpc_ckpt_{ckpt_number}.ckpt'
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu)


def run_prediction(dataloader, trained_model, dp=False, min_acc=0.50):
Expand Down
8 changes: 6 additions & 2 deletions tests/models/data/horovod/train_default_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,13 @@ def run_test_from_config(trainer_options):
for dataloader in test_loaders:
run_prediction(dataloader, pretrained_model)

# test HPC loading / saving
# test HPC saving
trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger)
trainer.checkpoint_connector.hpc_load(ckpt_path, on_gpu=args.on_gpu)
# test HPC loading
max_suffix = trainer.checkpoint_connector.max_ckpt_in_folder(ckpt_path)
ckpt_number = max_suffix if max_suffix is not None else 0
checkpoint_path = f'{ckpt_path}/hpc_ckpt_{ckpt_number}.ckpt'
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=args.on_gpu)

if args.on_gpu:
trainer = Trainer(gpus=1, distributed_backend='horovod', max_epochs=1)
Expand Down