Skip to content

Commit

Permalink
Support state restoration of logged results 2/2(#7966)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
4 people authored Jun 25, 2021
1 parent ad95710 commit 24db914
Show file tree
Hide file tree
Showing 34 changed files with 294 additions and 172 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fault-tolerant training
* Add `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948))
* Checkpoint the loop results ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966))


- Add `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966))


- Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734))
Expand Down
17 changes: 17 additions & 0 deletions docs/source/advanced/multi_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,23 @@ Note if you use any built in metrics or custom metrics that use the :doc:`Metric
# Add sync_dist=True to sync logging across all GPU workers
self.log('test_loss', loss, on_step=True, on_epoch=True, sync_dist=True)

It is possible to perform some computation manually and log the reduced result on rank 0 as follows:

.. testcode::

def test_step(self, batch, batch_idx):
x, y = batch
tensors = self(x)
return tensors

def test_epoch_end(self, outputs):
mean = torch.mean(self.all_gather(outputs))

# When logging only on rank 0, don't forget to add
# ``rank_zero_only=True`` to avoid deadlocks on synchronization.
if self.trainer.is_global_zero:
self.log("my_reduced_metric", mean, rank_zero_only=True)


Make models pickleable
^^^^^^^^^^^^^^^^^^^^^^
Expand Down
39 changes: 35 additions & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
Expand Down Expand Up @@ -112,6 +113,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._automatic_optimization: bool = True
self._truncated_bptt_steps: int = 0
self._param_requires_grad_state = dict()
self._metric_attributes: Optional[Dict[int, str]] = None

def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]:
if use_pl_optimizer:
Expand Down Expand Up @@ -273,6 +275,8 @@ def log(
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
batch_size: Optional[int] = None,
metric_attribute: Optional[str] = None,
rank_zero_only: Optional[bool] = None,
) -> None:
"""
Log a key, value
Expand Down Expand Up @@ -310,6 +314,10 @@ def log(
each dataloader to not mix values
batch_size: Current batch_size. This will be directly inferred from the loaded batch,
but some data structures might need to explicitly provide it.
metric_attribute: To restore the metric state, Lightning requires the reference of the
:class:`torchmetrics.Metric` in your model. This is found automatically if it is a model attribute.
rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which
would produce a deadlock as not all processes would perform this log call.
"""
if tbptt_reduce_fx is not None:
rank_zero_deprecation(
Expand Down Expand Up @@ -346,7 +354,7 @@ def log(
results = self.trainer._results
assert results is not None
assert self._current_fx_name is not None
results.fx_validator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)
FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)

# make sure user doesn't introduce logic for multi-dataloaders
if "/dataloader_idx_" in name:
Expand All @@ -362,6 +370,27 @@ def log(
# reset any tensors for the new hook name
results.reset(metrics=False, fx=self._current_fx_name)

if metric_attribute is None and isinstance(value, Metric):
if self._metric_attributes is None:
# compute once
self._metric_attributes = {
id(module): name
for name, module in self.named_children() if isinstance(module, Metric)
}
if not self._metric_attributes:
raise MisconfigurationException(
"Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
" You can fix this by setting an attribute for the metric in your `LightningModule`."
)
# try to find the passed metric in the LightningModule
metric_attribute = self._metric_attributes.get(id(value))
if metric_attribute is None:
raise MisconfigurationException(
"Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
f" You can fix this by calling `self.log({name}, ..., metric_attribute=name)` where `name` is one"
f" of {list(self._metric_attributes.values())}"
)

results.log(
self._current_fx_name,
name,
Expand All @@ -374,9 +403,11 @@ def log(
enable_graph=enable_graph,
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
batch_size=batch_size,
sync_dist=sync_dist,
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp_if_available,
sync_dist=sync_dist and distributed_available(),
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp,
sync_dist_group=sync_dist_group,
metric_attribute=metric_attribute,
rank_zero_only=rank_zero_only,
)

self.trainer.logger_connector._current_fx = self._current_fx_name
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import numpy as np
import torch

import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_only


Expand Down Expand Up @@ -300,7 +300,7 @@ def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs):
kwargs: Optional keywoard arguments, depends on the specific logger being used
"""

def log_graph(self, model: LightningModule, input_array=None) -> None:
def log_graph(self, model: 'pl.LightningModule', input_array=None) -> None:
"""
Record model graph
Expand Down Expand Up @@ -396,7 +396,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
for logger in self._logger_iterable:
logger.log_hyperparams(params)

def log_graph(self, model: LightningModule, input_array=None) -> None:
def log_graph(self, model: 'pl.LightningModule', input_array=None) -> None:
for logger in self._logger_iterable:
logger.log_graph(model, input_array)

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch
from torch import is_tensor

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -318,6 +318,6 @@ def __getstate__(self):
state["_experiment"] = None
return state

def log_graph(self, model: LightningModule, input_array=None) -> None:
def log_graph(self, model: 'pl.LightningModule', input_array=None) -> None:
if self._experiment is not None:
self._experiment.set_model_graph(model)
4 changes: 2 additions & 2 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard.summary import hparams

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn
Expand Down Expand Up @@ -223,7 +223,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
raise ValueError(m) from ex

@rank_zero_only
def log_graph(self, model: LightningModule, input_array=None):
def log_graph(self, model: 'pl.LightningModule', input_array=None):
if self._log_graph:
if input_array is None:
input_array = model.example_input_array
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loggers/test_tube.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from argparse import Namespace
from typing import Any, Dict, Optional, Union

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only
Expand Down Expand Up @@ -153,7 +153,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
self.experiment.log(metrics, global_step=step)

@rank_zero_only
def log_graph(self, model: LightningModule, input_array=None):
def log_graph(self, model: 'pl.LightningModule', input_array=None):
if self._log_graph:
if input_array is None:
input_array = model.example_input_array
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/overrides/fairscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE

Expand All @@ -23,7 +23,7 @@ class LightningShardedDataParallel(_LightningModuleWrapperBase):
# Just do this for later docstrings
pass

def unwrap_lightning_module_sharded(wrapped_model) -> LightningModule:
def unwrap_lightning_module_sharded(wrapped_model) -> 'pl.LightningModule':
model = wrapped_model
if isinstance(model, ShardedDataParallel):
model = model.module
Expand Down
7 changes: 3 additions & 4 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType
from pytorch_lightning.utilities.types import _PARAMETERS
Expand Down Expand Up @@ -50,7 +49,7 @@ def dispatch(self, trainer: 'pl.Trainer') -> None:

def backward(
self,
model: LightningModule,
model: 'pl.LightningModule',
closure_loss: Tensor,
optimizer: Optimizer,
opt_idx: int,
Expand All @@ -76,7 +75,7 @@ def backward(

# do backward pass
# TODO: not entirely sure, why we need this
if model is not None and isinstance(model, LightningModule):
if model is not None and isinstance(model, pl.LightningModule):
model.backward(closure_loss, optimizer, opt_idx, **kwargs)

# TODO: avoid dev_debugger and track these calls with mock
Expand Down Expand Up @@ -118,7 +117,7 @@ def reinit_scheduler_properties(optimizers: Sequence[Optimizer], schedulers: Seq

def pre_optimizer_step(
self,
pl_module: LightningModule,
pl_module: 'pl.LightningModule',
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch.nn as nn
from torch.optim import Optimizer

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand All @@ -33,7 +33,7 @@ class LightningDoublePrecisionModule(_LightningPrecisionModuleWrapperBase):
pl_module: the model to wrap
"""

def __init__(self, pl_module: LightningModule):
def __init__(self, pl_module: 'pl.LightningModule'):
super().__init__(pl_module)

@staticmethod
Expand Down Expand Up @@ -96,7 +96,7 @@ def connect(
incoming floating point data to double (``torch.float64``) precision. Does not alter `optimizers` or
`lr_schedulers`.
"""
model = cast(LightningModule, model.to(dtype=torch.float64))
model = cast(pl.LightningModule, model.to(dtype=torch.float64))
model = LightningDoublePrecisionModule(model)

return super().connect(model, optimizers, lr_schedulers)
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,9 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
checkpoint_callback = self.lightning_module.trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None

# requires to compute the state_dict on all processes in case Metrics are present
state_dict = self.lightning_module.state_dict()

if self.global_rank == 0 and self.mp_queue is not None:
rank_zero_warn("cleaning up ddp environment...")

Expand All @@ -286,7 +289,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
atomic_save(self.on_save(self.lightning_module.state_dict()), last_path)
atomic_save(self.on_save(state_dict), last_path)

# todo, pass complete checkpoint as state dictionary
self.mp_queue.put(best_model_path)
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import torch
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
Expand All @@ -37,7 +37,7 @@

class LightningIPUModule(_LightningModuleWrapperBase):

def __init__(self, pl_module: LightningModule, precision: Union[str, int]):
def __init__(self, pl_module: 'pl.LightningModule', precision: Union[str, int]):
super().__init__(pl_module)
self.precision = precision

Expand Down Expand Up @@ -184,7 +184,7 @@ def _validate_opts(self, opts: 'poptorch.Options', training: bool) -> None:
opts.Training.set(gradient_accumulation=1)

@property
def lightning_module(self) -> Optional[LightningModule]:
def lightning_module(self) -> Optional['pl.LightningModule']:
return self.model.module if isinstance(self.model, LightningIPUModule) else self.model

def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
from torch.optim import Optimizer

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import is_lightning_optimizer
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.states import TrainerFn
Expand Down Expand Up @@ -86,7 +86,7 @@ def _optim_state_dict(self, optimizer):
return optimizer.state_dict()

@property
def lightning_module(self) -> LightningModule:
def lightning_module(self) -> 'pl.LightningModule':
if not _FAIRSCALE_AVAILABLE: # pragma: no cover
raise MisconfigurationException(
"`DDPShardedPlugin` requires `fairscale` to be installed."
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
from torch.optim import Optimizer

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerFn
Expand Down Expand Up @@ -71,7 +71,7 @@ def _optim_state_dict(self, optimizer):
return optimizer.state_dict()

@property
def lightning_module(self) -> LightningModule:
def lightning_module(self) -> 'pl.LightningModule':
if not _FAIRSCALE_AVAILABLE: # pragma: no cover
raise MisconfigurationException(
"`DDPSpawnShardedPlugin` requires `fairscale` to be installed."
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
checkpoint_callback = self.lightning_module.trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None

# requires to compute the state_dict on all processes in case Metrics are present
state_dict = self.lightning_module.state_dict()

if self.mp_queue is not None:
rank_zero_warn("cleaning up tpu spawn environment...")

Expand All @@ -195,7 +198,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
self.save(self.lightning_module.state_dict(), last_path)
self.save(state_dict, last_path)

if self.local_rank == 0:
# todo, pass complete checkpoint as state dictionary
Expand Down
Loading

0 comments on commit 24db914

Please sign in to comment.