Skip to content

Commit

Permalink
Merge branch 'master' into fix-exception-chaining
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Oct 1, 2020
2 parents 9b19705 + e4e60e9 commit 73908d0
Show file tree
Hide file tree
Showing 23 changed files with 247 additions and 88 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for datamodules to save and load checkpoints when training ([#3563]https://github.com/PyTorchLightning/pytorch-lightning/pull/3563)

- Added support for datamodule in learning rate finder ([#3425](https://github.com/PyTorchLightning/pytorch-lightning/pull/3425))

### Changed

- Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251))
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,4 @@ Indices and tables
api/pytorch_lightning.utilities
api/pytorch_lightning.tuner
api/pytorch_lightning.plugins
api/pytorch_lightning.distributed
3 changes: 2 additions & 1 deletion pl_examples/basic_examples/image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,14 @@ def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss, on_epoch=True)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
loss = F.cross_entropy(y_hat, y)
self.log('valid_loss', loss)
self.log('valid_loss', loss, on_step=True)

def test_step(self, batch, batch_idx):
x, y = batch
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/accelerators/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict

try:
from apex import amp
Expand All @@ -21,6 +22,7 @@ class Accelerator(object):

def __init__(self, trainer):
self.trainer = trainer
self.dist = AttributeDict(rank=0, device=None)

def setup(self, model):
pass
Expand All @@ -31,6 +33,9 @@ def teardown(self):
def barrier(self, name: str = None):
pass

def broadcast(self, obj, src=0):
return obj

def train_or_test(self):
if self.trainer.testing:
results = self.trainer.run_test()
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/accelerators/ddp_base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.distributed.dist import LightningDistributed

try:
from hydra.core.hydra_config import HydraConfig
Expand All @@ -38,6 +39,7 @@ class DDPBase(Accelerator):

def __init__(self, trainer):
super().__init__(trainer)
self.dist = LightningDistributed()

def training_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
Expand Down Expand Up @@ -177,6 +179,9 @@ def ddp_train_tmp(self, process_idx, mp_queue, model, is_master=False, proc_offs
if self.trainer.global_rank == 0:
return results

def broadcast(self, obj, src=0):
return self.dist.broadcast(obj)

def set_world_ranks(self, process_idx):
raise NotImplementedError('to create a ddp backend, please implement set_world_ranks')

Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.distributed import find_free_network_port
from pytorch_lightning.distributed.dist import LightningDistributed

try:
from hydra.core.hydra_config import HydraConfig
Expand All @@ -41,6 +42,7 @@ def __init__(self, trainer, nprocs):
super().__init__(trainer)
self.mp_queue = None
self.nprocs = nprocs
self.dist = LightningDistributed()

def setup(self, model):
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port()))
Expand Down Expand Up @@ -174,6 +176,9 @@ def test_step(self, args):
def barrier(self, name: str = None):
torch_distrib.barrier()

def broadcast(self, obj, src=0):
return self.dist.broadcast(obj)

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/accelerators/dp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch import optim

from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.core import LightningModule
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.overrides.data_parallel import LightningDataParallel
from pytorch_lightning.utilities import AMPType
Expand All @@ -28,6 +28,7 @@ class DataParallelBackend(Accelerator):
def __init__(self, trainer):
super().__init__(trainer)
self.model_autocast_original_forward = None
self.dist = LightningDistributed()

def setup(self, model):
# call setup after the ddp process has connected
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/accelerators/gpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.distributed.dist import LightningDistributed


class GPUBackend(Accelerator):
amp_backend: AMPType

def __init__(self, trainer):
super().__init__(trainer)
self.dist = LightningDistributed()

def setup(self, model):

Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/accelerators/horovod_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,7 @@ def on_train_epoch_end(self):

def barrier(self, name: str = None):
hvd.join()

def broadcast(self, obj, src=0):
obj = hvd.broadcast_object(obj, src)
return obj
70 changes: 36 additions & 34 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn, rank_zero_info
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -176,16 +176,16 @@ def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):
self.best_model_score = checkpointed_state["best_model_score"]
self.best_model_path = checkpointed_state["best_model_path"]

@rank_zero_only
def save_checkpoint(self, trainer, pl_module):
"""
Performs the main logic around saving a checkpoint
Performs the main logic around saving a checkpoint.
This method runs on all ranks, it is the responsibility of `self.save_function`
to handle correct behaviour in distributed training, i.e., saving only on rank 0.
"""
epoch = trainer.current_epoch

if (
trainer.global_rank != 0 # only run on main process
or self.save_top_k == 0 # no models are saved
self.save_top_k == 0 # no models are saved
or self.period < 1 # no models are saved
or (epoch + 1) % self.period # skip epoch
or trainer.running_sanity_check # don't save anything during sanity check
Expand All @@ -207,13 +207,13 @@ def save_checkpoint(self, trainer, pl_module):

# callback supports multiple simultaneous modes
# here we call each mode sequentially
# Mode 1: save the last checkpoint
self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath)

# Mode 2: save all checkpoints OR only the top k
# Mode 1: save all checkpoints OR only the top k
if self.save_top_k:
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, epoch, filepath)

# Mode 2: save the last checkpoint
self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath)

def __validate_init_configuration(self):
if self.save_top_k is not None and self.save_top_k < -1:
raise MisconfigurationException(
Expand Down Expand Up @@ -255,7 +255,6 @@ def __init_ckpt_dir(self, filepath, save_top_k):
if self._fs.protocol == "file": # dont normalize remote paths
filepath = os.path.realpath(filepath)
self.dirpath, self.filename = os.path.split(filepath)
self._fs.makedirs(self.dirpath, exist_ok=True)

def __init_monitor_mode(self, monitor, mode):
torch_inf = torch.tensor(np.Inf)
Expand All @@ -276,24 +275,30 @@ def __init_monitor_mode(self, monitor, mode):

self.kth_value, self.mode = mode_dict[mode]

@rank_zero_only
def _del_model(self, filepath: str):
if self._fs.exists(filepath):
self._fs.rm(filepath)
log.debug(f"Removed checkpoint: {filepath}")

def _save_model(self, filepath: str, trainer, pl_module):
# in debugging, track when we save checkpoints
trainer.dev_debugger.track_checkpointing_history(filepath)

# make paths
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)
if trainer.is_global_zero:
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)

# delegate the saving to the model
# delegate the saving to the trainer
if self.save_function is not None:
self.save_function(filepath, self.save_weights_only)
else:
raise ValueError(".save_function() not set")

def check_monitor_top_k(self, current) -> bool:
if current is None:
return False

if self.save_top_k == -1:
return True

Expand Down Expand Up @@ -325,7 +330,7 @@ def _format_checkpoint_name(
filename = "{epoch}"
# check and parse user passed keys in the string
groups = re.findall(r"(\{.*?)[:\}]", filename)
if groups:
if len(groups) >= 0:
metrics["epoch"] = epoch
for group in groups:
name = group[1:]
Expand Down Expand Up @@ -364,7 +369,6 @@ def format_checkpoint_name(
ckpt_name = f"{filename}.ckpt"
return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name

@rank_zero_only
def __resolve_ckpt_dir(self, trainer, pl_module):
"""
Determines model checkpoint save directory at runtime. References attributes from the
Expand Down Expand Up @@ -396,18 +400,19 @@ def __resolve_ckpt_dir(self, trainer, pl_module):
if isinstance(trainer.logger.version, str)
else f"version_{trainer.logger.version}"
)

version, name = trainer.accelerator_backend.broadcast((version, trainer.logger.name))

ckpt_path = os.path.join(
save_dir, trainer.logger.name, version, "checkpoints"
save_dir, name, version, "checkpoints"
)
else:
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")

self.dirpath = ckpt_path

assert (
trainer.global_rank == 0
), "tried to make a checkpoint from non global_rank=0"
self._fs.makedirs(self.dirpath, exist_ok=True)
if trainer.is_global_zero:
self._fs.makedirs(self.dirpath, exist_ok=True)

def _add_backward_monitor_support(self, trainer):
metrics = trainer.logger_connector.callback_metrics
Expand All @@ -419,7 +424,7 @@ def _add_backward_monitor_support(self, trainer):
if self.monitor is None and 'checkpoint_on' in metrics:
self.monitor = 'checkpoint_on'

if self.save_top_k is None:
if self.save_top_k is None and self.monitor is not None:
self.save_top_k = 1

def _validate_monitor_key(self, trainer):
Expand Down Expand Up @@ -460,13 +465,18 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi

# when user ALSO asked for the 'last.ckpt' change the name
if self.save_last:
filename = self._format_checkpoint_name(
last_filepath = self._format_checkpoint_name(
self.CHECKPOINT_NAME_LAST, epoch, ckpt_name_metrics, prefix=self.prefix
)
last_filepath = os.path.join(self.dirpath, f"{filename}.ckpt")
last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt")

self._save_model(last_filepath, trainer, pl_module)
if self.last_model_path and self.last_model_path != last_filepath and (self.save_top_k != -1 or self.save_last):
if (
self.last_model_path
and self.last_model_path != last_filepath
and (self.save_top_k != -1 or self.save_last)
and trainer.is_global_zero
):
self._del_model(self.last_model_path)
self.last_model_path = last_filepath

Expand All @@ -479,18 +489,10 @@ def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, filepath):
if not isinstance(current, torch.Tensor) and current is not None:
current = torch.tensor(current, device=pl_module.device)

if current is None:
m = f"Can save best model only with {self.monitor} available, skipping."
if self.monitor == 'checkpoint_on':
m = (
'No checkpoint_on found. HINT: Did you set it in '
'EvalResult(checkpoint_on=tensor) or TrainResult(checkpoint_on=tensor)?'
)
rank_zero_warn(m, RuntimeWarning)
elif self.check_monitor_top_k(current):
if self.check_monitor_top_k(current):
self._update_best_and_save(filepath, current, epoch, trainer, pl_module)
elif self.verbose:
log.info(
rank_zero_info(
f"Epoch {epoch:d}: {self.monitor} was not in top {self.save_top_k}"
)

Expand Down Expand Up @@ -528,7 +530,7 @@ def _update_best_and_save(
self.best_model_score = self.best_k_models[self.best_model_path]

if self.verbose:
log.info(
rank_zero_info(
f"Epoch {epoch:d}: {self.monitor} reached"
f" {current:0.5f} (best {self.best_model_score:0.5f}),"
f" saving model to {filepath} as top {k}"
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ def __init__(
self._test_transforms = test_transforms
self._dims = dims if dims is not None else ()

# Pointer to the trainer object
self.trainer = None

# Private attrs to keep track of whether or not data hooks have been called yet
self._has_prepared_data = False
self._has_setup_fit = False
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pytorch_lightning.distributed.dist import LightningDistributed
36 changes: 36 additions & 0 deletions pytorch_lightning/distributed/dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import io
import torch
from typing import Any
from torch import distributed as torch_distrib


class LightningDistributed:

def __init__(self, rank=None, device=None):
self.rank = rank
self.device = device

def broadcast(self, obj: Any):
if self.rank == 0:
self._emit(obj)
else:
obj = self._receive()
return obj

def _emit(self, obj):
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
length_tensor = torch.tensor([len(data)]).long().to(self.device)
length_tensor = torch_distrib.broadcast(length_tensor, src=0)
data_tensor = torch.ByteTensor(data).to(self.device)
data_tensor = torch_distrib.broadcast(data_tensor, src=0)

def _receive(self):
length_tensor = torch.tensor([0]).long().to(self.device)
torch_distrib.broadcast(length_tensor, src=0)
data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device)
torch_distrib.broadcast(data_tensor, src=0)
buffer = io.BytesIO(data_tensor.cpu().numpy())
obj = torch.load(buffer)
return obj
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/connectors/logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ def _log_on_evaluation_epoch_end_metrics(self):
# track the final results for the dataloader
self.eval_loop_results.append(deepcopy(self.callback_metrics))

# actually log
self.log_metrics(logger_metrics, {}, step=self.trainer.global_step)

def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders):
if num_loaders == 1:
return metrics
Expand Down
Loading

0 comments on commit 73908d0

Please sign in to comment.