Skip to content

Commit

Permalink
added UserWarnings if max_epochs not set in the Trainer class (#10700)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rajathbharadwaj authored Dec 6, 2021
1 parent 99bb62a commit 7914e5c
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 8 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `LightningCLI.configure_optimizers` to override the `configure_optimizers` return value ([#10860](https://github.com/PyTorchLightning/pytorch-lightning/issues/10860))


- Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/issues/10700))


### Changed

- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))
Expand Down
39 changes: 38 additions & 1 deletion pytorch_lightning/loops/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,23 @@
# limitations under the License.
from collections import OrderedDict
from contextlib import contextmanager
from datetime import timedelta
from functools import lru_cache
from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple
from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.warnings import PossibleUserWarning


def check_finite_loss(loss: Optional[torch.Tensor]) -> None:
Expand Down Expand Up @@ -61,6 +64,40 @@ def _extract_hiddens(training_step_output: STEP_OUTPUT, truncated_bptt_steps: in
return hiddens


def _parse_loop_limits(
min_steps: Optional[int],
max_steps: int,
min_epochs: Optional[int],
max_epochs: int,
max_time: Optional[Union[str, timedelta, Dict[str, int]]],
) -> Tuple[Optional[int], int, Optional[int], int, Optional[Union[str, timedelta, Dict[str, int]]]]:
"""This utility computes the default values for the minimum and maximum number of steps and epochs given the
values the user has selected.
Args:
min_steps: Minimum number of steps.
max_steps: Maximum number of steps.
min_epochs: Minimum number of epochs.
max_epochs: Maximum number of epochs.
max_time: Maximum time for the training.
Returns:
The parsed limits, with default values being set for the ones that the user did not specify.
"""
if max_epochs is None:
if max_steps == -1 and max_time is None:
rank_zero_warn(
"`max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit,"
" set `max_epochs=-1`.",
category=PossibleUserWarning,
)
max_epochs = 1000
else:
max_epochs = -1
min_epochs = 1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs
return min_steps, max_steps, min_epochs, max_epochs, max_time


def _build_training_step_kwargs(
lightning_module: "pl.LightningModule",
optimizers: Sequence[Optimizer],
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from pytorch_lightning.loops import PredictionLoop, TrainingBatchLoop, TrainingEpochLoop
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.loops.utilities import _parse_loop_limits
from pytorch_lightning.plugins import (
ApexMixedPrecisionPlugin,
DDPSpawnPlugin,
Expand Down Expand Up @@ -455,13 +456,11 @@ def __init__(
self.signal_connector = SignalConnector(self)
self.tuner = Tuner(self)

fit_loop = FitLoop(
min_epochs=(1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs),
max_epochs=(
max_epochs if max_epochs is not None else (1000 if (max_steps == -1 and max_time is None) else -1)
),
min_steps, max_steps, min_epochs, max_epochs, max_time = _parse_loop_limits(
min_steps, max_steps, min_epochs, max_epochs, max_time
)
training_epoch_loop = TrainingEpochLoop(min_steps, max_steps)
fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
training_epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
training_batch_loop = TrainingBatchLoop()
training_validation_loop = EvaluationLoop()
training_epoch_loop.connect(batch_loop=training_batch_loop, val_loop=training_validation_loop)
Expand Down
2 changes: 1 addition & 1 deletion tests/accelerators/test_ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_warning_if_ipus_not_used(tmpdir):
@RunIf(ipu=True)
def test_no_warning_plugin(tmpdir):
with pytest.warns(None) as record:
Trainer(default_root_dir=tmpdir, strategy=IPUPlugin(training_opts=poptorch.Options()))
Trainer(default_root_dir=tmpdir, max_epochs=1, strategy=IPUPlugin(training_opts=poptorch.Options()))
assert len(record) == 0


Expand Down
8 changes: 8 additions & 0 deletions tests/trainer/flags/test_min_max_epochs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from tests.helpers import BoringModel


Expand Down Expand Up @@ -33,3 +34,10 @@ def test_min_max_steps_epochs(tmpdir, min_epochs, max_epochs, min_steps, max_ste
# check training stopped at max_epochs or max_steps
if trainer.max_steps and not trainer.max_epochs:
assert trainer.global_step == trainer.max_steps


def test_max_epochs_not_set_warning():
"""Test that a warning is emitted when `max_epochs` was not set by the user."""
with pytest.warns(PossibleUserWarning, match="`max_epochs` was not set. Setting it to 1000 epochs."):
trainer = Trainer(max_epochs=None)
assert trainer.max_epochs == 1000

0 comments on commit 7914e5c

Please sign in to comment.