Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loop Refactor 1/N - Training Loop #7871

Merged
merged 543 commits into from
Jun 15, 2021
Merged
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
543 commits
Select commit Hold shift + click to select a range
5ca5b97
forked_name
carmocca May 31, 2021
0a6b185
is_tensor_and
carmocca May 31, 2021
30eff02
Add missing variable
carmocca May 31, 2021
e4c7d2f
Refactor fwd
carmocca May 31, 2021
17cd9d1
Merge branch 'master' into refactor/logger-connector-poc
carmocca May 31, 2021
b2de630
Merge branch 'master' into refactor/loops/loops_everywhere
awaelchli May 31, 2021
a9adba0
integrate #7772
awaelchli May 31, 2021
3881b18
Forgot to pass value
carmocca May 31, 2021
041ec8e
Docstrings
carmocca May 31, 2021
1047836
Revert if condition
carmocca May 31, 2021
b33f5dc
is_train -> training
carmocca May 31, 2021
43777a6
Unnecessary getter/setter
carmocca May 31, 2021
b7b9633
Unnecessary getter/setter
carmocca May 31, 2021
5768b55
Linter
carmocca May 31, 2021
08fac3d
root_device -> device
carmocca May 31, 2021
7886151
Docstrings
carmocca May 31, 2021
be94ec9
Better exceptions
carmocca May 31, 2021
76eb0ee
flake8 and mypy
carmocca May 31, 2021
ae93240
Clean up batch size?
carmocca Jun 1, 2021
8fbbcff
Replace repr with str
carmocca Jun 1, 2021
df0ff74
Refactor reset
carmocca Jun 1, 2021
20bcf0c
Docstring
carmocca Jun 1, 2021
6da9d8f
hook_name -> fx
carmocca Jun 1, 2021
f10084e
Fix reset
carmocca Jun 1, 2021
13785a0
Refactor
carmocca Jun 1, 2021
cce7023
Serialization updates
carmocca Jun 1, 2021
54da8eb
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 1, 2021
f51a2aa
Move is_tensor to ResultMetric
carmocca Jun 1, 2021
5700df4
Fixes
carmocca Jun 1, 2021
bf8bc52
Rename metric attribute
carmocca Jun 1, 2021
89149ac
Comment getstate
carmocca Jun 1, 2021
d7823e3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2021
37e7589
Leave out serialization for a different PR
carmocca Jun 1, 2021
c6e3e26
Encapsulate dict collections
carmocca Jun 1, 2021
a4a6903
"extra" changes
carmocca Jun 1, 2021
f9304fc
Fix typing
carmocca Jun 1, 2021
4237307
Enable mypy
carmocca Jun 1, 2021
24eb614
Legacy code
carmocca Jun 1, 2021
092fec0
Stricter type checking
carmocca Jun 1, 2021
935fe5b
Minor changes
carmocca Jun 1, 2021
eac0b82
Performance optimizations
carmocca Jun 1, 2021
e789cef
Fixes and remove support for nested dict
carmocca Jun 2, 2021
1f79212
Make sure _forward_cache gets float values
carmocca Jun 2, 2021
b24b410
Update legacy test
carmocca Jun 2, 2021
92c5ac0
Sync improvements
carmocca Jun 2, 2021
e628b1f
Convert to float outside of sync
carmocca Jun 2, 2021
4efc01a
Fix empty collections
carmocca Jun 3, 2021
8509c4a
Install latest torchmetrics
carmocca Jun 3, 2021
f8b824f
Update CODEOWNERS
carmocca Jun 3, 2021
5ef05e3
Keep no_grad ctx manager if enable graph
carmocca Jun 3, 2021
33b06d0
Move has_reset to ResultMetric
carmocca Jun 3, 2021
4170875
Sync dataclass
carmocca Jun 3, 2021
8612746
Detach if not enable graph
carmocca Jun 3, 2021
0f8e620
Prepend dataclasses with underscore
carmocca Jun 3, 2021
7008fe9
Add test
carmocca Jun 3, 2021
6db7d66
Update code after torchmetrics==0.3.2
carmocca Jun 3, 2021
b4582e3
Merge two ifs
carmocca Jun 3, 2021
fa91cdf
Reorder code
carmocca Jun 3, 2021
03f769c
Error for custom reductions
carmocca Jun 3, 2021
7cf8b8b
improve reduce_fx test
carmocca Jun 3, 2021
23a628d
Minor logger connector refactoring
carmocca Jun 3, 2021
13d57af
Refactor eval loop results
carmocca Jun 3, 2021
fcbf408
Formatting
carmocca Jun 3, 2021
574e713
Typing
carmocca Jun 3, 2021
e74a8ad
Formatting
carmocca Jun 3, 2021
585e62c
Move back call
carmocca Jun 3, 2021
7cce418
Fix FIXME
carmocca Jun 3, 2021
ff71204
Merge branch 'master' into refactor/loops/loops_everywhere
awaelchli Jun 4, 2021
86d6c11
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 4, 2021
a87c68f
integrate #7631 (logger connector refactor)
awaelchli Jun 4, 2021
d180bb2
call logger connector on_train_split_start at start of train split
awaelchli Jun 7, 2021
af6db00
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 7, 2021
a386347
integrate d180bb2
awaelchli Jun 7, 2021
42210fc
Minor changes
carmocca Jun 7, 2021
e76010a
Refactor loop logic into logger connector
carmocca Jun 7, 2021
305e31b
Merge branch 'refactor/logger-connector-poc' of https://github.com/Py…
carmocca Jun 7, 2021
d65e152
deletions
awaelchli Jun 7, 2021
4a1218c
restore predict
awaelchli Jun 7, 2021
6acd71c
remove evaluation
awaelchli Jun 7, 2021
0778b36
x
awaelchli Jun 7, 2021
0ef000f
loop
awaelchli Jun 7, 2021
d421b56
logger
awaelchli Jun 7, 2021
e761890
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 7, 2021
fa9ff73
Refactor test
carmocca Jun 7, 2021
a1f839f
Tighter fx validator
carmocca Jun 7, 2021
a587879
Add back split idx
carmocca Jun 7, 2021
631da98
Typing
carmocca Jun 7, 2021
e098631
update
tchaton Jun 7, 2021
2f13234
Conflict
carmocca Jun 7, 2021
502dcbd
Fix tests
carmocca Jun 7, 2021
c716736
resolve grad_norm
tchaton Jun 7, 2021
26f5e03
update
tchaton Jun 7, 2021
89aaaa2
Merge branch 'refactor/logger-connector-poc' of https://github.com/Py…
tchaton Jun 7, 2021
badd645
move to train loop
tchaton Jun 7, 2021
aac11a0
Bye grad_norm_dict parameter
carmocca Jun 7, 2021
919cbfb
Fix sync test
carmocca Jun 7, 2021
1c75341
update
tchaton Jun 7, 2021
b00e4a5
Merge branch 'refactor/logger-connector-poc' of https://github.com/Py…
tchaton Jun 7, 2021
23a7510
Fix bug when validation is run mid epoch
carmocca Jun 7, 2021
9df0da2
fix grad_norm_dict test
carmocca Jun 7, 2021
0485a98
Fix fx_validator test
carmocca Jun 7, 2021
e0702aa
fix grad_norm_dict test
carmocca Jun 7, 2021
32ca719
Fix order bug
carmocca Jun 7, 2021
c68825c
Detach tensors in test
carmocca Jun 7, 2021
37ed74d
resolve some tests
tchaton Jun 7, 2021
39829b5
Merge branch 'refactor/logger-connector-poc' of https://github.com/Py…
tchaton Jun 7, 2021
cabd48b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 7, 2021
bc106c3
remove pdb
tchaton Jun 7, 2021
e048417
merge
tchaton Jun 7, 2021
98298e1
resolve flake8
tchaton Jun 7, 2021
d8da3cd
Update test
carmocca Jun 7, 2021
49558bf
more tests
tchaton Jun 7, 2021
bae75e9
Merge branch 'refactor/logger-connector-poc' of https://github.com/Py…
tchaton Jun 7, 2021
0fe870a
Merge branch 'refactor/logger-connector-poc' of https://github.com/Py…
carmocca Jun 7, 2021
f27d0a3
Revert last thomas' changes
carmocca Jun 7, 2021
f4444c4
resolve 1 test
tchaton Jun 7, 2021
fa4bcec
Merge branch 'refactor/logger-connector-poc' of https://github.com/Py…
tchaton Jun 7, 2021
d144771
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 7, 2021
ffff6f4
Refactor context restoration
carmocca Jun 7, 2021
d0062ec
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 7, 2021
9d72028
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 7, 2021
d3302ed
update
awaelchli Jun 7, 2021
4e94f33
Merge remote-tracking branch 'origin/refactor/loops/loops_everywhere_…
awaelchli Jun 7, 2021
601aa95
x
awaelchli Jun 7, 2021
e272385
on trainer init
awaelchli Jun 7, 2021
1763d8f
test
awaelchli Jun 7, 2021
d718498
update trainer
awaelchli Jun 7, 2021
6d98a07
integrate latest changes from logger connector refactor poc
awaelchli Jun 7, 2021
7ca1049
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 7, 2021
515ad9f
Minor changes
carmocca Jun 7, 2021
b03591c
update changelog
awaelchli Jun 7, 2021
0aa8428
Remove unused argument
carmocca Jun 7, 2021
24b41e3
Update CHANGELOG
carmocca Jun 7, 2021
6d71e6a
Copy call_hook changes
carmocca Jun 7, 2021
44ad4ac
Docs
carmocca Jun 7, 2021
2c74018
Fix ref
carmocca Jun 7, 2021
b15984b
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 7, 2021
e8021bb
merge
tchaton Jun 8, 2021
9747023
move to cpu
tchaton Jun 8, 2021
d9ae37a
Bad merge
carmocca Jun 8, 2021
bad51c6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2021
273bc92
remove pdb
tchaton Jun 8, 2021
f214632
remove pdb
tchaton Jun 8, 2021
5fdf3c5
merge
tchaton Jun 8, 2021
99543a7
Refactor to
carmocca Jun 8, 2021
738c810
Avoid partial
carmocca Jun 8, 2021
6a7637d
trigger ci
carmocca Jun 8, 2021
8077cf9
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 8, 2021
aff9e3d
Bad merge
carmocca Jun 8, 2021
464f581
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 8, 2021
461332b
integrate latest logger connector changes
awaelchli Jun 8, 2021
417ad31
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2021
9321b11
remove grad norm dicts list
awaelchli Jun 8, 2021
e75a958
Diff
carmocca Jun 8, 2021
007dcac
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 8, 2021
2e4bb24
Bad merge
carmocca Jun 8, 2021
f5154ae
Reuse metrics_to_scalars
carmocca Jun 8, 2021
558cdf4
Use active loop
carmocca Jun 8, 2021
90d71bf
Move to device
carmocca Jun 8, 2021
d7f1761
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 8, 2021
6ce6762
resolve test
tchaton Jun 8, 2021
fba9a87
properties first
awaelchli Jun 8, 2021
fd967af
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 8, 2021
79c73b9
define union
awaelchli Jun 8, 2021
37a0b9d
Update logger connector
carmocca Jun 8, 2021
aaea387
Update result
carmocca Jun 8, 2021
e2f69ce
Update imports
carmocca Jun 8, 2021
6037833
Update after rename
carmocca Jun 8, 2021
3804963
Merge branch 'refactor/logger-connector-poc' of https://github.com/Py…
carmocca Jun 8, 2021
499da76
Refactor reduce_fx and op
carmocca Jun 8, 2021
6eb448a
Fix test after rename
carmocca Jun 8, 2021
f871cbd
mypy
carmocca Jun 8, 2021
5631b53
manual merge poc changes
awaelchli Jun 9, 2021
d10d5c7
integrate latest changes from logger connector poc
awaelchli Jun 9, 2021
7b6803a
Fix test
carmocca Jun 9, 2021
9bfedc9
Refactor test
carmocca Jun 9, 2021
c9c7829
Deprecate `self.log(sync_dist_op)` in favor of `self.log(reduce_fx)`
carmocca Jun 9, 2021
e3dde0b
Undo field
carmocca Jun 9, 2021
bae2139
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 9, 2021
2c167cc
rename
awaelchli Jun 9, 2021
99db497
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2021
832dfb9
rename
awaelchli Jun 9, 2021
f92e01d
imports
awaelchli Jun 9, 2021
b15fc34
loop hygiene
awaelchli Jun 9, 2021
7175a50
yapf on loops
awaelchli Jun 9, 2021
59d6227
protected new loop trigger
awaelchli Jun 9, 2021
e1d4fd2
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 9, 2021
a7c3555
Replace code
carmocca Jun 9, 2021
501224d
Fix names and imports
carmocca Jun 9, 2021
dee7e5f
Remove metric_attribute
carmocca Jun 9, 2021
4eb9757
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 9, 2021
d4bb357
integrate latest logger connector changes
awaelchli Jun 9, 2021
c9b4e9e
resolve todo dataloading reset
awaelchli Jun 10, 2021
a3ef0aa
re-add notebooks
awaelchli Jun 10, 2021
b071532
Merge branch 'master' into refactor/logger-connector-poc
awaelchli Jun 10, 2021
53deef8
add missing init
awaelchli Jun 10, 2021
93fd682
bad merge
awaelchli Jun 10, 2021
80c406e
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 10, 2021
a041b6f
remove iteration count method
awaelchli Jun 10, 2021
e080be8
todo for a fix in #5007
awaelchli Jun 10, 2021
4950821
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 10, 2021
c56adc1
remove NEW_LOOP guard
awaelchli Jun 10, 2021
5e72d1d
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 10, 2021
bace4a2
flake8
awaelchli Jun 10, 2021
71bfb6f
exclude coverage
awaelchli Jun 10, 2021
acc6d4f
Merge branch 'master' into refactor/loops/loops_everywhere_train
awaelchli Jun 10, 2021
41e0e64
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2021
643bef0
flake8 vs yapf wars
awaelchli Jun 10, 2021
4b6bd18
Merge branch 'master' into refactor/loops/loops_everywhere_train
awaelchli Jun 10, 2021
536574a
integrate #7917, remove teardown from training loop
awaelchli Jun 10, 2021
b28fb09
update "accumulated_batches_reached" condition
awaelchli Jun 11, 2021
6f17688
remove public loop properties
awaelchli Jun 11, 2021
6dd4e1d
make skip backward protected again
awaelchli Jun 11, 2021
c394267
typing base loop
awaelchli Jun 11, 2021
4adae06
typing fit loop
awaelchli Jun 11, 2021
c49875d
typing training_batch_loop
awaelchli Jun 11, 2021
80edb75
typing training epoch loop
awaelchli Jun 11, 2021
8b54505
fix merge error
justusschock Jun 11, 2021
9fd8ed1
Merge branch 'master' into refactor/loops/loops_everywhere_train
awaelchli Jun 11, 2021
e4ffa6c
integrate train loop changes from master
awaelchli Jun 11, 2021
69ed0e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 11, 2021
eeebc9a
fix tpipes moving model to cpu and leaving it there.
awaelchli Jun 12, 2021
ce9dd2a
don't reset fit loop
awaelchli Jun 12, 2021
80e225a
fix test iteration count <-> batch_idx reset
awaelchli Jun 14, 2021
4880b26
replace torch.Tensor -> Tensor
awaelchli Jun 14, 2021
5461f73
fix attribute error to block_ddp_sync_behaviour
awaelchli Jun 14, 2021
a2d3f0d
Merge branch 'master' into refactor/loops/loops_everywhere_train
awaelchli Jun 14, 2021
0fe6d9f
ignore mypy errors
awaelchli Jun 14, 2021
5497fc0
fix flake8 and yapf conflict
awaelchli Jun 14, 2021
4c51c45
remove redundant override
awaelchli Jun 14, 2021
8f68b61
Apply suggestions from code review
awaelchli Jun 14, 2021
0150f6c
Apply suggestions from code review
awaelchli Jun 14, 2021
fd90c10
Apply suggestions from code review
awaelchli Jun 14, 2021
153d264
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2021
4eb0eb1
remove all empty space between atoms
awaelchli Jun 14, 2021
70cdb14
carlos
awaelchli Jun 14, 2021
bf26aa3
Apply suggestions from code review
justusschock Jun 14, 2021
ffc4f45
Apply suggestions from code review
justusschock Jun 14, 2021
79f8c18
Merge remote-tracking branch 'origin/refactor/loops/loops_everywhere_…
awaelchli Jun 14, 2021
3373cc8
resolve a todo integrating on_train_batch_end with on_advance_end
awaelchli Jun 14, 2021
e1a40c0
clarify what is todo and what is fixme
awaelchli Jun 14, 2021
b5bb08a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2021
5d98009
shorten a docstring
awaelchli Jun 14, 2021
03bce7a
Merge remote-tracking branch 'origin/refactor/loops/loops_everywhere_…
awaelchli Jun 14, 2021
f001f81
move on_epoch_start to on_run_start of training loop
awaelchli Jun 14, 2021
24fa859
Merge branch 'master' into refactor/loops/loops_everywhere_train
awaelchli Jun 14, 2021
d191fe1
Update pytorch_lightning/loops/base.py
awaelchli Jun 15, 2021
1d21065
update class names in changelog
awaelchli Jun 15, 2021
2d8c441
add empty teardown method
awaelchli Jun 15, 2021
f874182
added skip property
awaelchli Jun 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Moved attributes `hiddens` and `split_idx` to TrainLoop ([#7507](https://github.com/PyTorchLightning/pytorch-lightning/pull/7507))
* Refactored the logic around manual and automatic optimization inside the optimizer loop ([#7526](https://github.com/PyTorchLightning/pytorch-lightning/pull/7526))
* Simplified "should run validation" logic ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682))
* Refactored "should run validation" logic when the trainer is signaled to stop ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701))

* Simplified logic for updating the learning rate for schedulers ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682))
* Removed the `on_epoch` guard from the "should stop" validation check ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701))
...
* Refactored internal loop interface; added new classes `EpochLoop`, `TrainingLoop`, `BatchLoop` ([#7871](https://github.com/PyTorchLightning/pytorch-lightning/pull/7871))
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

- Refactored logging
* Renamed and moved `core/step_result.py` to `trainer/connectors/logger_connector/result.py` ([#7736](https://github.com/PyTorchLightning/pytorch-lightning/pull/7736))
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,7 @@ def training_step(...):

# backward
self._running_manual_backward = True
self.trainer.train_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs)
self.trainer.fit_loop.training_loop.batch_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs)
self._running_manual_backward = False

def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None:
Expand Down Expand Up @@ -1445,7 +1445,7 @@ def optimizer_step(
If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter
to ``optimizer.step()`` function as shown in the examples. This ensures that
``training_step()``, ``optimizer.zero_grad()``, ``backward()`` are called within
:meth:`~pytorch_lightning.trainer.training_loop.TrainLoop.run_training_batch`.
:meth:`~pytorch_lightning.trainer.fit_loop.training_loop.batch_loop.TrainingBatchLoop.advance`.

Args:
epoch: Current epoch
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def toggle_model(self, sync_grad: bool = True):
during the accumulation phase.
Setting `sync_grad` to False will block this synchronization and improve performance.
"""
with self._trainer.train_loop.block_ddp_sync_behaviour(not sync_grad):
with self._trainer.train_loop.training_loop.batch_loop.block_ddp_sync_behaviour(not sync_grad):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
justusschock marked this conversation as resolved.
Show resolved Hide resolved
self._toggle_model()
yield
self._untoggle_model()
Expand Down
18 changes: 18 additions & 0 deletions pytorch_lightning/loops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.loops.base import Loop # noqa: F401
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
from pytorch_lightning.loops.training_batch_loop import TrainingBatchLoop # noqa: F401
from pytorch_lightning.loops.training_epoch_loop import TrainingEpochLoop # noqa: F401
115 changes: 115 additions & 0 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Optional
from weakref import proxy

from deprecate import void

import pytorch_lightning as pl


class Loop(ABC):
"""
Basic Loops interface. All classes derived from this must implement the following properties and methods:

* :attr`done` (property): Condition to break the loop
* :attr`reset` (method): Resets the internal state between multiple calls of :attr`run`
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
* :attr`advance` (method): Implements one step of the loop

This class implements the following loop structure:

.. codeblock:: python

on_run_start()

while not done:
on_advance_start()
advance()
on_advance_end()

on_run_end()

"""

def __init__(self) -> None:
self.iteration_count: int = 0
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.trainer: Optional['pl.Trainer'] = None

@property
@abstractmethod
def done(self) -> bool:
"""Property indicating when loop is finished"""

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects Loop with all the necessary things like connectors and accelerators."""
self.trainer = proxy(trainer)

@abstractmethod
def reset(self) -> None:
"""Resets the internal state of the loop at the beginning of each call to :attr:`run`."""

def run(self, *args: Any, **kwargs: Any) -> Any:
"""
The main entry point to the loop.

Will frequently check the :attr:`done` condition and calls :attr:`advance`
until :attr`done` evaluates to ``True``.
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

Returns:
the output of :attr`on_run_end` (often outputs collected from each step of the loop)
"""
self.reset()
self.on_run_start(*args, **kwargs)

while not self.done:
try:
self.on_advance_start(*args, **kwargs)
self.advance(*args, **kwargs)
self.on_advance_end()
self.iteration_count += 1
except StopIteration:
break

return self.on_run_end()

def on_run_start(self, *args: Any, **kwargs: Any) -> None:
"""
Hook to be called as the first thing after entering :attr:`run` (except the state reset).

Accepts all arguments passed to :attr:`run`.

awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""
void(*args, **kwargs)

def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
"""
Hook to be called each time before :attr:`advance` is called. Accepts all arguments passed to :attr`run`.

"""
void(*args, **kwargs)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

@abstractmethod
def advance(self, *args: Any, **kwargs: Any) -> None:
"""
Performs a single step. Accepts all arguments passed to :attr:`run`.

"""

def on_advance_end(self) -> None:
"""Hook to be called each time after :attr:`advance` is called."""
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

def on_run_end(self) -> Any:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""Hook to be called at the end of the run. Its return argument is returned from :attr:`run`."""
Loading