diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index 9fde434e78579..de37f5839e82b 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -907,6 +907,9 @@ Stop training once this number of epochs is reached # default used by the Trainer trainer = Trainer(max_epochs=1000) +If both ``max_epochs`` and ``max_steps`` aren't specified, ``max_epochs`` will default to ``1000``. +To enable infinite training, set ``max_epochs = -1``. + min_epochs ^^^^^^^^^^ @@ -947,6 +950,9 @@ Training will stop if max_steps or max_epochs have reached (earliest). # Stop after 100 steps trainer = Trainer(max_steps=100) +If ``max_steps`` is not specified, ``max_epochs`` will be used instead (and ``max_epochs`` defaults to +``1000`` if ``max_epochs`` is not specified). To disable this default, set ``max_steps = -1``. + min_steps ^^^^^^^^^ diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 2020ac6cc6564..93628e5114ec5 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -37,6 +37,9 @@ class TrainingEpochLoop(loops.Loop): def __init__(self, min_steps: int, max_steps: int): super().__init__() self.min_steps: int = min_steps + + if max_steps and max_steps < -1: + raise MisconfigurationException(f"`max_steps` must be a positive integer or -1. You passed in {max_steps}.") self.max_steps: int = max_steps self.global_step: int = 0 diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 47eb50a2ab100..c7d5dad492bef 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -20,6 +20,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.trainer.supporters import TensorRunningAccum +from pytorch_lightning.utilities.exceptions import MisconfigurationException log = logging.getLogger(__name__) @@ -35,6 +36,12 @@ class FitLoop(Loop): def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None): super().__init__() + # Allow max_epochs or max_steps to be zero, since this will be handled by fit_loop.done + if max_epochs and max_epochs < -1: + raise MisconfigurationException( + f"`max_epochs` must be a positive integer or -1. You passed in {max_epochs}." + ) + self.max_epochs = max_epochs self.min_epochs = min_epochs self.epoch_loop: Optional[TrainingEpochLoop] = None @@ -98,6 +105,8 @@ def max_steps(self) -> int: def max_steps(self, value: int) -> None: """Sets the maximum number of steps (forwards to epoch_loop)""" # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided + if value and value < -1: + raise MisconfigurationException(f"`max_steps` must be a positive integer or -1. You passed in {value}.") self.epoch_loop.max_steps = value @property @@ -123,6 +132,19 @@ def _results(self) -> ResultCollection: return self.epoch_loop.val_loop._results raise RuntimeError("`FitLoop._results` property isn't defined. Accessed outside of scope") + @staticmethod + def _is_max_limit_enabled(max_value: Optional[int]) -> bool: + """Checks whether the max_value is enabled. This can + be used for checking whether max_epochs or max_steps is enabled. + + Args: + max_value: the value to check + + Returns: + whether the limit for this value should be enabled + """ + return max_value not in (None, -1) + @property def done(self) -> bool: """Evaluates when to leave the loop. @@ -131,8 +153,8 @@ def done(self) -> bool: or if the maximum number of steps or epochs is reached. """ # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop - stop_steps = self.max_steps is not None and self.global_step >= self.max_steps - stop_epochs = self.max_epochs is not None and self.current_epoch >= self.max_epochs + stop_steps = FitLoop._is_max_limit_enabled(self.max_steps) and self.global_step >= self.max_steps + stop_epochs = FitLoop._is_max_limit_enabled(self.max_epochs) and self.current_epoch >= self.max_epochs should_stop = False if self.trainer.should_stop: diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 73002dc1b9325..ce119d80c24eb 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -21,6 +21,7 @@ from torchmetrics import Metric import pytorch_lightning as pl +from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -190,7 +191,10 @@ def restore_loops(self) -> None: self.trainer.fit_loop.current_epoch = self._loaded_checkpoint["epoch"] # crash if max_epochs is lower then the current epoch from the checkpoint - if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs: + if ( + FitLoop._is_max_limit_enabled(self.trainer.max_epochs) + and self.trainer.current_epoch > self.trainer.max_epochs + ): raise MisconfigurationException( f"You restored a checkpoint with current_epoch={self.trainer.current_epoch}," f" but you have set Trainer(max_epochs={self.trainer.max_epochs})." diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3e84c725b9663..d19d38efed522 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -268,12 +268,15 @@ def __init__( Can be used on CPU, GPU or TPUs. max_epochs: Stop training once this number of epochs is reached. Disabled by default (None). - If both max_epochs and max_steps are not specified, defaults to ``max_epochs`` = 1000. + If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``. + To enable infinite training, set ``max_epochs = -1``. min_epochs: Force training for at least these many epochs. Disabled by default (None). - If both min_epochs and min_steps are not specified, defaults to ``min_epochs`` = 1. + If both min_epochs and min_steps are not specified, defaults to ``min_epochs = 1``. - max_steps: Stop training after this number of steps. Disabled by default (None). + max_steps: Stop training after this number of steps. Disabled by default (None). If ``max_steps = None`` + and ``max_epochs = None``, will default to ``max_epochs = 1000``. To disable this default, set + ``max_steps`` to ``-1``. min_steps: Force training for at least these number of steps. Disabled by default (None). @@ -380,6 +383,7 @@ def __init__( self.slurm_connector = SLURMConnector(self) self.tuner = Tuner(self) + # max_epochs won't default to 1000 if max_steps/max_time are specified (including being set to -1). fit_loop = FitLoop( min_epochs=(1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs), max_epochs=(1000 if (max_epochs is None and max_steps is None and max_time is None) else max_epochs), diff --git a/tests/callbacks/test_timer.py b/tests/callbacks/test_timer.py index c7b636d3f843a..92643ba51b82c 100644 --- a/tests/callbacks/test_timer.py +++ b/tests/callbacks/test_timer.py @@ -42,9 +42,13 @@ def on_fit_start(self): trainer.fit(TestModel()) assert "callbacks list already contains a Timer" in caplog.text - seconds = 1 - trainer = Trainer(max_time=dict(seconds=seconds)) - assert trainer.max_epochs is None + # Make sure max_time still honored even if max_epochs == -1 + trainer = Trainer(max_time=dict(seconds=1), max_epochs=-1) + with pytest.raises(SystemExit): + trainer.fit(TestModel()) + timer = [c for c in trainer.callbacks if isinstance(c, Timer)][0] + assert timer._duration == 1 + assert trainer.max_epochs == -1 assert trainer.max_steps is None @@ -153,7 +157,11 @@ def test_timer_resume_training(tmpdir): checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1) # initial training - trainer = Trainer(default_root_dir=tmpdir, max_epochs=100, callbacks=[timer, checkpoint_callback]) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=100, + callbacks=[timer, checkpoint_callback], + ) trainer.fit(model) assert not timer._offset assert timer.time_remaining() <= 0 diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 9a58351ecaaf5..90b5847fc0057 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -33,7 +33,7 @@ import tests.helpers.utils as tutils from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, Timer from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv from pytorch_lightning.loggers import TensorBoardLogger @@ -492,6 +492,57 @@ def test_trainer_max_steps_and_epochs(tmpdir): assert trainer.global_step == num_train_samples * trainer.max_epochs assert trainer.current_epoch == trainer.max_epochs - 1, "Model did not stop at max_epochs" + # if max_steps is positive and max_epochs is negative, use max_steps + trainer_kwargs["max_epochs"] = -1 + trainer_kwargs["max_steps"] = 3 + trainer = Trainer(**trainer_kwargs) + trainer.fit(model) + + assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.global_step == 3 + + +@pytest.mark.parametrize( + "max_epochs,max_steps,incorrect_variable,incorrect_value", + [ + (-100, None, "max_epochs", -100), + (1, -2, "max_steps", -2), + ], +) +def test_trainer_max_steps_and_epochs_validation(max_epochs, max_steps, incorrect_variable, incorrect_value): + """Don't allow max_epochs or max_steps to be less than -1 or a float""" + with pytest.raises( + MisconfigurationException, + match=f"`{incorrect_variable}` must be a positive integer or -1. You passed in {incorrect_value}", + ): + Trainer(max_epochs=max_epochs, max_steps=max_steps) + + +@pytest.mark.parametrize( + "max_epochs,max_steps,is_done,correct_trainer_epochs", + [ + (None, None, False, 1000), + (-1, None, False, -1), + (None, -1, False, None), + (5, -1, False, 5), + (-1, 10, False, -1), + (None, 0, True, None), + (0, None, True, 0), + (-1, 0, True, -1), + (0, -1, True, 0), + ], +) +def test_trainer_max_steps_and_epochs_fit_loop_done(max_epochs, max_steps, is_done, correct_trainer_epochs): + trainer = Trainer(max_epochs=max_epochs, max_steps=max_steps) + + assert trainer.max_epochs == correct_trainer_epochs + assert trainer.max_steps == max_steps + assert trainer.fit_loop.done is is_done + + # Make sure there is no timer + timer_callbacks = [c for c in trainer.callbacks if isinstance(c, Timer)] + assert len(timer_callbacks) == 0 + def test_trainer_min_steps_and_epochs(tmpdir): """Verify model trains according to specified min steps"""