Skip to content

Commit

Permalink
Support infinite training (#8877)
Browse files Browse the repository at this point in the history
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
4 people authored Sep 4, 2021
1 parent c30d9b9 commit cf1a589
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 11 deletions.
6 changes: 6 additions & 0 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,9 @@ Stop training once this number of epochs is reached
# default used by the Trainer
trainer = Trainer(max_epochs=1000)

If both ``max_epochs`` and ``max_steps`` aren't specified, ``max_epochs`` will default to ``1000``.
To enable infinite training, set ``max_epochs = -1``.

min_epochs
^^^^^^^^^^

Expand Down Expand Up @@ -947,6 +950,9 @@ Training will stop if max_steps or max_epochs have reached (earliest).
# Stop after 100 steps
trainer = Trainer(max_steps=100)

If ``max_steps`` is not specified, ``max_epochs`` will be used instead (and ``max_epochs`` defaults to
``1000`` if ``max_epochs`` is not specified). To disable this default, set ``max_steps = -1``.

min_steps
^^^^^^^^^

Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class TrainingEpochLoop(loops.Loop):
def __init__(self, min_steps: int, max_steps: int):
super().__init__()
self.min_steps: int = min_steps

if max_steps and max_steps < -1:
raise MisconfigurationException(f"`max_steps` must be a positive integer or -1. You passed in {max_steps}.")
self.max_steps: int = max_steps

self.global_step: int = 0
Expand Down
26 changes: 24 additions & 2 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities.exceptions import MisconfigurationException

log = logging.getLogger(__name__)

Expand All @@ -35,6 +36,12 @@ class FitLoop(Loop):

def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None):
super().__init__()
# Allow max_epochs or max_steps to be zero, since this will be handled by fit_loop.done
if max_epochs and max_epochs < -1:
raise MisconfigurationException(
f"`max_epochs` must be a positive integer or -1. You passed in {max_epochs}."
)

self.max_epochs = max_epochs
self.min_epochs = min_epochs
self.epoch_loop: Optional[TrainingEpochLoop] = None
Expand Down Expand Up @@ -98,6 +105,8 @@ def max_steps(self) -> int:
def max_steps(self, value: int) -> None:
"""Sets the maximum number of steps (forwards to epoch_loop)"""
# TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided
if value and value < -1:
raise MisconfigurationException(f"`max_steps` must be a positive integer or -1. You passed in {value}.")
self.epoch_loop.max_steps = value

@property
Expand All @@ -123,6 +132,19 @@ def _results(self) -> ResultCollection:
return self.epoch_loop.val_loop._results
raise RuntimeError("`FitLoop._results` property isn't defined. Accessed outside of scope")

@staticmethod
def _is_max_limit_enabled(max_value: Optional[int]) -> bool:
"""Checks whether the max_value is enabled. This can
be used for checking whether max_epochs or max_steps is enabled.
Args:
max_value: the value to check
Returns:
whether the limit for this value should be enabled
"""
return max_value not in (None, -1)

@property
def done(self) -> bool:
"""Evaluates when to leave the loop.
Expand All @@ -131,8 +153,8 @@ def done(self) -> bool:
or if the maximum number of steps or epochs is reached.
"""
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
stop_steps = self.max_steps is not None and self.global_step >= self.max_steps
stop_epochs = self.max_epochs is not None and self.current_epoch >= self.max_epochs
stop_steps = FitLoop._is_max_limit_enabled(self.max_steps) and self.global_step >= self.max_steps
stop_epochs = FitLoop._is_max_limit_enabled(self.max_epochs) and self.current_epoch >= self.max_epochs

should_stop = False
if self.trainer.should_stop:
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchmetrics import Metric

import pytorch_lightning as pl
from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -190,7 +191,10 @@ def restore_loops(self) -> None:
self.trainer.fit_loop.current_epoch = self._loaded_checkpoint["epoch"]

# crash if max_epochs is lower then the current epoch from the checkpoint
if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs:
if (
FitLoop._is_max_limit_enabled(self.trainer.max_epochs)
and self.trainer.current_epoch > self.trainer.max_epochs
):
raise MisconfigurationException(
f"You restored a checkpoint with current_epoch={self.trainer.current_epoch},"
f" but you have set Trainer(max_epochs={self.trainer.max_epochs})."
Expand Down
10 changes: 7 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,15 @@ def __init__(
Can be used on CPU, GPU or TPUs.
max_epochs: Stop training once this number of epochs is reached. Disabled by default (None).
If both max_epochs and max_steps are not specified, defaults to ``max_epochs`` = 1000.
If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``.
To enable infinite training, set ``max_epochs = -1``.
min_epochs: Force training for at least these many epochs. Disabled by default (None).
If both min_epochs and min_steps are not specified, defaults to ``min_epochs`` = 1.
If both min_epochs and min_steps are not specified, defaults to ``min_epochs = 1``.
max_steps: Stop training after this number of steps. Disabled by default (None).
max_steps: Stop training after this number of steps. Disabled by default (None). If ``max_steps = None``
and ``max_epochs = None``, will default to ``max_epochs = 1000``. To disable this default, set
``max_steps`` to ``-1``.
min_steps: Force training for at least these number of steps. Disabled by default (None).
Expand Down Expand Up @@ -380,6 +383,7 @@ def __init__(
self.slurm_connector = SLURMConnector(self)
self.tuner = Tuner(self)

# max_epochs won't default to 1000 if max_steps/max_time are specified (including being set to -1).
fit_loop = FitLoop(
min_epochs=(1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs),
max_epochs=(1000 if (max_epochs is None and max_steps is None and max_time is None) else max_epochs),
Expand Down
16 changes: 12 additions & 4 deletions tests/callbacks/test_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,13 @@ def on_fit_start(self):
trainer.fit(TestModel())
assert "callbacks list already contains a Timer" in caplog.text

seconds = 1
trainer = Trainer(max_time=dict(seconds=seconds))
assert trainer.max_epochs is None
# Make sure max_time still honored even if max_epochs == -1
trainer = Trainer(max_time=dict(seconds=1), max_epochs=-1)
with pytest.raises(SystemExit):
trainer.fit(TestModel())
timer = [c for c in trainer.callbacks if isinstance(c, Timer)][0]
assert timer._duration == 1
assert trainer.max_epochs == -1
assert trainer.max_steps is None


Expand Down Expand Up @@ -153,7 +157,11 @@ def test_timer_resume_training(tmpdir):
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1)

# initial training
trainer = Trainer(default_root_dir=tmpdir, max_epochs=100, callbacks=[timer, checkpoint_callback])
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=100,
callbacks=[timer, checkpoint_callback],
)
trainer.fit(model)
assert not timer._offset
assert timer.time_remaining() <= 0
Expand Down
53 changes: 52 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

import tests.helpers.utils as tutils
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, Timer
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
from pytorch_lightning.loggers import TensorBoardLogger
Expand Down Expand Up @@ -492,6 +492,57 @@ def test_trainer_max_steps_and_epochs(tmpdir):
assert trainer.global_step == num_train_samples * trainer.max_epochs
assert trainer.current_epoch == trainer.max_epochs - 1, "Model did not stop at max_epochs"

# if max_steps is positive and max_epochs is negative, use max_steps
trainer_kwargs["max_epochs"] = -1
trainer_kwargs["max_steps"] = 3
trainer = Trainer(**trainer_kwargs)
trainer.fit(model)

assert trainer.state.finished, f"Training failed with {trainer.state}"
assert trainer.global_step == 3


@pytest.mark.parametrize(
"max_epochs,max_steps,incorrect_variable,incorrect_value",
[
(-100, None, "max_epochs", -100),
(1, -2, "max_steps", -2),
],
)
def test_trainer_max_steps_and_epochs_validation(max_epochs, max_steps, incorrect_variable, incorrect_value):
"""Don't allow max_epochs or max_steps to be less than -1 or a float"""
with pytest.raises(
MisconfigurationException,
match=f"`{incorrect_variable}` must be a positive integer or -1. You passed in {incorrect_value}",
):
Trainer(max_epochs=max_epochs, max_steps=max_steps)


@pytest.mark.parametrize(
"max_epochs,max_steps,is_done,correct_trainer_epochs",
[
(None, None, False, 1000),
(-1, None, False, -1),
(None, -1, False, None),
(5, -1, False, 5),
(-1, 10, False, -1),
(None, 0, True, None),
(0, None, True, 0),
(-1, 0, True, -1),
(0, -1, True, 0),
],
)
def test_trainer_max_steps_and_epochs_fit_loop_done(max_epochs, max_steps, is_done, correct_trainer_epochs):
trainer = Trainer(max_epochs=max_epochs, max_steps=max_steps)

assert trainer.max_epochs == correct_trainer_epochs
assert trainer.max_steps == max_steps
assert trainer.fit_loop.done is is_done

# Make sure there is no timer
timer_callbacks = [c for c in trainer.callbacks if isinstance(c, Timer)]
assert len(timer_callbacks) == 0


def test_trainer_min_steps_and_epochs(tmpdir):
"""Verify model trains according to specified min steps"""
Expand Down

0 comments on commit cf1a589

Please sign in to comment.