Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow disabling automatic stopping after max_steps or max_epochs #8877

Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
379dda2
Update docstring + logic to disable automatic stopping
EricWiener Aug 12, 2021
d506881
Add a test to check passing negative max_epochs
EricWiener Aug 12, 2021
8d9e707
Updated logic for disabling automatic stopping
EricWiener Aug 27, 2021
3a7cc8a
Updated test cases for max_epochs/max_steps + max_time
EricWiener Aug 27, 2021
a0b8c61
Change brackets to parentheses
EricWiener Aug 27, 2021
a239358
Corrected max_epoch error checking restore_loops
EricWiener Aug 27, 2021
d201464
Validating max_epochs and max_steps
EricWiener Aug 27, 2021
91422f5
Added parameterized tests for max_epochs + max_steps
EricWiener Aug 27, 2021
c6562b4
Shortened timer to 1 sec from 10 sec
EricWiener Aug 27, 2021
f7e8176
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2021
78a1eb8
Fix type error comparing to None
EricWiener Aug 28, 2021
5558917
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2021
6c5439c
Added FitLoop._is_max_limit_enabled
EricWiener Aug 28, 2021
cd7732a
Removed mentioning max_epochs in max_steps docstring
EricWiener Aug 28, 2021
7209976
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2021
794bf31
Remove type signature on `max_value`
EricWiener Aug 29, 2021
945b731
Remove type signature on return value
EricWiener Aug 29, 2021
1b70637
Now checking that max vals are int (vs. not float)
EricWiener Aug 29, 2021
6e53053
Condensed test_timer test
EricWiener Aug 29, 2021
fcada92
Moved details desc of max_epochs/steps to trainer.rst
EricWiener Aug 29, 2021
769c7fc
Shortened max_* desc in trainer.rst
EricWiener Aug 29, 2021
4e6bed4
Update pytorch_lightning/trainer/connectors/checkpoint_connector.py
awaelchli Sep 1, 2021
fe11371
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 1, 2021
e8a3174
Change brackets to paranthesis
EricWiener Sep 2, 2021
1e13018
Update docs/source/common/trainer.rst
EricWiener Sep 2, 2021
db53b8c
Update pytorch_lightning/trainer/trainer.py
EricWiener Sep 2, 2021
b805ff5
No longer checking if max_epochs/steps is an int
EricWiener Sep 2, 2021
4c5f5d8
Fixed test_trainer_max_steps_and_epochs_fit_loop_done
EricWiener Sep 3, 2021
05ed3a3
Fix test_timer.py::test_trainer_flag
EricWiener Sep 3, 2021
d21b009
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 3, 2021
4594284
Fixed test_trainer_max_steps_and_epochs_validation
EricWiener Sep 3, 2021
318fec8
Decrease global step in tests/trainer/test_trainer.py
EricWiener Sep 3, 2021
0c322ea
Change EvalModelTemplate to BoringModel
EricWiener Sep 3, 2021
145c27b
Moved max_* validation into constructors
EricWiener Sep 3, 2021
3b3f29f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 3, 2021
38b4436
Fix pre-commit
carmocca Sep 4, 2021
1a17f87
Keep TODO at the top
carmocca Sep 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def done(self) -> bool:
or if the maximum number of steps or epochs is reached.
"""
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
stop_steps = self.max_steps is not None and self.global_step >= self.max_steps
stop_epochs = self.max_epochs is not None and self.current_epoch >= self.max_epochs
stop_steps = self.max_steps not in (None, -1) and self.global_step >= self.max_steps
EricWiener marked this conversation as resolved.
Show resolved Hide resolved
stop_epochs = self.max_epochs not in (None, -1) and self.current_epoch >= self.max_epochs

should_stop = False
if self.trainer.should_stop:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def restore_loops(self) -> None:
self.trainer.fit_loop.current_epoch = self._loaded_checkpoint["epoch"]

# crash if max_epochs is lower then the current epoch from the checkpoint
if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs:
if self.trainer.max_epochs not in (None, -1) and self.trainer.current_epoch > self.trainer.max_epochs:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException(
f"You restored a checkpoint with current_epoch={self.trainer.current_epoch},"
f" but you have set Trainer(max_epochs={self.trainer.max_epochs})."
Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,15 @@ def __init__(

max_epochs: Stop training once this number of epochs is reached. Disabled by default (None).
If both max_epochs and max_steps are not specified, defaults to ``max_epochs`` = 1000.
To disable automatic stopping, you can set ``max_epochs = -1`` and set ``max_steps`` as ``None``
or ``-1``. Note that if the the ``max_time`` limit is specified, it will still be observed.
EricWiener marked this conversation as resolved.
Show resolved Hide resolved

min_epochs: Force training for at least these many epochs. Disabled by default (None).
If both min_epochs and min_steps are not specified, defaults to ``min_epochs`` = 1.

max_steps: Stop training after this number of steps. Disabled by default (None).
max_steps: Stop training after this number of steps. Disabled by default (None). If ``max_steps = None``
EricWiener marked this conversation as resolved.
Show resolved Hide resolved
and ``max_epochs = None``, will default to ``max_epochs = 1000``. To override this
behavior, see ``max_epochs``.

EricWiener marked this conversation as resolved.
Show resolved Hide resolved
min_steps: Force training for at least these number of steps. Disabled by default (None).

Expand Down Expand Up @@ -374,6 +378,7 @@ def __init__(
self.slurm_connector = SLURMConnector(self)
self.tuner = Tuner(self)

# max_epochs won't default to 1000 if max_steps/max_time are specified (including being set to -1).
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
fit_loop = FitLoop(
min_epochs=(1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs),
max_epochs=(1000 if (max_epochs is None and max_steps is None and max_time is None) else max_epochs),
Expand Down
7 changes: 7 additions & 0 deletions tests/callbacks/test_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ def on_fit_start(self):
assert trainer.max_epochs is None
assert trainer.max_steps is None

# Make sure max_time still honored even if max_epochs == -1
trainer = Trainer(max_time=dict(seconds=10), max_epochs=-1)
EricWiener marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(SystemExit):
trainer.fit(TestModel())
timer = [c for c in trainer.callbacks if isinstance(c, Timer)][0]
assert timer._duration == 10


@pytest.mark.parametrize(
"duration,expected",
Expand Down
34 changes: 34 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,40 @@ def test_trainer_max_steps_and_epochs(tmpdir):
assert trainer.global_step == num_train_samples * trainer.max_epochs
assert trainer.current_epoch == trainer.max_epochs - 1, "Model did not stop at max_epochs"

# if max_steps is positive and max_epochs is negative, use max_steps
trainer_kwargs["max_epochs"] = -1
trainer_kwargs["max_steps"] = 3 * 2 * num_train_samples
trainer = Trainer(**trainer_kwargs)
trainer.fit(model)

assert trainer.state.finished, f"Training failed with {trainer.state}"
assert trainer.global_step == 3 * 2 * num_train_samples
EricWiener marked this conversation as resolved.
Show resolved Hide resolved

# if max_steps is 0 and max_epochs is negative, use max_steps
trainer_kwargs["max_epochs"] = -1
trainer_kwargs["max_steps"] = 0
trainer = Trainer(**trainer_kwargs)

assert trainer.done is True
EricWiener marked this conversation as resolved.
Show resolved Hide resolved

# allow specifying max_epochs < 0 and max_steps = None. This should immediately stop
trainer_kwargs["max_epochs"] = -100
trainer_kwargs["max_steps"] = None
trainer = Trainer(**trainer_kwargs)

assert trainer.done is True

# Make sure various combinations work to disable automatic stopping
for x, y in [(-1, None), (None, -1), (None, None)]:
trainer_kwargs["max_epochs"] = x
trainer_kwargs["max_steps"] = y
trainer = Trainer(**trainer_kwargs)

assert trainer.max_epochs == x
assert trainer.max_steps == y
assert trainer.max_time is None
assert trainer.done is False
EricWiener marked this conversation as resolved.
Show resolved Hide resolved


def test_trainer_min_steps_and_epochs(tmpdir):
"""Verify model trains according to specified min steps"""
Expand Down