diff --git a/CHANGELOG.md b/CHANGELOG.md index bf8d002bce0e8..16cbd586c7a64 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added remaining `sklearn` metrics: `AveragePrecision`, `BalancedAccuracy`, `CohenKappaScore`, `DCG`, `Hamming`, `Hinge`, `Jaccard`, `MeanAbsoluteError`, `MeanSquaredError`, `MeanSquaredLogError`, `MedianAbsoluteError`, `R2Score`, `MeanPoissonDeviance`, `MeanGammaDeviance`, `MeanTweedieDeviance`, `ExplainedVariance` ([#2562](https://github.com/PyTorchLightning/pytorch-lightning/pull/2562)) +- Added support returning python scalars in DP ([#1935](https://github.com/PyTorchLightning/pytorch-lightning/pull/1935)) + ### Changed - Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f816726ddf1e1..d7a503c07aca4 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -168,7 +168,7 @@ def forward(self, batch): """ - def training_step(self, *args, **kwargs) -> Union[int, Dict[str, Union[Tensor, Dict[str, Tensor]]]]: + def training_step(self, *args, **kwargs) -> Union[int, Dict[str, Union[Tensor, Dict[str, Union[float, Tensor]]]]]: r""" Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger. @@ -186,8 +186,8 @@ def training_step(self, *args, **kwargs) -> Union[int, Dict[str, Union[Tensor, D When implementing :meth:`training_step`, return whatever you need in that step: - loss -> tensor scalar **REQUIRED** - - progress_bar -> Dict for progress bar display. Must have only tensors - - log -> Dict of metrics to add to logger. Must have only tensors (no images, etc) + - progress_bar -> Dict for progress bar display. Must have either scalar tensors or Python scalars + - log -> Dict of metrics to add to logger. Must have either scalar tensors or Python scalars (no images, etc) In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific. @@ -202,14 +202,14 @@ def training_step(self, batch, batch_idx): out = self(x) loss = self.loss(out, x) - logger_logs = {'training_loss': loss} # optional (MUST ALL BE TENSORS) + logger_logs = {'training_loss': loss} # optional # if using TestTubeLogger or TensorBoardLogger you can nest scalars - logger_logs = {'losses': logger_logs} # optional (MUST ALL BE TENSORS) + logger_logs = {'losses': logger_logs} # optional output = { 'loss': loss, # required - 'progress_bar': {'training_loss': loss}, # optional (MUST ALL BE TENSORS) + 'progress_bar': {'training_loss': loss}, # optional 'log': logger_logs } @@ -259,8 +259,8 @@ def training_end(self, *args, **kwargs): """ def training_epoch_end( - self, outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] - ) -> Dict[str, Dict[str, Tensor]]: + self, outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Union[float, Tensor]]]]] + ) -> Dict[str, Dict[str, Union[float, Tensor]]]: """Called at the end of the training epoch with the outputs of all training steps. .. code-block:: python @@ -334,7 +334,7 @@ def training_epoch_end(self, outputs): return results """ - def training_step_end(self, *args, **kwargs) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]: + def training_step_end(self, *args, **kwargs) -> Dict[str, Union[Tensor, Dict[str, Union[float, Tensor]]]]: """ Use this when training with dp or ddp2 because :meth:`training_step` will operate on only part of the batch. However, this is still optional @@ -358,8 +358,8 @@ def training_step_end(self, *args, **kwargs) -> Dict[str, Union[Tensor, Dict[str Dict with loss key and optional log or progress bar keys. - loss -> tensor scalar **REQUIRED** - - progress_bar -> Dict for progress bar display. Must have only tensors - - log -> Dict of metrics to add to logger. Must have only tensors (no images, etc) + - progress_bar -> Dict for progress bar display. Must have either scalar tensors or Python scalars + - log -> Dict of metrics to add to logger. Must have either scalar tensors or Python scalars (no images, etc) Examples: .. code-block:: python @@ -396,7 +396,7 @@ def training_step_end(self, outputs): See the :ref:`multi-gpu-training` guide for more details. """ - def validation_step(self, *args, **kwargs) -> Dict[str, Tensor]: + def validation_step(self, *args, **kwargs) -> Dict[str, Union[float, Tensor]]: r""" Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy. @@ -486,7 +486,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx): the model goes back to training mode and gradients are enabled. """ - def validation_step_end(self, *args, **kwargs) -> Dict[str, Tensor]: + def validation_step_end(self, *args, **kwargs) -> Dict[str, Union[float, Tensor]]: """ Use this when validating with dp or ddp2 because :meth:`validation_step` will operate on only part of the batch. However, this is still optional @@ -553,8 +553,8 @@ def validation_end(self, outputs): """ def validation_epoch_end( - self, outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] - ) -> Dict[str, Dict[str, Tensor]]: + self, outputs: Union[List[Dict[str, Union[float, Tensor]]], List[List[Dict[str, Union[float, Tensor]]]]] + ) -> Dict[str, Dict[str, Union[float, Tensor]]]: """ Called at the end of the validation epoch with the outputs of all validation steps. @@ -575,8 +575,8 @@ def validation_epoch_end( Dict or OrderedDict. May have the following optional keys: - - progress_bar (dict for progress bar display; only tensors) - - log (dict of metrics to add to logger; only tensors). + - progress_bar (dict for progress bar display; either scalar tensors or Python scalars) + - log (dict of metrics to add to logger; either scalar tensors or Python scalars). Note: If you didn't define a :meth:`validation_step`, this won't be called. @@ -630,7 +630,7 @@ def validation_epoch_end(self, outputs): return results """ - def test_step(self, *args, **kwargs) -> Dict[str, Tensor]: + def test_step(self, *args, **kwargs) -> Dict[str, Union[float, Tensor]]: r""" Operates on a single batch of data from the test set. In this step you'd normally generate examples or calculate anything of interest @@ -713,7 +713,7 @@ def test_step(self, batch, batch_idx, dataloader_idx): to training mode and gradients are enabled. """ - def test_step_end(self, *args, **kwargs) -> Dict[str, Tensor]: + def test_step_end(self, *args, **kwargs) -> Dict[str, Union[float, Tensor]]: """ Use this when testing with dp or ddp2 because :meth:`test_step` will operate on only part of the batch. However, this is still optional @@ -779,8 +779,8 @@ def test_end(self, outputs): """ def test_epoch_end( - self, outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] - ) -> Dict[str, Dict[str, Tensor]]: + self, outputs: Union[List[Dict[str, Union[float, Tensor]]], List[List[Dict[str, Union[float, Tensor]]]]] + ) -> Dict[str, Dict[str, Union[float, Tensor]]]: """ Called at the end of a test epoch with the output of all test steps. @@ -800,8 +800,8 @@ def test_epoch_end( Return: Dict or OrderedDict: Dict has the following optional keys: - - progress_bar -> Dict for progress bar display. Must have only tensors. - - log -> Dict of metrics to add to logger. Must have only tensors (no images, etc). + - progress_bar -> Dict for progress bar display. Must have either scalar tensors or Python scalars. + - log -> Dict of metrics to add to logger. Must have either scalar tensors or Python scalars (no images, etc). Note: If you didn't define a :meth:`test_step`, this won't be called. diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index c9c793cc89a2f..3945d770fe8d4 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -1,11 +1,13 @@ import itertools import threading from itertools import chain +from collections import Mapping, Iterable import torch from torch.cuda._utils import _get_device_index from torch.nn import DataParallel from torch.nn.parallel import DistributedDataParallel +from torch.nn.parallel._functions import Gather from pytorch_lightning.core.step_result import Result @@ -68,7 +70,7 @@ def forward(self, *inputs, **kwargs): if isinstance(outputs[0], Result): outputs = self.__gather_structured_result(outputs) else: - outputs = self.gather(outputs, self.output_device) + outputs = self.gather(outputs) return outputs def __gather_structured_result(self, outputs): @@ -81,7 +83,7 @@ def __gather_structured_result(self, outputs): for i, output in enumerate(outputs): del output['meta'] - outputs = self.gather(outputs, self.output_device) + outputs = self.gather(outputs) # pass minimize to constructor for TrainResult if 'minimize' in outputs: @@ -93,6 +95,39 @@ def __gather_structured_result(self, outputs): result['meta'] = meta return result + def gather(self, outputs): + r""" + Override the gather method to support python scalars as well. + """ + def gather_map(outputs): + elem = outputs[0] + elem_type = type(elem) + + if isinstance(elem, torch.Tensor): + return Gather.apply(self.output_device, self.dim, *outputs) + + if elem is None: + return None + + if isinstance(elem, Mapping): + if not all((len(elem) == len(d) for d in outputs)): + raise ValueError('All dicts must have the same number of keys') + return elem_type(((k, gather_map([d[k] for d in outputs])) + for k in elem)) + + if isinstance(elem, Iterable) and not isinstance(elem, str): + return elem_type(map(gather_map, zip(*outputs))) + + return outputs + + # Recursive function calls like this create reference cycles. + # Setting the function to None clears the refcycle. + try: + res = gather_map(outputs) + finally: + gather_map = None + return res + def parallel_apply(self, replicas, inputs, kwargs): return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) @@ -126,9 +161,8 @@ def forward(self, *inputs, **kwargs): # pragma: no-cover outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) output = self.gather(outputs, self.output_device) else: - # normal # output = self.module(*inputs, **kwargs) - # lightning (ddp_cpu) + # normal lightning (ddp_cpu) if self.module.training: output = self.module.training_step(*inputs, **kwargs) elif self.module.testing: diff --git a/pytorch_lightning/trainer/ignored_warnings.py b/pytorch_lightning/trainer/ignored_warnings.py index 9260720ec350f..e697c0c846be5 100644 --- a/pytorch_lightning/trainer/ignored_warnings.py +++ b/pytorch_lightning/trainer/ignored_warnings.py @@ -3,12 +3,9 @@ def ignore_scalar_return_in_dp(): # Users get confused by this warning so we silence it - m_1 = """ - Was asked to gather along dimension 0, but all - input tensors were scalars; will instead unsqueeze - and return a vector. - """ - warnings.filterwarnings('ignore', message=m_1) + warnings.filterwarnings('ignore', message='Was asked to gather along dimension 0, but all' + ' input tensors were scalars; will instead unsqueeze' + ' and return a vector.') ignore_scalar_return_in_dp() diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 3baed4ef9d81d..aced8b64a47ed 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -208,6 +208,10 @@ def reduce_distributed_output(self, output, num_gpus): if isinstance(output[k], dict): output[k] = self.reduce_distributed_output(output[k], num_gpus) + # compute the average of scalars + elif isinstance(output[k], list): + output[k] = sum(output[k]) / len(output[k]) + # do nothing when there's a scalar elif isinstance(output[k], torch.Tensor) and output[k].dim() == 0: pass diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index 189e496564da1..ec72f1127bec9 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -10,6 +10,7 @@ class TrainingStepVariations(ABC): """ Houses all variations of training steps """ + test_step_inf_loss = float('inf') def training_step(self, batch, batch_idx, optimizer_idx=None): @@ -17,18 +18,23 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): # forward pass x, y = batch x = x.view(x.size(0), -1) - y_hat = self(x) # calculate loss loss_val = self.loss(y, y_hat) - - # alternate possible outputs to test - output = OrderedDict({ - 'loss': loss_val, - 'progress_bar': {'some_val': loss_val * loss_val}, - 'log': {'train_some_val': loss_val * loss_val}, - }) + log_val = loss_val + + # alternate between tensors and scalars for "log" and "progress_bar" + if batch_idx % 2 == 0: + log_val = log_val.item() + + output = OrderedDict( + { + 'loss': loss_val, + 'progress_bar': {'some_val': log_val * log_val}, + 'log': {'train_some_val': log_val * log_val}, + } + ) return output def training_step__inf_loss(self, batch, batch_idx, optimizer_idx=None): diff --git a/tests/base/model_valid_epoch_ends.py b/tests/base/model_valid_epoch_ends.py index f09c382a38c82..8974224409624 100644 --- a/tests/base/model_valid_epoch_ends.py +++ b/tests/base/model_valid_epoch_ends.py @@ -25,6 +25,11 @@ def _mean(res, key): val_loss_mean = _mean(outputs, 'val_loss') val_acc_mean = _mean(outputs, 'val_acc') + # alternate between tensor and scalar + if self.current_epoch % 2 == 0: + val_loss_mean = val_loss_mean.item() + val_acc_mean = val_acc_mean.item() + metrics_dict = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean} results = {'progress_bar': metrics_dict, 'log': metrics_dict} return results @@ -54,6 +59,6 @@ def _mean(res, key): results = { 'val_loss': torch.stack([v for k, v in pbar.items() if k.startswith('val_loss')]).mean(), 'progress_bar': pbar, - 'log': logs + 'log': logs, } return results