Skip to content

Commit

Permalink
Mark evaluation epoch loops attributes as protected (#8420)
Browse files Browse the repository at this point in the history
* Mark evaluation epoch loops attributes as protected

* Fix pre-commit
  • Loading branch information
carmocca authored Jul 15, 2021
1 parent 7d1f4ce commit 176df20
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 18 deletions.
28 changes: 12 additions & 16 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ def __init__(self) -> None:
super().__init__()
self.predictions: Optional[PredictionCollection] = None
self.dataloader: Optional[Iterator] = None
self.dl_max_batches: Optional[int] = None
self.dataloader_idx: Optional[int] = None
self.num_dataloaders: Optional[int] = None
self._dl_max_batches: Optional[int] = None
self._num_dataloaders: Optional[int] = None
self.outputs: List[STEP_OUTPUT] = []
self.progress = EpochProgress()

Expand All @@ -54,15 +53,14 @@ def connect(
@property
def done(self) -> bool:
"""Returns ``True`` if the current iteration count reaches the number of dataloader batches."""
return self.iteration_count >= self.dl_max_batches
return self.iteration_count >= self._dl_max_batches

def reset(self) -> None:
"""Resets the loop's internal state."""
self.iteration_count = 0
self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size)
self.dl_max_batches = None
self.dataloader_idx = None
self.num_dataloaders = None
self._dl_max_batches = None
self._num_dataloaders = None
self.outputs = []

def on_run_start(
Expand All @@ -80,11 +78,9 @@ def on_run_start(
dl_max_batches: maximum number of batches the dataloader can produce
num_dataloaders: the total number of dataloaders
"""
void(dataloader_iter)

self.dl_max_batches = dl_max_batches
self.dataloader_idx = dataloader_idx
self.num_dataloaders = num_dataloaders
void(dataloader_iter, dataloader_idx)
self._dl_max_batches = dl_max_batches
self._num_dataloaders = num_dataloaders

def advance(
self,
Expand Down Expand Up @@ -182,8 +178,8 @@ def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx:
"""
self.trainer.logger_connector.on_batch_start()

assert self.num_dataloaders is not None
self.trainer.logger_connector.on_evaluation_batch_start(batch, batch_idx, dataloader_idx, self.num_dataloaders)
assert self._num_dataloaders is not None
self.trainer.logger_connector.on_evaluation_batch_start(batch, batch_idx, dataloader_idx, self._num_dataloaders)

if self.trainer.testing:
self.trainer.call_hook("on_test_batch_start", batch, batch_idx, dataloader_idx)
Expand Down Expand Up @@ -243,8 +239,8 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict
# make dataloader_idx arg in validation_step optional
step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)])

multiple_val_loaders = not self.trainer.testing and self.num_dataloaders > 1
multiple_test_loaders = self.trainer.testing and self.num_dataloaders > 1
multiple_val_loaders = not self.trainer.testing and self._num_dataloaders > 1
multiple_test_loaders = self.trainer.testing and self._num_dataloaders > 1

if multiple_test_loaders or multiple_val_loaders:
step_kwargs["dataloader_idx"] = dataloader_idx
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,10 @@ def select_precision_plugin(self) -> PrecisionPlugin:
"You have asked for native AMP on CPU, but AMP is only available on GPU."
)
if not _NATIVE_AMP_AVAILABLE:
msg = "You have asked for native AMP but your PyTorch version does not support it." \
" Consider upgrading with `pip install torch>=1.6`."
msg = (
"You have asked for native AMP but your PyTorch version does not support it."
" Consider upgrading with `pip install torch>=1.6`."
)
if _APEX_AVAILABLE:
self.amp_type = AMPType.APEX
msg += " We will attempt to use NVIDIA Apex for this session."
Expand Down

0 comments on commit 176df20

Please sign in to comment.