Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loop Refactor 5/N - Prediction Loop #7700

Merged
merged 616 commits into from
Jun 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
616 commits
Select commit Hold shift + click to select a range
42210fc
Minor changes
carmocca Jun 7, 2021
e76010a
Refactor loop logic into logger connector
carmocca Jun 7, 2021
305e31b
Merge branch 'refactor/logger-connector-poc' of https://github.com/Py…
carmocca Jun 7, 2021
e761890
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 7, 2021
fa9ff73
Refactor test
carmocca Jun 7, 2021
a1f839f
Tighter fx validator
carmocca Jun 7, 2021
a587879
Add back split idx
carmocca Jun 7, 2021
631da98
Typing
carmocca Jun 7, 2021
e098631
update
tchaton Jun 7, 2021
2f13234
Conflict
carmocca Jun 7, 2021
502dcbd
Fix tests
carmocca Jun 7, 2021
c716736
resolve grad_norm
tchaton Jun 7, 2021
26f5e03
update
tchaton Jun 7, 2021
89aaaa2
Merge branch 'refactor/logger-connector-poc' of https://github.com/Py…
tchaton Jun 7, 2021
badd645
move to train loop
tchaton Jun 7, 2021
aac11a0
Bye grad_norm_dict parameter
carmocca Jun 7, 2021
919cbfb
Fix sync test
carmocca Jun 7, 2021
1c75341
update
tchaton Jun 7, 2021
b00e4a5
Merge branch 'refactor/logger-connector-poc' of https://github.com/Py…
tchaton Jun 7, 2021
23a7510
Fix bug when validation is run mid epoch
carmocca Jun 7, 2021
9df0da2
fix grad_norm_dict test
carmocca Jun 7, 2021
0485a98
Fix fx_validator test
carmocca Jun 7, 2021
e0702aa
fix grad_norm_dict test
carmocca Jun 7, 2021
32ca719
Fix order bug
carmocca Jun 7, 2021
c68825c
Detach tensors in test
carmocca Jun 7, 2021
37ed74d
resolve some tests
tchaton Jun 7, 2021
39829b5
Merge branch 'refactor/logger-connector-poc' of https://github.com/Py…
tchaton Jun 7, 2021
cabd48b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 7, 2021
bc106c3
remove pdb
tchaton Jun 7, 2021
e048417
merge
tchaton Jun 7, 2021
98298e1
resolve flake8
tchaton Jun 7, 2021
d8da3cd
Update test
carmocca Jun 7, 2021
49558bf
more tests
tchaton Jun 7, 2021
bae75e9
Merge branch 'refactor/logger-connector-poc' of https://github.com/Py…
tchaton Jun 7, 2021
0fe870a
Merge branch 'refactor/logger-connector-poc' of https://github.com/Py…
carmocca Jun 7, 2021
f27d0a3
Revert last thomas' changes
carmocca Jun 7, 2021
f4444c4
resolve 1 test
tchaton Jun 7, 2021
fa4bcec
Merge branch 'refactor/logger-connector-poc' of https://github.com/Py…
tchaton Jun 7, 2021
d144771
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 7, 2021
ffff6f4
Refactor context restoration
carmocca Jun 7, 2021
43e8873
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 7, 2021
54b6c88
integrate latest changes from logger connector refactor poc
awaelchli Jun 7, 2021
b7edf27
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 7, 2021
766ad99
integrate latest changes from logger connector refactor poc
awaelchli Jun 7, 2021
515ad9f
Minor changes
carmocca Jun 7, 2021
7df8ddb
update changelog
awaelchli Jun 7, 2021
0aa8428
Remove unused argument
carmocca Jun 7, 2021
24b41e3
Update CHANGELOG
carmocca Jun 7, 2021
6d71e6a
Copy call_hook changes
carmocca Jun 7, 2021
44ad4ac
Docs
carmocca Jun 7, 2021
2c74018
Fix ref
carmocca Jun 7, 2021
b15984b
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 7, 2021
e8021bb
merge
tchaton Jun 8, 2021
9747023
move to cpu
tchaton Jun 8, 2021
d9ae37a
Bad merge
carmocca Jun 8, 2021
bad51c6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2021
273bc92
remove pdb
tchaton Jun 8, 2021
f214632
remove pdb
tchaton Jun 8, 2021
5fdf3c5
merge
tchaton Jun 8, 2021
99543a7
Refactor to
carmocca Jun 8, 2021
738c810
Avoid partial
carmocca Jun 8, 2021
6a7637d
trigger ci
carmocca Jun 8, 2021
8077cf9
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 8, 2021
aff9e3d
Bad merge
carmocca Jun 8, 2021
c55267d
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 8, 2021
33a7d5a
integrate latest logger connector changes
awaelchli Jun 8, 2021
0e59b09
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2021
d63f6bb
remove grad norm dicts list
awaelchli Jun 8, 2021
0fd3eee
Merge remote-tracking branch 'origin/refactor/loops_everywhere' into …
awaelchli Jun 8, 2021
e75a958
Diff
carmocca Jun 8, 2021
cb076d0
properties first
awaelchli Jun 8, 2021
007dcac
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 8, 2021
2e4bb24
Bad merge
carmocca Jun 8, 2021
f5154ae
Reuse metrics_to_scalars
carmocca Jun 8, 2021
558cdf4
Use active loop
carmocca Jun 8, 2021
90d71bf
Move to device
carmocca Jun 8, 2021
d7f1761
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 8, 2021
6ce6762
resolve test
tchaton Jun 8, 2021
08e103e
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 8, 2021
e1d30d6
integrate latest changes from logger connector poc
awaelchli Jun 8, 2021
fd967af
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 8, 2021
ec0a7a2
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 8, 2021
05eddd0
define union
awaelchli Jun 8, 2021
79c73b9
define union
awaelchli Jun 8, 2021
37a0b9d
Update logger connector
carmocca Jun 8, 2021
aaea387
Update result
carmocca Jun 8, 2021
e2f69ce
Update imports
carmocca Jun 8, 2021
6037833
Update after rename
carmocca Jun 8, 2021
3804963
Merge branch 'refactor/logger-connector-poc' of https://github.com/Py…
carmocca Jun 8, 2021
499da76
Refactor reduce_fx and op
carmocca Jun 8, 2021
6eb448a
Fix test after rename
carmocca Jun 8, 2021
f871cbd
mypy
carmocca Jun 8, 2021
a32f8ea
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 8, 2021
3b63f44
integrate latest logger connector refactor poc changes
awaelchli Jun 9, 2021
7b6803a
Fix test
carmocca Jun 9, 2021
9bfedc9
Refactor test
carmocca Jun 9, 2021
c9c7829
Deprecate `self.log(sync_dist_op)` in favor of `self.log(reduce_fx)`
carmocca Jun 9, 2021
e3dde0b
Undo field
carmocca Jun 9, 2021
9891f8f
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 9, 2021
21e7637
add redundant return
awaelchli Jun 9, 2021
54183e5
rename
awaelchli Jun 9, 2021
b562550
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2021
e374fe4
rename
awaelchli Jun 9, 2021
e1d4fd2
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 9, 2021
a7c3555
Replace code
carmocca Jun 9, 2021
501224d
Fix names and imports
carmocca Jun 9, 2021
dee7e5f
Remove metric_attribute
carmocca Jun 9, 2021
7628326
imports
awaelchli Jun 9, 2021
c20de54
loop hygiene
awaelchli Jun 9, 2021
0477437
yapf on loops
awaelchli Jun 9, 2021
b8c8cdd
protected new loop trigger
awaelchli Jun 9, 2021
ec4cb49
rename NEW LOOP guard
awaelchli Jun 9, 2021
4627984
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 9, 2021
d254b38
integrate latest logger connector changes
awaelchli Jun 9, 2021
75848b2
integrate latest logger connector changes (eval loop)
awaelchli Jun 9, 2021
deb4790
resolve todo dataloading reset
awaelchli Jun 10, 2021
2b183e6
re-add notebooks
awaelchli Jun 10, 2021
b071532
Merge branch 'master' into refactor/logger-connector-poc
awaelchli Jun 10, 2021
53deef8
add missing init
awaelchli Jun 10, 2021
dcac700
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 10, 2021
93fd682
bad merge
awaelchli Jun 10, 2021
0baee04
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 10, 2021
4950821
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 10, 2021
ecac33b
remove NEW_LOOP guard
awaelchli Jun 10, 2021
32cfcb9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2021
4c55ab5
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 10, 2021
5e99672
flake8
awaelchli Jun 10, 2021
4701e85
exclude coverage
awaelchli Jun 10, 2021
2b6cc28
Merge branch 'master' into refactor/loops/loops_everywhere
awaelchli Jun 10, 2021
d06a8e8
integrate #7917, remove teardown from training loop
awaelchli Jun 10, 2021
e7b4174
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2021
f3e55c0
update "accumulated_batches_reached" condition
awaelchli Jun 11, 2021
ef63f1f
remove public loop properties
awaelchli Jun 11, 2021
7de848d
make skip backward protected again
awaelchli Jun 11, 2021
7a36e28
typing base loop
justusschock Jun 11, 2021
ece57fd
typing fit loop
justusschock Jun 11, 2021
a28d62d
typing training_batch_loop
justusschock Jun 11, 2021
5bb3e04
typing evaluation loop
justusschock Jun 11, 2021
21fd993
typing prediction loop
justusschock Jun 11, 2021
6824814
typing training epoch loop
justusschock Jun 11, 2021
d84e57a
dataloader_loop
justusschock Jun 11, 2021
812661c
evaluation_dataloader_loop
justusschock Jun 11, 2021
61f8495
prediction_dataloader_loop
justusschock Jun 11, 2021
f1e537a
merge
justusschock Jun 11, 2021
b447a6b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 11, 2021
7851e26
Merge branch 'master' into refactor/loops/loops_everywhere
awaelchli Jun 11, 2021
65d85a3
integrate train loop changes from master
awaelchli Jun 11, 2021
9194d99
integrate eval loop changes from master
awaelchli Jun 11, 2021
6713126
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 11, 2021
8b905db
fix tpipes moving model to cpu and leaving it there.
awaelchli Jun 12, 2021
e5eac72
don't reset fit loop
awaelchli Jun 12, 2021
7a377be
fix test iteration count <-> batch_idx reset
awaelchli Jun 14, 2021
ccafaa9
replace torch.Tensor -> Tensor
awaelchli Jun 14, 2021
50f8bb8
fix attribute error to block_ddp_sync_behaviour
awaelchli Jun 14, 2021
a864a25
Merge branch 'master' into refactor/loops/loops_everywhere
awaelchli Jun 14, 2021
3ff21b5
fix flake8 and yapf conflict
awaelchli Jun 14, 2021
36f4fae
remove redundant override
awaelchli Jun 14, 2021
11bbbf6
Merge branch 'master' into refactor/loops/loops_everywhere
awaelchli Jun 15, 2021
d1ab532
add classes
awaelchli Jun 15, 2021
781f26b
trainer changes
awaelchli Jun 15, 2021
7c0f96e
connect
awaelchli Jun 15, 2021
4533b2c
clean up
awaelchli Jun 15, 2021
ea48342
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 15, 2021
9a3a908
update test renaming
awaelchli Jun 15, 2021
a7d2d86
Merge remote-tracking branch 'origin/refactor/loops/loops_everywhere_…
awaelchli Jun 15, 2021
d711a49
rename evaluation loop to evaluation epoch loop
awaelchli Jun 15, 2021
e592423
minor docstring improvements
awaelchli Jun 15, 2021
e3c0512
update chlog
awaelchli Jun 15, 2021
9044c19
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 15, 2021
c5c02ec
try ci fix
awaelchli Jun 15, 2021
cc4c57c
Merge remote-tracking branch 'origin/refactor/loops/loops_everywhere_…
awaelchli Jun 15, 2021
6d8af9f
Merge branch 'master' into refactor/loops/loops_everywhere_eval
awaelchli Jun 16, 2021
7e030e5
update code owners for pl/loops
awaelchli Jun 16, 2021
d174d04
update mock path
awaelchli Jun 16, 2021
d135da8
re-order
awaelchli Jun 16, 2021
ed6352c
simplify dataloader reset
awaelchli Jun 16, 2021
0fb4e90
simplify get_dataloaders()
awaelchli Jun 16, 2021
ca73fc4
save predictions on_run_end()
awaelchli Jun 16, 2021
a6b9798
improve skip condition re-routing
awaelchli Jun 17, 2021
4ff7084
re-order
awaelchli Jun 17, 2021
0ecbd55
Merge branch 'master' into refactor/loops/loops_everywhere_eval
awaelchli Jun 17, 2021
cab35a6
remove unused type import
awaelchli Jun 17, 2021
b7483b4
check which assert is failing
awaelchli Jun 17, 2021
33d89e0
pig
awaelchli Jun 17, 2021
e81b0db
hobbit
awaelchli Jun 17, 2021
cfbdc8b
teardown for evaluation
awaelchli Jun 17, 2021
7f70bfa
Revert "hobbit"
awaelchli Jun 17, 2021
f700ab4
Revert "pig"
awaelchli Jun 17, 2021
4f88695
Revert "check which assert is failing"
awaelchli Jun 17, 2021
9769e5a
free memory in fit loop teardown
awaelchli Jun 17, 2021
71c1f6e
update docstring
awaelchli Jun 17, 2021
a6229f5
period
awaelchli Jun 17, 2021
aa27b9f
remove dead code
awaelchli Jun 17, 2021
e082524
else carlos
awaelchli Jun 17, 2021
ea8fcf4
Update pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py
awaelchli Jun 17, 2021
4650461
Merge remote-tracking branch 'origin/refactor/loops/loops_everywhere_…
awaelchli Jun 17, 2021
424a1a6
update chlog
awaelchli Jun 18, 2021
0e110b6
unused imp
awaelchli Jun 18, 2021
8fd15c6
move default construction in run_evaluation
awaelchli Jun 18, 2021
61b3837
add something for lawyer to read
awaelchli Jun 18, 2021
07d605e
switch typehint for eval loop trainer property
awaelchli Jun 18, 2021
c5a2511
add missing imports
awaelchli Jun 18, 2021
0f4d536
remove a todo that needs more discussion
awaelchli Jun 18, 2021
9db3ddc
combine _get_num_dataloaders with the property
awaelchli Jun 18, 2021
90da366
Update pytorch_lightning/loops/dataloader/dataloader_loop.py
awaelchli Jun 18, 2021
891a429
Merge remote-tracking branch 'origin/refactor/loops/loops_everywhere_…
awaelchli Jun 18, 2021
e583d6e
black + yapf
awaelchli Jun 18, 2021
c7579ab
avoid coverage on old unused eval loop
awaelchli Jun 18, 2021
7f785f4
empty space in docstring
awaelchli Jun 18, 2021
89ba6fc
resolve todo for args forwarding
awaelchli Jun 18, 2021
f34b68f
Merge branch 'refactor/loops/loops_everywhere_eval' into refactor/loo…
awaelchli Jun 18, 2021
d4e6969
weekproxy trainer
awaelchli Jun 18, 2021
e95f36d
fix check for num dataloaders kwargs
awaelchli Jun 18, 2021
deb5bb2
clean up num prediction dataloaders property
awaelchli Jun 18, 2021
e806bfc
free memory
awaelchli Jun 18, 2021
ceb7d3f
Merge branch 'master' into refactor/loops/loops_everywhere
awaelchli Jun 18, 2021
90e4f9a
rm notebooks folder
awaelchli Jun 18, 2021
238b830
rm old file
awaelchli Jun 18, 2021
c6c580a
revert changes to old eval loop
awaelchli Jun 18, 2021
f4e13f3
bad merge
awaelchli Jun 18, 2021
46dba5a
undo teardown
awaelchli Jun 18, 2021
e7f42d8
setup signature
awaelchli Jun 18, 2021
0b8e0cf
remove file for notes
awaelchli Jun 18, 2021
be8e303
free memory
awaelchli Jun 18, 2021
e061bad
chlog
awaelchli Jun 18, 2021
ce1d1e7
Merge branch 'master' into refactor/loops/loops_everywhere
awaelchli Jun 21, 2021
4910b61
Revert "weekproxy trainer"
awaelchli Jun 18, 2021
d472357
connect trainer
awaelchli Jun 21, 2021
27914dc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 21, 2021
49d2ada
clean up max batches and dataloaders
awaelchli Jun 21, 2021
126b136
max batches handling
awaelchli Jun 21, 2021
1f34677
no grad handling
awaelchli Jun 21, 2021
45b1f4a
Merge remote-tracking branch 'origin/refactor/loops_everywhere' into …
awaelchli Jun 21, 2021
f219fdb
unused argument
awaelchli Jun 21, 2021
e0463eb
protected attrs
awaelchli Jun 21, 2021
06dcd39
unused imports
awaelchli Jun 21, 2021
d7d88e7
undo unintentional rename
awaelchli Jun 21, 2021
9094ae4
consistent naming
awaelchli Jun 21, 2021
0d47b93
Merge branch 'master' into refactor/loops/loops_everywhere
awaelchli Jun 22, 2021
9f26850
capitalization in docstring
awaelchli Jun 22, 2021
6f06f3b
Merge branch 'master' into refactor/loops/loops_everywhere
awaelchli Jun 22, 2021
2790175
Merge branch 'master' into refactor/loops/loops_everywhere
awaelchli Jun 22, 2021
2639c78
list all args
awaelchli Jun 22, 2021
3a094de
Update pytorch_lightning/loops/prediction_epoch_loop.py
awaelchli Jun 22, 2021
dddfe80
Update pytorch_lightning/loops/prediction_epoch_loop.py
awaelchli Jun 22, 2021
36faf76
Update pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py
awaelchli Jun 22, 2021
286982b
Update pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py
awaelchli Jun 22, 2021
935dea2
Update pytorch_lightning/loops/prediction_epoch_loop.py
awaelchli Jun 23, 2021
e707aff
Merge branch 'master' into refactor/loops_everywhere
justusschock Jun 23, 2021
5ca00bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Refactored evaluation loop interface; added new classes `DataLoaderLoop`, `EvaluationDataLoaderLoop`, `EvaluationEpochLoop` ([#7990](https://github.com/PyTorchLightning/pytorch-lightning/pull/7990))
* Removed `pytorch_lightning/trainer/evaluation_loop.py` ([#8056](https://github.com/PyTorchLightning/pytorch-lightning/pull/8056))
* Refactored trainer `_run_*` functions and separate evaluation loops ([#8065](https://github.com/PyTorchLightning/pytorch-lightning/pull/8065))
* Refactored prediction loop interface; added new classes `PredictionDataLoaderLoop`, `PredictionEpochLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700))


- Refactored logging
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/prediction_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def on_predict_batch_end(
if not self.interval.on_batch:
return
is_distributed = trainer.accelerator_connector.is_distributed
batch_indices = trainer.predict_loop.batch_indices if is_distributed else None
batch_indices = trainer.predict_loop.epoch_loop.current_batch_indices if is_distributed else None
self.write_on_batch_end(trainer, pl_module, outputs, batch_indices, batch, batch_idx, dataloader_idx)

def on_predict_epoch_end(
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from abc import ABC, abstractmethod
from typing import Any, Optional
from weakref import proxy

from deprecate import void

Expand Down Expand Up @@ -59,7 +58,8 @@ def skip(self) -> bool:

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects Loop with all the necessary things like connectors and accelerators."""
self.trainer = proxy(trainer)
# TODO(@justusschock): Make the trainer a weakref/proxy
self.trainer = trainer
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

def on_skip(self) -> Optional[Any]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def predictions(self):
def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
"""Connects the loop to everything necessary (like trainer and accelerators)"""
super().connect(trainer, *args, **kwargs)
# TODO: Make the trainer a weakref/proxy
self.epoch_loop.connect(trainer)

@property
Expand Down
148 changes: 148 additions & 0 deletions pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from typing import Any, List, Optional, Sequence, Union

from deprecate.utils import void
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop
from pytorch_lightning.loops.prediction_epoch_loop import PredictionEpochLoop
from pytorch_lightning.plugins import DDPSpawnPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _PREDICT_OUTPUT


class PredictionDataLoaderLoop(DataLoaderLoop):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""Loop to run over dataloaders for prediction"""

def __init__(self):
super().__init__()
self.epoch_loop: PredictionEpochLoop = PredictionEpochLoop()
self.predictions: Optional[List[List[Any]]] = None
self.epoch_batch_indices: Optional[List[List[int]]] = None
self._return_predictions: bool = False

@property
def return_predictions(self) -> bool:
"""Whether to return the predictions or not"""
return self._return_predictions

@return_predictions.setter
def return_predictions(self, return_predictions: Optional[bool] = None) -> None:
# ``DDPSpawnPlugin`` plugins and derivate don't support return predictions.
is_ddp_spawn = isinstance(self.trainer.training_type_plugin, DDPSpawnPlugin)
if return_predictions and is_ddp_spawn:
raise MisconfigurationException(
"`return_predictions` should be set to `False` when using the `DDPSpawnPlugin` or children class. "
f"Found {return_predictions} with training_type_plugin {type(self.trainer.training_type_plugin)}."
)
# For non ``DDPSpawnPlugin`` plugin, the `return_predictions` is True by default unless user decide otherwise.
self._return_predictions = not is_ddp_spawn if return_predictions is None else return_predictions

@property
def num_dataloaders(self) -> int:
"""Returns the number of prediction dataloaders"""
# case where user does:
# return dl1, dl2
dataloaders = self.dataloaders
length = len(dataloaders)
if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)):
length = len(dataloaders[0])
return length

@property
def max_batches(self) -> List[int]:
"""The max number of batches this loop will run for each dataloader."""
max_batches = self.trainer.num_predict_batches
if isinstance(max_batches, int):
max_batches = [max_batches] * len(self.dataloaders)
return max_batches

@property
def dataloaders(self) -> Sequence[DataLoader]:
"""Returns all prediction dataloaders"""
return self.trainer.predict_dataloaders

@property
def done(self) -> bool:
"""Whether prediction is finished: Max batches run or all dataloaders processed"""
return self.current_dataloader_idx >= len(self.dataloaders)

@property
def skip(self) -> bool:
return sum(self.max_batches) == 0

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects the loop with all necessary things (like trainer)"""
super().connect(trainer, *args, **kwargs)
self.epoch_loop.connect(trainer, *args, **kwargs)

def reset(self) -> None:
"""Resets the internal state of the loop for a new run"""
super().reset()
self.predictions = []
self.epoch_batch_indices = []

def on_run_start(self) -> None:
"""Calls ``on_predict_start`` hook"""
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
self.on_predict_start()

def advance(self, *args: Any, **kwargs: Any) -> None:
"""Predicts one entire dataloader"""
void(*args, **kwargs)
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
dataloader_iter = enumerate(dataloader)
dl_max_batches = self.max_batches[self.current_dataloader_idx]

dl_predictions, dl_batch_indices = self.epoch_loop.run(
dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders, self.return_predictions
)
self.predictions.append(dl_predictions)
self.epoch_batch_indices.append(dl_batch_indices)

def on_run_end(self) -> Union[List[Any], List[List[Any]]]:
"""Calls ``on_predict_epoch_end`` and ``on_predict_end`` hooks and returns results from all dataloaders"""
results = self.on_predict_epoch_end()
self.on_predict_end()
return results

def on_predict_start(self) -> None:
"""
Sets model to eval mode and disables gradients. Also calls ``on_predict_start`` and
``on_predict_epoch_start`` hooks.
"""
# enable eval mode + no grads
self.on_predict_model_eval()
self.trainer.lightning_module.zero_grad()

# hook
self.trainer.call_hook("on_predict_start")
self.trainer.call_hook("on_predict_epoch_start")

def on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
"""Calls ``on_predict_epoch_end`` hook.

Returns:
the results for all dataloaders
"""
self.trainer.profiler.describe()

results = self.predictions

self.trainer.call_hook("on_predict_epoch_end", results)

if self.return_predictions:
return results[0] if self.num_dataloaders == 1 else results

def on_predict_end(self) -> None:
"""Resets previous gradient status and calls ``on_predict_end`` hook"""
# clear memory. the predictions are extracted in `on_predict_epoch_end`.
self.predictions = []
self.epoch_batch_indices = []

# hook
self.trainer.call_hook("on_predict_end")

def on_predict_model_eval(self):
"""Calls ``on_predict_model_eval`` hook"""
model_ref = self.trainer.lightning_module
model_ref.on_predict_model_eval()
6 changes: 1 addition & 5 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from contextlib import suppress
from typing import Any, List, Optional, Tuple

from deprecate import void
from torch.optim import Optimizer

import pytorch_lightning as pl
Expand Down Expand Up @@ -167,10 +166,7 @@ def skip(self) -> bool:

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects the loop with necessary arguments like the trainer"""
# TODO(@justusschock): Do we want to forward *args and **kwargs to the inner loop here?
# TODO(@justusschock): Can we make the trainer a weakref/proxy?
void(*args, **kwargs)
self.trainer = trainer
super().connect(trainer, *args, **kwargs)
self.training_loop.connect(trainer)
self.validation_loop.connect(trainer)

Expand Down
151 changes: 151 additions & 0 deletions pytorch_lightning/loops/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from collections import OrderedDict
from typing import Any, Dict, Iterator, List, Optional, Tuple

from deprecate import void

from pytorch_lightning.loops.base import Loop
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.utilities.warnings import WarningCache


class PredictionEpochLoop(Loop):
"""Loop performing prediction on arbitrary sequentially used dataloaders."""

def __init__(self) -> None:
super().__init__()
self.return_predictions: bool = False
self.predictions: List[Any] = []
self.current_batch_indices: List[int] = []
self._dl_max_batches: Optional[int] = None
self._num_dataloaders: Optional[int] = None
self._warning_cache = WarningCache()
self._all_batch_indices: List[int] = []

@property
def done(self) -> bool:
"""Ends prediction when the iteration count exceeds the total number of available batches"""
return self.iteration_count >= self._dl_max_batches

@property
def should_store_predictions(self) -> bool:
"""Whether the predictions should be stored for later usage (e.g. aggregation or returning)"""
any_pred = any(cb.interval.on_epoch for cb in self.trainer.prediction_writer_callbacks)
return self.return_predictions or any_pred

def reset(self) -> None:
"""Resets the loops internal state"""
self.iteration_count = 0
self._all_batch_indices: List[int] = []
self.predictions: List[Any] = []

def on_run_start(
self,
dataloader_iter: Iterator,
dataloader_idx: int,
dl_max_batches: int,
num_dataloaders: int,
return_predictions: bool = False
) -> None:
"""
Prepares the loops internal state

Args:
dataloader_iter: the iterator over the current dataloader
dataloader_idx: the index of the current dataloader
dl_max_batches: the maximum number of batches the current loader can produce
num_dataloaders: the total number of dataloaders
return_predictions: whether to return the obtained predictions
"""
void(dataloader_iter, dataloader_idx)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self._dl_max_batches = dl_max_batches
self._num_dataloaders = num_dataloaders
self.return_predictions = return_predictions

def advance(
self,
dataloader_iter: Iterator,
dataloader_idx: int,
dl_max_batches: int,
num_dataloaders: int,
return_predictions: bool = False
) -> None:
"""
Runs one prediction step.

Args:
dataloader_iter: the iterator over the current dataloader
dataloader_idx: the index of the current dataloader
dl_max_batches: the maximum number of batches the current loader can produce
num_dataloaders: the total number of dataloaders
return_predictions: whether to return the obtained predictions
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""
batch_idx, batch = next(dataloader_iter)
if batch is None:
raise StopIteration

with self.trainer.profiler.profile("predict_step"):
self._predict_step(batch, batch_idx, dataloader_idx)

def on_run_end(self) -> Tuple[Any, Any]:
"""Returns the predictions and the corresponding batch indices"""
return self.predictions, self._all_batch_indices

def teardown(self) -> None:
"""Frees memory of collected predictions."""
self.predictions = []
self._all_batch_indices = []

def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
"""Runs the actual predict step together with all the
necessary bookkeeping and the hooks tied to the predict step.

Args:
batch: the current batch to run the prediction on
batch_idx: the index of the current batch
dataloader_idx: the index of the dataloader producing the current batch
"""
# configure step_kwargs
step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx)

# extract batch_indices and store them
self._store_batch_indices(dataloader_idx)

model_ref = self.trainer.lightning_module

self.trainer.call_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx)

model_ref._current_fx_name = "predict_step"
predictions = self.trainer.accelerator.predict_step(step_kwargs)

if predictions is None:
self._warning_cache.warn("predict returned None if it was on purpose, ignore this warning...")

self.trainer.call_hook("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx)

if self.should_store_predictions:
self.predictions.append(predictions)

def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Any]:
"""
Assembles the keyword arguments for the ``predict_step``

Args:
batch: the current batch to run the prediction on
batch_idx: the index of the current batch
dataloader_idx: the index of the dataloader producing the current batch

Returns:
the dictionary containing all the keyboard arguments for the predict step
"""
step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)])
if self._num_dataloaders > 1:
step_kwargs['dataloader_idx'] = dataloader_idx
return step_kwargs

def _store_batch_indices(self, dataloader_idx: int) -> None:
"""Stores the batch indices if the predictions should be stored"""
batch_sampler = self.trainer.predict_dataloaders[dataloader_idx].batch_sampler
if isinstance(batch_sampler, IndexBatchSamplerWrapper):
self.current_batch_indices = batch_sampler.batch_indices
if self.should_store_predictions:
self._all_batch_indices.append(batch_sampler.batch_indices)
11 changes: 3 additions & 8 deletions pytorch_lightning/loops/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from torch import Tensor
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.plugins import ParallelPlugin
Expand Down Expand Up @@ -67,11 +66,6 @@ def optimizer_freq_cumsum(self) -> int:
self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies)
return self._optimizer_freq_cumsum

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
# TODO(@justusschock): can we make this a weakref/proxy?
void(*args, **kwargs)
self.trainer = trainer

def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict:
"""Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks

Expand All @@ -96,8 +90,9 @@ def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict:
return AttributeDict(signal=-1)

super().run(batch, batch_idx, dataloader_idx)

return AttributeDict(signal=0, training_step_output=self.batch_outputs)
output = AttributeDict(signal=0, training_step_output=self.batch_outputs)
self.batch_outputs = None # free memory
return output

def reset(self) -> None:
"""Resets the loop state"""
Expand Down
Loading