Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add progress tracking on Loops - 2/n #8362

Merged
merged 112 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from 111 commits
Commits
Show all changes
112 commits
Select commit Hold shift + click to select a range
fe55d6e
resolve issues
tchaton Jul 8, 2021
4ee4a73
update
tchaton Jul 8, 2021
1291418
update
tchaton Jul 8, 2021
0e69bea
update
tchaton Jul 8, 2021
368e179
add more exceptions
tchaton Jul 8, 2021
eb4475c
resolve bug
tchaton Jul 8, 2021
449ca62
update
tchaton Jul 8, 2021
cdf38f0
update
tchaton Jul 8, 2021
88bafaf
update changelog
tchaton Jul 8, 2021
0981e94
resolve bug
tchaton Jul 8, 2021
7f8000f
resolve comments
tchaton Jul 10, 2021
4153b81
update
tchaton Jul 10, 2021
d6280e0
update
tchaton Jul 10, 2021
c499c24
update changelog
tchaton Jul 10, 2021
3cb6df2
update
tchaton Jul 10, 2021
e8c12e9
update
tchaton Jul 10, 2021
df4b1ba
remove space
tchaton Jul 10, 2021
ee8d9b8
update
tchaton Jul 10, 2021
65540a8
add progress tracking to loops
tchaton Jul 10, 2021
22fa5fb
validate json
tchaton Jul 10, 2021
6d45fe2
update
tchaton Jul 10, 2021
71d01d6
convert to dict for better readability
tchaton Jul 10, 2021
1c6c566
validate reload
tchaton Jul 10, 2021
bc49cc7
update
tchaton Jul 10, 2021
0a0b5e3
update
tchaton Jul 10, 2021
45fb657
update on comments
tchaton Jul 12, 2021
335caa7
Merge branch 'master' into add_progress_tracking_on_loops
tchaton Jul 12, 2021
65821c9
remove deadcode
tchaton Jul 12, 2021
d0492b5
clean changelog
tchaton Jul 12, 2021
462b357
clean changelog
tchaton Jul 12, 2021
8c0426b
update
tchaton Jul 12, 2021
b7c4113
update on comments
tchaton Jul 12, 2021
7e0456b
CHANGELOG
carmocca Jul 12, 2021
c266532
CHANGELOG
carmocca Jul 12, 2021
30ddd10
Update pytorch_lightning/loops/base.py
tchaton Jul 12, 2021
ffc6ca7
whitespace suggestions
awaelchli Jul 12, 2021
9ac0b61
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
8ddb020
make fault_tolerant_enabled protected
awaelchli Jul 12, 2021
50b6f49
whitespace fixes around Args
awaelchli Jul 12, 2021
2133355
Merge remote-tracking branch 'origin/add_progress_tracking_on_loops' …
awaelchli Jul 12, 2021
8e9682e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
9a7f0a0
Merge branch 'master' into add_progress_tracking_on_loops
justusschock Jul 12, 2021
0838d7a
update
tchaton Jul 12, 2021
8204cb7
Merge branch 'add_progress_tracking_on_loops' of https://github.com/P…
tchaton Jul 12, 2021
107e143
typo it's -> its
awaelchli Jul 12, 2021
e49cd50
fix copy-paste typo in progress docstring
awaelchli Jul 12, 2021
2e0423a
Delete classes
carmocca Jul 13, 2021
7caca87
Minor change
carmocca Jul 13, 2021
2800eae
docs
carmocca Jul 13, 2021
feec34f
protected get_loops_state
awaelchli Jul 13, 2021
ccdd09d
merge restore_loops with restore_progress
awaelchli Jul 13, 2021
01768cb
Fix tests after removals
carmocca Jul 13, 2021
71e05d3
explicit save with trainer.save_checkpoint()
awaelchli Jul 14, 2021
3d13b64
handle optimization restart based on optimizer_idx
awaelchli Jul 14, 2021
78d13e2
update increments
awaelchli Jul 14, 2021
1048259
update val batch progress and remove iteration count
awaelchli Jul 14, 2021
668a4cf
update progress tracking for dataloader loops
awaelchli Jul 14, 2021
ad8b342
remove self.dataloader_idx from eval_epoch_loop
awaelchli Jul 14, 2021
512ee0d
add batch progress to predict loop
awaelchli Jul 14, 2021
2633d51
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 14, 2021
4bbc7ac
incorporate progress tracking for current_epoch
awaelchli Jul 14, 2021
01f8714
Fix test
carmocca Jul 14, 2021
65405b8
Actually remove it
carmocca Jul 14, 2021
6dd2182
Remove unused TrainingEpochProgress
carmocca Jul 14, 2021
b71e151
Fix optimization progress - missing scheduler
carmocca Jul 14, 2021
e5a392a
Restarting changes
carmocca Jul 14, 2021
49c5112
Scheduler progress
carmocca Jul 14, 2021
018da6a
Unused property, reset on epoch
carmocca Jul 14, 2021
0b1834c
Resolve FIXME
carmocca Jul 14, 2021
d7bcafa
Remove FIXME
carmocca Jul 14, 2021
e794fbe
fix test_progress (wip)
awaelchli Jul 14, 2021
c98bd29
fix batch_progress.current.reset
awaelchli Jul 14, 2021
f90334c
Hold off on split progress. Out of scope of this PR
carmocca Jul 14, 2021
7fb78de
Unnecessary if
carmocca Jul 14, 2021
8130a47
fix structure in test_progress
awaelchli Jul 14, 2021
b6b9ea4
structure
awaelchli Jul 14, 2021
4780b19
clean up unused variables in test_progress
awaelchli Jul 14, 2021
7eee718
refactor naming and organization in test_progress
awaelchli Jul 14, 2021
a1bd989
Unnecessary variable
carmocca Jul 14, 2021
f6d3a5f
Remove unnecessary diff
carmocca Jul 14, 2021
d57bddf
Improve comment
carmocca Jul 14, 2021
099edd0
Undo typing change to avoid polluting everything with mypy fixes
carmocca Jul 14, 2021
9145c82
Fix and improve test_loops.py
carmocca Jul 14, 2021
b0fc845
Fix and organize `test_loop_state_dict`
carmocca Jul 14, 2021
1577aa8
Remove unnecessary checks in test
carmocca Jul 14, 2021
1f3ae63
Update test after disallowing updates on None attributes
carmocca Jul 14, 2021
ad8224c
Typing
carmocca Jul 15, 2021
403ea9d
Minor test cleanup
carmocca Jul 15, 2021
6492cde
Fix and move loop test
carmocca Jul 15, 2021
bc5544d
Move test from progress to loops
carmocca Jul 15, 2021
098c7b5
Reset the scheduler progress
carmocca Jul 15, 2021
ef7c9e0
SchedulerProgress fix
carmocca Jul 15, 2021
7938403
Consistent whitespace
carmocca Jul 15, 2021
7799101
Fix final test
carmocca Jul 15, 2021
a375607
Minor test changes
carmocca Jul 15, 2021
dc30c4c
Merge branch 'master' into add_progress_tracking_on_loops
tchaton Jul 15, 2021
abb08a0
One test to rule them all
carmocca Jul 15, 2021
fc18c16
Formatting
carmocca Jul 15, 2021
e550e6d
Rename and clean variables
carmocca Jul 15, 2021
01a8a45
Shorter names
carmocca Jul 15, 2021
1a6c2a1
Shorter scheduler name
carmocca Jul 15, 2021
e1906b7
Fix optimizer step calculation for stop_batch=2
carmocca Jul 15, 2021
2951700
Merge branch 'master' into add_progress_tracking_on_loops
carmocca Jul 15, 2021
5eaf5b3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2021
29ce552
Remove empty connects
carmocca Jul 15, 2021
3984578
Update CHANGELOG
carmocca Jul 15, 2021
70a9bca
Holy shit finally got the formula right
carmocca Jul 15, 2021
ae94d7a
Fix final thing!!!
carmocca Jul 16, 2021
83b3dd6
Do not check state dicts
carmocca Jul 16, 2021
5af9730
parametrize multiple_dataloader progress test
awaelchli Jul 16, 2021
d1a8bc0
Update CHANGELOG.md
awaelchli Jul 16, 2021
b510a96
resolve flake8
tchaton Jul 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


- Progress tracking
* Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603), [#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574), [#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140))
* Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603), [#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574), [#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140), [#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))
* Add `{,load_}state_dict` to the progress tracking dataclasses ([#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140))
* Connect the progress tracking dataclasses to the loops ([#8244](https://github.com/PyTorchLightning/pytorch-lightning/pull/8244))
* Connect the progress tracking dataclasses to the loops ([#8244](https://github.com/PyTorchLightning/pytorch-lightning/pull/8244), [#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))


- Added support for passing a `LightningDataModule` positionally as the second argument to `trainer.{validate,test,predict}` ([#7431](https://github.com/PyTorchLightning/pytorch-lightning/pull/7431))
Expand Down Expand Up @@ -92,6 +92,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fault-tolerant training
* Added `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948))
* Added `{,load_}state_dict` to `Loops` ([#8197](https://github.com/PyTorchLightning/pytorch-lightning/pull/8197))
* Set `Loop.restarting=False` at the end of the first iteration ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))
* Save the loops state with the checkpoint (opt-in) ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))
* Save a checkpoint to restore the state on exception (opt-in) ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))


- Added `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966))
Expand Down Expand Up @@ -393,8 +396,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated `optimizer` argument in `LightningModule.manual_backward()`; Toggling optimizers in manual optimization should be done using `LightningModule.{un}toggle_optimizer()` ([#8287](https://github.com/PyTorchLightning/pytorch-lightning/pull/8287))




### Fixed

- Fixed `lr_scheduler` checkpointed state by calling `update_lr_schedulers` before saving checkpoints ([#7877](https://github.com/PyTorchLightning/pytorch-lightning/pull/7877))
Expand Down
12 changes: 7 additions & 5 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Loop(ABC):
"""

def __init__(self) -> None:
# TODO: replace by progress tracking
self.iteration_count: int = 0
self.restarting = False
self._trainer: Optional['pl.Trainer'] = None
Expand All @@ -56,8 +57,8 @@ def trainer(self) -> Optional['pl.Trainer']:

@trainer.setter
def trainer(self, trainer: 'pl.Trainer'):
"""Connect the Trainer to this loop and all children."""
if not isinstance(trainer, pl.Trainer) and trainer is not None:
"""Connects this loop's trainer and its children"""
if not isinstance(trainer, pl.Trainer):
raise MisconfigurationException(
f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}."
)
Expand Down Expand Up @@ -112,6 +113,7 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
self.advance(*args, **kwargs)
self.on_advance_end()
self.iteration_count += 1
self.restarting = False
except StopIteration:
break

Expand Down Expand Up @@ -158,7 +160,7 @@ def on_save_checkpoint(self) -> Dict:
"""
return {}

def on_load_checkpoint(self, state_dict: Dict):
def on_load_checkpoint(self, state_dict: Dict) -> None:
"""Called when loading a model checkpoint, use to reload loop state."""

def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] = "") -> Dict:
Expand All @@ -183,14 +185,14 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] =

return destination

def load_state_dict(self, state_dict: Dict, prefix="", restart_progress: bool = True):
def load_state_dict(self, state_dict: Dict, prefix: str = "", restart_progress: bool = True) -> None:
""" Loads the state of this loop and all its children. """
self._load_from_state_dict(state_dict.copy(), prefix, restart_progress)
for k, v in self.__dict__.items():
if isinstance(v, Loop):
v.load_state_dict(state_dict.copy(), prefix + k + ".", restart_progress)

def _load_from_state_dict(self, state_dict, prefix, restart_progress):
def _load_from_state_dict(self, state_dict: Dict, prefix: str, restart_progress: bool) -> None:
for k, v in self.__dict__.items():
if isinstance(v, BaseProgress):
v.load_state_dict(state_dict[prefix + k])
Expand Down
35 changes: 17 additions & 18 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@
from torch import Tensor
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import BatchProgress, OptimizationProgress
from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand All @@ -50,7 +49,6 @@ def __init__(self) -> None:
self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=20)
self.batch_idx: int = 0
self.split_idx: Optional[int] = None
self.progress = BatchProgress()
self.optim_progress = OptimizationProgress()

awaelchli marked this conversation as resolved.
Show resolved Hide resolved
self._warning_cache: WarningCache = WarningCache()
Expand All @@ -59,21 +57,6 @@ def __init__(self) -> None:
self._remaining_splits: Optional[List[Any]] = None
self._skip_backward: bool = False

def connect(
self,
trainer: 'pl.Trainer',
*args: Any,
progress: Optional[BatchProgress] = None,
optim_progress: Optional[OptimizationProgress] = None,
**kwargs: Any
) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
if progress is not None:
self.progress = progress
if optim_progress is not None:
self.optim_progress = optim_progress

@property
def done(self) -> bool:
"""Returns if all batch splits have been processed already"""
Expand Down Expand Up @@ -109,6 +92,8 @@ def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict:
if response == -1:
return AttributeDict(signal=-1)

self.trainer.fit_loop.epoch_loop.batch_progress.increment_started()

super().run(batch, batch_idx, dataloader_idx)
output = AttributeDict(signal=0, training_step_output=self.batch_outputs)
self.batch_outputs = None # free memory
Expand Down Expand Up @@ -149,6 +134,13 @@ def advance(self, batch, batch_idx, dataloader_idx):

if self.trainer.lightning_module.automatic_optimization:
for opt_idx, optimizer in self.get_active_optimizers(batch_idx):
# handle optimization restart
if self.restarting:
if opt_idx < self.optim_progress.optimizer_idx:
continue

self.optim_progress.optimizer_idx = opt_idx

result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer)
if result:
self.batch_outputs[opt_idx].append(result.training_step_output)
Expand Down Expand Up @@ -395,6 +387,8 @@ def _optimizer_step(
# wraps into LightningOptimizer only for running step
optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx)

self.optim_progress.optimizer.step.increment_ready()

# model hook
model_ref.optimizer_step(
self.trainer.current_epoch,
Expand All @@ -407,13 +401,17 @@ def _optimizer_step(
using_lbfgs=is_lbfgs,
)

tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.optim_progress.optimizer.step.increment_completed()

def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
"""Calls the ``on_before_zero_grad`` hook.

Args:
optimizer: the current optimizer
"""
self.optim_progress.optimizer.zero_grad.increment_ready()
self.trainer.call_hook('on_before_zero_grad', optimizer)
self.optim_progress.optimizer.zero_grad.increment_started()

def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None:
"""Zeroes out all gradients of parameters optimized by the current optimizer.
Expand All @@ -424,6 +422,7 @@ def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer,
opt_idx: the index of the current optimizer
"""
self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
self.optim_progress.optimizer.zero_grad.increment_completed()

def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Tensor]:
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.
Expand Down
20 changes: 16 additions & 4 deletions pytorch_lightning/loops/dataloader/dataloader_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,21 @@
# limitations under the License.

from abc import abstractmethod
from typing import Sequence
from typing import Any, Sequence

from torch.utils.data import DataLoader

from pytorch_lightning.loops.base import Loop
from pytorch_lightning.trainer.progress import DataLoaderProgress


class DataLoaderLoop(Loop):
"""Base class to loop over all dataloaders"""

def __init__(self):
super().__init__()
self.dataloader_progress = DataLoaderProgress()

@property
@abstractmethod
def dataloaders(self) -> Sequence[DataLoader]:
Expand All @@ -31,7 +36,7 @@ def dataloaders(self) -> Sequence[DataLoader]:
@property
def current_dataloader_idx(self) -> int:
"""Returns the index of the current dataloader"""
return self.iteration_count
return self.dataloader_progress.current.ready - 1

@property
def current_dataloader(self) -> DataLoader:
Expand All @@ -46,8 +51,15 @@ def num_dataloaders(self) -> int:
@property
def done(self) -> bool:
"""Returns whether all dataloaders have been processed"""
return self.current_dataloader_idx >= self.num_dataloaders
return self.dataloader_progress.current.completed >= self.num_dataloaders

def reset(self) -> None:
"""Resets the internal state"""
self.iteration_count = 0
if not self.restarting:
self.dataloader_progress.current.reset()

def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
self.dataloader_progress.increment_ready()

def on_advance_end(self) -> None:
self.dataloader_progress.increment_completed()
16 changes: 5 additions & 11 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import EpochLoopProgress
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
Expand All @@ -33,8 +32,6 @@ class EvaluationLoop(DataLoaderLoop):
def __init__(self):
super().__init__()
self.outputs = []
self.progress = EpochLoopProgress()

self.epoch_loop = EvaluationEpochLoop()

awaelchli marked this conversation as resolved.
Show resolved Hide resolved
self._results = ResultCollection(training=False)
Expand Down Expand Up @@ -66,19 +63,15 @@ def predictions(self):
"""Returns the predictions from all dataloaders"""
return self.epoch_loop.predictions

def connect(
self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any
) -> None:
def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
if progress is not None:
self.progress = progress
self.epoch_loop.connect(trainer, progress=self.progress.epoch)
self.epoch_loop.connect(trainer)

@property
def done(self) -> bool:
"""Returns whether all dataloaders are processed or evaluation should be skipped altogether"""
return (self.current_dataloader_idx >= len(self.dataloaders)) or self.skip
return super().done or self.skip

@property
def skip(self) -> bool:
Expand All @@ -88,14 +81,15 @@ def skip(self) -> bool:

def reset(self) -> None:
"""Resets the internal state of the loop"""
self.iteration_count = 0
self._max_batches = self.get_max_batches()
# bookkeeping
self.outputs = []

if isinstance(self._max_batches, int):
self._max_batches = [self._max_batches] * len(self.dataloaders)

super().reset()

def on_skip(self) -> List:
return []

Expand Down
16 changes: 2 additions & 14 deletions pytorch_lightning/loops/dataloader/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop
from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop
from pytorch_lightning.plugins import DDPSpawnPlugin
from pytorch_lightning.trainer.progress import EpochLoopProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _PREDICT_OUTPUT

Expand All @@ -19,8 +18,6 @@ def __init__(self):
super().__init__()
self.predictions: Optional[List[List[Any]]] = None
self.epoch_batch_indices: Optional[List[List[int]]] = None
self.progress = EpochLoopProgress()

self.epoch_loop = PredictionEpochLoop()

awaelchli marked this conversation as resolved.
Show resolved Hide resolved
self._results = None # for `trainer._results` access
Expand Down Expand Up @@ -67,23 +64,14 @@ def dataloaders(self) -> Sequence[DataLoader]:
"""Returns all prediction dataloaders"""
return self.trainer.predict_dataloaders

@property
def done(self) -> bool:
"""Whether prediction is finished: Max batches run or all dataloaders processed"""
return self.current_dataloader_idx >= len(self.dataloaders)

@property
def skip(self) -> bool:
return sum(self.max_batches) == 0

def connect(
self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any
) -> None:
def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
if progress is not None:
self.progress = progress
self.epoch_loop.connect(trainer, progress=self.progress.epoch)
self.epoch_loop.connect(trainer)

def reset(self) -> None:
"""Resets the internal state of the loop for a new run"""
Expand Down
Loading