diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index d6fc6ce5fe64e..c25ffb40cb505 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -21,10 +21,10 @@ # Packages /pytorch_lightning/accelerators @williamfalcon @tchaton @SeanNaren @awaelchli @justusschock @kaushikb11 /pytorch_lightning/callbacks @williamfalcon @tchaton @carmocca @borda @kaushikb11 -/pytorch_lightning/cluster_environments @borda @tchaton @SeanNaren @carmocca @kaushikb11 /pytorch_lightning/core @tchaton @SeanNaren @borda @carmocca @justusschock @kaushikb11 /pytorch_lightning/distributed @williamfalcon @tchaton @awaelchli @kaushikb11 /pytorch_lightning/loggers @tchaton @awaelchli @borda +/pytorch_lightning/loggers/wandb.py @borisdayma /pytorch_lightning/loops @tchaton @awaelchli @justusschock @carmocca /pytorch_lightning/overrides @tchaton @SeanNaren @borda /pytorch_lightning/plugins @tchaton @SeanNaren @awaelchli @justusschock @@ -38,11 +38,6 @@ /pytorch_lightning/trainer/connectors/logger_connector @tchaton @carmocca /pytorch_lightning/trainer/progress.py @tchaton @awaelchli @carmocca -# Metrics -/pytorch_lightning/metrics/ @SkafteNicki @ananyahjha93 @justusschock -/tests/metrics/ @SkafteNicki @ananyahjha93 @justusschock -/docs/source/metrics.rst @SkafteNicki @ananyahjha93 @justusschock - # API /pytorch_lightning/callbacks/base.py @williamfalcon @awaelchli @ananthsub @carmocca /pytorch_lightning/core/datamodule.py @williamFalcon @awaelchli @ananthsub @carmocca diff --git a/CHANGELOG.md b/CHANGELOG.md index 56dd27e64fccc..662f0435892b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added Fault Tolerant Training to `DataFetcher` ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891)) * Replaced old prefetch iterator with new `DataFetcher` in training loop ([#8953](https://github.com/PyTorchLightning/pytorch-lightning/pull/8953)) * Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950)) + * Converted state to tuple explicitly when setting Python random state ([#9401](https://github.com/PyTorchLightning/pytorch-lightning/pull/9401)) + - Checkpoint saving & loading extensibility: * Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743)) @@ -107,9 +109,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `on_exception` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183)) + - Add a warning to deepspeed when inferring batch size ([#9221](https://github.com/PyTorchLightning/pytorch-lightning/pull/9221)) +- Added `inference_mode` for evaluation and prediction ([8813](https://github.com/PyTorchLightning/pytorch-lightning/pull/8813)) + + ### Changed - Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770)) @@ -173,6 +179,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `DataModule` properties: `train_transforms`, `val_transforms`, `test_transforms`, `size`, `dims` ([#8851](https://github.com/PyTorchLightning/pytorch-lightning/pull/8851)) +- Deprecated `LightningModule.get_progress_bar_dict` and `Trainer.progress_bar_dict` in favor of `pytorch_lightning.callbacks.progress.base.get_standard_metrics` and `ProgressBarBase.get_metrics` ([#8985](https://github.com/PyTorchLightning/pytorch-lightning/pull/8985)) + + - Deprecated `prepare_data_per_node` flag on Trainer and set it as a property of `DataHooks`, accessible in the `LightningModule` and `LightningDataModule` ([#8958](https://github.com/PyTorchLightning/pytorch-lightning/pull/8958)) @@ -289,7 +298,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `EarlyStopping` running on train epoch end when `check_val_every_n_epoch>1` is set ([#9156](https://github.com/PyTorchLightning/pytorch-lightning/pull/9156)) -- Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333)) +- Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8685](https://github.com/PyTorchLightning/pytorch-lightning/pull/8685)) - Fixed the Apex and DeepSpeed plugin closure running after the `on_before_optimizer_step` hook ([#9288](https://github.com/PyTorchLightning/pytorch-lightning/issues/9288)) @@ -319,12 +328,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed incorrect main progress bar indicator when resuming training mid-epoch ([#9310](https://github.com/PyTorchLightning/pytorch-lightning/pull/9310)) +- Fixed logging of nan parameters ([#9364](https://github.com/PyTorchLightning/pytorch-lightning/pull/9364)) + + - Fixed `replace_sampler` missing the batch size under specific conditions ([#9367](https://github.com/PyTorchLightning/pytorch-lightning/pull/9367)) - Fixed bug where the training step output needed to be `deepcopy`-ed ([#9349](https://github.com/PyTorchLightning/pytorch-lightning/pull/9349)) +- Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386)) + + ## [1.4.5] - 2021-08-31 - Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142)) diff --git a/CITATION.cff b/CITATION.cff index 978873013c19c..b1c572984ef2e 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -4,8 +4,8 @@ title: "PyTorch Lightning" abstract: "The lightweight PyTorch wrapper for high-performance AI research. Scale your models, not the boilerplate." date-released: 2019-03-30 authors: - - family-names: "William" - given-names: "Falcon" + - family-names: "Falcon" + given-names: "William" - name: "The PyTorch Lightning team" version: 1.4 doi: 10.5281/zenodo.3828935 diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 5158fc8dde788..8cdcdf106aa2f 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1242,12 +1242,6 @@ backward .. automethod:: pytorch_lightning.core.lightning.LightningModule.backward :noindex: -get_progress_bar_dict -~~~~~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict - :noindex: - on_before_backward ~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/extensions/logging.rst b/docs/source/extensions/logging.rst index 9a382a7dac1ea..39af8752fb5b9 100644 --- a/docs/source/extensions/logging.rst +++ b/docs/source/extensions/logging.rst @@ -245,13 +245,13 @@ Modifying the progress bar The progress bar by default already includes the training loss and version number of the experiment if you are using a logger. These defaults can be customized by overriding the -:func:`~pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict` hook in your module. +:func:`~pytorch_lightning.callbacks.base.ProgressBarBase.get_metrics` hook in your module. .. code-block:: python - def get_progress_bar_dict(self): + def get_metrics(self): # don't show the version number - items = super().get_progress_bar_dict() + items = super().get_metrics() items.pop("v_num", None) return items diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index f40dc9e1576cf..93915ac946ae9 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -173,15 +173,7 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dat def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: """The actual training step. - Args: - step_kwargs: the arguments for the models training step. Can consist of the following: - - - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): - The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - - batch_idx (int): Integer displaying index of this batch - - optimizer_idx (int): When using multiple optimizers, this argument will also be present. - - hiddens(:class:`~torch.Tensor`): Passed in if - :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. + See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details """ with self.precision_plugin.train_step_context(): return self.training_type_plugin.training_step(*step_kwargs.values()) @@ -192,14 +184,7 @@ def post_training_step(self) -> None: def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]: """The actual validation step. - Args: - step_kwargs: the arguments for the models validation step. Can consist of the following: - - - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): - The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - - batch_idx (int): The index of this batch - - dataloader_idx (int): The index of the dataloader that produced this batch - (only if multiple val dataloaders used) + See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details """ with self.precision_plugin.val_step_context(): return self.training_type_plugin.validation_step(*step_kwargs.values()) @@ -207,14 +192,7 @@ def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[S def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]: """The actual test step. - Args: - step_kwargs: the arguments for the models test step. Can consist of the following: - - - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): - The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - - batch_idx (int): The index of this batch. - - dataloader_idx (int): The index of the dataloader that produced this batch - (only if multiple test dataloaders used). + See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details """ with self.precision_plugin.test_step_context(): return self.training_type_plugin.test_step(*step_kwargs.values()) @@ -222,14 +200,7 @@ def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OU def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: """The actual predict step. - Args: - step_kwargs: the arguments for the models predict step. Can consist of the following: - - - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): - The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - - batch_idx (int): The index of this batch. - - dataloader_idx (int): The index of the dataloader that produced this batch - (only if multiple predict dataloaders used). + See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details """ with self.precision_plugin.predict_step_context(): return self.training_type_plugin.predict_step(*step_kwargs.values()) diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index fe4c89a66d665..334dd05ab3cab 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -11,7 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Union + +import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities import rank_zero_warn class ProgressBarBase(Callback): @@ -177,3 +181,70 @@ def on_predict_epoch_start(self, trainer, pl_module): def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._predict_batch_idx += 1 + + def get_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]: + r""" + Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. + Implement this to override the items displayed in the progress bar. + + Here is an example of how to override the defaults: + + .. code-block:: python + + def get_metrics(self, trainer, model): + # don't show the version number + items = super().get_metrics(trainer, model) + items.pop("v_num", None) + return items + + Return: + Dictionary with the items to be displayed in the progress bar. + """ + standard_metrics = pl_module.get_progress_bar_dict() + pbar_metrics = trainer.progress_bar_metrics + duplicates = list(standard_metrics.keys() & pbar_metrics.keys()) + if duplicates: + rank_zero_warn( + f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and" + f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. " + " If this is undesired, change the name or override `get_metrics()` in the progress bar callback.", + UserWarning, + ) + + return {**standard_metrics, **pbar_metrics} + + +def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]: + r""" + Returns several standard metrics displayed in the progress bar, including the average loss value, + split index of BPTT (if used) and the version of the experiment when using a logger. + + .. code-block:: + + Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, loss=4.501, v_num=10] + + Return: + Dictionary with the standard metrics to be displayed in the progress bar. + """ + # call .item() only once but store elements without graphs + running_train_loss = trainer.fit_loop.running_loss.mean() + avg_training_loss = None + if running_train_loss is not None: + avg_training_loss = running_train_loss.cpu().item() + elif pl_module.automatic_optimization: + avg_training_loss = float("NaN") + + items_dict = {} + if avg_training_loss is not None: + items_dict["loss"] = f"{avg_training_loss:.3g}" + + if pl_module.truncated_bptt_steps > 0: + items_dict["split_idx"] = trainer.fit_loop.split_idx + + if trainer.logger is not None and trainer.logger.version is not None: + version = trainer.logger.version + # show last 4 places of long version strings + version = version[-4:] if isinstance(version, str) else version + items_dict["v_num"] = version + + return items_dict diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index cfd0b3a36f2ce..1f36e86978e92 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -46,8 +46,9 @@ def render(self, task) -> RenderableType: class MetricsTextColumn(ProgressColumn): """A column containing text.""" - def __init__(self, trainer, stage): + def __init__(self, trainer, pl_module, stage): self._trainer = trainer + self._pl_module = pl_module self._stage = stage self._tasks = {} self._current_task_id = 0 @@ -64,7 +65,13 @@ def render(self, task) -> Text: if self._trainer.training and task.id != self._current_task_id: return self._tasks[task.id] _text = "" - for k, v in self._trainer.progress_bar_dict.items(): + # TODO(@daniellepintz): make this code cleaner + progress_bar_callback = getattr(self._trainer, "progress_bar_callback", None) + if progress_bar_callback: + metrics = self._trainer.progress_bar_callback.get_metrics(self._trainer, self._pl_module) + else: + metrics = self._trainer.progress_bar_metrics + for k, v in metrics.items(): _text += f"{k}: {round(v, 3) if isinstance(v, float) else v} " text = Text.from_markup(_text, style=None, justify="left") return text @@ -163,7 +170,7 @@ def setup(self, trainer, pl_module, stage): "[", CustomTimeColumn(), ProcessingSpeedColumn(), - MetricsTextColumn(trainer, stage), + MetricsTextColumn(trainer, pl_module, stage), "]", console=self.console, refresh_per_second=self.refresh_rate, diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 51bee9f624b87..7f6911588388d 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -237,7 +237,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data total_batches = convert_inf(total_batches) if self._should_update(self.train_batch_idx, total_batches): self._update_bar(self.main_progress_bar) - self.main_progress_bar.set_postfix(trainer.progress_bar_dict) + self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) def on_validation_start(self, trainer, pl_module): super().on_validation_start(trainer, pl_module) @@ -257,7 +257,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, def on_validation_end(self, trainer, pl_module): super().on_validation_end(trainer, pl_module) if self.main_progress_bar is not None: - self.main_progress_bar.set_postfix(trainer.progress_bar_dict) + self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) self.val_progress_bar.close() def on_train_end(self, trainer, pl_module): diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e3c7402242a3b..b6a00e3168340 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -31,6 +31,7 @@ from torch.optim.optimizer import Optimizer from torchmetrics import Metric +from pytorch_lightning.callbacks.progress import base as progress_base from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin from pytorch_lightning.core.optimizer import LightningOptimizer @@ -620,9 +621,9 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT: Args: batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - batch_idx (int): Integer displaying index of this batch - optimizer_idx (int): When using multiple optimizers, this argument will also be present. - hiddens(:class:`~torch.Tensor`): Passed in if + batch_idx (``int``): Integer displaying index of this batch + optimizer_idx (``int``): When using multiple optimizers, this argument will also be present. + hiddens (``Any``): Passed in if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. Return: @@ -667,9 +668,8 @@ def training_step(self, batch, batch_idx, optimizer_idx): # Truncated back-propagation through time def training_step(self, batch, batch_idx, hiddens): # hiddens are the hidden states from the previous truncated backprop step - ... out, hiddens = self.lstm(data, hiddens) - ... + loss = ... return {"loss": loss, "hiddens": hiddens} Note: @@ -1585,7 +1585,7 @@ def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): """ optimizer.zero_grad() - def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list: + def tbptt_split_batch(self, batch: Any, split_size: int) -> List[Any]: r""" When using truncated backpropagation through time, each batch must be split along the time dimension. Lightning handles this by default, but for custom behavior override @@ -1603,29 +1603,25 @@ def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list: Examples:: def tbptt_split_batch(self, batch, split_size): - splits = [] - for t in range(0, time_dims[0], split_size): - batch_split = [] - for i, x in enumerate(batch): - if isinstance(x, torch.Tensor): - split_x = x[:, t:t + split_size] - elif isinstance(x, collections.Sequence): - split_x = [None] * len(x) - for batch_idx in range(len(x)): + splits = [] + for t in range(0, time_dims[0], split_size): + batch_split = [] + for i, x in enumerate(batch): + if isinstance(x, torch.Tensor): + split_x = x[:, t:t + split_size] + elif isinstance(x, collections.Sequence): + split_x = [None] * len(x) + for batch_idx in range(len(x)): split_x[batch_idx] = x[batch_idx][t:t + split_size] - - batch_split.append(split_x) - - splits.append(batch_split) - - return splits + batch_split.append(split_x) + splits.append(batch_split) + return splits Note: Called in the training loop after :meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start` if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. Each returned batch split is passed separately to :meth:`training_step`. - """ time_dims = [len(x[0]) for x in batch if isinstance(x, (torch.Tensor, collections.Sequence))] assert len(time_dims) >= 1, "Unable to determine batch time dimension" @@ -1705,6 +1701,10 @@ def unfreeze(self) -> None: def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]: r""" + .. deprecated:: v1.5 + This method was deprecated in v1.5 in favor of + `pytorch_lightning.callbacks.progress.base.get_standard_metrics` and will be removed in v1.7. + Implement this to override the default items displayed in the progress bar. By default it includes the average loss value, split index of BPTT (if used) and the version of the experiment when using a logger. @@ -1726,28 +1726,7 @@ def get_progress_bar_dict(self): Return: Dictionary with the items to be displayed in the progress bar. """ - # call .item() only once but store elements without graphs - running_train_loss = self.trainer.fit_loop.running_loss.mean() - avg_training_loss = None - if running_train_loss is not None: - avg_training_loss = running_train_loss.cpu().item() - elif self.automatic_optimization: - avg_training_loss = float("NaN") - - tqdm_dict = {} - if avg_training_loss is not None: - tqdm_dict["loss"] = f"{avg_training_loss:.3g}" - - if self.truncated_bptt_steps > 0: - tqdm_dict["split_idx"] = self.trainer.fit_loop.split_idx - - if self.trainer.logger is not None and self.trainer.logger.version is not None: - version = self.trainer.logger.version - # show last 4 places of long version strings - version = version[-4:] if isinstance(version, str) else version - tqdm_dict["v_num"] = version - - return tqdm_dict + return progress_base.get_standard_metrics(self.trainer, self) def _verify_is_manual_optimization(self, fn_name): if self.automatic_optimization: diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 4d6c654c93816..a6a19f5f5fb38 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -135,6 +135,7 @@ def on_run_end(self) -> EPOCH_OUTPUT: outputs = self.outputs # free memory self.outputs = [] + self.dataloader_iter = None return outputs def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]: diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index f8040c9686aba..83fa9009311e9 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -233,6 +233,8 @@ def on_run_end(self) -> None: if self._num_training_batches_reached(self.is_last_batch): self.update_lr_schedulers("epoch", update_plateau_schedulers=True) + self.dataloader_iter = None + def teardown(self) -> None: self._results.cpu() self.batch_loop.teardown() diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index c9310ab524ad0..54d505a56653c 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -16,7 +16,6 @@ from typing import Any, Dict, Generator, Iterator, Mapping, Optional, Sequence import torch -from torch import Tensor from torch.optim import Optimizer import pytorch_lightning as pl @@ -92,7 +91,7 @@ def _build_training_step_kwargs( batch: Any, batch_idx: int, opt_idx: Optional[int], - hiddens: Optional[Tensor], + hiddens: Optional[Any], ) -> Dict[str, Any]: """Builds the keyword arguments for training_step. diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 54abc06da696f..d249e4dc76440 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -326,7 +326,7 @@ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int): optimizers = self.lightning_module.trainer.optimizers if self._model_averaging_period is None: raise ValueError( - "Post-localSGD algorithm is used, " "but model averaging period is not provided to DDP plugin." + "Post-localSGD algorithm is used, but model averaging period is not provided to DDP plugin." ) averager = averagers.PeriodicModelAverager(period=self._model_averaging_period, warmup_steps=warmup_steps) for x, optimizer in enumerate(optimizers): diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 1c30768285fc1..603ec88ef41e9 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -43,6 +43,8 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None: elif self.trainer.state.fn == TrainerFn.PREDICTING: self.__verify_predict_loop_configuration(model) self.__verify_dp_batch_transfer_support(model) + # TODO(@daniellepintz): Delete _check_progress_bar in v1.7 + self._check_progress_bar(model) # TODO: Delete _check_on_keyboard_interrupt in v1.7 self._check_on_keyboard_interrupt() @@ -111,6 +113,19 @@ def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None "(rather, they are called on every optimization step)." ) + def _check_progress_bar(self, model: "pl.LightningModule") -> None: + r""" + Checks if get_progress_bar_dict is overriden and sends a deprecation warning. + + Args: + model: The model to check the get_progress_bar_dict method. + """ + if is_overridden("get_progress_bar_dict", model): + rank_zero_deprecation( + "The `LightningModule.get_progress_bar_dict` method was deprecated in v1.5 and will be removed in v1.7." + " Please use the `ProgressBarBase.get_metrics` instead." + ) + def __verify_eval_loop_configuration(self, model: "pl.LightningModule", stage: str) -> None: loader_name = f"{stage}_dataloader" step_name = "validation_step" if stage == "val" else "test_step" diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 22d8ee01d3871..fd72e6d4397fe 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -39,13 +39,7 @@ from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus -from pytorch_lightning.utilities import ( - DeviceType, - DistributedType, - GradClipAlgorithmType, - rank_zero_deprecation, - rank_zero_warn, -) +from pytorch_lightning.utilities import DeviceType, DistributedType, GradClipAlgorithmType, rank_zero_deprecation from pytorch_lightning.utilities.argparse import ( add_argparse_args, from_argparse_args, @@ -306,21 +300,15 @@ def progress_bar_callback(self) -> Optional[ProgressBarBase]: @property def progress_bar_dict(self) -> dict: """Read-only for progress bar metrics.""" + rank_zero_deprecation( + "`trainer.progress_bar_dict` is deprecated in v1.5 and will be removed in v1.7." + " Use `ProgressBarBase.get_metrics` instead." + ) ref_model = self.lightning_module ref_model = cast(pl.LightningModule, ref_model) - - standard_metrics = ref_model.get_progress_bar_dict() - pbar_metrics = self.progress_bar_metrics - duplicates = list(standard_metrics.keys() & pbar_metrics.keys()) - if duplicates: - rank_zero_warn( - f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and" - f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. " - " If this is undesired, change the name or override `get_progress_bar_dict()`" - " in `LightingModule`.", - UserWarning, - ) - return {**standard_metrics, **pbar_metrics} + if self.progress_bar_callback: + return self.progress_bar_callback.get_metrics(self, ref_model) + return self.progress_bar_metrics @property def _should_reload_dl_epoch(self) -> bool: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index decef9cf43218..963741b995272 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -16,9 +16,10 @@ import os import traceback import warnings +from contextlib import contextmanager from datetime import timedelta from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union from weakref import proxy import torch @@ -76,7 +77,7 @@ from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _fault_tolerant_training +from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_9 from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.model_summary import ModelSummary, summarize from pytorch_lightning.utilities.seed import reset_seed @@ -1146,7 +1147,7 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT: # reset trainer on this loop and all child loops in case user connected a custom loop self._evaluation_loop.trainer = self - with self.profiler.profile(f"run_{self.state.stage}_evaluation"), torch.no_grad(): + with self.profiler.profile(f"run_{self.state.stage}_evaluation"), self._evaluation_context(): eval_loop_results = self._evaluation_loop.run() # remove the tensors from the eval results @@ -1162,7 +1163,7 @@ def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: self.reset_predict_dataloader(self.lightning_module) # reset trainer on this loop and all child loops in case user connected a custom loop self.predict_loop.trainer = self - with torch.no_grad(): + with self._evaluation_context(): return self.predict_loop.run() def _run_sanity_check(self, ref_model): @@ -1391,3 +1392,8 @@ def _on_exception(self): # save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure. file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt") self.save_checkpoint(file_path) + + @contextmanager + def _evaluation_context(self) -> Generator: + with torch.inference_mode() if _TORCH_GREATER_EQUAL_1_9 else torch.no_grad(): + yield diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index d67c7d74231e3..0cb8f522e2189 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -263,7 +263,8 @@ def set_rng_states(rng_state_dict: Dict[str, Any]) -> None: """Set the global random state of :mod:`torch`, :mod:`numpy` and Python in the current process.""" torch.set_rng_state(rng_state_dict.get("torch")) np.random.set_state(rng_state_dict.get("numpy")) - python_set_rng_state(rng_state_dict.get("python")) + version, state, gauss = rng_state_dict.get("python") + python_set_rng_state((version, tuple(state), gauss)) class CaptureIterableDataset(IterableDataset): diff --git a/pytorch_lightning/utilities/finite_checks.py b/pytorch_lightning/utilities/finite_checks.py index 4dfc5843de8c2..27ba78373f1ab 100644 --- a/pytorch_lightning/utilities/finite_checks.py +++ b/pytorch_lightning/utilities/finite_checks.py @@ -25,7 +25,7 @@ def print_nan_gradients(model: nn.Module) -> None: """Iterates over model parameters and prints out parameter + gradient information if NaN.""" for param in model.parameters(): if (param.grad is not None) and torch.isnan(param.grad.float()).any(): - log.info(param, param.grad) + log.info(f"{param}, {param.grad}") def detect_nan_parameters(model: nn.Module) -> None: diff --git a/requirements/extra.txt b/requirements/extra.txt index f4e1e6a807091..dfffc6fce8428 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -9,3 +9,4 @@ onnxruntime>=1.3.0 hydra-core>=1.0 jsonargparse[signatures]>=3.19.0 gcsfs>=2021.5.0 +rich>=10.2.2 diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 8356994b7b018..8cf7c0c5fd2b7 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -25,6 +25,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase from pytorch_lightning.callbacks.progress.tqdm_progress import Tqdm +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -557,6 +558,28 @@ def _test_progress_bar_max_val_check_interval( assert trainer.progress_bar_callback.main_progress_bar.total == total_train_batches + total_val_batches +def test_get_progress_bar_metrics(tmpdir: str): + class TestProgressBar(ProgressBar): + def get_metrics(self, trainer: Trainer, model: LightningModule): + items = super().get_metrics(trainer, model) + items.pop("v_num", None) + return items + + progress_bar = TestProgressBar() + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[progress_bar], + fast_dev_run=True, + ) + model = BoringModel() + trainer.fit(model) + model.truncated_bptt_steps = 2 + standard_metrics = progress_bar.get_metrics(trainer, model) + assert "loss" in standard_metrics.keys() + assert "split_idx" in standard_metrics.keys() + assert "v_num" not in standard_metrics.keys() + + def test_progress_bar_main_bar_resume(): """Test that the progress bar can resume its counters based on the Trainer state.""" bar = ProgressBar() diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index c6f44759ba371..687f1f679b858 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -17,6 +17,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar +from pytorch_lightning.utilities.imports import _RICH_AVAILABLE from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -58,5 +59,6 @@ def test_rich_progress_bar(progress_update, tmpdir): def test_rich_progress_bar_import_error(): - with pytest.raises(ImportError, match="`RichProgressBar` requires `rich` to be installed."): - Trainer(callbacks=RichProgressBar()) + if not _RICH_AVAILABLE: + with pytest.raises(ImportError, match="`RichProgressBar` requires `rich` to be installed."): + Trainer(callbacks=RichProgressBar()) diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 488e14a498f3d..822e65bd2ef33 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -84,6 +84,29 @@ def test_v1_7_0_datamodule_dims_property(tmpdir): _ = LightningDataModule(dims=(1, 1, 1)) +def test_v1_7_0_moved_get_progress_bar_dict(tmpdir): + class TestModel(BoringModel): + def get_progress_bar_dict(self): + items = super().get_progress_bar_dict() + items.pop("v_num", None) + return items + + trainer = Trainer( + default_root_dir=tmpdir, + progress_bar_refresh_rate=None, + fast_dev_run=True, + ) + test_model = TestModel() + with pytest.deprecated_call(match=r"`LightningModule.get_progress_bar_dict` method was deprecated in v1.5"): + trainer.fit(test_model) + standard_metrics_postfix = trainer.progress_bar_callback.main_progress_bar.postfix + assert "loss" in standard_metrics_postfix + assert "v_num" not in standard_metrics_postfix + + with pytest.deprecated_call(match=r"`trainer.progress_bar_dict` is deprecated in v1.5"): + _ = trainer.progress_bar_dict + + def test_v1_7_0_trainer_prepare_data_per_node(tmpdir): with pytest.deprecated_call( match="Setting `prepare_data_per_node` with the trainer flag is deprecated and will be removed in v1.7.0!" diff --git a/tests/loops/batch/__init__.py b/tests/loops/batch/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/loops/batch/test_truncated_bptt.py b/tests/loops/batch/test_truncated_bptt.py new file mode 100644 index 0000000000000..874a621f8a485 --- /dev/null +++ b/tests/loops/batch/test_truncated_bptt.py @@ -0,0 +1,172 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import pytest +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader, TensorDataset + +from pytorch_lightning import LightningModule, Trainer + + +class LSTMModel(LightningModule): + """LSTM sequence-to-sequence model for testing TBPTT with automatic optimization.""" + + def __init__(self, truncated_bptt_steps=2, input_size=1, hidden_size=8): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.lstm = torch.nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True) + self.truncated_bptt_steps = truncated_bptt_steps + self.automatic_optimization = True + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.01) + + def training_step(self, batch, batch_idx, hiddens): + x, y = batch + pred, hiddens = self.lstm(x, hiddens) + loss = F.mse_loss(pred, y) + return {"loss": loss, "hiddens": hiddens} + + def train_dataloader(self): + dataset = TensorDataset(torch.rand(16, 8, self.input_size), torch.rand(16, 8, self.input_size)) + return DataLoader(dataset=dataset, batch_size=4) + + +class ManualLSTMModel(LSTMModel): + """LSTM sequence-to-sequence model for testing TBPTT with manual optimization.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.automatic_optimization = False + + def training_step(self, batch, batch_idx, hiddens): + out = super().training_step(batch, batch_idx, hiddens) + loss, hiddens = out["loss"], out["hiddens"] + opt = self.optimizers() + opt.zero_grad() + self.manual_backward(loss) + opt.step() + return {"loss": loss, "hiddens": hiddens} + + +@pytest.mark.parametrize("model_class", (LSTMModel, ManualLSTMModel)) +def test_persistent_hidden_state_transfer(tmpdir, model_class): + """Test that the hidden state reference gets passed through from one training_step to the next and remains + unmodified apart from detached grad_fn.""" + + class TBPTTModel(model_class): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.test_hidden = None + + def training_step(self, batch, batch_idx, hiddens): + split_idx = self.trainer.fit_loop.split_idx + # the hidden state may only be None for the first split_idx + assert not ((split_idx == 0) ^ (hiddens is None)) + # test_hiddens is None when hiddens is None + assert not ((hiddens is None) ^ (self.test_hidden is None)) + # the states are equal (persistent) + assert hiddens is None or all(torch.equal(h, th) for h, th in zip(hiddens, self.test_hidden)) + # the incoming hidden state never has a grad_fn (gets automatically detached) + assert hiddens is None or all(h.grad_fn is None for h in hiddens) + out = super().training_step(batch, batch_idx, hiddens) + + # store hiddens, assert persistence in next training_step + self.test_hidden = out["hiddens"] + + # hiddens may have grad_fn when returning, gets automatically detached + assert all(h.grad_fn is not None for h in self.test_hidden) + return out + + def on_train_batch_start(self, *_, **__) -> None: + self.test_hidden = None + + model = TBPTTModel(truncated_bptt_steps=2, input_size=1, hidden_size=8) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + weights_summary=None, + logger=False, + checkpoint_callback=False, + ) + trainer.fit(model) + + +@pytest.mark.parametrize("model_class", (LSTMModel, ManualLSTMModel)) +def test_tbptt_split_shapes(tmpdir, model_class): + """Test that the sequence data gets split correctly and that the outputs are correctly passed from hook to + hook.""" + batch_size = 10 + truncated_bptt_steps = 2 + n, t, f = 32, 15, 1 # (num samples, sequence size, input size) + assert t % truncated_bptt_steps != 0, "test must run with sequence length not divisible by tbptt steps" + + seq2seq_dataset = TensorDataset(torch.rand(n, t, f), torch.rand(n, t, f)) + train_dataloader = DataLoader(dataset=seq2seq_dataset, batch_size=batch_size) + + class TBPTTModel(model_class): + def training_step(self, batch, batch_idx, hiddens): + x, y = batch + if self.trainer.fit_loop.epoch_loop.batch_loop.done: + # last split idx, not aligned + assert x.shape[1] == t % truncated_bptt_steps + assert y.shape[1] == t % truncated_bptt_steps + else: + assert x.shape[1] == truncated_bptt_steps + assert y.shape[1] == truncated_bptt_steps + return super().training_step(batch, batch_idx, hiddens) + + def training_epoch_end(self, training_step_outputs): + training_step_outputs = training_step_outputs[0] + assert len(training_step_outputs) == math.ceil(t / self.truncated_bptt_steps) + assert all(out["loss"].grad_fn is None for out in training_step_outputs) + assert all("hiddens" not in out for out in training_step_outputs) + + model = TBPTTModel(truncated_bptt_steps=truncated_bptt_steps, input_size=f, hidden_size=8) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + weights_summary=None, + logger=False, + checkpoint_callback=False, + ) + trainer.fit(model, train_dataloaders=train_dataloader) + + assert trainer.fit_loop.batch_idx == n // batch_size + assert trainer.fit_loop.split_idx == t // truncated_bptt_steps + + +@pytest.mark.parametrize("model_class", (LSTMModel, ManualLSTMModel)) +def test_tbptt_logging(tmpdir, model_class): + """Test step-level and epoch-level logging works with TBPTT.""" + + class TBPTTModel(model_class): + def training_step(self, *args, **kwargs): + out = super().training_step(*args, **kwargs) + self.log("loss", out["loss"], on_step=True, on_epoch=True) + return out + + model = TBPTTModel(truncated_bptt_steps=2) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + log_every_n_steps=2, + weights_summary=None, + checkpoint_callback=False, + ) + trainer.fit(model) + assert set(trainer.logged_metrics) == {"loss_step", "loss_epoch", "epoch"} diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 77d9fce06d6ca..6ea8b76d253fa 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -66,7 +66,7 @@ def training_epoch_end(self, outputs): trainer = Trainer(max_epochs=num_epochs, default_root_dir=tmpdir, overfit_batches=2) trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" - metrics = trainer.progress_bar_dict + metrics = trainer.progress_bar_callback.get_metrics(trainer, model) # metrics added in training step should be unchanged by epoch end method assert metrics["step_metric"] == -1 diff --git a/tests/models/test_truncated_bptt.py b/tests/models/test_truncated_bptt.py deleted file mode 100644 index d7a2dfb31652d..0000000000000 --- a/tests/models/test_truncated_bptt.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import torch -from torch.utils.data import DataLoader, Dataset - -from pytorch_lightning import Trainer -from tests.helpers import BoringModel - - -class LinearModel(BoringModel): - """Linear model for testing TBPTT with automatic optimization.""" - - def __init__(self, truncated_bptt_steps=2, n_hidden_states=1, sequence_size=30, batch_size=30): - super().__init__() - self.truncated_bptt_steps = truncated_bptt_steps - self.n_hidden_states = n_hidden_states - self.sequence_size = sequence_size - self.batch_size = batch_size - self.automatic_optimization = True - - self.example_input_array = torch.randn(5, truncated_bptt_steps) - self.layer = torch.nn.Linear(in_features=truncated_bptt_steps, out_features=truncated_bptt_steps) - self.test_hidden = None - - def training_step(self, batch, batch_idx, hiddens): - assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" - if self.n_hidden_states == 1: - self.test_hidden = torch.rand(1) - else: - self.test_hidden = tuple([torch.rand(1)] * self.n_hidden_states) - - x_tensor, y_list = batch - assert x_tensor.shape[1] == self.truncated_bptt_steps, "tbptt split Tensor failed" - - y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype) - assert y_tensor.shape[1] == self.truncated_bptt_steps, "tbptt split list failed" - - pred = self(x_tensor.view(self.batch_size, self.truncated_bptt_steps)) - loss_val = torch.nn.functional.mse_loss(pred, y_tensor.view(self.batch_size, self.truncated_bptt_steps)) - return {"loss": loss_val, "hiddens": self.test_hidden} - - def training_epoch_end(self, training_step_outputs): - training_step_outputs = training_step_outputs[0] - assert len(training_step_outputs) == (self.sequence_size / self.truncated_bptt_steps) - loss = torch.stack([x["loss"] for x in training_step_outputs]).mean() - assert loss.grad_fn is None - self.log("train_loss", loss) - - -class ManualLinearModel(LinearModel): - """Linear model for testing TBPTT with manual optimization.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.automatic_optimization = False - - def training_step(self, batch, batch_idx, hiddens): - out = super().training_step(batch, batch_idx, hiddens) - loss, hiddens = out["loss"], out["hiddens"] - opt = self.optimizers() - opt.zero_grad() - self.manual_backward(loss) - opt.step() - assert loss.grad_fn is not None - return {"loss": loss, "hiddens": hiddens} - - -@pytest.mark.parametrize("model_class", (LinearModel, ManualLinearModel)) -@pytest.mark.parametrize("n_hidden_states", (1, 2)) -def test_tbptt_cpu_model_manual(tmpdir, n_hidden_states, model_class): - """Test truncated back propagation through time works with automatic and manual optimization.""" - - sequence_size = 30 - batch_size = 30 - - x_seq = torch.rand(batch_size, sequence_size, 1) - y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist() - - class MockSeq2SeqDataset(Dataset): - def __getitem__(self, i): - return x_seq, y_seq_list - - def __len__(self): - return 1 - - train_dataloader = DataLoader(dataset=MockSeq2SeqDataset(), batch_size=batch_size, shuffle=False) - model = model_class(n_hidden_states=n_hidden_states, sequence_size=sequence_size, batch_size=batch_size) - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_val_batches=0, - weights_summary=None, - ) - trainer.fit(model, train_dataloader) - - -def test_tbptt_log(tmpdir): - truncated_bptt_steps = 2 - N, T, F = 32, 15, 1 # batches x timesteps (sequence size) x features - batch_size = 10 - assert T % truncated_bptt_steps != 0, "Should test leftover time steps" - - class MockSeq2SeqDataset(Dataset): - def __init__(self): - self.x_seq = torch.randn(N, T, F) - self.y_seq = torch.randn(N, T, F) - - def __getitem__(self, index): - return self.x_seq[index], self.y_seq[index] - - def __len__(self): - return N - - class TestModel(BoringModel): - def __init__(self): - super().__init__() - self.test_hidden = None - self.layer = torch.nn.LSTM(input_size=F, hidden_size=T, batch_first=True) - self.truncated_bptt_steps = truncated_bptt_steps - - def training_step(self, batch, batch_idx, hiddens): - assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" - if hiddens is not None: - assert hiddens.grad_fn is None - split_idx = self.trainer.fit_loop.split_idx - self.test_hidden = torch.tensor(split_idx, requires_grad=True, dtype=torch.float).pow(2) - - x, y = batch - if self.trainer.fit_loop.epoch_loop.batch_loop.done: - # last split idx, not aligned - assert x.shape[1] == T % truncated_bptt_steps - assert y.shape[1] == T % truncated_bptt_steps - else: - assert x.shape[1] == truncated_bptt_steps - assert y.shape[1] == truncated_bptt_steps - - pred, _ = self(x) - loss = torch.nn.functional.mse_loss(pred, y) - - self.log("a", loss, on_epoch=True) - - return {"loss": loss, "hiddens": self.test_hidden} - - def on_train_batch_start(self, *args, **kwargs) -> None: - self.test_hidden = None - - def train_dataloader(self): - return DataLoader(dataset=MockSeq2SeqDataset(), batch_size=batch_size) - - model = TestModel() - model.training_epoch_end = None - - trainer = Trainer( - default_root_dir=tmpdir, - limit_val_batches=0, - max_epochs=2, - log_every_n_steps=2, - weights_summary=None, - ) - trainer.fit(model) - - assert trainer.fit_loop.batch_idx == N // batch_size - assert trainer.fit_loop.split_idx == T // truncated_bptt_steps - assert set(trainer.logged_metrics) == {"a_step", "a_epoch", "epoch"} diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 7bc10d564fd07..3c72a78331720 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -23,7 +23,8 @@ from torchmetrics import Accuracy from pytorch_lightning import callbacks, Trainer -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBar +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDictDataset from tests.helpers.runif import RunIf @@ -315,7 +316,7 @@ def training_step(self, batch, batch_idx): trainer.fit(model) # Make sure the func_name output equals the average from all logged values when on_epoch true - assert trainer.progress_bar_dict["train_loss"] == model.seen_losses[-1] + assert trainer.progress_bar_callback.get_metrics(trainer, model)["train_loss"] == model.seen_losses[-1] assert trainer.callback_metrics["train_loss"] == model.seen_losses[-1] assert cb.call_counter == { @@ -449,7 +450,7 @@ def validation_step(self, batch, batch_idx): assert trainer.logged_metrics["bar"] == 2 -def test_progress_bar_dict_contains_values_on_train_epoch_end(tmpdir): +def test_progress_bar_metrics_contains_values_on_train_epoch_end(tmpdir: str): class TestModel(BoringModel): def training_step(self, *args): self.log("foo", torch.tensor(self.current_epoch), on_step=False, on_epoch=True, prog_bar=True) @@ -461,20 +462,28 @@ def on_train_epoch_end(self, *_): ) self.on_train_epoch_end_called = True - def on_epoch_end(self): - assert self.trainer.progress_bar_dict["foo"] == self.current_epoch - assert self.trainer.progress_bar_dict["foo_2"] == self.current_epoch - self.on_epoch_end_called = True + class TestProgressBar(ProgressBar): + def get_metrics(self, trainer: Trainer, model: LightningModule): + items = super().get_metrics(trainer, model) + items.pop("v_num", None) + return items + def on_epoch_end(self, trainer: Trainer, model: LightningModule): + metrics = self.get_metrics(trainer, model) + assert metrics["foo"] == self.trainer.current_epoch + assert metrics["foo_2"] == self.trainer.current_epoch + model.on_epoch_end_called = True + + progress_bar = TestProgressBar() trainer = Trainer( default_root_dir=tmpdir, + callbacks=[progress_bar], max_epochs=2, limit_train_batches=1, limit_val_batches=0, checkpoint_callback=False, logger=False, weights_summary=None, - progress_bar_refresh_rate=0, ) model = TestModel() trainer.fit(model)