From da81f93295298c9a80670c52778294fc3fbd77c4 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Mon, 26 Apr 2021 14:50:17 +0200 Subject: [PATCH 001/455] loops --- pytorch_lightning/loops/__init__.py | 0 pytorch_lightning/loops/base.py | 55 ++++ pytorch_lightning/loops/epoch_loop.py | 381 ++++++++++++++++++++++++++ pytorch_lightning/trainer/trainer.py | 102 ++++++- 4 files changed, 536 insertions(+), 2 deletions(-) create mode 100644 pytorch_lightning/loops/__init__.py create mode 100644 pytorch_lightning/loops/base.py create mode 100644 pytorch_lightning/loops/epoch_loop.py diff --git a/pytorch_lightning/loops/__init__.py b/pytorch_lightning/loops/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py new file mode 100644 index 0000000000000..93856ac3fcd11 --- /dev/null +++ b/pytorch_lightning/loops/base.py @@ -0,0 +1,55 @@ +from abc import ABCMeta, abstractmethod +from typing import Any, Counter, List + +import pytorch_lightning as pl + + +class Loop(metaclass=ABCMeta): + + def __init__(self): + self.iteration_count: int = 0 + self.trainer: 'pl.Trainer' + + @abstractmethod + def connect(self): + """Connects Loop with all the necessary things like connectors and accelerators""" + + @property + @abstractmethod + def done(self): + """Property indicating when loop is finished""" + + @abstractmethod + def advance(self, *args: Any, **kwargs: Any): + """What to do within a single step""" + + def on_run_start(self, *args: Any, **kwargs: Any): + pass + + def on_run_end(self, outputs: List) -> List: + return outputs + + def on_advance_start(self, *args: Any, **kwargs: Any): + pass + + def on_advance_end(self, curr_output: Any) -> Any: + return curr_output + + def run(self, *args: Any, **kwargs: Any): + self.on_start(*args, **kwargs) + + outputs = [] + + while not self.done: + + self.on_advance_start(*args, **kwargs) + curr_output = self.advance(*args, **kwargs) + curr_output = self.on_advance_end(curr_output) + + outputs.append(curr_output) + + self.iteration_count += 1 + + outputs = self.on_end(outputs) + + return outputs diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py new file mode 100644 index 0000000000000..6af2bf155adb5 --- /dev/null +++ b/pytorch_lightning/loops/epoch_loop.py @@ -0,0 +1,381 @@ +from contextlib import suppress +from copy import deepcopy +from logging import log +from typing import Any, List, Optional + +import pytorch_lightning as pl +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.core.step_result import Result +from pytorch_lightning.loops.base import Loop +from pytorch_lightning.trainer.supporters import TensorRunningAccum +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.parsing import AttributeDict +from pytorch_lightning.utilities.warnings import WarningCache + + +class EpochLoop(Loop): + + def connect( + self, + num_epochs: int, + max_steps: Optional[int], + trainer: 'pl.Trainer', + *loops_to_run: Loop, + ): + self.num_epochs = num_epochs + self.max_steps = max_steps + self.trainer = trainer + for loop in loops_to_run: + if isinstance(loop, Loop) or hasattr(loop, 'run'): + self.loops_to_run.append(loop) + + @property + def done(self) -> bool: + stop_steps = self.trainer.max_steps and self.trainer.max_steps <= self.trainer.global_step + + should_stop = False + if self.trainer.should_stop: + # early stopping + met_min_epochs = (self.iteration_count >= self.trainer.min_epochs - 1) if self.trainer.min_epochs else True + met_min_steps = self.trainer.global_step >= self.trainer.min_steps if self.trainer.min_steps else True + if met_min_epochs and met_min_steps: + self.train_loop.on_train_end() + should_stop = True + else: + log.info( + 'Trainer was signaled to stop but required minimum epochs' + f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' + ' not been met. Training will continue...' + ) + self.trainer.should_stop = False + + stop_epochs = self.iteration_count >= self.num_epochs + + return stop_steps or should_stop or stop_epochs + + def on_run_start(self): + # hook + self.trainer.call_hook("on_train_start") + + def on_run_end(self): + if self._teardown_already_run: + return + self._teardown_already_run = True + + # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates + # when a checkpoint was saved at the last step + self.trainer.global_step -= 1 + self.check_checkpoint_callback(should_update=True, is_last=True) + self.trainer.global_step += 1 + + # hook + self.trainer.call_hook("on_train_end") + + # todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers. + # It might be related to xla tensors blocked when moving the cpu + # kill loggers + if self.trainer.logger is not None: + self.trainer.logger.finalize("success") + + # summarize profile results + self.trainer.profiler.describe() + + # give accelerators a chance to finish + self.trainer.accelerator.on_train_end() + + # reset bookkeeping + self.trainer._running_stage = None + + def on_advance_start(self): # equal to on train epoch start + # implemented here since this code has to be run always no matter the actual epoch implementation + epoch = self.iteration_count + 1 + + # update training progress in trainer + self.trainer.current_epoch = epoch + + model = self.trainer.lightning_module + + # reset train dataloader + if epoch != 0 and self.trainer.reload_dataloaders_every_epoch: + self.trainer.reset_train_dataloader(model) + + # todo: specify the possible exception + with suppress(Exception): + # set seed for distributed sampler (enables shuffling for each epoch) + self.trainer.train_dataloader.sampler.set_epoch(epoch) + + # changing gradient according accumulation_scheduler + self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) + + # stores accumulated grad fractions per batch + self.accumulated_loss = TensorRunningAccum(window_length=self.trainer.accumulate_grad_batches) + + # hook + self.trainer.call_hook("on_epoch_start") + self.trainer.call_hook("on_train_epoch_start") + + def on_advance_end(self, outputs): + # handle epoch_output on epoch end + self.on_train_epoch_end(outputs) + + # log epoch metrics + self.trainer.logger_connector.log_train_epoch_end_metrics(outputs) + + should_check_val = self.should_check_val_fx(self.trainer.batch_idx, self.trainer.is_last_batch, on_epoch=True) + should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) + should_train_only = self.trainer.disable_validation or should_skip_eval + + # update epoch level lr_schedulers if no val loop outside train loop is triggered + if (val_loop_called and not should_check_val) or should_train_only: + self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + + if should_train_only: + self.check_checkpoint_callback(True) + self.check_early_stopping_callback(True) + + if should_check_val: + self.trainer.validating = True + self.trainer.run_evaluation(on_epoch=True) + self.trainer.training = True + + # increment the global step once + # progress global step according to grads progress + self.increment_accumulated_grad_global_step() + + def advance(self): + ret_vals = [] + with self.trainer.profiler.profile("run_training_epoch"): + # run train epoch + for loop in self.loops_to_run: + ret_vals.append(loop.run()) + + return ret_vals + + +class TrainingLoop(Loop): + + def connect(self, trainer: 'pl.Trainer'): + self.trainer = trainer + self.batch_loop = BatchLoop + + def on_run_start(self): + # modify dataloader if needed (ddp, etc...) + train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) + + self._train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) + self._dataloader_idx = 0 + + def advance(self): + batch_idx, (batch, is_last) = next(self._train_dataloader) + + self.trainer.batch_idx = batch_idx + self.trainer.is_last_batch = is_last + + # ------------------------------------ + # TRAINING_STEP + TRAINING_STEP_END + # ------------------------------------ + with self.trainer.profiler.profile("run_training_batch"): + batch_output = self.run_training_batch(batch, batch_idx, self._dataloader_idx) + + # when returning -1 from train_step, we end epoch early + if batch_output.signal == -1: + self._skip_remaining_steps = True + return + + # hook + # TODO: add outputs to batches + self.on_train_batch_end( + epoch_output, + batch_output.training_step_output_for_epoch_end, + batch, + batch_idx, + self._dataloader_idx, + ) + + def on_advance_end(self, output): + # ----------------------------------------- + # SAVE METRICS TO LOGGERS + # ----------------------------------------- + self.trainer.logger_connector.log_train_step_metrics(output) + + # ----------------------------------------- + # VALIDATE IF NEEDED + CHECKPOINT CALLBACK + # ----------------------------------------- + should_check_val = self.should_check_val_fx(self.trainer.batch_idx, self.trainer.is_last_batch) + if should_check_val: + self.trainer.validating = True + self.trainer.run_evaluation() + self.trainer.training = True + + # ----------------------------------------- + # SAVE LOGGERS (ie: Tensorboard, etc...) + # ----------------------------------------- + self.save_loggers_on_train_batch_end() + + # update LR schedulers + monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics) + self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) + self.trainer.checkpoint_connector.has_trained = True + + # progress global step according to grads progress + self.increment_accumulated_grad_global_step() + + @property + def done(self): + # max steps reached, end training + if ( + self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step + 1 + and self._accumulated_batches_reached() + ): + return True + + # end epoch early + # stop when the flag is changed or we've gone past the amount + # requested in the batches + if self.trainer.should_stop: + return True + + self.trainer.total_batch_idx += 1 + + # stop epoch if we limited the number of training batches + if self._num_training_batches_reached(self.trainer.is_last_batch): + return True + + def on_run_end(self, outputs): + # inform logger the batch loop has finished + self.trainer.logger_connector.on_train_epoch_end() + + # prepare epoch output + processed_outputs = self._prepare_outputs(outputs, batch_mode=False) + + # get the model and call model.training_epoch_end + model = self.trainer.lightning_module + + if is_overridden('training_epoch_end', model=model): + # run training_epoch_end + # refresh the result for custom logging at the epoch level + model._current_fx_name = 'training_epoch_end' + + # lightningmodule hook + training_epoch_end_output = model.training_epoch_end(processed_outputs) + + if training_epoch_end_output is not None: + raise MisconfigurationException( + 'training_epoch_end expects a return of None. ' + 'HINT: remove the return statement in training_epoch_end' + ) + + # capture logging + self.trainer.logger_connector.cache_logged_metrics() + + # call train epoch end hooks + self.trainer.call_hook('on_train_epoch_end', processed_outputs) + self.trainer.call_hook('on_epoch_end') + + # increment the global step once + # progress global step according to grads progress + self.increment_accumulated_grad_global_step() + + +class BatchLoop(Loop): + + def on_run_start(self, batch, batch_idx, dataloader_idx): + self._grad_norm_dic = {} + self.trainer.hiddens = None + self._optimizers = self.prepare_optimizers() + # lightning module hook + self._splits = self.tbptt_split_batch(batch) + + def on_advance_start(self): + return super().on_advance_start() + + def advance(self, *args: Any, **kwargs: Any): + return super().advance(*args, **kwargs) + + def run(self, batch, batch_idx, dataloader_idx): + if batch is None: + return AttributeDict(signal=0, grad_norm_dic={}) + + # hook + response = self.trainer.call_hook("on_batch_start") + if response == -1: + return AttributeDict(signal=-1, grad_norm_dic={}) + + # hook + response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx) + if response == -1: + return AttributeDict(signal=-1, grad_norm_dic={}) + + return super().run(batch, batch_idx, dataloader_idx) + + +def run_training_batch(self, batch, batch_idx, dataloader_idx): + + for split_idx, split_batch in enumerate(splits): + + # create an iterable for optimizers and loop over them + for opt_idx, optimizer in optimizers: + + # toggle model params + set info to logger_connector + self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) + + if self.should_accumulate(): + # For gradient accumulation + + # ------------------- + # calculate loss (train step + train step end) + # ------------------- + + # automatic_optimization=True: perform dpp sync only when performing optimizer_step + # automatic_optimization=False: don't block synchronization here + with self.block_ddp_sync_behaviour(): + self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) + + batch_outputs = self._process_closure_result( + batch_outputs=batch_outputs, + opt_idx=opt_idx, + ) + + # ------------------------------ + # BACKWARD PASS + # ------------------------------ + # gradient update with accumulated gradients + + else: + if self.automatic_optimization: + + def train_step_and_backward_closure(): + result = self.training_step_and_backward( + split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens + ) + return None if result is None else result.loss + + # optimizer step + self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) + + else: + self._curr_step_result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens) + + if self._curr_step_result is None: + # user decided to skip optimization + # make sure to zero grad. + continue + + batch_outputs = self._process_closure_result( + batch_outputs=batch_outputs, + opt_idx=opt_idx, + ) + + # todo: Properly aggregate grad_norm accros opt_idx and split_idx + grad_norm_dic = self._cur_grad_norm_dict + self._cur_grad_norm_dict = None + + # update running loss + reset accumulated loss + self.update_running_loss() + + result = AttributeDict( + signal=0, + grad_norm_dic=grad_norm_dic, + training_step_output_for_epoch_end=batch_outputs, + ) + return result diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2a6a53a7c192c..bff300e8b5880 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -375,7 +375,7 @@ def __init__( truncated_bptt_steps, terminate_on_nan, ) - self.train_loop.on_trainer_init( + self._setup_fit_on_init( max_epochs, min_epochs, max_steps, @@ -412,6 +412,50 @@ def __init__( # Callback system self.on_init_end() + def _setup_on_init( + self, + max_epochs: Optional[int], + min_epochs: Optional[int], + max_steps: Optional[int], + min_steps: Optional[int], + num_sanity_val_steps: int, + ): + self.global_step = 0 + self.current_epoch = 0 + self.should_stop = False + self._state = TrainerState.INITIALIZING + + self.total_batch_idx = 0 + self.batch_idx = 0 + self.num_training_batches = 0 + self.train_dataloader = None + + # If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000 + self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs + # If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1 + self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs + self.max_steps = max_steps + self.min_steps = min_steps + + if num_sanity_val_steps == -1: + self.num_sanity_val_steps = float("inf") + else: + self.num_sanity_val_steps = num_sanity_val_steps + + def _setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): + # clean hparams + if hasattr(model, "hparams"): + parsing.clean_namespace(model.hparams) + + # links data to the trainer + self.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule) + + # check that model is configured correctly + self.config_validator.verify_loop_configurations(model) + + # attach model log function to callback + self.callback_connector.attach_model_logging_functions(model) + def fit( self, model: LightningModule, @@ -834,6 +878,60 @@ def _pre_training_routine(self): ref_model.on_pretrain_routine_end() def run_train(self) -> None: + + new_loop = False + + if new_loop: + self._run_train_new_loop() + else: + self._run_train_old_loop() + + def _should_skip_training(self) -> bool: + should_by_max_steps = self.max_steps is not None and self.global_step >= self.max_steps + should_by_epoch = self.max_epochs is not None and self.current_epoch >= self.max_epochs + return should_by_max_steps or should_by_epoch or self.num_training_batches == 0 + + def _run_train_new_loop(self) -> None: + self._pre_training_routine() + if not self.is_global_zero and self.progress_bar_callback is not None: + self.progress_bar_callback.disable() + + self.run_sanity_check(self.lightning_module) + + self.checkpoint_connector.has_trained = False + + # enable train mode + self.model.train() + torch.set_grad_enabled(True) + + # reload data when needed + model = self.lightning_module + + # This might move somewhere else + self.train_loop.reset_train_val_dataloaders(model) + + try: + if self._should_skip_training(): + return + self.train_loop.run() + except KeyboardInterrupt: + rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...') + # user could press Ctrl+c many times... only shutdown once + if not self.interrupted: + self.state = TrainerState.INTERRUPTED + self.on_keyboard_interrupt() + # same treatment as below + self.accelerator.on_train_end() + self._running_stage = None + except BaseException: + # give accelerators a chance to finish + self.accelerator.on_train_end() + # reset bookkeeping + self._running_stage = None + raise + + def _run_train_old_loop(self) -> None: + self._pre_training_routine() if not self.is_global_zero and self.progress_bar_callback is not None: @@ -855,7 +953,7 @@ def run_train(self) -> None: self.train_loop.on_train_start() try: - if self.train_loop.should_skip_training(): + if self._should_skip_training(): return # run all epochs epochs = range(self.current_epoch, self.max_epochs) if self.max_epochs else count(self.current_epoch) From 91454b00fa9ea769a218e97fab89b317d4fb24d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Apr 2021 03:51:22 +0200 Subject: [PATCH 002/455] continue loop structure for batch loop --- pytorch_lightning/loops/base.py | 7 +- pytorch_lightning/loops/epoch_loop.py | 142 ++++++++++++++------------ 2 files changed, 82 insertions(+), 67 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 93856ac3fcd11..a69df82592dbe 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -11,7 +11,7 @@ def __init__(self): self.trainer: 'pl.Trainer' @abstractmethod - def connect(self): + def connect(self, *args, **kwargs): """Connects Loop with all the necessary things like connectors and accelerators""" @property @@ -36,7 +36,7 @@ def on_advance_end(self, curr_output: Any) -> Any: return curr_output def run(self, *args: Any, **kwargs: Any): - self.on_start(*args, **kwargs) + self.on_run_start(*args, **kwargs) outputs = [] @@ -50,6 +50,5 @@ def run(self, *args: Any, **kwargs: Any): self.iteration_count += 1 - outputs = self.on_end(outputs) - + outputs = self.on_run_end(outputs) return outputs diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 6af2bf155adb5..46f55f011843d 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -8,6 +8,7 @@ from pytorch_lightning.core.step_result import Result from pytorch_lightning.loops.base import Loop from pytorch_lightning.trainer.supporters import TensorRunningAccum +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.parsing import AttributeDict from pytorch_lightning.utilities.warnings import WarningCache @@ -25,6 +26,7 @@ def connect( self.num_epochs = num_epochs self.max_steps = max_steps self.trainer = trainer + self.loops_to_run = [] for loop in loops_to_run: if isinstance(loop, Loop) or hasattr(loop, 'run'): self.loops_to_run.append(loop) @@ -114,6 +116,7 @@ def on_advance_start(self): # equal to on train epoch start self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") + # why is this not the same as the old on_train_epoch_end? def on_advance_end(self, outputs): # handle epoch_output on epoch end self.on_train_epoch_end(outputs) @@ -153,10 +156,11 @@ def advance(self): class TrainingLoop(Loop): + """ Runs over all batches in a dataloader (one epoch). """ def connect(self, trainer: 'pl.Trainer'): self.trainer = trainer - self.batch_loop = BatchLoop + self.batch_loop = BatchLoop() def on_run_start(self): # modify dataloader if needed (ddp, etc...) @@ -166,6 +170,7 @@ def on_run_start(self): self._dataloader_idx = 0 def advance(self): + # TODO: profiling is gone batch_idx, (batch, is_last) = next(self._train_dataloader) self.trainer.batch_idx = batch_idx @@ -175,7 +180,8 @@ def advance(self): # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ with self.trainer.profiler.profile("run_training_batch"): - batch_output = self.run_training_batch(batch, batch_idx, self._dataloader_idx) + # batch_output = self.run_training_batch(batch, batch_idx, self._dataloader_idx) + batch_output = self.batch_loop.run(batch, batch_idx, self._dataloader_idx) # when returning -1 from train_step, we end epoch early if batch_output.signal == -1: @@ -241,6 +247,7 @@ def done(self): if self._num_training_batches_reached(self.trainer.is_last_batch): return True + # this is the old on train_epoch_end? def on_run_end(self, outputs): # inform logger the batch loop has finished self.trainer.logger_connector.on_train_epoch_end() @@ -278,21 +285,32 @@ def on_run_end(self, outputs): class BatchLoop(Loop): + """ Runs over a single batch of data. """ def on_run_start(self, batch, batch_idx, dataloader_idx): self._grad_norm_dic = {} self.trainer.hiddens = None self._optimizers = self.prepare_optimizers() # lightning module hook - self._splits = self.tbptt_split_batch(batch) + self._splits = enumerate(self.tbptt_split_batch(batch)) + self.tbptt_loop = BatchSplitLoop(self._optimizers) def on_advance_start(self): return super().on_advance_start() - def advance(self, *args: Any, **kwargs: Any): - return super().advance(*args, **kwargs) + def advance(self, batch, batch_idx): + split_idx, split_batch = next(self._splits) + batch_outputs = self.tbptt_loop.run(split_batch, split_idx, batch_idx) + + result = AttributeDict( + signal=0, + grad_norm_dic=grad_norm_dic, + training_step_output_for_epoch_end=batch_outputs, + ) + return result def run(self, batch, batch_idx, dataloader_idx): + # TODO why is this not in on_run_start? if batch is None: return AttributeDict(signal=0, grad_norm_dic={}) @@ -309,73 +327,71 @@ def run(self, batch, batch_idx, dataloader_idx): return super().run(batch, batch_idx, dataloader_idx) -def run_training_batch(self, batch, batch_idx, dataloader_idx): - - for split_idx, split_batch in enumerate(splits): +class BatchSplitLoop(Loop): + """ Runs over a single split of a batch of data (TBPTT). """ - # create an iterable for optimizers and loop over them - for opt_idx, optimizer in optimizers: + def __init__(self, optimizers): + super().__init__() + self._optimizers = enumerate(optimizers) - # toggle model params + set info to logger_connector - self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) + def advance(self, split_batch, split_idx, batch_idx): + opt_idx, optimizer = next(self._optimizers) - if self.should_accumulate(): - # For gradient accumulation + # toggle model params + set info to logger_connector + self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) - # ------------------- - # calculate loss (train step + train step end) - # ------------------- + if self.should_accumulate(): + # For gradient accumulation - # automatic_optimization=True: perform dpp sync only when performing optimizer_step - # automatic_optimization=False: don't block synchronization here - with self.block_ddp_sync_behaviour(): - self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) + # ------------------- + # calculate loss (train step + train step end) + # ------------------- - batch_outputs = self._process_closure_result( - batch_outputs=batch_outputs, - opt_idx=opt_idx, - ) - - # ------------------------------ - # BACKWARD PASS - # ------------------------------ - # gradient update with accumulated gradients + # automatic_optimization=True: perform dpp sync only when performing optimizer_step + # automatic_optimization=False: don't block synchronization here + with self.block_ddp_sync_behaviour(): + self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) - else: - if self.automatic_optimization: + batch_outputs = self._process_closure_result( + batch_outputs=batch_outputs, + opt_idx=opt_idx, + ) - def train_step_and_backward_closure(): - result = self.training_step_and_backward( - split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens - ) - return None if result is None else result.loss + # ------------------------------ + # BACKWARD PASS + # ------------------------------ + # gradient update with accumulated gradients - # optimizer step - self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) + else: + if self.automatic_optimization: - else: - self._curr_step_result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens) + def train_step_and_backward_closure(): + result = self.training_step_and_backward( + split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens + ) + return None if result is None else result.loss - if self._curr_step_result is None: - # user decided to skip optimization - # make sure to zero grad. - continue + # optimizer step + self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) - batch_outputs = self._process_closure_result( - batch_outputs=batch_outputs, - opt_idx=opt_idx, - ) - - # todo: Properly aggregate grad_norm accros opt_idx and split_idx - grad_norm_dic = self._cur_grad_norm_dict - self._cur_grad_norm_dict = None - - # update running loss + reset accumulated loss - self.update_running_loss() - - result = AttributeDict( - signal=0, - grad_norm_dic=grad_norm_dic, - training_step_output_for_epoch_end=batch_outputs, - ) - return result + else: + self._curr_step_result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens) + + if self._curr_step_result is None: + # user decided to skip optimization + # make sure to zero grad. + # TODO add logic to skip in the outer loop + return + # continue + + batch_outputs = self._process_closure_result( + batch_outputs=batch_outputs, + opt_idx=opt_idx, + ) + + # todo: Properly aggregate grad_norm accros opt_idx and split_idx + grad_norm_dic = self._cur_grad_norm_dict + self._cur_grad_norm_dict = None + + # update running loss + reset accumulated loss + self.update_running_loss() From 1405d088fecd791e0bbf5323e79c590afd603628 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Apr 2021 06:29:03 +0200 Subject: [PATCH 003/455] trainer ref --- pytorch_lightning/loops/base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index a69df82592dbe..992e819f7a3ba 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -1,3 +1,4 @@ +from _weakref import proxy from abc import ABCMeta, abstractmethod from typing import Any, Counter, List @@ -8,11 +9,12 @@ class Loop(metaclass=ABCMeta): def __init__(self): self.iteration_count: int = 0 - self.trainer: 'pl.Trainer' + self.trainer: 'pl.Trainer' = None @abstractmethod - def connect(self, *args, **kwargs): + def connect(self, trainer, *args, **kwargs): """Connects Loop with all the necessary things like connectors and accelerators""" + self.trainer = proxy(trainer) @property @abstractmethod From 64c6273d62c1ec696a616cc15b57be935238eb4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Apr 2021 06:33:47 +0200 Subject: [PATCH 004/455] split files --- pytorch_lightning/loops/batch_loop.py | 115 +++++++++++ pytorch_lightning/loops/epoch_loop.py | 241 ----------------------- pytorch_lightning/loops/training_loop.py | 136 +++++++++++++ 3 files changed, 251 insertions(+), 241 deletions(-) create mode 100644 pytorch_lightning/loops/batch_loop.py create mode 100644 pytorch_lightning/loops/training_loop.py diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py new file mode 100644 index 0000000000000..da4be9a94ce80 --- /dev/null +++ b/pytorch_lightning/loops/batch_loop.py @@ -0,0 +1,115 @@ +from pytorch_lightning.loops.base import Loop +from pytorch_lightning.utilities import AttributeDict + + +class BatchLoop(Loop): + """ Runs over a single batch of data. """ + + def on_run_start(self, batch, batch_idx, dataloader_idx): + self._grad_norm_dic = {} + self.trainer.hiddens = None + self._optimizers = self.prepare_optimizers() + # lightning module hook + self._splits = enumerate(self.tbptt_split_batch(batch)) + self.tbptt_loop = BatchSplitLoop(self._optimizers) + + def on_advance_start(self): + return super().on_advance_start() + + def advance(self, batch, batch_idx): + split_idx, split_batch = next(self._splits) + batch_outputs = self.tbptt_loop.run(split_batch, split_idx, batch_idx) + + result = AttributeDict( + signal=0, + grad_norm_dic=grad_norm_dic, + training_step_output_for_epoch_end=batch_outputs, + ) + return result + + def run(self, batch, batch_idx, dataloader_idx): + # TODO why is this not in on_run_start? + if batch is None: + return AttributeDict(signal=0, grad_norm_dic={}) + + # hook + response = self.trainer.call_hook("on_batch_start") + if response == -1: + return AttributeDict(signal=-1, grad_norm_dic={}) + + # hook + response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx) + if response == -1: + return AttributeDict(signal=-1, grad_norm_dic={}) + + return super().run(batch, batch_idx, dataloader_idx) + + +class BatchSplitLoop(Loop): + """ Runs over a single split of a batch of data (TBPTT). """ + + def __init__(self, optimizers): + super().__init__() + self._optimizers = enumerate(optimizers) + + def advance(self, split_batch, split_idx, batch_idx): + opt_idx, optimizer = next(self._optimizers) + + # toggle model params + set info to logger_connector + self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) + + if self.should_accumulate(): + # For gradient accumulation + + # ------------------- + # calculate loss (train step + train step end) + # ------------------- + + # automatic_optimization=True: perform dpp sync only when performing optimizer_step + # automatic_optimization=False: don't block synchronization here + with self.block_ddp_sync_behaviour(): + self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) + + batch_outputs = self._process_closure_result( + batch_outputs=batch_outputs, + opt_idx=opt_idx, + ) + + # ------------------------------ + # BACKWARD PASS + # ------------------------------ + # gradient update with accumulated gradients + + else: + if self.automatic_optimization: + + def train_step_and_backward_closure(): + result = self.training_step_and_backward( + split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens + ) + return None if result is None else result.loss + + # optimizer step + self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) + + else: + self._curr_step_result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens) + + if self._curr_step_result is None: + # user decided to skip optimization + # make sure to zero grad. + # TODO add logic to skip in the outer loop + return + # continue + + batch_outputs = self._process_closure_result( + batch_outputs=batch_outputs, + opt_idx=opt_idx, + ) + + # todo: Properly aggregate grad_norm accros opt_idx and split_idx + grad_norm_dic = self._cur_grad_norm_dict + self._cur_grad_norm_dict = None + + # update running loss + reset accumulated loss + self.update_running_loss() \ No newline at end of file diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 46f55f011843d..868de024c4f2e 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -154,244 +154,3 @@ def advance(self): return ret_vals - -class TrainingLoop(Loop): - """ Runs over all batches in a dataloader (one epoch). """ - - def connect(self, trainer: 'pl.Trainer'): - self.trainer = trainer - self.batch_loop = BatchLoop() - - def on_run_start(self): - # modify dataloader if needed (ddp, etc...) - train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) - - self._train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) - self._dataloader_idx = 0 - - def advance(self): - # TODO: profiling is gone - batch_idx, (batch, is_last) = next(self._train_dataloader) - - self.trainer.batch_idx = batch_idx - self.trainer.is_last_batch = is_last - - # ------------------------------------ - # TRAINING_STEP + TRAINING_STEP_END - # ------------------------------------ - with self.trainer.profiler.profile("run_training_batch"): - # batch_output = self.run_training_batch(batch, batch_idx, self._dataloader_idx) - batch_output = self.batch_loop.run(batch, batch_idx, self._dataloader_idx) - - # when returning -1 from train_step, we end epoch early - if batch_output.signal == -1: - self._skip_remaining_steps = True - return - - # hook - # TODO: add outputs to batches - self.on_train_batch_end( - epoch_output, - batch_output.training_step_output_for_epoch_end, - batch, - batch_idx, - self._dataloader_idx, - ) - - def on_advance_end(self, output): - # ----------------------------------------- - # SAVE METRICS TO LOGGERS - # ----------------------------------------- - self.trainer.logger_connector.log_train_step_metrics(output) - - # ----------------------------------------- - # VALIDATE IF NEEDED + CHECKPOINT CALLBACK - # ----------------------------------------- - should_check_val = self.should_check_val_fx(self.trainer.batch_idx, self.trainer.is_last_batch) - if should_check_val: - self.trainer.validating = True - self.trainer.run_evaluation() - self.trainer.training = True - - # ----------------------------------------- - # SAVE LOGGERS (ie: Tensorboard, etc...) - # ----------------------------------------- - self.save_loggers_on_train_batch_end() - - # update LR schedulers - monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics) - self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) - self.trainer.checkpoint_connector.has_trained = True - - # progress global step according to grads progress - self.increment_accumulated_grad_global_step() - - @property - def done(self): - # max steps reached, end training - if ( - self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step + 1 - and self._accumulated_batches_reached() - ): - return True - - # end epoch early - # stop when the flag is changed or we've gone past the amount - # requested in the batches - if self.trainer.should_stop: - return True - - self.trainer.total_batch_idx += 1 - - # stop epoch if we limited the number of training batches - if self._num_training_batches_reached(self.trainer.is_last_batch): - return True - - # this is the old on train_epoch_end? - def on_run_end(self, outputs): - # inform logger the batch loop has finished - self.trainer.logger_connector.on_train_epoch_end() - - # prepare epoch output - processed_outputs = self._prepare_outputs(outputs, batch_mode=False) - - # get the model and call model.training_epoch_end - model = self.trainer.lightning_module - - if is_overridden('training_epoch_end', model=model): - # run training_epoch_end - # refresh the result for custom logging at the epoch level - model._current_fx_name = 'training_epoch_end' - - # lightningmodule hook - training_epoch_end_output = model.training_epoch_end(processed_outputs) - - if training_epoch_end_output is not None: - raise MisconfigurationException( - 'training_epoch_end expects a return of None. ' - 'HINT: remove the return statement in training_epoch_end' - ) - - # capture logging - self.trainer.logger_connector.cache_logged_metrics() - - # call train epoch end hooks - self.trainer.call_hook('on_train_epoch_end', processed_outputs) - self.trainer.call_hook('on_epoch_end') - - # increment the global step once - # progress global step according to grads progress - self.increment_accumulated_grad_global_step() - - -class BatchLoop(Loop): - """ Runs over a single batch of data. """ - - def on_run_start(self, batch, batch_idx, dataloader_idx): - self._grad_norm_dic = {} - self.trainer.hiddens = None - self._optimizers = self.prepare_optimizers() - # lightning module hook - self._splits = enumerate(self.tbptt_split_batch(batch)) - self.tbptt_loop = BatchSplitLoop(self._optimizers) - - def on_advance_start(self): - return super().on_advance_start() - - def advance(self, batch, batch_idx): - split_idx, split_batch = next(self._splits) - batch_outputs = self.tbptt_loop.run(split_batch, split_idx, batch_idx) - - result = AttributeDict( - signal=0, - grad_norm_dic=grad_norm_dic, - training_step_output_for_epoch_end=batch_outputs, - ) - return result - - def run(self, batch, batch_idx, dataloader_idx): - # TODO why is this not in on_run_start? - if batch is None: - return AttributeDict(signal=0, grad_norm_dic={}) - - # hook - response = self.trainer.call_hook("on_batch_start") - if response == -1: - return AttributeDict(signal=-1, grad_norm_dic={}) - - # hook - response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx) - if response == -1: - return AttributeDict(signal=-1, grad_norm_dic={}) - - return super().run(batch, batch_idx, dataloader_idx) - - -class BatchSplitLoop(Loop): - """ Runs over a single split of a batch of data (TBPTT). """ - - def __init__(self, optimizers): - super().__init__() - self._optimizers = enumerate(optimizers) - - def advance(self, split_batch, split_idx, batch_idx): - opt_idx, optimizer = next(self._optimizers) - - # toggle model params + set info to logger_connector - self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) - - if self.should_accumulate(): - # For gradient accumulation - - # ------------------- - # calculate loss (train step + train step end) - # ------------------- - - # automatic_optimization=True: perform dpp sync only when performing optimizer_step - # automatic_optimization=False: don't block synchronization here - with self.block_ddp_sync_behaviour(): - self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) - - batch_outputs = self._process_closure_result( - batch_outputs=batch_outputs, - opt_idx=opt_idx, - ) - - # ------------------------------ - # BACKWARD PASS - # ------------------------------ - # gradient update with accumulated gradients - - else: - if self.automatic_optimization: - - def train_step_and_backward_closure(): - result = self.training_step_and_backward( - split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens - ) - return None if result is None else result.loss - - # optimizer step - self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) - - else: - self._curr_step_result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens) - - if self._curr_step_result is None: - # user decided to skip optimization - # make sure to zero grad. - # TODO add logic to skip in the outer loop - return - # continue - - batch_outputs = self._process_closure_result( - batch_outputs=batch_outputs, - opt_idx=opt_idx, - ) - - # todo: Properly aggregate grad_norm accros opt_idx and split_idx - grad_norm_dic = self._cur_grad_norm_dict - self._cur_grad_norm_dict = None - - # update running loss + reset accumulated loss - self.update_running_loss() diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py new file mode 100644 index 0000000000000..77a35f283c52e --- /dev/null +++ b/pytorch_lightning/loops/training_loop.py @@ -0,0 +1,136 @@ +from copy import deepcopy + +import pytorch_lightning as pl +from pytorch_lightning.loops.base import Loop +from pytorch_lightning.loops.epoch_loop import BatchLoop +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_helpers import is_overridden + + +class TrainingLoop(Loop): + """ Runs over all batches in a dataloader (one epoch). """ + + def connect(self, trainer: 'pl.Trainer', *args, **kwargs): + self.trainer = trainer + self.batch_loop = BatchLoop() + + def on_run_start(self): + # modify dataloader if needed (ddp, etc...) + train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) + + self._train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) + self._dataloader_idx = 0 + + def advance(self): + # TODO: profiling is gone + batch_idx, (batch, is_last) = next(self._train_dataloader) + + self.trainer.batch_idx = batch_idx + self.trainer.is_last_batch = is_last + + # ------------------------------------ + # TRAINING_STEP + TRAINING_STEP_END + # ------------------------------------ + with self.trainer.profiler.profile("run_training_batch"): + # batch_output = self.run_training_batch(batch, batch_idx, self._dataloader_idx) + batch_output = self.batch_loop.run(batch, batch_idx, self._dataloader_idx) + + # when returning -1 from train_step, we end epoch early + if batch_output.signal == -1: + self._skip_remaining_steps = True + return + + # hook + # TODO: add outputs to batches + self.on_train_batch_end( + epoch_output, + batch_output.training_step_output_for_epoch_end, + batch, + batch_idx, + self._dataloader_idx, + ) + + def on_advance_end(self, output): + # ----------------------------------------- + # SAVE METRICS TO LOGGERS + # ----------------------------------------- + self.trainer.logger_connector.log_train_step_metrics(output) + + # ----------------------------------------- + # VALIDATE IF NEEDED + CHECKPOINT CALLBACK + # ----------------------------------------- + should_check_val = self.should_check_val_fx(self.trainer.batch_idx, self.trainer.is_last_batch) + if should_check_val: + self.trainer.validating = True + self.trainer.run_evaluation() + self.trainer.training = True + + # ----------------------------------------- + # SAVE LOGGERS (ie: Tensorboard, etc...) + # ----------------------------------------- + self.save_loggers_on_train_batch_end() + + # update LR schedulers + monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics) + self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) + self.trainer.checkpoint_connector.has_trained = True + + # progress global step according to grads progress + self.increment_accumulated_grad_global_step() + + @property + def done(self): + # max steps reached, end training + if ( + self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step + 1 + and self._accumulated_batches_reached() + ): + return True + + # end epoch early + # stop when the flag is changed or we've gone past the amount + # requested in the batches + if self.trainer.should_stop: + return True + + self.trainer.total_batch_idx += 1 + + # stop epoch if we limited the number of training batches + if self._num_training_batches_reached(self.trainer.is_last_batch): + return True + + # this is the old on train_epoch_end? + def on_run_end(self, outputs): + # inform logger the batch loop has finished + self.trainer.logger_connector.on_train_epoch_end() + + # prepare epoch output + processed_outputs = self._prepare_outputs(outputs, batch_mode=False) + + # get the model and call model.training_epoch_end + model = self.trainer.lightning_module + + if is_overridden('training_epoch_end', model=model): + # run training_epoch_end + # refresh the result for custom logging at the epoch level + model._current_fx_name = 'training_epoch_end' + + # lightningmodule hook + training_epoch_end_output = model.training_epoch_end(processed_outputs) + + if training_epoch_end_output is not None: + raise MisconfigurationException( + 'training_epoch_end expects a return of None. ' + 'HINT: remove the return statement in training_epoch_end' + ) + + # capture logging + self.trainer.logger_connector.cache_logged_metrics() + + # call train epoch end hooks + self.trainer.call_hook('on_train_epoch_end', processed_outputs) + self.trainer.call_hook('on_epoch_end') + + # increment the global step once + # progress global step according to grads progress + self.increment_accumulated_grad_global_step() \ No newline at end of file From 4f79a0f7c229b3eb85f478d7b095649bb474c53b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Apr 2021 15:03:37 +0200 Subject: [PATCH 005/455] rename batchsplitloop to optimizerloop and add missing helper methods --- pytorch_lightning/loops/batch_loop.py | 81 +--- pytorch_lightning/loops/optimizer_loop.py | 443 ++++++++++++++++++++++ 2 files changed, 453 insertions(+), 71 deletions(-) create mode 100644 pytorch_lightning/loops/optimizer_loop.py diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index da4be9a94ce80..4f10b01cce2ac 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -1,4 +1,5 @@ from pytorch_lightning.loops.base import Loop +from pytorch_lightning.loops.optimizer_loop import OptimizerLoop from pytorch_lightning.utilities import AttributeDict @@ -8,10 +9,10 @@ class BatchLoop(Loop): def on_run_start(self, batch, batch_idx, dataloader_idx): self._grad_norm_dic = {} self.trainer.hiddens = None - self._optimizers = self.prepare_optimizers() + # self._optimizers = self.prepare_optimizers() # lightning module hook self._splits = enumerate(self.tbptt_split_batch(batch)) - self.tbptt_loop = BatchSplitLoop(self._optimizers) + self.tbptt_loop = OptimizerLoop() def on_advance_start(self): return super().on_advance_start() @@ -44,72 +45,10 @@ def run(self, batch, batch_idx, dataloader_idx): return super().run(batch, batch_idx, dataloader_idx) - -class BatchSplitLoop(Loop): - """ Runs over a single split of a batch of data (TBPTT). """ - - def __init__(self, optimizers): - super().__init__() - self._optimizers = enumerate(optimizers) - - def advance(self, split_batch, split_idx, batch_idx): - opt_idx, optimizer = next(self._optimizers) - - # toggle model params + set info to logger_connector - self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) - - if self.should_accumulate(): - # For gradient accumulation - - # ------------------- - # calculate loss (train step + train step end) - # ------------------- - - # automatic_optimization=True: perform dpp sync only when performing optimizer_step - # automatic_optimization=False: don't block synchronization here - with self.block_ddp_sync_behaviour(): - self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) - - batch_outputs = self._process_closure_result( - batch_outputs=batch_outputs, - opt_idx=opt_idx, - ) - - # ------------------------------ - # BACKWARD PASS - # ------------------------------ - # gradient update with accumulated gradients - - else: - if self.automatic_optimization: - - def train_step_and_backward_closure(): - result = self.training_step_and_backward( - split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens - ) - return None if result is None else result.loss - - # optimizer step - self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) - - else: - self._curr_step_result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens) - - if self._curr_step_result is None: - # user decided to skip optimization - # make sure to zero grad. - # TODO add logic to skip in the outer loop - return - # continue - - batch_outputs = self._process_closure_result( - batch_outputs=batch_outputs, - opt_idx=opt_idx, - ) - - # todo: Properly aggregate grad_norm accros opt_idx and split_idx - grad_norm_dic = self._cur_grad_norm_dict - self._cur_grad_norm_dict = None - - # update running loss + reset accumulated loss - self.update_running_loss() \ No newline at end of file + def tbptt_split_batch(self, batch): + splits = [batch] + if self.trainer.truncated_bptt_steps is not None: + model_ref = self.trainer.lightning_module + with self.trainer.profiler.profile("tbptt_split_batch"): + splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps) + return splits diff --git a/pytorch_lightning/loops/optimizer_loop.py b/pytorch_lightning/loops/optimizer_loop.py new file mode 100644 index 0000000000000..726abe352467a --- /dev/null +++ b/pytorch_lightning/loops/optimizer_loop.py @@ -0,0 +1,443 @@ +from contextlib import contextmanager +from copy import copy + +import numpy as np +import torch + +from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.core.step_result import Result +from pytorch_lightning.loops.base import Loop +from pytorch_lightning.plugins import ParallelPlugin +from pytorch_lightning.trainer.supporters import TensorRunningAccum +from pytorch_lightning.utilities import AttributeDict, DeviceType, AMPType +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.finite_checks import detect_nan_parameters +from pytorch_lightning.utilities.imports import _TPU_AVAILABLE +from pytorch_lightning.utilities.warnings import WarningCache + + +class OptimizerLoop(Loop): + """ Runs over a single split of a batch of data (TBPTT). """ + + def __init__(self): + super().__init__() + self._optimizers = enumerate(self.prepare_optimizers()) + + self.accumulated_loss = None + self.warning_cache = WarningCache() + # self._teardown_already_run = False + self.running_loss = TensorRunningAccum(window_length=20) + self.automatic_optimization = True + self._curr_step_result = None + # self._cur_grad_norm_dict = None + # self._multiple_trainloader_mode = multiple_trainloader_mode + self._skip_backward = False + # self.trainer._multiple_trainloader_mode = multiple_trainloader_mode + + def advance(self, split_batch, split_idx, batch_idx): + opt_idx, optimizer = next(self._optimizers) + + # toggle model params + set info to logger_connector + self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) + + if self.should_accumulate(): + # For gradient accumulation + + # ------------------- + # calculate loss (train step + train step end) + # ------------------- + + # automatic_optimization=True: perform dpp sync only when performing optimizer_step + # automatic_optimization=False: don't block synchronization here + with self.block_ddp_sync_behaviour(): + self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) + + batch_outputs = self._process_closure_result( + batch_outputs=batch_outputs, + opt_idx=opt_idx, + ) + + # ------------------------------ + # BACKWARD PASS + # ------------------------------ + # gradient update with accumulated gradients + + else: + if self.automatic_optimization: + + def train_step_and_backward_closure(): + result = self.training_step_and_backward( + split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens + ) + return None if result is None else result.loss + + # optimizer step + self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) + + else: + self._curr_step_result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens) + + if self._curr_step_result is None: + # user decided to skip optimization + # make sure to zero grad. + # TODO add logic to skip in the outer loop + return + # continue + + batch_outputs = self._process_closure_result( + batch_outputs=batch_outputs, + opt_idx=opt_idx, + ) + + # todo: Properly aggregate grad_norm accros opt_idx and split_idx + grad_norm_dic = self._cur_grad_norm_dict + self._cur_grad_norm_dict = None + + # update running loss + reset accumulated loss + self.update_running_loss() + + return batch_outputs + + def prepare_optimizers(self): + # in manual optimization we loop over all optimizers at once + optimizers = self.get_optimizers_iterable() + if not self.automatic_optimization: + optimizers = [optimizers[0]] + return optimizers + + def get_optimizers_iterable(self): + """ + Generates an iterable with (idx, optimizer) for each optimizer. + """ + if not self.trainer.optimizer_frequencies: + # call training_step once per optimizer + return list(enumerate(self.trainer.optimizers)) + + optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) + optimizers_loop_length = optimizer_freq_cumsum[-1] + current_place_in_loop = self.trainer.total_batch_idx % optimizers_loop_length + + # find optimzier index by looking for the first {item > current_place} in the cumsum list + opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) + return [[opt_idx, self.trainer.optimizers[opt_idx]]] + + def on_after_backward(self, training_step_output, batch_idx, untouched_loss): + training_step_output.detach() + + # insert after step hook + self.trainer.call_hook("on_after_backward") + + # when in dev debugging track the losses + self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach()) + + def _check_training_step_output(self, training_step_output): + if isinstance(training_step_output, torch.Tensor) and not self.automatic_optimization: + if training_step_output.grad_fn is None: + # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... + raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor") + + def training_step(self, split_batch, batch_idx, opt_idx, hiddens): + # give the PL module a result for logging + model_ref = self.trainer.lightning_module + + with self.trainer.profiler.profile("model_forward"): + args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) + + # manually capture logged metrics + model_ref._current_fx_name = 'training_step' + model_ref._results = Result() + with self.trainer.profiler.profile("training_step"): + training_step_output = self.trainer.accelerator.training_step(args) + self.trainer.accelerator.post_training_step() + + self.trainer.logger_connector.cache_logged_metrics() + + self._check_training_step_output(training_step_output) + + training_step_output = self.trainer.call_hook("training_step_end", training_step_output) + + training_step_output_for_epoch_end, training_step_output = self._process_training_step_output( + training_step_output, split_batch + ) + if training_step_output_for_epoch_end is None: + return + + # enable empty loss when using manual opt + closure_loss = None + untouched_loss = None + + if self.automatic_optimization: + # accumulate loss. if accumulate_grad_batches==1, no effect + closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches + + # the loss will get scaled for amp. avoid any modifications to it + untouched_loss = closure_loss.detach().clone() + + # result + result = AttributeDict( + closure_loss=closure_loss, + loss=untouched_loss, + training_step_output=training_step_output, + training_step_output_for_epoch_end=training_step_output_for_epoch_end, + ) + return result + + def _process_training_step_output(self, training_step_output, split_batch): + training_step_output_for_epoch_end = training_step_output + + # enable validation_step return None + if training_step_output_for_epoch_end is None: + return None, None + + result = self.trainer.lightning_module._results + + loss = None + hiddens = None + result["extra"] = {} + + # handle dict return + if isinstance(training_step_output, dict): + loss = training_step_output.pop("loss", None) + hiddens = training_step_output.pop("hiddens", None) + if hiddens is not None: + hiddens = hiddens.detach() + result["extra"] = training_step_output + + # handle scalar return + elif isinstance(training_step_output, torch.Tensor): + loss = training_step_output + + # map to results under the hood + result.minimize = loss + self.trainer.hiddens = hiddens + + # track batch for manual reduction with result + result.track_batch_size(len(split_batch)) + + # track metrics without grads for epoch reduction + training_step_output_for_epoch_end = copy(result) + training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() + if self.trainer.move_metrics_to_cpu: + training_step_output_for_epoch_end = training_step_output_for_epoch_end.cpu() + + return training_step_output_for_epoch_end, result + + def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): + model_ref = self.trainer.lightning_module + + is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) + using_native_amp = self.trainer.amp_backend == AMPType.NATIVE + + # native amp + lbfgs is a no go right now + if using_native_amp and is_lbfgs: + raise MisconfigurationException( + 'native PyTorch amp and lbfgs are not compatible.' + ' To request, please file a Github issue in PyTorch and tag @mcarilli' + ) + + # wraps into LightningOptimizer only for running step + optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx) + + # model hook + model_ref.optimizer_step( + self.trainer.current_epoch, + batch_idx, + optimizer, + opt_idx, + train_step_and_backward_closure, + on_tpu=(self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE), + using_native_amp=using_native_amp, + using_lbfgs=is_lbfgs, + ) + + def on_before_zero_grad(self, optimizer): + self.trainer.call_hook('on_before_zero_grad', optimizer) + + def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): + self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) + + def track_and_norm_grad(self, optimizer): + # track gradient norms + grad_norm_dic = self._track_gradient_norm() + + # clip gradients + self.trainer.accelerator.clip_gradients( + optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm + ) + self._cur_grad_norm_dict = grad_norm_dic + + def _track_gradient_norm(self): + grad_norm_dict = {} + if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0: + if float(self.trainer.track_grad_norm) > 0: + model = self.trainer.lightning_module + grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm) + return grad_norm_dict + + def _accumulated_batches_reached(self): + return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 + + def _num_training_batches_reached(self, is_last_batch=False): + return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch + + def should_accumulate(self): + # checks if backward or backward + optimizer step (via closure) + accumulation_done = self._accumulated_batches_reached() + is_final_batch = self._num_training_batches_reached() + return not (accumulation_done or is_final_batch) + + def build_train_args(self, batch, batch_idx, opt_idx, hiddens): + # enable not needing to add opt_idx to training_step + args = [batch, batch_idx] + + if len(self.trainer.optimizers) > 1: + if self.trainer.has_arg("training_step", "optimizer_idx"): + if not self.automatic_optimization: + self.warning_cache.warn( + "`training_step` hook signature has changed in v1.3." + " `optimizer_idx` argument has been removed in case of manual optimization. Support for" + " the old signature will be removed in v1.5", DeprecationWarning + ) + args.append(opt_idx) + elif not self.trainer.has_arg("training_step", "optimizer_idx") and self.automatic_optimization: + raise ValueError( + f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but" + ' `training_step` is missing the `optimizer_idx` argument.' + ) + + # pass hiddens if using tbptt + if self.trainer.truncated_bptt_steps is not None: + args.append(hiddens) + + return args + + def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): + # set split_idx to trainer for tracking + self.trainer.split_idx = split_idx + + # make sure only the gradients of the current optimizer's parameters are calculated + # in the training step to prevent dangling gradients in multiple-optimizer setup. + if self.automatic_optimization and len(self.trainer.optimizers) > 1: + model = self.trainer.lightning_module + model.toggle_optimizer(optimizer, opt_idx) + + # use to track metrics internally + self.trainer.logger_connector.on_train_split_start(split_idx, opt_idx, split_batch) + + @contextmanager + def block_ddp_sync_behaviour(self, should_block_sync: bool = False): + """ + automatic_optimization = True + Blocks ddp sync gradients behaviour on backwards pass. + This is useful for skipping sync when accumulating gradients, reducing communication overhead + + automatic_optimization = False + do not block ddp gradient sync when using manual optimization + as gradients are needed within the training step + + Returns: + context manager with sync behaviour off + + """ + if ( + isinstance(self.trainer.training_type_plugin, ParallelPlugin) + and (self.automatic_optimization or should_block_sync) + ): + with self.trainer.training_type_plugin.block_backward_sync(): + yield None + else: + yield None + + def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list: + opt_closure_result = self._curr_step_result + + if opt_closure_result is not None: + + # cache metrics + self.trainer.logger_connector.cache_training_step_metrics(opt_closure_result) + + # check if loss or model weights are nan + if self.trainer.terminate_on_nan: + self._check_finite(opt_closure_result.loss) + + # track all the outputs across all steps + batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0 + batch_outputs[batch_opt_idx].append(opt_closure_result.training_step_output_for_epoch_end) + + if self.automatic_optimization: + # track total loss for logging (avoid mem leaks) + self.accumulated_loss.append(opt_closure_result.loss) + + self._curr_step_result = None + + return batch_outputs + + def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): + """Wrap forward, zero_grad and backward in a closure so second order methods work""" + with self.trainer.profiler.profile("training_step_and_backward"): + # lightning module hook + result = self.training_step(split_batch, batch_idx, opt_idx, hiddens) + self._curr_step_result = result + + if not self._skip_backward and self.automatic_optimization: + is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0 + + if is_first_batch_to_accumulate: + self.on_before_zero_grad(optimizer) + self.optimizer_zero_grad(batch_idx, optimizer, opt_idx) + + # backward pass + if result is not None: + with self.trainer.profiler.profile("backward"): + self.backward(result, optimizer, opt_idx) + + # hook - call this hook only + # when gradients have finished to accumulate + if not self.should_accumulate(): + self.on_after_backward(result.training_step_output, batch_idx, result.loss) + + # check if loss or model weights are nan + if self.trainer.terminate_on_nan: + self._check_finite(result.loss) + + else: + self.warning_cache.warn("training_step returned None if it was on purpose, ignore this warning...") + + if len(self.trainer.optimizers) > 1: + # revert back to previous state + self.trainer.lightning_module.untoggle_optimizer(opt_idx) + + return result + + def _check_finite(self, loss: torch.Tensor) -> None: + if not torch.isfinite(loss).all(): + raise ValueError(f'The loss returned in `training_step` is {loss}.') + model = self.trainer.lightning_module + detect_nan_parameters(model) + + def backward(self, result, optimizer, opt_idx, *args, **kwargs): + self.trainer.dev_debugger.track_event("backward_call") + + should_accumulate = self.should_accumulate() + + # backward can be called manually in the training loop + if isinstance(result, torch.Tensor): + self.trainer.accelerator.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs) + else: + result.closure_loss = self.trainer.accelerator.backward( + result.closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs + ) + + if not self.should_accumulate(): + # track gradients + self.track_and_norm_grad(optimizer=optimizer) + + def update_running_loss(self): + accumulated_loss = self.accumulated_loss.mean() + + if accumulated_loss is not None: + # calculate running loss for display + self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) + + # reset for next set of accumulated grads + self.accumulated_loss.reset() From 102baa35a2adf84729e73a8dd36b957cc10a76dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Apr 2021 15:18:15 +0200 Subject: [PATCH 006/455] add missing helpers to training loop --- pytorch_lightning/loops/training_loop.py | 134 ++++++++++++++++++++++- 1 file changed, 132 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 77a35f283c52e..7eb7bc129539f 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -1,8 +1,10 @@ from copy import deepcopy +from typing import List, Dict, Union import pytorch_lightning as pl +from pytorch_lightning.core.step_result import Result from pytorch_lightning.loops.base import Loop -from pytorch_lightning.loops.epoch_loop import BatchLoop +from pytorch_lightning.loops.batch_loop import BatchLoop from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -133,4 +135,132 @@ def on_run_end(self, outputs): # increment the global step once # progress global step according to grads progress - self.increment_accumulated_grad_global_step() \ No newline at end of file + self.increment_accumulated_grad_global_step() + + def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx): + batch_end_outputs = [opt_idx_out for opt_idx_out in batch_end_outputs if len(opt_idx_out)] + + processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True) + + # hook + self.trainer.call_hook('on_train_batch_end', processed_batch_end_outputs, batch, batch_idx, dataloader_idx) + self.trainer.call_hook('on_batch_end') + + # figure out what to track for epoch end + self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs) + + # reset batch logger internals + self.trainer.logger_connector.on_train_batch_end() + + def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): + + # track the outputs to reduce at the end of the epoch + for opt_idx, opt_outputs in enumerate(batch_end_outputs): + sample_output = opt_outputs[-1] + + # decide if we need to reduce at the end of the epoch automatically + auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end + hook_overridden = ( + is_overridden("training_epoch_end", model=self.trainer.lightning_module) + or is_overridden("on_train_epoch_end", model=self.trainer.lightning_module) + ) + + # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end + if not (hook_overridden or auto_reduce_tng_result): + continue + + # with 1 step (no tbptt) don't use a sequence at epoch end + if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result): + opt_outputs = opt_outputs[0] + + epoch_output[opt_idx].append(opt_outputs) + + @staticmethod + def _prepare_outputs( + outputs: List[List[List[Result]]], + batch_mode: bool, + ) -> Union[List[List[List[Dict]]], List[List[Dict]], List[Dict], Dict]: + """ + Extract required information from batch or epoch end results. + + Args: + outputs: A 3-dimensional list of ``Result`` objects with dimensions: + [optimizer outs][batch outs][tbptt steps]. + + batch_mode: If True, ignore the batch output dimension. + + Returns: + The cleaned outputs with ``Result`` objects converted to dictionaries. All list dimensions of size one will + be collapsed. + """ + processed_outputs = [] + for opt_outputs in outputs: + # handle an edge case where an optimizer output is the empty list + if len(opt_outputs) == 0: + continue + + processed_batch_outputs = [] + + if batch_mode: + opt_outputs = [opt_outputs] + + for batch_outputs in opt_outputs: + processed_tbptt_outputs = [] + + for tbptt_output in batch_outputs: + out = tbptt_output.extra + out['loss'] = tbptt_output.minimize + processed_tbptt_outputs.append(out) + + # if there was only one tbptt step then we can collapse that dimension + if len(processed_tbptt_outputs) == 1: + processed_tbptt_outputs = processed_tbptt_outputs[0] + processed_batch_outputs.append(processed_tbptt_outputs) + + # batch_outputs should be just one dict (or a list of dicts if using tbptt) per optimizer + if batch_mode: + processed_batch_outputs = processed_batch_outputs[0] + processed_outputs.append(processed_batch_outputs) + + # if there is only one optimiser then we collapse that dimension + if len(processed_outputs) == 1: + processed_outputs = processed_outputs[0] + return processed_outputs + + def update_train_loop_lr_schedulers(self, monitor_metrics=None): + num_accumulated_batches_reached = self._accumulated_batches_reached() + num_training_batches_reached = self._num_training_batches_reached() + + if num_accumulated_batches_reached or num_training_batches_reached: + # update lr + self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics) + + def increment_accumulated_grad_global_step(self): + num_accumulated_batches_reached = self._accumulated_batches_reached() + num_training_batches_reached = self._num_training_batches_reached() + + # progress global step according to grads progress + if num_accumulated_batches_reached or num_training_batches_reached: + self.trainer.global_step = self.trainer.accelerator.update_global_step( + self.trainer.total_batch_idx, self.trainer.global_step + ) + + def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False): + # decide if we should run validation + is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 + is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 + can_check_val = self.trainer.enable_validation and is_val_check_epoch + is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") + epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 + + should_check_val = ((is_val_check_batch and epoch_end_val_check) or self.trainer.should_stop + or is_last_batch_for_infinite_dataset + ) if on_epoch else (is_val_check_batch and not epoch_end_val_check) + + return should_check_val and can_check_val + + def save_loggers_on_train_batch_end(self): + # when loggers should save to disk + should_flush_logs = self.trainer.logger_connector.should_flush_logs + if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: + self.trainer.logger.save() \ No newline at end of file From 4a236bc42e9a0a4621cf2b2845bafb429575a861 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Apr 2021 15:24:47 +0200 Subject: [PATCH 007/455] add separation between new API and helpers --- pytorch_lightning/loops/base.py | 4 ++-- pytorch_lightning/loops/epoch_loop.py | 4 +++- pytorch_lightning/loops/optimizer_loop.py | 6 +++++- pytorch_lightning/loops/training_loop.py | 5 +++++ 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 992e819f7a3ba..aeac69f50c9a4 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -1,6 +1,6 @@ from _weakref import proxy from abc import ABCMeta, abstractmethod -from typing import Any, Counter, List +from typing import Any, Counter, List, Optional import pytorch_lightning as pl @@ -9,7 +9,7 @@ class Loop(metaclass=ABCMeta): def __init__(self): self.iteration_count: int = 0 - self.trainer: 'pl.Trainer' = None + self.trainer: Optional['pl.Trainer'] = None @abstractmethod def connect(self, trainer, *args, **kwargs): diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 868de024c4f2e..b94642f698975 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -1,6 +1,6 @@ +import logging from contextlib import suppress from copy import deepcopy -from logging import log from typing import Any, List, Optional import pytorch_lightning as pl @@ -13,6 +13,8 @@ from pytorch_lightning.utilities.parsing import AttributeDict from pytorch_lightning.utilities.warnings import WarningCache +log = logging.getLogger(__name__) + class EpochLoop(Loop): diff --git a/pytorch_lightning/loops/optimizer_loop.py b/pytorch_lightning/loops/optimizer_loop.py index 726abe352467a..e771c07826241 100644 --- a/pytorch_lightning/loops/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer_loop.py @@ -29,7 +29,7 @@ def __init__(self): self.running_loss = TensorRunningAccum(window_length=20) self.automatic_optimization = True self._curr_step_result = None - # self._cur_grad_norm_dict = None + self._cur_grad_norm_dict = None # self._multiple_trainloader_mode = multiple_trainloader_mode self._skip_backward = False # self.trainer._multiple_trainloader_mode = multiple_trainloader_mode @@ -98,6 +98,10 @@ def train_step_and_backward_closure(): return batch_outputs +# ------------------------------------------------------------------------------------------------------------ +# HELPER +# ------------------------------------------------------------------------------------------------------------ + def prepare_optimizers(self): # in manual optimization we loop over all optimizers at once optimizers = self.get_optimizers_iterable() diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 7eb7bc129539f..5589cd77b30bf 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -137,6 +137,11 @@ def on_run_end(self, outputs): # progress global step according to grads progress self.increment_accumulated_grad_global_step() +# ------------------------------------------------------------------------------------------------------------ +# HELPER +# ------------------------------------------------------------------------------------------------------------ + + # TODO move to on_advance_end() def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx): batch_end_outputs = [opt_idx_out for opt_idx_out in batch_end_outputs if len(opt_idx_out)] From aadef6b44c9e84cbe81fb3c0205b4a7c19703e82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Apr 2021 15:27:31 +0200 Subject: [PATCH 008/455] wip --- pytorch_lightning/trainer/training_loop.py | 695 ++++++++++++++++++--- 1 file changed, 592 insertions(+), 103 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 790dc4c70bdeb..29fe17a782c9c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -39,14 +39,14 @@ class TrainLoop: def __init__(self, trainer, multiple_trainloader_mode: str): self.trainer = trainer - self.accumulated_loss = None - self.warning_cache = WarningCache() + # self.accumulated_loss = None + # self.warning_cache = WarningCache() self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20) self._curr_step_result = None self._cur_grad_norm_dict = None self._multiple_trainloader_mode = multiple_trainloader_mode - self._skip_backward = False + # self._skip_backward = False self.trainer._multiple_trainloader_mode = multiple_trainloader_mode self._optimizer_freq_cumsum = None @@ -84,6 +84,10 @@ def on_trainer_init( def num_optimizers(self): num_optimizers = len(self.get_optimizers_iterable()) return num_optimizers + # + # def on_train_start(self): + # # hook + # self.trainer.call_hook("on_train_start") @property def optimizer_freq_cumsum(self): @@ -196,7 +200,28 @@ def reset_train_val_dataloaders(self, model) -> None: if self.trainer.val_dataloaders is None: self.trainer.reset_val_dataloader(model) - def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): + # def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): + # + # # track the outputs to reduce at the end of the epoch + # for opt_idx, opt_outputs in enumerate(batch_end_outputs): + # sample_output = opt_outputs[-1] + # + # # decide if we need to reduce at the end of the epoch automatically + # auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end + # hook_overridden = ( + # is_overridden("training_epoch_end", model=self.trainer.lightning_module) + # or is_overridden("on_train_epoch_end", model=self.trainer.lightning_module) + # ) + # + # # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end + # if not (hook_overridden or auto_reduce_tng_result): + # continue + # + # # with 1 step (no tbptt) don't use a sequence at epoch end + # if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result): + # opt_outputs = opt_outputs[0] + # + # epoch_output[opt_idx].append(opt_outputs) hook_overridden = self._should_add_batch_output_to_epoch_output() @@ -207,15 +232,201 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): # decide if we need to reduce at the end of the epoch automatically auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end - # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end - if not (hook_overridden or auto_reduce_tng_result): - continue - - # with 1 step (no tbptt) don't use a sequence at epoch end - if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result): - opt_outputs = opt_outputs[0] - - epoch_output[opt_idx].append(opt_outputs) + # def _check_training_step_output(self, training_step_output): + # if isinstance(training_step_output, torch.Tensor) and not self.automatic_optimization: + # if training_step_output.grad_fn is None: + # # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... + # raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor") + # + # def training_step(self, split_batch, batch_idx, opt_idx, hiddens): + # # give the PL module a result for logging + # model_ref = self.trainer.lightning_module + # + # with self.trainer.profiler.profile("model_forward"): + # args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) + # + # # manually capture logged metrics + # model_ref._current_fx_name = 'training_step' + # model_ref._results = Result() + # with self.trainer.profiler.profile("training_step"): + # training_step_output = self.trainer.accelerator.training_step(args) + # self.trainer.accelerator.post_training_step() + # + # self.trainer.logger_connector.cache_logged_metrics() + # + # self._check_training_step_output(training_step_output) + # + # training_step_output = self.trainer.call_hook("training_step_end", training_step_output) + # + # training_step_output_for_epoch_end, training_step_output = self._process_training_step_output( + # training_step_output, split_batch + # ) + # if training_step_output_for_epoch_end is None: + # return + # + # # enable empty loss when using manual opt + # closure_loss = None + # untouched_loss = None + # + # if self.automatic_optimization: + # # accumulate loss. if accumulate_grad_batches==1, no effect + # closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches + # + # # the loss will get scaled for amp. avoid any modifications to it + # untouched_loss = closure_loss.detach().clone() + # + # # result + # result = AttributeDict( + # closure_loss=closure_loss, + # loss=untouched_loss, + # training_step_output=training_step_output, + # training_step_output_for_epoch_end=training_step_output_for_epoch_end, + # ) + # return result + # + # def _process_training_step_output(self, training_step_output, split_batch): + # training_step_output_for_epoch_end = training_step_output + # + # # enable validation_step return None + # if training_step_output_for_epoch_end is None: + # return None, None + # + # result = self.trainer.lightning_module._results + # + # loss = None + # hiddens = None + # result["extra"] = {} + # + # # handle dict return + # if isinstance(training_step_output, dict): + # loss = training_step_output.pop("loss", None) + # hiddens = training_step_output.pop("hiddens", None) + # if hiddens is not None: + # hiddens = hiddens.detach() + # result["extra"] = training_step_output + # + # # handle scalar return + # elif isinstance(training_step_output, torch.Tensor): + # loss = training_step_output + # + # # map to results under the hood + # result.minimize = loss + # self.trainer.hiddens = hiddens + # + # # track batch for manual reduction with result + # result.track_batch_size(len(split_batch)) + # + # # track metrics without grads for epoch reduction + # training_step_output_for_epoch_end = copy(result) + # training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() + # if self.trainer.move_metrics_to_cpu: + # training_step_output_for_epoch_end = training_step_output_for_epoch_end.cpu() + # + # return training_step_output_for_epoch_end, result + + # @staticmethod + # def _prepare_outputs( + # outputs: List[List[List[Result]]], + # batch_mode: bool, + # ) -> Union[List[List[List[Dict]]], List[List[Dict]], List[Dict], Dict]: + # """ + # Extract required information from batch or epoch end results. + # + # Args: + # outputs: A 3-dimensional list of ``Result`` objects with dimensions: + # [optimizer outs][batch outs][tbptt steps]. + # + # batch_mode: If True, ignore the batch output dimension. + # + # Returns: + # The cleaned outputs with ``Result`` objects converted to dictionaries. All list dimensions of size one will + # be collapsed. + # """ + # processed_outputs = [] + # for opt_outputs in outputs: + # # handle an edge case where an optimizer output is the empty list + # if len(opt_outputs) == 0: + # continue + # + # processed_batch_outputs = [] + # + # if batch_mode: + # opt_outputs = [opt_outputs] + # + # for batch_outputs in opt_outputs: + # processed_tbptt_outputs = [] + # + # for tbptt_output in batch_outputs: + # out = tbptt_output.extra + # out['loss'] = tbptt_output.minimize + # processed_tbptt_outputs.append(out) + # + # # if there was only one tbptt step then we can collapse that dimension + # if len(processed_tbptt_outputs) == 1: + # processed_tbptt_outputs = processed_tbptt_outputs[0] + # processed_batch_outputs.append(processed_tbptt_outputs) + # + # # batch_outputs should be just one dict (or a list of dicts if using tbptt) per optimizer + # if batch_mode: + # processed_batch_outputs = processed_batch_outputs[0] + # processed_outputs.append(processed_batch_outputs) + # + # # if there is only one optimiser then we collapse that dimension + # if len(processed_outputs) == 1: + # processed_outputs = processed_outputs[0] + # return processed_outputs + + # def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): + # model_ref = self.trainer.lightning_module + # + # is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) + # using_native_amp = self.trainer.amp_backend == AMPType.NATIVE + # + # # native amp + lbfgs is a no go right now + # if using_native_amp and is_lbfgs: + # raise MisconfigurationException( + # 'native PyTorch amp and lbfgs are not compatible.' + # ' To request, please file a Github issue in PyTorch and tag @mcarilli' + # ) + # + # # wraps into LightningOptimizer only for running step + # optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx) + # + # # model hook + # model_ref.optimizer_step( + # self.trainer.current_epoch, + # batch_idx, + # optimizer, + # opt_idx, + # train_step_and_backward_closure, + # on_tpu=self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE, + # using_native_amp=using_native_amp, + # using_lbfgs=is_lbfgs, + # ) + # + # def on_before_zero_grad(self, optimizer): + # self.trainer.call_hook('on_before_zero_grad', optimizer) + # + # def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): + # self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) + # + # def track_and_norm_grad(self, optimizer): + # # track gradient norms + # grad_norm_dic = self._track_gradient_norm() + # + # # clip gradients + # self.trainer.accelerator.clip_gradients( + # optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm + # ) + # self._cur_grad_norm_dict = grad_norm_dic + # + # def _track_gradient_norm(self): + # grad_norm_dict = {} + # if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0: + # if float(self.trainer.track_grad_norm) > 0: + # model = self.trainer.lightning_module + # grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm) + # return grad_norm_dict def _should_add_batch_output_to_epoch_output(self) -> bool: # We add to the epoch outputs if @@ -483,11 +694,30 @@ def run_training_epoch(self): self.trainer.batch_idx = batch_idx self.trainer.is_last_batch = is_last_batch - # ------------------------------------ - # TRAINING_STEP + TRAINING_STEP_END - # ------------------------------------ - with self.trainer.profiler.profile("run_training_batch"): - batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx) + # # ----------------------------------------- + # # SAVE METRICS TO LOGGERS + # # ----------------------------------------- + # self.trainer.logger_connector.log_train_step_metrics(batch_output) + # + # # ----------------------------------------- + # # VALIDATE IF NEEDED + CHECKPOINT CALLBACK + # # ----------------------------------------- + # should_check_val = self.should_check_val_fx(batch_idx, is_last_batch) + # if should_check_val: + # self.trainer.validating = True + # self.trainer.run_evaluation() + # self.trainer.training = True + # val_loop_called = True + # + # # ----------------------------------------- + # # SAVE LOGGERS (ie: Tensorboard, etc...) + # # ----------------------------------------- + # self.save_loggers_on_train_batch_end() + # + # # update LR schedulers + # monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics) + # self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) + # self.trainer.checkpoint_connector.has_trained = True # when returning -1 from train_step, we end epoch early if batch_output.signal == -1: @@ -529,23 +759,84 @@ def run_training_epoch(self): self.trainer.checkpoint_connector.has_trained = True # max steps reached, end training - if ( - self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step + 1 - and self._accumulated_batches_reached() - ): - break - - # end epoch early - # stop when the flag is changed or we've gone past the amount - # requested in the batches - if self.trainer.should_stop: - break - - self.trainer.total_batch_idx += 1 - - # stop epoch if we limited the number of training batches - if self._num_training_batches_reached(is_last_batch): - break + # if ( + # self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step + 1 + # and self._accumulated_batches_reached() + # ): + # break + # + # # end epoch early + # # stop when the flag is changed or we've gone past the amount + # # requested in the batches + # if self.trainer.should_stop: + # break + # + # self.trainer.total_batch_idx += 1 + # + # # stop epoch if we limited the number of training batches + # if self._num_training_batches_reached(is_last_batch): + # break + + # # progress global step according to grads progress + # self.increment_accumulated_grad_global_step() + # + # # handle epoch_output on epoch end + # self.on_train_epoch_end(epoch_output) + + # # log epoch metrics + # self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output) + # + # should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) + # should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) + # should_train_only = self.trainer.disable_validation or should_skip_eval + # + # # update epoch level lr_schedulers if no val loop outside train loop is triggered + # if (val_loop_called and not should_check_val) or should_train_only: + # self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + # + # if should_train_only: + # self.check_checkpoint_callback(True) + # self.check_early_stopping_callback(True) + # + # if should_check_val: + # self.trainer.validating = True + # self.trainer.run_evaluation(on_epoch=True) + # self.trainer.training = True + # + # # increment the global step once + # # progress global step according to grads progress + # self.increment_accumulated_grad_global_step() + + # def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None: + # # inform logger the batch loop has finished + # self.trainer.logger_connector.on_train_epoch_end() + # + # # prepare epoch output + # processed_epoch_output = TrainLoop._prepare_outputs(epoch_output, batch_mode=False) + # + # # get the model and call model.training_epoch_end + # model = self.trainer.lightning_module + # + # if is_overridden('training_epoch_end', model=model): + # # run training_epoch_end + # # refresh the result for custom logging at the epoch level + # model._current_fx_name = 'training_epoch_end' + # + # # lightningmodule hook + # training_epoch_end_output = model.training_epoch_end(processed_epoch_output) + # + # if training_epoch_end_output is not None: + # raise MisconfigurationException( + # 'training_epoch_end expects a return of None. ' + # 'HINT: remove the return statement in training_epoch_end' + # ) + # + # # capture logging + # self.trainer.logger_connector.cache_logged_metrics() + # + # # call train epoch end hooks + # self.trainer.call_hook('on_train_epoch_end', processed_epoch_output) + # self.trainer.call_hook('on_epoch_end') # progress global step according to grads progress self.increment_accumulated_grad_global_step() @@ -660,10 +951,13 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): # bookkeeping self.trainer.hiddens = None - optimizers = self.prepare_optimizers() + # optimizers = self.prepare_optimizers() # track all outputs across time and num of optimizers batch_outputs = [[] for _ in range(len(optimizers))] + # + # if batch is None: + # return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) if batch is None: self.warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") @@ -673,81 +967,276 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): training_step_output_for_epoch_end=batch_outputs, ) - # hook - response = self.trainer.call_hook("on_batch_start") - if response == -1: - return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) - - # hook - response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx) - if response == -1: - return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) - # lightning module hook splits = self._tbptt_split_batch(batch) - for split_idx, split_batch in enumerate(splits): - # create an iterable for optimizers and loop over them - for opt_idx, optimizer in optimizers: - - # toggle model params + set info to logger_connector - self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) - - if self.should_accumulate(): - # For gradient accumulation - - # ------------------- - # calculate loss (train step + train step end) - # ------------------- - - # automatic_optimization=True: perform dpp sync only when performing optimizer_step - # automatic_optimization=False: don't block synchronization here - with self.block_ddp_sync_behaviour(): - self.training_step_and_backward( - split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens - ) - - batch_outputs = self._process_closure_result( - batch_outputs=batch_outputs, - opt_idx=opt_idx, - ) - - # ------------------------------ - # BACKWARD PASS - # ------------------------------ - # gradient update with accumulated gradients + # for opt_idx, optimizer in optimizers: + # + # # toggle model params + set info to logger_connector + # self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) + # + # if self.should_accumulate(): + # # For gradient accumulation + # + # # ------------------- + # # calculate loss (train step + train step end) + # # ------------------- + # + # # automatic_optimization=True: perform dpp sync only when performing optimizer_step + # # automatic_optimization=False: don't block synchronization here + # with self.block_ddp_sync_behaviour(): + # self.training_step_and_backward( + # split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens + # ) + # + # batch_outputs = self._process_closure_result( + # batch_outputs=batch_outputs, + # opt_idx=opt_idx, + # ) + # + # # ------------------------------ + # # BACKWARD PASS + # # ------------------------------ + # # gradient update with accumulated gradients + # + # else: + # if self.automatic_optimization: + # + # def train_step_and_backward_closure(): + # result = self.training_step_and_backward( + # split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens + # ) + # return None if result is None else result.loss + # + # # optimizer step + # self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) + # + # else: + # self._curr_step_result = self.training_step( + # split_batch, batch_idx, opt_idx, self.trainer.hiddens + # ) + # + # if self._curr_step_result is None: + # # user decided to skip optimization + # # make sure to zero grad. + # continue + # + # batch_outputs = self._process_closure_result( + # batch_outputs=batch_outputs, + # opt_idx=opt_idx, + # ) + # + # # todo: Properly aggregate grad_norm accros opt_idx and split_idx + # grad_norm_dic = self._cur_grad_norm_dict + # self._cur_grad_norm_dict = None + # + # # update running loss + reset accumulated loss + # self.update_running_loss() + # + # result = AttributeDict( + # signal=0, + # grad_norm_dic=grad_norm_dic, + # training_step_output_for_epoch_end=batch_outputs, + # ) + # return result + + # @contextmanager + # def block_ddp_sync_behaviour(self, should_block_sync: bool = False): + # """ + # automatic_optimization = True + # Blocks ddp sync gradients behaviour on backwards pass. + # This is useful for skipping sync when accumulating gradients, reducing communication overhead + # + # automatic_optimization = False + # do not block ddp gradient sync when using manual optimization + # as gradients are needed within the training step + # + # Returns: + # context manager with sync behaviour off + # + # """ + # if ( + # isinstance(self.trainer.training_type_plugin, ParallelPlugin) + # and (self.automatic_optimization or should_block_sync) + # ): + # with self.trainer.training_type_plugin.block_backward_sync(): + # yield None + # else: + # yield None + + # def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list: + # opt_closure_result = self._curr_step_result + # + # if opt_closure_result is not None: + # + # # cache metrics + # self.trainer.logger_connector.cache_training_step_metrics(opt_closure_result) + # + # # check if loss or model weights are nan + # if self.trainer.terminate_on_nan: + # self._check_finite(opt_closure_result.loss) + # + # # track all the outputs across all steps + # batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0 + # batch_outputs[batch_opt_idx].append(opt_closure_result.training_step_output_for_epoch_end) + # + # if self.automatic_optimization: + # # track total loss for logging (avoid mem leaks) + # self.accumulated_loss.append(opt_closure_result.loss) + # + # self._curr_step_result = None + # + # return batch_outputs + # + # def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): + # """Wrap forward, zero_grad and backward in a closure so second order methods work""" + # with self.trainer.profiler.profile("training_step_and_backward"): + # # lightning module hook + # result = self.training_step(split_batch, batch_idx, opt_idx, hiddens) + # self._curr_step_result = result + # + # if not self._skip_backward and self.automatic_optimization: + # is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0 + # + # if is_first_batch_to_accumulate: + # self.on_before_zero_grad(optimizer) + # self.optimizer_zero_grad(batch_idx, optimizer, opt_idx) + # + # # backward pass + # if result is not None: + # with self.trainer.profiler.profile("backward"): + # self.backward(result, optimizer, opt_idx) + # + # # hook - call this hook only + # # when gradients have finished to accumulate + # if not self.should_accumulate(): + # self.on_after_backward(result.training_step_output, batch_idx, result.loss) + # + # # check if loss or model weights are nan + # if self.trainer.terminate_on_nan: + # self._check_finite(result.loss) + # + # else: + # self.warning_cache.warn("training_step returned None if it was on purpose, ignore this warning...") + # + # if len(self.trainer.optimizers) > 1: + # # revert back to previous state + # self.trainer.lightning_module.untoggle_optimizer(opt_idx) + # + # return result + + # def _check_finite(self, loss: torch.Tensor) -> None: + # if not torch.isfinite(loss).all(): + # raise ValueError(f'The loss returned in `training_step` is {loss}.') + # model = self.trainer.lightning_module + # detect_nan_parameters(model) + # + # def backward(self, result, optimizer, opt_idx, *args, **kwargs): + # self.trainer.dev_debugger.track_event("backward_call") + # + # should_accumulate = self.should_accumulate() + # + # # backward can be called manually in the training loop + # if isinstance(result, torch.Tensor): + # self.trainer.accelerator.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs) + # else: + # result.closure_loss = self.trainer.accelerator.backward( + # result.closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs + # ) + # + # if not self.should_accumulate(): + # # track gradients + # self.track_and_norm_grad(optimizer=optimizer) + + # def update_train_loop_lr_schedulers(self, monitor_metrics=None): + # num_accumulated_batches_reached = self._accumulated_batches_reached() + # num_training_batches_reached = self._num_training_batches_reached() + # + # if num_accumulated_batches_reached or num_training_batches_reached: + # # update lr + # self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics) + # + # def increment_accumulated_grad_global_step(self): + # num_accumulated_batches_reached = self._accumulated_batches_reached() + # num_training_batches_reached = self._num_training_batches_reached() + # + # # progress global step according to grads progress + # if num_accumulated_batches_reached or num_training_batches_reached: + # self.trainer.global_step = self.trainer.accelerator.update_global_step( + # self.trainer.total_batch_idx, self.trainer.global_step + # ) + + # def _accumulated_batches_reached(self): + # return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 + # + # def _num_training_batches_reached(self, is_last_batch=False): + # return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch + + # def should_accumulate(self): + # # checks if backward or backward + optimizer step (via closure) + # accumulation_done = self._accumulated_batches_reached() + # is_final_batch = self._num_training_batches_reached() + # return not (accumulation_done or is_final_batch) else: if self.trainer.lightning_module.automatic_optimization: - def train_step_and_backward_closure(): - result = self.training_step_and_backward( - split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens - ) - return None if result is None else result.loss - - # optimizer step - self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) - - else: - self._curr_step_result = self.training_step( - split_batch, batch_idx, opt_idx, self.trainer.hiddens - ) - - if self._curr_step_result is None: - # user decided to skip optimization - # make sure to zero grad. - continue - - batch_outputs = self._process_closure_result( - batch_outputs=batch_outputs, - opt_idx=opt_idx, - ) - - # todo: Properly aggregate grad_norm accros opt_idx and split_idx - grad_norm_dic = self._cur_grad_norm_dict - self._cur_grad_norm_dict = None + # def build_train_args(self, batch, batch_idx, opt_idx, hiddens): + # # enable not needing to add opt_idx to training_step + # args = [batch, batch_idx] + # + # if len(self.trainer.optimizers) > 1: + # if self.trainer.has_arg("training_step", "optimizer_idx"): + # if not self.automatic_optimization: + # self.warning_cache.warn( + # "`training_step` hook signature has changed in v1.3." + # " `optimizer_idx` argument has been removed in case of manual optimization. Support for" + # " the old signature will be removed in v1.5", DeprecationWarning + # ) + # args.append(opt_idx) + # elif not self.trainer.has_arg("training_step", "optimizer_idx") and self.automatic_optimization: + # raise ValueError( + # f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but" + # ' `training_step` is missing the `optimizer_idx` argument.' + # ) + # + # # pass hiddens if using tbptt + # if self.trainer.truncated_bptt_steps is not None: + # args.append(hiddens) + # + # return args + + # def save_loggers_on_train_batch_end(self): + # # when loggers should save to disk + # should_flush_logs = self.trainer.logger_connector.should_flush_logs + # if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: + # self.trainer.logger.save() + + + + # def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): + # # set split_idx to trainer for tracking + # self.trainer.split_idx = split_idx + # + # # make sure only the gradients of the current optimizer's parameters are calculated + # # in the training step to prevent dangling gradients in multiple-optimizer setup. + # if self.automatic_optimization and len(self.trainer.optimizers) > 1: + # model = self.trainer.lightning_module + # model.toggle_optimizer(optimizer, opt_idx) + # + # # use to track metrics internally + # self.trainer.logger_connector.on_train_split_start(split_idx, opt_idx, split_batch) + + # def update_running_loss(self): + # accumulated_loss = self.accumulated_loss.mean() + # + # if accumulated_loss is not None: + # # calculate running loss for display + # self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) + # + # # reset for next set of accumulated grads + # self.accumulated_loss.reset() # update running loss + reset accumulated loss self.update_running_loss() From 9dd0ee6f022875cb84795e4c3875f37b32fc50df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 13:58:40 +0200 Subject: [PATCH 009/455] statedict --- pytorch_lightning/loops/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index aeac69f50c9a4..1e232ca5d869c 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -54,3 +54,8 @@ def run(self, *args: Any, **kwargs: Any): outputs = self.on_run_end(outputs) return outputs + + def state_dict(self): + return dict() + + From 921e3c907b9db2fe903095a6724ded834d648831 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 15:15:13 +0200 Subject: [PATCH 010/455] comment TODOs --- pytorch_lightning/loops/epoch_loop.py | 63 +++++++++++----------- pytorch_lightning/loops/optimizer_loop.py | 2 +- pytorch_lightning/loops/training_loop.py | 14 ++--- pytorch_lightning/trainer/training_loop.py | 2 +- 4 files changed, 40 insertions(+), 41 deletions(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index b94642f698975..e673540b75051 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -7,6 +7,7 @@ from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.core.step_result import Result from pytorch_lightning.loops.base import Loop +from pytorch_lightning.loops.training_loop import TrainingLoop from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -18,23 +19,18 @@ class EpochLoop(Loop): - def connect( - self, - num_epochs: int, - max_steps: Optional[int], - trainer: 'pl.Trainer', - *loops_to_run: Loop, - ): - self.num_epochs = num_epochs - self.max_steps = max_steps + def connect(self, trainer: 'pl.Trainer', *args, **kwargs): + self.num_epochs = trainer.max_epochs + self.min_epochs = trainer.min_epochs + # TODO: let inner loop track the steps + self.max_steps = trainer.max_steps + self.min_steps = trainer.min_steps self.trainer = trainer - self.loops_to_run = [] - for loop in loops_to_run: - if isinstance(loop, Loop) or hasattr(loop, 'run'): - self.loops_to_run.append(loop) + self.training_loop = TrainingLoop() @property def done(self) -> bool: + # TODO: Move track steps inside training loop and move part of these condition inside training loop stop_steps = self.trainer.max_steps and self.trainer.max_steps <= self.trainer.global_step should_stop = False @@ -43,7 +39,7 @@ def done(self) -> bool: met_min_epochs = (self.iteration_count >= self.trainer.min_epochs - 1) if self.trainer.min_epochs else True met_min_steps = self.trainer.global_step >= self.trainer.min_steps if self.trainer.min_steps else True if met_min_epochs and met_min_steps: - self.train_loop.on_train_end() + self.training_loop.on_train_end() should_stop = True else: log.info( @@ -54,7 +50,6 @@ def done(self) -> bool: self.trainer.should_stop = False stop_epochs = self.iteration_count >= self.num_epochs - return stop_steps or should_stop or stop_epochs def on_run_start(self): @@ -69,7 +64,8 @@ def on_run_end(self): # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates # when a checkpoint was saved at the last step self.trainer.global_step -= 1 - self.check_checkpoint_callback(should_update=True, is_last=True) + # TODO: see discussion/rework https://github.com/PyTorchLightning/pytorch-lightning/issues/7406 + # self.check_checkpoint_callback(should_update=True, is_last=True) self.trainer.global_step += 1 # hook @@ -120,39 +116,40 @@ def on_advance_start(self): # equal to on train epoch start # why is this not the same as the old on_train_epoch_end? def on_advance_end(self, outputs): - # handle epoch_output on epoch end - self.on_train_epoch_end(outputs) + # # handle epoch_output on epoch end + # self.on_train_epoch_end(outputs) # Handled in on_run_end of training_loop now # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics(outputs) - should_check_val = self.should_check_val_fx(self.trainer.batch_idx, self.trainer.is_last_batch, on_epoch=True) + # should_check_val = self.should_check_val_fx(self.trainer.batch_idx, self.trainer.is_last_batch, on_epoch=True) should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) should_train_only = self.trainer.disable_validation or should_skip_eval # update epoch level lr_schedulers if no val loop outside train loop is triggered - if (val_loop_called and not should_check_val) or should_train_only: - self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + # if (val_loop_called and not should_check_val) or should_train_only: + self.trainer.optimizer_connector.update_learning_rates(interval='epoch') - if should_train_only: - self.check_checkpoint_callback(True) - self.check_early_stopping_callback(True) + # if should_train_only: + # self.check_checkpoint_callback(True) + # self.check_early_stopping_callback(True) - if should_check_val: - self.trainer.validating = True - self.trainer.run_evaluation(on_epoch=True) - self.trainer.training = True + # TODO: see discussion/rework https://github.com/PyTorchLightning/pytorch-lightning/issues/7406 + # if should_check_val: + # self.trainer.validating = True + # self.trainer.run_evaluation(on_epoch=True) + # self.trainer.training = True # increment the global step once # progress global step according to grads progress - self.increment_accumulated_grad_global_step() + # TODO: move inside training_loop.on_run_end? equivalent? order? + self.training_loop.increment_accumulated_grad_global_step() def advance(self): - ret_vals = [] + with self.trainer.profiler.profile("run_training_epoch"): # run train epoch - for loop in self.loops_to_run: - ret_vals.append(loop.run()) + output = self.training_loop.run() - return ret_vals + return output diff --git a/pytorch_lightning/loops/optimizer_loop.py b/pytorch_lightning/loops/optimizer_loop.py index e771c07826241..74fa0711449b5 100644 --- a/pytorch_lightning/loops/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer_loop.py @@ -99,7 +99,7 @@ def train_step_and_backward_closure(): return batch_outputs # ------------------------------------------------------------------------------------------------------------ -# HELPER +# HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ def prepare_optimizers(self): diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 5589cd77b30bf..9aed70470468b 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -12,8 +12,14 @@ class TrainingLoop(Loop): """ Runs over all batches in a dataloader (one epoch). """ + def __init__(self): + super().__init__() + # cache of all outputs in a single training run / epoch + # self.epoch_output = [[]] + def connect(self, trainer: 'pl.Trainer', *args, **kwargs): self.trainer = trainer + # self.epoch_output = [[] for _ in range(len(trainer.optimizers))] self.batch_loop = BatchLoop() def on_run_start(self): @@ -43,7 +49,7 @@ def advance(self): return # hook - # TODO: add outputs to batches + epoch_output = [[]] # TODO: track and return output, let loop base concatenate all outputs into a list etc. self.on_train_batch_end( epoch_output, batch_output.training_step_output_for_epoch_end, @@ -133,12 +139,8 @@ def on_run_end(self, outputs): self.trainer.call_hook('on_train_epoch_end', processed_outputs) self.trainer.call_hook('on_epoch_end') - # increment the global step once - # progress global step according to grads progress - self.increment_accumulated_grad_global_step() - # ------------------------------------------------------------------------------------------------------------ -# HELPER +# HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ # TODO move to on_advance_end() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 29fe17a782c9c..5cedba648222c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -954,7 +954,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): # optimizers = self.prepare_optimizers() # track all outputs across time and num of optimizers - batch_outputs = [[] for _ in range(len(optimizers))] + # batch_outputs = [[] for _ in range(len(optimizers))] # # if batch is None: # return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) From d9d57c8031afe2201004d9d8bfbae23e5c7b4afa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 15:15:23 +0200 Subject: [PATCH 011/455] trainer on init call --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index bff300e8b5880..a3e2586426a2f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -375,7 +375,7 @@ def __init__( truncated_bptt_steps, terminate_on_nan, ) - self._setup_fit_on_init( + self._setup_on_init( max_epochs, min_epochs, max_steps, From e23bee6d06387089911f7ea12a98b3e64e166912 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 16:00:49 +0200 Subject: [PATCH 012/455] restore --- pytorch_lightning/trainer/training_loop.py | 697 +++------------------ 1 file changed, 104 insertions(+), 593 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 5cedba648222c..790dc4c70bdeb 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -39,14 +39,14 @@ class TrainLoop: def __init__(self, trainer, multiple_trainloader_mode: str): self.trainer = trainer - # self.accumulated_loss = None - # self.warning_cache = WarningCache() + self.accumulated_loss = None + self.warning_cache = WarningCache() self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20) self._curr_step_result = None self._cur_grad_norm_dict = None self._multiple_trainloader_mode = multiple_trainloader_mode - # self._skip_backward = False + self._skip_backward = False self.trainer._multiple_trainloader_mode = multiple_trainloader_mode self._optimizer_freq_cumsum = None @@ -84,10 +84,6 @@ def on_trainer_init( def num_optimizers(self): num_optimizers = len(self.get_optimizers_iterable()) return num_optimizers - # - # def on_train_start(self): - # # hook - # self.trainer.call_hook("on_train_start") @property def optimizer_freq_cumsum(self): @@ -200,28 +196,7 @@ def reset_train_val_dataloaders(self, model) -> None: if self.trainer.val_dataloaders is None: self.trainer.reset_val_dataloader(model) - # def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): - # - # # track the outputs to reduce at the end of the epoch - # for opt_idx, opt_outputs in enumerate(batch_end_outputs): - # sample_output = opt_outputs[-1] - # - # # decide if we need to reduce at the end of the epoch automatically - # auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end - # hook_overridden = ( - # is_overridden("training_epoch_end", model=self.trainer.lightning_module) - # or is_overridden("on_train_epoch_end", model=self.trainer.lightning_module) - # ) - # - # # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end - # if not (hook_overridden or auto_reduce_tng_result): - # continue - # - # # with 1 step (no tbptt) don't use a sequence at epoch end - # if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result): - # opt_outputs = opt_outputs[0] - # - # epoch_output[opt_idx].append(opt_outputs) + def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): hook_overridden = self._should_add_batch_output_to_epoch_output() @@ -232,201 +207,15 @@ def reset_train_val_dataloaders(self, model) -> None: # decide if we need to reduce at the end of the epoch automatically auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end - # def _check_training_step_output(self, training_step_output): - # if isinstance(training_step_output, torch.Tensor) and not self.automatic_optimization: - # if training_step_output.grad_fn is None: - # # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... - # raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor") - # - # def training_step(self, split_batch, batch_idx, opt_idx, hiddens): - # # give the PL module a result for logging - # model_ref = self.trainer.lightning_module - # - # with self.trainer.profiler.profile("model_forward"): - # args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) - # - # # manually capture logged metrics - # model_ref._current_fx_name = 'training_step' - # model_ref._results = Result() - # with self.trainer.profiler.profile("training_step"): - # training_step_output = self.trainer.accelerator.training_step(args) - # self.trainer.accelerator.post_training_step() - # - # self.trainer.logger_connector.cache_logged_metrics() - # - # self._check_training_step_output(training_step_output) - # - # training_step_output = self.trainer.call_hook("training_step_end", training_step_output) - # - # training_step_output_for_epoch_end, training_step_output = self._process_training_step_output( - # training_step_output, split_batch - # ) - # if training_step_output_for_epoch_end is None: - # return - # - # # enable empty loss when using manual opt - # closure_loss = None - # untouched_loss = None - # - # if self.automatic_optimization: - # # accumulate loss. if accumulate_grad_batches==1, no effect - # closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches - # - # # the loss will get scaled for amp. avoid any modifications to it - # untouched_loss = closure_loss.detach().clone() - # - # # result - # result = AttributeDict( - # closure_loss=closure_loss, - # loss=untouched_loss, - # training_step_output=training_step_output, - # training_step_output_for_epoch_end=training_step_output_for_epoch_end, - # ) - # return result - # - # def _process_training_step_output(self, training_step_output, split_batch): - # training_step_output_for_epoch_end = training_step_output - # - # # enable validation_step return None - # if training_step_output_for_epoch_end is None: - # return None, None - # - # result = self.trainer.lightning_module._results - # - # loss = None - # hiddens = None - # result["extra"] = {} - # - # # handle dict return - # if isinstance(training_step_output, dict): - # loss = training_step_output.pop("loss", None) - # hiddens = training_step_output.pop("hiddens", None) - # if hiddens is not None: - # hiddens = hiddens.detach() - # result["extra"] = training_step_output - # - # # handle scalar return - # elif isinstance(training_step_output, torch.Tensor): - # loss = training_step_output - # - # # map to results under the hood - # result.minimize = loss - # self.trainer.hiddens = hiddens - # - # # track batch for manual reduction with result - # result.track_batch_size(len(split_batch)) - # - # # track metrics without grads for epoch reduction - # training_step_output_for_epoch_end = copy(result) - # training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() - # if self.trainer.move_metrics_to_cpu: - # training_step_output_for_epoch_end = training_step_output_for_epoch_end.cpu() - # - # return training_step_output_for_epoch_end, result - - # @staticmethod - # def _prepare_outputs( - # outputs: List[List[List[Result]]], - # batch_mode: bool, - # ) -> Union[List[List[List[Dict]]], List[List[Dict]], List[Dict], Dict]: - # """ - # Extract required information from batch or epoch end results. - # - # Args: - # outputs: A 3-dimensional list of ``Result`` objects with dimensions: - # [optimizer outs][batch outs][tbptt steps]. - # - # batch_mode: If True, ignore the batch output dimension. - # - # Returns: - # The cleaned outputs with ``Result`` objects converted to dictionaries. All list dimensions of size one will - # be collapsed. - # """ - # processed_outputs = [] - # for opt_outputs in outputs: - # # handle an edge case where an optimizer output is the empty list - # if len(opt_outputs) == 0: - # continue - # - # processed_batch_outputs = [] - # - # if batch_mode: - # opt_outputs = [opt_outputs] - # - # for batch_outputs in opt_outputs: - # processed_tbptt_outputs = [] - # - # for tbptt_output in batch_outputs: - # out = tbptt_output.extra - # out['loss'] = tbptt_output.minimize - # processed_tbptt_outputs.append(out) - # - # # if there was only one tbptt step then we can collapse that dimension - # if len(processed_tbptt_outputs) == 1: - # processed_tbptt_outputs = processed_tbptt_outputs[0] - # processed_batch_outputs.append(processed_tbptt_outputs) - # - # # batch_outputs should be just one dict (or a list of dicts if using tbptt) per optimizer - # if batch_mode: - # processed_batch_outputs = processed_batch_outputs[0] - # processed_outputs.append(processed_batch_outputs) - # - # # if there is only one optimiser then we collapse that dimension - # if len(processed_outputs) == 1: - # processed_outputs = processed_outputs[0] - # return processed_outputs - - # def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): - # model_ref = self.trainer.lightning_module - # - # is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) - # using_native_amp = self.trainer.amp_backend == AMPType.NATIVE - # - # # native amp + lbfgs is a no go right now - # if using_native_amp and is_lbfgs: - # raise MisconfigurationException( - # 'native PyTorch amp and lbfgs are not compatible.' - # ' To request, please file a Github issue in PyTorch and tag @mcarilli' - # ) - # - # # wraps into LightningOptimizer only for running step - # optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx) - # - # # model hook - # model_ref.optimizer_step( - # self.trainer.current_epoch, - # batch_idx, - # optimizer, - # opt_idx, - # train_step_and_backward_closure, - # on_tpu=self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE, - # using_native_amp=using_native_amp, - # using_lbfgs=is_lbfgs, - # ) - # - # def on_before_zero_grad(self, optimizer): - # self.trainer.call_hook('on_before_zero_grad', optimizer) - # - # def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): - # self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) - # - # def track_and_norm_grad(self, optimizer): - # # track gradient norms - # grad_norm_dic = self._track_gradient_norm() - # - # # clip gradients - # self.trainer.accelerator.clip_gradients( - # optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm - # ) - # self._cur_grad_norm_dict = grad_norm_dic - # - # def _track_gradient_norm(self): - # grad_norm_dict = {} - # if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0: - # if float(self.trainer.track_grad_norm) > 0: - # model = self.trainer.lightning_module - # grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm) - # return grad_norm_dict + # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end + if not (hook_overridden or auto_reduce_tng_result): + continue + + # with 1 step (no tbptt) don't use a sequence at epoch end + if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result): + opt_outputs = opt_outputs[0] + + epoch_output[opt_idx].append(opt_outputs) def _should_add_batch_output_to_epoch_output(self) -> bool: # We add to the epoch outputs if @@ -694,30 +483,11 @@ def run_training_epoch(self): self.trainer.batch_idx = batch_idx self.trainer.is_last_batch = is_last_batch - # # ----------------------------------------- - # # SAVE METRICS TO LOGGERS - # # ----------------------------------------- - # self.trainer.logger_connector.log_train_step_metrics(batch_output) - # - # # ----------------------------------------- - # # VALIDATE IF NEEDED + CHECKPOINT CALLBACK - # # ----------------------------------------- - # should_check_val = self.should_check_val_fx(batch_idx, is_last_batch) - # if should_check_val: - # self.trainer.validating = True - # self.trainer.run_evaluation() - # self.trainer.training = True - # val_loop_called = True - # - # # ----------------------------------------- - # # SAVE LOGGERS (ie: Tensorboard, etc...) - # # ----------------------------------------- - # self.save_loggers_on_train_batch_end() - # - # # update LR schedulers - # monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics) - # self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) - # self.trainer.checkpoint_connector.has_trained = True + # ------------------------------------ + # TRAINING_STEP + TRAINING_STEP_END + # ------------------------------------ + with self.trainer.profiler.profile("run_training_batch"): + batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx) # when returning -1 from train_step, we end epoch early if batch_output.signal == -1: @@ -759,84 +529,23 @@ def run_training_epoch(self): self.trainer.checkpoint_connector.has_trained = True # max steps reached, end training - # if ( - # self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step + 1 - # and self._accumulated_batches_reached() - # ): - # break - # - # # end epoch early - # # stop when the flag is changed or we've gone past the amount - # # requested in the batches - # if self.trainer.should_stop: - # break - # - # self.trainer.total_batch_idx += 1 - # - # # stop epoch if we limited the number of training batches - # if self._num_training_batches_reached(is_last_batch): - # break - - # # progress global step according to grads progress - # self.increment_accumulated_grad_global_step() - # - # # handle epoch_output on epoch end - # self.on_train_epoch_end(epoch_output) - - # # log epoch metrics - # self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output) - # - # should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) - # should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) - # should_train_only = self.trainer.disable_validation or should_skip_eval - # - # # update epoch level lr_schedulers if no val loop outside train loop is triggered - # if (val_loop_called and not should_check_val) or should_train_only: - # self.trainer.optimizer_connector.update_learning_rates(interval='epoch') - # - # if should_train_only: - # self.check_checkpoint_callback(True) - # self.check_early_stopping_callback(True) - # - # if should_check_val: - # self.trainer.validating = True - # self.trainer.run_evaluation(on_epoch=True) - # self.trainer.training = True - # - # # increment the global step once - # # progress global step according to grads progress - # self.increment_accumulated_grad_global_step() - - # def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None: - # # inform logger the batch loop has finished - # self.trainer.logger_connector.on_train_epoch_end() - # - # # prepare epoch output - # processed_epoch_output = TrainLoop._prepare_outputs(epoch_output, batch_mode=False) - # - # # get the model and call model.training_epoch_end - # model = self.trainer.lightning_module - # - # if is_overridden('training_epoch_end', model=model): - # # run training_epoch_end - # # refresh the result for custom logging at the epoch level - # model._current_fx_name = 'training_epoch_end' - # - # # lightningmodule hook - # training_epoch_end_output = model.training_epoch_end(processed_epoch_output) - # - # if training_epoch_end_output is not None: - # raise MisconfigurationException( - # 'training_epoch_end expects a return of None. ' - # 'HINT: remove the return statement in training_epoch_end' - # ) - # - # # capture logging - # self.trainer.logger_connector.cache_logged_metrics() - # - # # call train epoch end hooks - # self.trainer.call_hook('on_train_epoch_end', processed_epoch_output) - # self.trainer.call_hook('on_epoch_end') + if ( + self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step + 1 + and self._accumulated_batches_reached() + ): + break + + # end epoch early + # stop when the flag is changed or we've gone past the amount + # requested in the batches + if self.trainer.should_stop: + break + + self.trainer.total_batch_idx += 1 + + # stop epoch if we limited the number of training batches + if self._num_training_batches_reached(is_last_batch): + break # progress global step according to grads progress self.increment_accumulated_grad_global_step() @@ -951,13 +660,10 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): # bookkeeping self.trainer.hiddens = None - # optimizers = self.prepare_optimizers() + optimizers = self.prepare_optimizers() # track all outputs across time and num of optimizers - # batch_outputs = [[] for _ in range(len(optimizers))] - # - # if batch is None: - # return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) + batch_outputs = [[] for _ in range(len(optimizers))] if batch is None: self.warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") @@ -967,276 +673,81 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): training_step_output_for_epoch_end=batch_outputs, ) + # hook + response = self.trainer.call_hook("on_batch_start") + if response == -1: + return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) + + # hook + response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx) + if response == -1: + return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) + # lightning module hook splits = self._tbptt_split_batch(batch) + for split_idx, split_batch in enumerate(splits): + # create an iterable for optimizers and loop over them - # for opt_idx, optimizer in optimizers: - # - # # toggle model params + set info to logger_connector - # self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) - # - # if self.should_accumulate(): - # # For gradient accumulation - # - # # ------------------- - # # calculate loss (train step + train step end) - # # ------------------- - # - # # automatic_optimization=True: perform dpp sync only when performing optimizer_step - # # automatic_optimization=False: don't block synchronization here - # with self.block_ddp_sync_behaviour(): - # self.training_step_and_backward( - # split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens - # ) - # - # batch_outputs = self._process_closure_result( - # batch_outputs=batch_outputs, - # opt_idx=opt_idx, - # ) - # - # # ------------------------------ - # # BACKWARD PASS - # # ------------------------------ - # # gradient update with accumulated gradients - # - # else: - # if self.automatic_optimization: - # - # def train_step_and_backward_closure(): - # result = self.training_step_and_backward( - # split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens - # ) - # return None if result is None else result.loss - # - # # optimizer step - # self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) - # - # else: - # self._curr_step_result = self.training_step( - # split_batch, batch_idx, opt_idx, self.trainer.hiddens - # ) - # - # if self._curr_step_result is None: - # # user decided to skip optimization - # # make sure to zero grad. - # continue - # - # batch_outputs = self._process_closure_result( - # batch_outputs=batch_outputs, - # opt_idx=opt_idx, - # ) - # - # # todo: Properly aggregate grad_norm accros opt_idx and split_idx - # grad_norm_dic = self._cur_grad_norm_dict - # self._cur_grad_norm_dict = None - # - # # update running loss + reset accumulated loss - # self.update_running_loss() - # - # result = AttributeDict( - # signal=0, - # grad_norm_dic=grad_norm_dic, - # training_step_output_for_epoch_end=batch_outputs, - # ) - # return result - - # @contextmanager - # def block_ddp_sync_behaviour(self, should_block_sync: bool = False): - # """ - # automatic_optimization = True - # Blocks ddp sync gradients behaviour on backwards pass. - # This is useful for skipping sync when accumulating gradients, reducing communication overhead - # - # automatic_optimization = False - # do not block ddp gradient sync when using manual optimization - # as gradients are needed within the training step - # - # Returns: - # context manager with sync behaviour off - # - # """ - # if ( - # isinstance(self.trainer.training_type_plugin, ParallelPlugin) - # and (self.automatic_optimization or should_block_sync) - # ): - # with self.trainer.training_type_plugin.block_backward_sync(): - # yield None - # else: - # yield None - - # def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list: - # opt_closure_result = self._curr_step_result - # - # if opt_closure_result is not None: - # - # # cache metrics - # self.trainer.logger_connector.cache_training_step_metrics(opt_closure_result) - # - # # check if loss or model weights are nan - # if self.trainer.terminate_on_nan: - # self._check_finite(opt_closure_result.loss) - # - # # track all the outputs across all steps - # batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0 - # batch_outputs[batch_opt_idx].append(opt_closure_result.training_step_output_for_epoch_end) - # - # if self.automatic_optimization: - # # track total loss for logging (avoid mem leaks) - # self.accumulated_loss.append(opt_closure_result.loss) - # - # self._curr_step_result = None - # - # return batch_outputs - # - # def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): - # """Wrap forward, zero_grad and backward in a closure so second order methods work""" - # with self.trainer.profiler.profile("training_step_and_backward"): - # # lightning module hook - # result = self.training_step(split_batch, batch_idx, opt_idx, hiddens) - # self._curr_step_result = result - # - # if not self._skip_backward and self.automatic_optimization: - # is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0 - # - # if is_first_batch_to_accumulate: - # self.on_before_zero_grad(optimizer) - # self.optimizer_zero_grad(batch_idx, optimizer, opt_idx) - # - # # backward pass - # if result is not None: - # with self.trainer.profiler.profile("backward"): - # self.backward(result, optimizer, opt_idx) - # - # # hook - call this hook only - # # when gradients have finished to accumulate - # if not self.should_accumulate(): - # self.on_after_backward(result.training_step_output, batch_idx, result.loss) - # - # # check if loss or model weights are nan - # if self.trainer.terminate_on_nan: - # self._check_finite(result.loss) - # - # else: - # self.warning_cache.warn("training_step returned None if it was on purpose, ignore this warning...") - # - # if len(self.trainer.optimizers) > 1: - # # revert back to previous state - # self.trainer.lightning_module.untoggle_optimizer(opt_idx) - # - # return result - - # def _check_finite(self, loss: torch.Tensor) -> None: - # if not torch.isfinite(loss).all(): - # raise ValueError(f'The loss returned in `training_step` is {loss}.') - # model = self.trainer.lightning_module - # detect_nan_parameters(model) - # - # def backward(self, result, optimizer, opt_idx, *args, **kwargs): - # self.trainer.dev_debugger.track_event("backward_call") - # - # should_accumulate = self.should_accumulate() - # - # # backward can be called manually in the training loop - # if isinstance(result, torch.Tensor): - # self.trainer.accelerator.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs) - # else: - # result.closure_loss = self.trainer.accelerator.backward( - # result.closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs - # ) - # - # if not self.should_accumulate(): - # # track gradients - # self.track_and_norm_grad(optimizer=optimizer) - - # def update_train_loop_lr_schedulers(self, monitor_metrics=None): - # num_accumulated_batches_reached = self._accumulated_batches_reached() - # num_training_batches_reached = self._num_training_batches_reached() - # - # if num_accumulated_batches_reached or num_training_batches_reached: - # # update lr - # self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics) - # - # def increment_accumulated_grad_global_step(self): - # num_accumulated_batches_reached = self._accumulated_batches_reached() - # num_training_batches_reached = self._num_training_batches_reached() - # - # # progress global step according to grads progress - # if num_accumulated_batches_reached or num_training_batches_reached: - # self.trainer.global_step = self.trainer.accelerator.update_global_step( - # self.trainer.total_batch_idx, self.trainer.global_step - # ) - - # def _accumulated_batches_reached(self): - # return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 - # - # def _num_training_batches_reached(self, is_last_batch=False): - # return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch - - # def should_accumulate(self): - # # checks if backward or backward + optimizer step (via closure) - # accumulation_done = self._accumulated_batches_reached() - # is_final_batch = self._num_training_batches_reached() - # return not (accumulation_done or is_final_batch) + for opt_idx, optimizer in optimizers: + + # toggle model params + set info to logger_connector + self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) + + if self.should_accumulate(): + # For gradient accumulation + + # ------------------- + # calculate loss (train step + train step end) + # ------------------- + + # automatic_optimization=True: perform dpp sync only when performing optimizer_step + # automatic_optimization=False: don't block synchronization here + with self.block_ddp_sync_behaviour(): + self.training_step_and_backward( + split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens + ) + + batch_outputs = self._process_closure_result( + batch_outputs=batch_outputs, + opt_idx=opt_idx, + ) + + # ------------------------------ + # BACKWARD PASS + # ------------------------------ + # gradient update with accumulated gradients else: if self.trainer.lightning_module.automatic_optimization: - # def build_train_args(self, batch, batch_idx, opt_idx, hiddens): - # # enable not needing to add opt_idx to training_step - # args = [batch, batch_idx] - # - # if len(self.trainer.optimizers) > 1: - # if self.trainer.has_arg("training_step", "optimizer_idx"): - # if not self.automatic_optimization: - # self.warning_cache.warn( - # "`training_step` hook signature has changed in v1.3." - # " `optimizer_idx` argument has been removed in case of manual optimization. Support for" - # " the old signature will be removed in v1.5", DeprecationWarning - # ) - # args.append(opt_idx) - # elif not self.trainer.has_arg("training_step", "optimizer_idx") and self.automatic_optimization: - # raise ValueError( - # f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but" - # ' `training_step` is missing the `optimizer_idx` argument.' - # ) - # - # # pass hiddens if using tbptt - # if self.trainer.truncated_bptt_steps is not None: - # args.append(hiddens) - # - # return args - - # def save_loggers_on_train_batch_end(self): - # # when loggers should save to disk - # should_flush_logs = self.trainer.logger_connector.should_flush_logs - # if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: - # self.trainer.logger.save() - - - - # def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): - # # set split_idx to trainer for tracking - # self.trainer.split_idx = split_idx - # - # # make sure only the gradients of the current optimizer's parameters are calculated - # # in the training step to prevent dangling gradients in multiple-optimizer setup. - # if self.automatic_optimization and len(self.trainer.optimizers) > 1: - # model = self.trainer.lightning_module - # model.toggle_optimizer(optimizer, opt_idx) - # - # # use to track metrics internally - # self.trainer.logger_connector.on_train_split_start(split_idx, opt_idx, split_batch) - - # def update_running_loss(self): - # accumulated_loss = self.accumulated_loss.mean() - # - # if accumulated_loss is not None: - # # calculate running loss for display - # self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) - # - # # reset for next set of accumulated grads - # self.accumulated_loss.reset() + def train_step_and_backward_closure(): + result = self.training_step_and_backward( + split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens + ) + return None if result is None else result.loss + + # optimizer step + self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) + + else: + self._curr_step_result = self.training_step( + split_batch, batch_idx, opt_idx, self.trainer.hiddens + ) + + if self._curr_step_result is None: + # user decided to skip optimization + # make sure to zero grad. + continue + + batch_outputs = self._process_closure_result( + batch_outputs=batch_outputs, + opt_idx=opt_idx, + ) + + # todo: Properly aggregate grad_norm accros opt_idx and split_idx + grad_norm_dic = self._cur_grad_norm_dict + self._cur_grad_norm_dict = None # update running loss + reset accumulated loss self.update_running_loss() From 7bfa49ea8249b85855d3bbf10ab95d893622ae25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 16:01:05 +0200 Subject: [PATCH 013/455] new loop entry point --- pytorch_lightning/trainer/trainer.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a3e2586426a2f..80b5981a8d032 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -29,6 +29,7 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.step_result import Result from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.loops.epoch_loop import EpochLoop from pytorch_lightning.plugins import Plugin from pytorch_lightning.plugins.environments import ClusterEnvironment from pytorch_lightning.profiler import BaseProfiler @@ -54,7 +55,7 @@ from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin from pytorch_lightning.trainer.predict_loop import PredictLoop from pytorch_lightning.trainer.properties import TrainerProperties -from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus +from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus, TrainerState from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.lr_finder import _LRFinder @@ -423,7 +424,7 @@ def _setup_on_init( self.global_step = 0 self.current_epoch = 0 self.should_stop = False - self._state = TrainerState.INITIALIZING + self.state = TrainerState() self.total_batch_idx = 0 self.batch_idx = 0 @@ -877,11 +878,26 @@ def _pre_training_routine(self): self.on_pretrain_routine_end() ref_model.on_pretrain_routine_end() + def reset_train_val_dataloaders(self, model) -> None: + """ + Resets train and val dataloaders if none are attached to the trainer. + + The val dataloader must be initialized before training loop starts, as the training loop + inspects the val dataloader to determine whether to run the evaluation loop. + """ + if self.train_dataloader is None: + self.reset_train_dataloader(model) + + if self.val_dataloaders is None: + self.reset_val_dataloader(model) + def run_train(self) -> None: - new_loop = False + new_loop = True if new_loop: + self.train_loop = EpochLoop() + self.train_loop.connect(self) self._run_train_new_loop() else: self._run_train_old_loop() @@ -908,7 +924,7 @@ def _run_train_new_loop(self) -> None: model = self.lightning_module # This might move somewhere else - self.train_loop.reset_train_val_dataloaders(model) + self.reset_train_val_dataloaders(model) try: if self._should_skip_training(): From adbf829a2211704b5d11e8272e54005112e8a674 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 16:01:32 +0200 Subject: [PATCH 014/455] is last --- pytorch_lightning/loops/training_loop.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 9aed70470468b..28dc63ba8be2a 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -14,6 +14,7 @@ class TrainingLoop(Loop): def __init__(self): super().__init__() + self.is_last_batch = True # cache of all outputs in a single training run / epoch # self.epoch_output = [[]] @@ -34,7 +35,7 @@ def advance(self): batch_idx, (batch, is_last) = next(self._train_dataloader) self.trainer.batch_idx = batch_idx - self.trainer.is_last_batch = is_last + self.is_last_batch = is_last # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END @@ -67,7 +68,7 @@ def on_advance_end(self, output): # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- - should_check_val = self.should_check_val_fx(self.trainer.batch_idx, self.trainer.is_last_batch) + should_check_val = self.should_check_val_fx(self.trainer.batch_idx, self.is_last_batch) if should_check_val: self.trainer.validating = True self.trainer.run_evaluation() @@ -104,7 +105,7 @@ def done(self): self.trainer.total_batch_idx += 1 # stop epoch if we limited the number of training batches - if self._num_training_batches_reached(self.trainer.is_last_batch): + if self._num_training_batches_reached(self.is_last_batch): return True # this is the old on train_epoch_end? @@ -143,6 +144,9 @@ def on_run_end(self, outputs): # HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ + def _num_training_batches_reached(self, is_last_batch=False): + return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch + # TODO move to on_advance_end() def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx): batch_end_outputs = [opt_idx_out for opt_idx_out in batch_end_outputs if len(opt_idx_out)] From a462b6394bf77637cc7908baa45c109a0f7d255f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 16:01:46 +0200 Subject: [PATCH 015/455] missing connect --- pytorch_lightning/loops/batch_loop.py | 8 ++++++++ pytorch_lightning/loops/epoch_loop.py | 1 + 2 files changed, 9 insertions(+) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 4f10b01cce2ac..9740d9c2a9b42 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -6,6 +6,14 @@ class BatchLoop(Loop): """ Runs over a single batch of data. """ + def connect(self, trainer, *args, **kwargs): + self.trainer = trainer + + @property + def done(self): + # TODO this + return True + def on_run_start(self, batch, batch_idx, dataloader_idx): self._grad_norm_dic = {} self.trainer.hiddens = None diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index e673540b75051..30e6ebd49c1dc 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -27,6 +27,7 @@ def connect(self, trainer: 'pl.Trainer', *args, **kwargs): self.min_steps = trainer.min_steps self.trainer = trainer self.training_loop = TrainingLoop() + self.training_loop.connect(trainer) @property def done(self) -> bool: From 8af2a835a27f0c9f464f3596b1926c16c3ff6ca5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 18:42:16 +0200 Subject: [PATCH 016/455] optimizer iteration --- pytorch_lightning/loops/base.py | 2 -- pytorch_lightning/loops/optimizer_loop.py | 18 +++++++++++++++--- pytorch_lightning/loops/training_loop.py | 5 +++-- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 1e232ca5d869c..106e15f12d3cb 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -57,5 +57,3 @@ def run(self, *args: Any, **kwargs: Any): def state_dict(self): return dict() - - diff --git a/pytorch_lightning/loops/optimizer_loop.py b/pytorch_lightning/loops/optimizer_loop.py index 74fa0711449b5..4d3bc8c270bec 100644 --- a/pytorch_lightning/loops/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer_loop.py @@ -8,7 +8,7 @@ from pytorch_lightning.core.step_result import Result from pytorch_lightning.loops.base import Loop from pytorch_lightning.plugins import ParallelPlugin -from pytorch_lightning.trainer.supporters import TensorRunningAccum +from pytorch_lightning.trainer.supporters import TensorRunningAccum, prefetch_iterator from pytorch_lightning.utilities import AttributeDict, DeviceType, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.finite_checks import detect_nan_parameters @@ -21,7 +21,8 @@ class OptimizerLoop(Loop): def __init__(self): super().__init__() - self._optimizers = enumerate(self.prepare_optimizers()) + self._optimizers = self.prepare_optimizers() + self._current_optmizer_idx = 0 self.accumulated_loss = None self.warning_cache = WarningCache() @@ -34,8 +35,19 @@ def __init__(self): self._skip_backward = False # self.trainer._multiple_trainloader_mode = multiple_trainloader_mode + def connect(self, trainer, *args, **kwargs): + self.trainer = trainer + + @property + def done(self): + return self._current_optmizer_idx >= len(self._optimizers) + + def next_optimizer(self): + next(self._optimizers) + def advance(self, split_batch, split_idx, batch_idx): - opt_idx, optimizer = next(self._optimizers) + opt_idx = self._current_optmizer_idx + optimizer = self._optimizers[opt_idx] # toggle model params + set info to logger_connector self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 28dc63ba8be2a..8e4109c3b9bf4 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -14,7 +14,7 @@ class TrainingLoop(Loop): def __init__(self): super().__init__() - self.is_last_batch = True + self.is_last_batch = False # cache of all outputs in a single training run / epoch # self.epoch_output = [[]] @@ -22,6 +22,7 @@ def connect(self, trainer: 'pl.Trainer', *args, **kwargs): self.trainer = trainer # self.epoch_output = [[] for _ in range(len(trainer.optimizers))] self.batch_loop = BatchLoop() + self.batch_loop.connect(trainer) def on_run_start(self): # modify dataloader if needed (ddp, etc...) @@ -145,7 +146,7 @@ def on_run_end(self, outputs): # ------------------------------------------------------------------------------------------------------------ def _num_training_batches_reached(self, is_last_batch=False): - return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch + return self.iteration_count == self.trainer.num_training_batches or is_last_batch # TODO move to on_advance_end() def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx): From acc02b6333cf49376f62d70861c612b4eafabe2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 10 May 2021 11:51:31 +0200 Subject: [PATCH 017/455] reunite batch loop and optimizer loop --- pytorch_lightning/loops/batch_loop.py | 470 +++++++++++++++++++++- pytorch_lightning/loops/optimizer_loop.py | 459 --------------------- pytorch_lightning/loops/training_loop.py | 6 +- 3 files changed, 454 insertions(+), 481 deletions(-) delete mode 100644 pytorch_lightning/loops/optimizer_loop.py diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 9740d9c2a9b42..82c5e60ee95da 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -1,43 +1,119 @@ +from contextlib import contextmanager +from copy import copy + +import numpy as np +import torch + +from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.core.step_result import Result from pytorch_lightning.loops.base import Loop -from pytorch_lightning.loops.optimizer_loop import OptimizerLoop -from pytorch_lightning.utilities import AttributeDict +from pytorch_lightning.plugins import ParallelPlugin +from pytorch_lightning.trainer.supporters import TensorRunningAccum, prefetch_iterator +from pytorch_lightning.utilities import AttributeDict, DeviceType, AMPType +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.finite_checks import detect_nan_parameters +from pytorch_lightning.utilities.imports import _TPU_AVAILABLE +from pytorch_lightning.utilities.warnings import WarningCache class BatchLoop(Loop): """ Runs over a single batch of data. """ + def __init__(self): + super().__init__() + # self.accumulated_loss = None # TODO: needs to be done over epoch + self.warning_cache = WarningCache() + # self._teardown_already_run = False + self.running_loss = TensorRunningAccum(window_length=20) + self.automatic_optimization = True + self._curr_step_result = None + self._cur_grad_norm_dict = None + # self._multiple_trainloader_mode = multiple_trainloader_mode + self._skip_backward = False + # self.trainer._multiple_trainloader_mode = multiple_trainloader_mode + def connect(self, trainer, *args, **kwargs): self.trainer = trainer + self._optimizers = self.prepare_optimizers() @property def done(self): - # TODO this - return True + return len(self._remaining_splits) == 0 def on_run_start(self, batch, batch_idx, dataloader_idx): self._grad_norm_dic = {} self.trainer.hiddens = None # self._optimizers = self.prepare_optimizers() # lightning module hook - self._splits = enumerate(self.tbptt_split_batch(batch)) - self.tbptt_loop = OptimizerLoop() + self._remaining_splits = list(enumerate(self.tbptt_split_batch(batch))) - def on_advance_start(self): - return super().on_advance_start() + def advance(self, batch, batch_idx, dataloader_idx): + split_idx, split_batch = self._remaining_splits.pop(0) - def advance(self, batch, batch_idx): - split_idx, split_batch = next(self._splits) - batch_outputs = self.tbptt_loop.run(split_batch, split_idx, batch_idx) + batch_outputs = [[] for _ in range(len(self._optimizers))] - result = AttributeDict( - signal=0, - grad_norm_dic=grad_norm_dic, - training_step_output_for_epoch_end=batch_outputs, - ) - return result + for opt_idx, optimizer in self._optimizers: + # toggle model params + set info to logger_connector + self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) + + if self.should_accumulate(): + # For gradient accumulation + + # ------------------- + # calculate loss (train step + train step end) + # ------------------- + + # automatic_optimization=True: perform dpp sync only when performing optimizer_step + # automatic_optimization=False: don't block synchronization here + with self.block_ddp_sync_behaviour(): + self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) + + batch_outputs = self._process_closure_result( + batch_outputs=batch_outputs, + opt_idx=opt_idx, + ) + + # ------------------------------ + # BACKWARD PASS + # ------------------------------ + # gradient update with accumulated gradients + + else: + if self.automatic_optimization: + + def train_step_and_backward_closure(): + result = self.training_step_and_backward( + split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens + ) + return None if result is None else result.loss + + # optimizer step + self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) + + else: + self._curr_step_result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens) + + if self._curr_step_result is None: + # user decided to skip optimization + # make sure to zero grad. + # TODO add logic to skip in the outer loop + return + # continue + + batch_outputs = self._process_closure_result( + batch_outputs=batch_outputs, + opt_idx=opt_idx, + ) + + # todo: Properly aggregate grad_norm accros opt_idx and split_idx + grad_norm_dic = self._cur_grad_norm_dict + self._cur_grad_norm_dict = None + + # update running loss + reset accumulated loss + # self.update_running_loss() + return batch_outputs def run(self, batch, batch_idx, dataloader_idx): - # TODO why is this not in on_run_start? if batch is None: return AttributeDict(signal=0, grad_norm_dic={}) @@ -51,7 +127,14 @@ def run(self, batch, batch_idx, dataloader_idx): if response == -1: return AttributeDict(signal=-1, grad_norm_dic={}) - return super().run(batch, batch_idx, dataloader_idx) + batch_outputs = super().run(batch, batch_idx, dataloader_idx) + + result = AttributeDict( + signal=0, + grad_norm_dic=self._cur_grad_norm_dict, + training_step_output_for_epoch_end=batch_outputs, + ) + return result def tbptt_split_batch(self, batch): splits = [batch] @@ -60,3 +143,352 @@ def tbptt_split_batch(self, batch): with self.trainer.profiler.profile("tbptt_split_batch"): splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps) return splits + +# ------------------------------------------------------------------------------------------------------------ +# HELPER --- TO BE CLEANED UP +# ------------------------------------------------------------------------------------------------------------ + + def prepare_optimizers(self): + # in manual optimization we loop over all optimizers at once + optimizers = self.get_optimizers_iterable() + if not self.automatic_optimization: + optimizers = [optimizers[0]] + return optimizers + + def get_optimizers_iterable(self): + """ + Generates an iterable with (idx, optimizer) for each optimizer. + """ + if not self.trainer.optimizer_frequencies: + # call training_step once per optimizer + return list(enumerate(self.trainer.optimizers)) + + optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) + optimizers_loop_length = optimizer_freq_cumsum[-1] + current_place_in_loop = self.trainer.total_batch_idx % optimizers_loop_length + + # find optimzier index by looking for the first {item > current_place} in the cumsum list + opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) + return [[opt_idx, self.trainer.optimizers[opt_idx]]] + + def on_after_backward(self, training_step_output, batch_idx, untouched_loss): + training_step_output.detach() + + # insert after step hook + self.trainer.call_hook("on_after_backward") + + # when in dev debugging track the losses + self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach()) + + def _check_training_step_output(self, training_step_output): + if isinstance(training_step_output, torch.Tensor) and not self.automatic_optimization: + if training_step_output.grad_fn is None: + # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... + raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor") + + def training_step(self, split_batch, batch_idx, opt_idx, hiddens): + # give the PL module a result for logging + model_ref = self.trainer.lightning_module + + with self.trainer.profiler.profile("model_forward"): + args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) + + # manually capture logged metrics + model_ref._current_fx_name = 'training_step' + model_ref._results = Result() + with self.trainer.profiler.profile("training_step"): + training_step_output = self.trainer.accelerator.training_step(args) + self.trainer.accelerator.post_training_step() + + self.trainer.logger_connector.cache_logged_metrics() + + self._check_training_step_output(training_step_output) + + training_step_output = self.trainer.call_hook("training_step_end", training_step_output) + + training_step_output_for_epoch_end, training_step_output = self._process_training_step_output( + training_step_output, split_batch + ) + if training_step_output_for_epoch_end is None: + return + + # enable empty loss when using manual opt + closure_loss = None + untouched_loss = None + + if self.automatic_optimization: + # accumulate loss. if accumulate_grad_batches==1, no effect + closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches + + # the loss will get scaled for amp. avoid any modifications to it + untouched_loss = closure_loss.detach().clone() + + # result + result = AttributeDict( + closure_loss=closure_loss, + loss=untouched_loss, + training_step_output=training_step_output, + training_step_output_for_epoch_end=training_step_output_for_epoch_end, + ) + return result + + def _process_training_step_output(self, training_step_output, split_batch): + training_step_output_for_epoch_end = training_step_output + + # enable validation_step return None + if training_step_output_for_epoch_end is None: + return None, None + + result = self.trainer.lightning_module._results + + loss = None + hiddens = None + result["extra"] = {} + + # handle dict return + if isinstance(training_step_output, dict): + loss = training_step_output.pop("loss", None) + hiddens = training_step_output.pop("hiddens", None) + if hiddens is not None: + hiddens = hiddens.detach() + result["extra"] = training_step_output + + # handle scalar return + elif isinstance(training_step_output, torch.Tensor): + loss = training_step_output + + # map to results under the hood + result.minimize = loss + self.trainer.hiddens = hiddens + + # track batch for manual reduction with result + result.track_batch_size(len(split_batch)) + + # track metrics without grads for epoch reduction + training_step_output_for_epoch_end = copy(result) + training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() + if self.trainer.move_metrics_to_cpu: + training_step_output_for_epoch_end = training_step_output_for_epoch_end.cpu() + + return training_step_output_for_epoch_end, result + + def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): + model_ref = self.trainer.lightning_module + + is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) + using_native_amp = self.trainer.amp_backend == AMPType.NATIVE + + # native amp + lbfgs is a no go right now + if using_native_amp and is_lbfgs: + raise MisconfigurationException( + 'native PyTorch amp and lbfgs are not compatible.' + ' To request, please file a Github issue in PyTorch and tag @mcarilli' + ) + + # wraps into LightningOptimizer only for running step + optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx) + + # model hook + model_ref.optimizer_step( + self.trainer.current_epoch, + batch_idx, + optimizer, + opt_idx, + train_step_and_backward_closure, + on_tpu=(self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE), + using_native_amp=using_native_amp, + using_lbfgs=is_lbfgs, + ) + + def on_before_zero_grad(self, optimizer): + self.trainer.call_hook('on_before_zero_grad', optimizer) + + def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): + self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) + + def track_and_norm_grad(self, optimizer): + # track gradient norms + grad_norm_dic = self._track_gradient_norm() + + # clip gradients + self.trainer.accelerator.clip_gradients( + optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm + ) + self._cur_grad_norm_dict = grad_norm_dic + + def _track_gradient_norm(self): + grad_norm_dict = {} + if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0: + if float(self.trainer.track_grad_norm) > 0: + model = self.trainer.lightning_module + grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm) + return grad_norm_dict + + def _accumulated_batches_reached(self): + return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 + + def _num_training_batches_reached(self, is_last_batch=False): + return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch + + def should_accumulate(self): + # checks if backward or backward + optimizer step (via closure) + accumulation_done = self._accumulated_batches_reached() + is_final_batch = self._num_training_batches_reached() + return not (accumulation_done or is_final_batch) + + def build_train_args(self, batch, batch_idx, opt_idx, hiddens): + # enable not needing to add opt_idx to training_step + args = [batch, batch_idx] + + if len(self.trainer.optimizers) > 1: + if self.trainer.has_arg("training_step", "optimizer_idx"): + if not self.automatic_optimization: + self.warning_cache.warn( + "`training_step` hook signature has changed in v1.3." + " `optimizer_idx` argument has been removed in case of manual optimization. Support for" + " the old signature will be removed in v1.5", DeprecationWarning + ) + args.append(opt_idx) + elif not self.trainer.has_arg("training_step", "optimizer_idx") and self.automatic_optimization: + raise ValueError( + f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but" + ' `training_step` is missing the `optimizer_idx` argument.' + ) + + # pass hiddens if using tbptt + if self.trainer.truncated_bptt_steps is not None: + args.append(hiddens) + + return args + + def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): + # set split_idx to trainer for tracking + self.trainer.split_idx = split_idx + + # make sure only the gradients of the current optimizer's parameters are calculated + # in the training step to prevent dangling gradients in multiple-optimizer setup. + if self.automatic_optimization and len(self.trainer.optimizers) > 1: + model = self.trainer.lightning_module + model.toggle_optimizer(optimizer, opt_idx) + + # use to track metrics internally + self.trainer.logger_connector.on_train_split_start(split_idx, opt_idx, split_batch) + + @contextmanager + def block_ddp_sync_behaviour(self, should_block_sync: bool = False): + """ + automatic_optimization = True + Blocks ddp sync gradients behaviour on backwards pass. + This is useful for skipping sync when accumulating gradients, reducing communication overhead + + automatic_optimization = False + do not block ddp gradient sync when using manual optimization + as gradients are needed within the training step + + Returns: + context manager with sync behaviour off + + """ + if ( + isinstance(self.trainer.training_type_plugin, ParallelPlugin) + and (self.automatic_optimization or should_block_sync) + ): + with self.trainer.training_type_plugin.block_backward_sync(): + yield None + else: + yield None + + def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list: + opt_closure_result = self._curr_step_result + + if opt_closure_result is not None: + + # cache metrics + self.trainer.logger_connector.cache_training_step_metrics(opt_closure_result) + + # check if loss or model weights are nan + if self.trainer.terminate_on_nan: + self._check_finite(opt_closure_result.loss) + + # track all the outputs across all steps + batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0 + batch_outputs[batch_opt_idx].append(opt_closure_result.training_step_output_for_epoch_end) + + if self.automatic_optimization: + # track total loss for logging (avoid mem leaks) + # self.accumulated_loss.append(opt_closure_result.loss) + pass + + self._curr_step_result = None + + return batch_outputs + + def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): + """Wrap forward, zero_grad and backward in a closure so second order methods work""" + with self.trainer.profiler.profile("training_step_and_backward"): + # lightning module hook + result = self.training_step(split_batch, batch_idx, opt_idx, hiddens) + self._curr_step_result = result + + if not self._skip_backward and self.automatic_optimization: + is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0 + + if is_first_batch_to_accumulate: + self.on_before_zero_grad(optimizer) + self.optimizer_zero_grad(batch_idx, optimizer, opt_idx) + + # backward pass + if result is not None: + with self.trainer.profiler.profile("backward"): + self.backward(result, optimizer, opt_idx) + + # hook - call this hook only + # when gradients have finished to accumulate + if not self.should_accumulate(): + self.on_after_backward(result.training_step_output, batch_idx, result.loss) + + # check if loss or model weights are nan + if self.trainer.terminate_on_nan: + self._check_finite(result.loss) + + else: + self.warning_cache.warn("training_step returned None if it was on purpose, ignore this warning...") + + if len(self.trainer.optimizers) > 1: + # revert back to previous state + self.trainer.lightning_module.untoggle_optimizer(opt_idx) + + return result + + def _check_finite(self, loss: torch.Tensor) -> None: + if not torch.isfinite(loss).all(): + raise ValueError(f'The loss returned in `training_step` is {loss}.') + model = self.trainer.lightning_module + detect_nan_parameters(model) + + def backward(self, result, optimizer, opt_idx, *args, **kwargs): + self.trainer.dev_debugger.track_event("backward_call") + + should_accumulate = self.should_accumulate() + + # backward can be called manually in the training loop + if isinstance(result, torch.Tensor): + self.trainer.accelerator.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs) + else: + result.closure_loss = self.trainer.accelerator.backward( + result.closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs + ) + + if not self.should_accumulate(): + # track gradients + self.track_and_norm_grad(optimizer=optimizer) + + def update_running_loss(self): + accumulated_loss = self.accumulated_loss.mean() + + if accumulated_loss is not None: + # calculate running loss for display + self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) + + # reset for next set of accumulated grads + self.accumulated_loss.reset() diff --git a/pytorch_lightning/loops/optimizer_loop.py b/pytorch_lightning/loops/optimizer_loop.py deleted file mode 100644 index 4d3bc8c270bec..0000000000000 --- a/pytorch_lightning/loops/optimizer_loop.py +++ /dev/null @@ -1,459 +0,0 @@ -from contextlib import contextmanager -from copy import copy - -import numpy as np -import torch - -from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.core.step_result import Result -from pytorch_lightning.loops.base import Loop -from pytorch_lightning.plugins import ParallelPlugin -from pytorch_lightning.trainer.supporters import TensorRunningAccum, prefetch_iterator -from pytorch_lightning.utilities import AttributeDict, DeviceType, AMPType -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.finite_checks import detect_nan_parameters -from pytorch_lightning.utilities.imports import _TPU_AVAILABLE -from pytorch_lightning.utilities.warnings import WarningCache - - -class OptimizerLoop(Loop): - """ Runs over a single split of a batch of data (TBPTT). """ - - def __init__(self): - super().__init__() - self._optimizers = self.prepare_optimizers() - self._current_optmizer_idx = 0 - - self.accumulated_loss = None - self.warning_cache = WarningCache() - # self._teardown_already_run = False - self.running_loss = TensorRunningAccum(window_length=20) - self.automatic_optimization = True - self._curr_step_result = None - self._cur_grad_norm_dict = None - # self._multiple_trainloader_mode = multiple_trainloader_mode - self._skip_backward = False - # self.trainer._multiple_trainloader_mode = multiple_trainloader_mode - - def connect(self, trainer, *args, **kwargs): - self.trainer = trainer - - @property - def done(self): - return self._current_optmizer_idx >= len(self._optimizers) - - def next_optimizer(self): - next(self._optimizers) - - def advance(self, split_batch, split_idx, batch_idx): - opt_idx = self._current_optmizer_idx - optimizer = self._optimizers[opt_idx] - - # toggle model params + set info to logger_connector - self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) - - if self.should_accumulate(): - # For gradient accumulation - - # ------------------- - # calculate loss (train step + train step end) - # ------------------- - - # automatic_optimization=True: perform dpp sync only when performing optimizer_step - # automatic_optimization=False: don't block synchronization here - with self.block_ddp_sync_behaviour(): - self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) - - batch_outputs = self._process_closure_result( - batch_outputs=batch_outputs, - opt_idx=opt_idx, - ) - - # ------------------------------ - # BACKWARD PASS - # ------------------------------ - # gradient update with accumulated gradients - - else: - if self.automatic_optimization: - - def train_step_and_backward_closure(): - result = self.training_step_and_backward( - split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens - ) - return None if result is None else result.loss - - # optimizer step - self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) - - else: - self._curr_step_result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens) - - if self._curr_step_result is None: - # user decided to skip optimization - # make sure to zero grad. - # TODO add logic to skip in the outer loop - return - # continue - - batch_outputs = self._process_closure_result( - batch_outputs=batch_outputs, - opt_idx=opt_idx, - ) - - # todo: Properly aggregate grad_norm accros opt_idx and split_idx - grad_norm_dic = self._cur_grad_norm_dict - self._cur_grad_norm_dict = None - - # update running loss + reset accumulated loss - self.update_running_loss() - - return batch_outputs - -# ------------------------------------------------------------------------------------------------------------ -# HELPER --- TO BE CLEANED UP -# ------------------------------------------------------------------------------------------------------------ - - def prepare_optimizers(self): - # in manual optimization we loop over all optimizers at once - optimizers = self.get_optimizers_iterable() - if not self.automatic_optimization: - optimizers = [optimizers[0]] - return optimizers - - def get_optimizers_iterable(self): - """ - Generates an iterable with (idx, optimizer) for each optimizer. - """ - if not self.trainer.optimizer_frequencies: - # call training_step once per optimizer - return list(enumerate(self.trainer.optimizers)) - - optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) - optimizers_loop_length = optimizer_freq_cumsum[-1] - current_place_in_loop = self.trainer.total_batch_idx % optimizers_loop_length - - # find optimzier index by looking for the first {item > current_place} in the cumsum list - opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) - return [[opt_idx, self.trainer.optimizers[opt_idx]]] - - def on_after_backward(self, training_step_output, batch_idx, untouched_loss): - training_step_output.detach() - - # insert after step hook - self.trainer.call_hook("on_after_backward") - - # when in dev debugging track the losses - self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach()) - - def _check_training_step_output(self, training_step_output): - if isinstance(training_step_output, torch.Tensor) and not self.automatic_optimization: - if training_step_output.grad_fn is None: - # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... - raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor") - - def training_step(self, split_batch, batch_idx, opt_idx, hiddens): - # give the PL module a result for logging - model_ref = self.trainer.lightning_module - - with self.trainer.profiler.profile("model_forward"): - args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) - - # manually capture logged metrics - model_ref._current_fx_name = 'training_step' - model_ref._results = Result() - with self.trainer.profiler.profile("training_step"): - training_step_output = self.trainer.accelerator.training_step(args) - self.trainer.accelerator.post_training_step() - - self.trainer.logger_connector.cache_logged_metrics() - - self._check_training_step_output(training_step_output) - - training_step_output = self.trainer.call_hook("training_step_end", training_step_output) - - training_step_output_for_epoch_end, training_step_output = self._process_training_step_output( - training_step_output, split_batch - ) - if training_step_output_for_epoch_end is None: - return - - # enable empty loss when using manual opt - closure_loss = None - untouched_loss = None - - if self.automatic_optimization: - # accumulate loss. if accumulate_grad_batches==1, no effect - closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches - - # the loss will get scaled for amp. avoid any modifications to it - untouched_loss = closure_loss.detach().clone() - - # result - result = AttributeDict( - closure_loss=closure_loss, - loss=untouched_loss, - training_step_output=training_step_output, - training_step_output_for_epoch_end=training_step_output_for_epoch_end, - ) - return result - - def _process_training_step_output(self, training_step_output, split_batch): - training_step_output_for_epoch_end = training_step_output - - # enable validation_step return None - if training_step_output_for_epoch_end is None: - return None, None - - result = self.trainer.lightning_module._results - - loss = None - hiddens = None - result["extra"] = {} - - # handle dict return - if isinstance(training_step_output, dict): - loss = training_step_output.pop("loss", None) - hiddens = training_step_output.pop("hiddens", None) - if hiddens is not None: - hiddens = hiddens.detach() - result["extra"] = training_step_output - - # handle scalar return - elif isinstance(training_step_output, torch.Tensor): - loss = training_step_output - - # map to results under the hood - result.minimize = loss - self.trainer.hiddens = hiddens - - # track batch for manual reduction with result - result.track_batch_size(len(split_batch)) - - # track metrics without grads for epoch reduction - training_step_output_for_epoch_end = copy(result) - training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() - if self.trainer.move_metrics_to_cpu: - training_step_output_for_epoch_end = training_step_output_for_epoch_end.cpu() - - return training_step_output_for_epoch_end, result - - def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): - model_ref = self.trainer.lightning_module - - is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) - using_native_amp = self.trainer.amp_backend == AMPType.NATIVE - - # native amp + lbfgs is a no go right now - if using_native_amp and is_lbfgs: - raise MisconfigurationException( - 'native PyTorch amp and lbfgs are not compatible.' - ' To request, please file a Github issue in PyTorch and tag @mcarilli' - ) - - # wraps into LightningOptimizer only for running step - optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx) - - # model hook - model_ref.optimizer_step( - self.trainer.current_epoch, - batch_idx, - optimizer, - opt_idx, - train_step_and_backward_closure, - on_tpu=(self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE), - using_native_amp=using_native_amp, - using_lbfgs=is_lbfgs, - ) - - def on_before_zero_grad(self, optimizer): - self.trainer.call_hook('on_before_zero_grad', optimizer) - - def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): - self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) - - def track_and_norm_grad(self, optimizer): - # track gradient norms - grad_norm_dic = self._track_gradient_norm() - - # clip gradients - self.trainer.accelerator.clip_gradients( - optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm - ) - self._cur_grad_norm_dict = grad_norm_dic - - def _track_gradient_norm(self): - grad_norm_dict = {} - if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0: - if float(self.trainer.track_grad_norm) > 0: - model = self.trainer.lightning_module - grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm) - return grad_norm_dict - - def _accumulated_batches_reached(self): - return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 - - def _num_training_batches_reached(self, is_last_batch=False): - return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch - - def should_accumulate(self): - # checks if backward or backward + optimizer step (via closure) - accumulation_done = self._accumulated_batches_reached() - is_final_batch = self._num_training_batches_reached() - return not (accumulation_done or is_final_batch) - - def build_train_args(self, batch, batch_idx, opt_idx, hiddens): - # enable not needing to add opt_idx to training_step - args = [batch, batch_idx] - - if len(self.trainer.optimizers) > 1: - if self.trainer.has_arg("training_step", "optimizer_idx"): - if not self.automatic_optimization: - self.warning_cache.warn( - "`training_step` hook signature has changed in v1.3." - " `optimizer_idx` argument has been removed in case of manual optimization. Support for" - " the old signature will be removed in v1.5", DeprecationWarning - ) - args.append(opt_idx) - elif not self.trainer.has_arg("training_step", "optimizer_idx") and self.automatic_optimization: - raise ValueError( - f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but" - ' `training_step` is missing the `optimizer_idx` argument.' - ) - - # pass hiddens if using tbptt - if self.trainer.truncated_bptt_steps is not None: - args.append(hiddens) - - return args - - def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): - # set split_idx to trainer for tracking - self.trainer.split_idx = split_idx - - # make sure only the gradients of the current optimizer's parameters are calculated - # in the training step to prevent dangling gradients in multiple-optimizer setup. - if self.automatic_optimization and len(self.trainer.optimizers) > 1: - model = self.trainer.lightning_module - model.toggle_optimizer(optimizer, opt_idx) - - # use to track metrics internally - self.trainer.logger_connector.on_train_split_start(split_idx, opt_idx, split_batch) - - @contextmanager - def block_ddp_sync_behaviour(self, should_block_sync: bool = False): - """ - automatic_optimization = True - Blocks ddp sync gradients behaviour on backwards pass. - This is useful for skipping sync when accumulating gradients, reducing communication overhead - - automatic_optimization = False - do not block ddp gradient sync when using manual optimization - as gradients are needed within the training step - - Returns: - context manager with sync behaviour off - - """ - if ( - isinstance(self.trainer.training_type_plugin, ParallelPlugin) - and (self.automatic_optimization or should_block_sync) - ): - with self.trainer.training_type_plugin.block_backward_sync(): - yield None - else: - yield None - - def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list: - opt_closure_result = self._curr_step_result - - if opt_closure_result is not None: - - # cache metrics - self.trainer.logger_connector.cache_training_step_metrics(opt_closure_result) - - # check if loss or model weights are nan - if self.trainer.terminate_on_nan: - self._check_finite(opt_closure_result.loss) - - # track all the outputs across all steps - batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0 - batch_outputs[batch_opt_idx].append(opt_closure_result.training_step_output_for_epoch_end) - - if self.automatic_optimization: - # track total loss for logging (avoid mem leaks) - self.accumulated_loss.append(opt_closure_result.loss) - - self._curr_step_result = None - - return batch_outputs - - def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): - """Wrap forward, zero_grad and backward in a closure so second order methods work""" - with self.trainer.profiler.profile("training_step_and_backward"): - # lightning module hook - result = self.training_step(split_batch, batch_idx, opt_idx, hiddens) - self._curr_step_result = result - - if not self._skip_backward and self.automatic_optimization: - is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0 - - if is_first_batch_to_accumulate: - self.on_before_zero_grad(optimizer) - self.optimizer_zero_grad(batch_idx, optimizer, opt_idx) - - # backward pass - if result is not None: - with self.trainer.profiler.profile("backward"): - self.backward(result, optimizer, opt_idx) - - # hook - call this hook only - # when gradients have finished to accumulate - if not self.should_accumulate(): - self.on_after_backward(result.training_step_output, batch_idx, result.loss) - - # check if loss or model weights are nan - if self.trainer.terminate_on_nan: - self._check_finite(result.loss) - - else: - self.warning_cache.warn("training_step returned None if it was on purpose, ignore this warning...") - - if len(self.trainer.optimizers) > 1: - # revert back to previous state - self.trainer.lightning_module.untoggle_optimizer(opt_idx) - - return result - - def _check_finite(self, loss: torch.Tensor) -> None: - if not torch.isfinite(loss).all(): - raise ValueError(f'The loss returned in `training_step` is {loss}.') - model = self.trainer.lightning_module - detect_nan_parameters(model) - - def backward(self, result, optimizer, opt_idx, *args, **kwargs): - self.trainer.dev_debugger.track_event("backward_call") - - should_accumulate = self.should_accumulate() - - # backward can be called manually in the training loop - if isinstance(result, torch.Tensor): - self.trainer.accelerator.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs) - else: - result.closure_loss = self.trainer.accelerator.backward( - result.closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs - ) - - if not self.should_accumulate(): - # track gradients - self.track_and_norm_grad(optimizer=optimizer) - - def update_running_loss(self): - accumulated_loss = self.accumulated_loss.mean() - - if accumulated_loss is not None: - # calculate running loss for display - self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) - - # reset for next set of accumulated grads - self.accumulated_loss.reset() diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 8e4109c3b9bf4..24e57f0156f67 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -46,9 +46,9 @@ def advance(self): batch_output = self.batch_loop.run(batch, batch_idx, self._dataloader_idx) # when returning -1 from train_step, we end epoch early - if batch_output.signal == -1: - self._skip_remaining_steps = True - return + # if batch_output.signal == -1: + # self._skip_remaining_steps = True + # return # hook epoch_output = [[]] # TODO: track and return output, let loop base concatenate all outputs into a list etc. From f41f3733d65f91ba745a42b604d43da1363c98bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 10 May 2021 13:55:31 +0200 Subject: [PATCH 018/455] remove accumulation, manual optimization --- pl_examples/bug_report_model.py | 2 +- pytorch_lightning/loops/batch_loop.py | 73 ++++++++---------------- pytorch_lightning/loops/epoch_loop.py | 15 ++++- pytorch_lightning/loops/training_loop.py | 9 ++- 4 files changed, 45 insertions(+), 54 deletions(-) diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index abb65ba86fd93..e7e9167cff62e 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -54,7 +54,7 @@ def run(): trainer = Trainer( default_root_dir=os.getcwd(), limit_train_batches=1, - limit_val_batches=1, + limit_val_batches=0, num_sanity_val_steps=0, max_epochs=1, weights_summary=None, diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 82c5e60ee95da..2bb55614ea656 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -1,5 +1,6 @@ from contextlib import contextmanager from copy import copy +from typing import List import numpy as np import torch @@ -56,61 +57,32 @@ def advance(self, batch, batch_idx, dataloader_idx): # toggle model params + set info to logger_connector self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) - if self.should_accumulate(): - # For gradient accumulation - - # ------------------- - # calculate loss (train step + train step end) - # ------------------- - - # automatic_optimization=True: perform dpp sync only when performing optimizer_step - # automatic_optimization=False: don't block synchronization here - with self.block_ddp_sync_behaviour(): - self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) - - batch_outputs = self._process_closure_result( - batch_outputs=batch_outputs, - opt_idx=opt_idx, + def train_step_and_backward_closure(): + result = self.training_step_and_backward( + split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens ) + return None if result is None else result.loss - # ------------------------------ - # BACKWARD PASS - # ------------------------------ - # gradient update with accumulated gradients - - else: - if self.automatic_optimization: + # optimizer step + self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) - def train_step_and_backward_closure(): - result = self.training_step_and_backward( - split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens - ) - return None if result is None else result.loss + if self._curr_step_result is None: + # user decided to skip optimization + # make sure to zero grad. + continue - # optimizer step - self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) + batch_outputs = self._process_closure_result( + batch_outputs=batch_outputs, + opt_idx=opt_idx, + ) - else: - self._curr_step_result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens) - - if self._curr_step_result is None: - # user decided to skip optimization - # make sure to zero grad. - # TODO add logic to skip in the outer loop - return - # continue - - batch_outputs = self._process_closure_result( - batch_outputs=batch_outputs, - opt_idx=opt_idx, - ) + # todo: Properly aggregate grad_norm accros opt_idx and split_idx + grad_norm_dic = self._cur_grad_norm_dict + self._cur_grad_norm_dict = None - # todo: Properly aggregate grad_norm accros opt_idx and split_idx - grad_norm_dic = self._cur_grad_norm_dict - self._cur_grad_norm_dict = None + # update running loss + reset accumulated loss + # self.update_running_loss() - # update running loss + reset accumulated loss - # self.update_running_loss() return batch_outputs def run(self, batch, batch_idx, dataloader_idx): @@ -132,10 +104,13 @@ def run(self, batch, batch_idx, dataloader_idx): result = AttributeDict( signal=0, grad_norm_dic=self._cur_grad_norm_dict, - training_step_output_for_epoch_end=batch_outputs, + training_step_output_for_epoch_end=batch_outputs[0], ) return result + def on_run_end(self, outputs: List) -> List: + return outputs + def tbptt_split_batch(self, batch): splits = [batch] if self.trainer.truncated_bptt_steps is not None: diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 30e6ebd49c1dc..d9bbd1309f326 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -3,6 +3,8 @@ from copy import deepcopy from typing import Any, List, Optional +import torch + import pytorch_lightning as pl from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.core.step_result import Result @@ -19,6 +21,11 @@ class EpochLoop(Loop): + def __init__(self): + super().__init__() + self.running_loss = torch.tensor(0.0) # dummy TODO: + self._teardown_already_run = False + def connect(self, trainer: 'pl.Trainer', *args, **kwargs): self.num_epochs = trainer.max_epochs self.min_epochs = trainer.min_epochs @@ -29,6 +36,10 @@ def connect(self, trainer: 'pl.Trainer', *args, **kwargs): self.training_loop = TrainingLoop() self.training_loop.connect(trainer) + def should_accumulate(self): + # TODO + return False + @property def done(self) -> bool: # TODO: Move track steps inside training loop and move part of these condition inside training loop @@ -57,7 +68,7 @@ def on_run_start(self): # hook self.trainer.call_hook("on_train_start") - def on_run_end(self): + def on_run_end(self, outputs): if self._teardown_already_run: return self._teardown_already_run = True @@ -87,6 +98,8 @@ def on_run_end(self): # reset bookkeeping self.trainer._running_stage = None + return outputs + def on_advance_start(self): # equal to on train epoch start # implemented here since this code has to be run always no matter the actual epoch implementation epoch = self.iteration_count + 1 diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 24e57f0156f67..775b30c4c8965 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -59,6 +59,7 @@ def advance(self): batch_idx, self._dataloader_idx, ) + return epoch_output def on_advance_end(self, output): # ----------------------------------------- @@ -87,6 +88,7 @@ def on_advance_end(self, output): # progress global step according to grads progress self.increment_accumulated_grad_global_step() + return output @property def done(self): @@ -138,8 +140,9 @@ def on_run_end(self, outputs): self.trainer.logger_connector.cache_logged_metrics() # call train epoch end hooks - self.trainer.call_hook('on_train_epoch_end', processed_outputs) + # self.trainer.call_hook('on_train_epoch_end', processed_outputs) self.trainer.call_hook('on_epoch_end') + return processed_outputs # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP @@ -240,7 +243,7 @@ def _prepare_outputs( return processed_outputs def update_train_loop_lr_schedulers(self, monitor_metrics=None): - num_accumulated_batches_reached = self._accumulated_batches_reached() + num_accumulated_batches_reached = self.batch_loop._accumulated_batches_reached() num_training_batches_reached = self._num_training_batches_reached() if num_accumulated_batches_reached or num_training_batches_reached: @@ -248,7 +251,7 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None): self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics) def increment_accumulated_grad_global_step(self): - num_accumulated_batches_reached = self._accumulated_batches_reached() + num_accumulated_batches_reached = self.batch_loop._accumulated_batches_reached() num_training_batches_reached = self._num_training_batches_reached() # progress global step according to grads progress From 00d3508286c9e99c1d7680256b4cb7fa25b4d499 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 10 May 2021 18:04:06 +0200 Subject: [PATCH 019/455] run multiple batches --- pl_examples/bug_report_model.py | 6 +-- pytorch_lightning/loops/batch_loop.py | 6 ++- pytorch_lightning/loops/training_loop.py | 59 ++++++++++++++++++++-- pytorch_lightning/trainer/training_loop.py | 6 +-- 4 files changed, 64 insertions(+), 13 deletions(-) diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index e7e9167cff62e..57011841a3731 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -53,9 +53,9 @@ def run(): model = BoringModel() trainer = Trainer( default_root_dir=os.getcwd(), - limit_train_batches=1, - limit_val_batches=0, - num_sanity_val_steps=0, + limit_train_batches=2, + limit_val_batches=2, + num_sanity_val_steps=2, max_epochs=1, weights_summary=None, ) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 2bb55614ea656..000d094985a2f 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -74,7 +74,7 @@ def train_step_and_backward_closure(): batch_outputs = self._process_closure_result( batch_outputs=batch_outputs, opt_idx=opt_idx, - ) + ) # 1 optimizer case: batch_outputs[0][0] = Result object # todo: Properly aggregate grad_norm accros opt_idx and split_idx grad_norm_dic = self._cur_grad_norm_dict @@ -101,10 +101,12 @@ def run(self, batch, batch_idx, dataloader_idx): batch_outputs = super().run(batch, batch_idx, dataloader_idx) + batch_outputs = batch_outputs[0] # TODO: hack for poc + result = AttributeDict( signal=0, grad_norm_dic=self._cur_grad_norm_dict, - training_step_output_for_epoch_end=batch_outputs[0], + training_step_output_for_epoch_end=batch_outputs, ) return result diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 775b30c4c8965..afcd56e4715e1 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -46,9 +46,9 @@ def advance(self): batch_output = self.batch_loop.run(batch, batch_idx, self._dataloader_idx) # when returning -1 from train_step, we end epoch early - # if batch_output.signal == -1: - # self._skip_remaining_steps = True - # return + if batch_output.signal == -1: + self._skip_remaining_steps = True + return # hook epoch_output = [[]] # TODO: track and return output, let loop base concatenate all outputs into a list etc. @@ -113,6 +113,9 @@ def done(self): # this is the old on train_epoch_end? def on_run_end(self, outputs): + # hack for poc + outputs = outputs[0] + # inform logger the batch loop has finished self.trainer.logger_connector.on_train_epoch_end() @@ -140,7 +143,7 @@ def on_run_end(self, outputs): self.trainer.logger_connector.cache_logged_metrics() # call train epoch end hooks - # self.trainer.call_hook('on_train_epoch_end', processed_outputs) + self._on_train_epoch_end_hook(processed_outputs) self.trainer.call_hook('on_epoch_end') return processed_outputs @@ -148,14 +151,60 @@ def on_run_end(self, outputs): # HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ + def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: + # We cannot rely on Trainer.call_hook because the signatures might be different across + # lightning module and callback + # As a result, we need to inspect if the module accepts `outputs` in `on_train_epoch_end` + + # This implementation is copied from Trainer.call_hook + hook_name = "on_train_epoch_end" + + # set hook_name to model + reset Result obj + skip = self.trainer._reset_result_and_set_hook_fx_name(hook_name) + + # always profile hooks + with self.trainer.profiler.profile(hook_name): + + # first call trainer hook + if hasattr(self.trainer, hook_name): + trainer_hook = getattr(self.trainer, hook_name) + trainer_hook(processed_epoch_output) + + # next call hook in lightningModule + model_ref = self.trainer.lightning_module + if is_overridden(hook_name, model_ref): + hook_fx = getattr(model_ref, hook_name) + if is_param_in_hook_signature(hook_fx, "outputs"): + self.warning_cache.warn( + "The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3." + " `outputs` parameter has been deprecated." + " Support for the old signature will be removed in v1.5", DeprecationWarning + ) + model_ref.on_train_epoch_end(processed_epoch_output) + else: + model_ref.on_train_epoch_end() + + # if the PL module doesn't have the hook then call the accelerator + # used to auto-reduce things for the user with Results obj + elif hasattr(self.trainer.accelerator, hook_name): + accelerator_hook = getattr(self.trainer.accelerator, hook_name) + accelerator_hook() + + if not skip: + self.trainer._cache_logged_metrics() + def _num_training_batches_reached(self, is_last_batch=False): return self.iteration_count == self.trainer.num_training_batches or is_last_batch # TODO move to on_advance_end() def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx): + + # epoch output : [[] ... ] + # batch_end_outputs[0][0] = Result obj + batch_end_outputs = [opt_idx_out for opt_idx_out in batch_end_outputs if len(opt_idx_out)] - processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True) + processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True) # dict with loss # hook self.trainer.call_hook('on_train_batch_end', processed_batch_end_outputs, batch, batch_idx, dataloader_idx) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 790dc4c70bdeb..b6245a9675fd4 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -307,8 +307,8 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): result = AttributeDict( closure_loss=closure_loss, loss=untouched_loss, - training_step_output=training_step_output, - training_step_output_for_epoch_end=training_step_output_for_epoch_end, + training_step_output=training_step_output, # Result object + training_step_output_for_epoch_end=training_step_output_for_epoch_end, # Result object ) return result @@ -784,7 +784,7 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False): yield None def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list: - opt_closure_result = self._curr_step_result + opt_closure_result = self._curr_step_result # AttributeDict if opt_closure_result is not None: From 821615375559629d04c42eddad427da3457f7f4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 10 May 2021 20:02:20 +0200 Subject: [PATCH 020/455] fix multi-epoch training --- pl_examples/bug_report_model.py | 2 +- pytorch_lightning/loops/batch_loop.py | 11 ++++++----- pytorch_lightning/loops/training_loop.py | 6 ++++-- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index 57011841a3731..1e6bd099af1f9 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -56,7 +56,7 @@ def run(): limit_train_batches=2, limit_val_batches=2, num_sanity_val_steps=2, - max_epochs=1, + max_epochs=2, weights_summary=None, ) trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 000d094985a2f..0bd3b484e6d73 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -27,11 +27,6 @@ def __init__(self): # self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20) self.automatic_optimization = True - self._curr_step_result = None - self._cur_grad_norm_dict = None - # self._multiple_trainloader_mode = multiple_trainloader_mode - self._skip_backward = False - # self.trainer._multiple_trainloader_mode = multiple_trainloader_mode def connect(self, trainer, *args, **kwargs): self.trainer = trainer @@ -48,6 +43,12 @@ def on_run_start(self, batch, batch_idx, dataloader_idx): # lightning module hook self._remaining_splits = list(enumerate(self.tbptt_split_batch(batch))) + self._curr_step_result = None + self._cur_grad_norm_dict = None + # self._multiple_trainloader_mode = multiple_trainloader_mode + self._skip_backward = False + # self.trainer._multiple_trainloader_mode = multiple_trainloader_mode + def advance(self, batch, batch_idx, dataloader_idx): split_idx, split_batch = self._remaining_splits.pop(0) diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index afcd56e4715e1..ea49c777f1021 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -14,7 +14,6 @@ class TrainingLoop(Loop): def __init__(self): super().__init__() - self.is_last_batch = False # cache of all outputs in a single training run / epoch # self.epoch_output = [[]] @@ -30,6 +29,8 @@ def on_run_start(self): self._train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) self._dataloader_idx = 0 + self.trainer.batch_idx = 0 + self.is_last_batch = False def advance(self): # TODO: profiling is gone @@ -113,6 +114,7 @@ def done(self): # this is the old on train_epoch_end? def on_run_end(self, outputs): + # hack for poc outputs = outputs[0] @@ -194,7 +196,7 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: self.trainer._cache_logged_metrics() def _num_training_batches_reached(self, is_last_batch=False): - return self.iteration_count == self.trainer.num_training_batches or is_last_batch + return self.trainer.batch_idx == self.trainer.num_training_batches or is_last_batch # TODO move to on_advance_end() def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx): From c64aedbcb68681269632c796864c05a7cb4285f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 12 May 2021 17:28:04 +0200 Subject: [PATCH 021/455] refactor results --- pytorch_lightning/trainer/training_loop.py | 30 ++++++++++------------ 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 43ed2c7ffa964..a4e6c5d1d2a6f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -51,8 +51,6 @@ def __init__( self.warning_cache = WarningCache() self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20) - self._curr_step_result = None - self._cur_grad_norm_dict = None self._multiple_trainloader_mode = multiple_trainloader_mode self._skip_backward = False self.trainer._multiple_trainloader_mode = multiple_trainloader_mode @@ -437,15 +435,15 @@ def on_before_zero_grad(self, optimizer): def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) - def track_and_norm_grad(self, optimizer): + def track_and_norm_grad(self, optimizer) -> dict: # track gradient norms - grad_norm_dic = self._track_gradient_norm() + grad_norm_dict = self._track_gradient_norm() # clip gradients self.trainer.accelerator.clip_gradients( optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm ) - self._cur_grad_norm_dict = grad_norm_dic + return grad_norm_dict def _track_gradient_norm(self): grad_norm_dict = {} @@ -693,6 +691,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): # toggle model params + set info to logger_connector self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) + result = AttributeDict() if self.should_accumulate(): # For gradient accumulation @@ -703,11 +702,12 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): # automatic_optimization=True: perform dpp sync only when performing optimizer_step # automatic_optimization=False: don't block synchronization here with self.block_ddp_sync_behaviour(): - self.training_step_and_backward( + result = self.training_step_and_backward( split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens ) batch_outputs = self._process_closure_result( + opt_closure_result=result, batch_outputs=batch_outputs, opt_idx=opt_idx, ) @@ -721,6 +721,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): if self.trainer.lightning_module.automatic_optimization: def train_step_and_backward_closure(): + nonlocal result result = self.training_step_and_backward( split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens ) @@ -730,23 +731,23 @@ def train_step_and_backward_closure(): self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) else: - self._curr_step_result = self.training_step( + result = self.training_step( split_batch, batch_idx, opt_idx, self.trainer.hiddens ) - if self._curr_step_result is None: + if result is None: # user decided to skip optimization # make sure to zero grad. continue batch_outputs = self._process_closure_result( + opt_closure_result=result, batch_outputs=batch_outputs, opt_idx=opt_idx, ) # todo: Properly aggregate grad_norm accros opt_idx and split_idx - grad_norm_dic = self._cur_grad_norm_dict - self._cur_grad_norm_dict = None + grad_norm_dic = result.get("grad_norm_dict", {}) # update running loss + reset accumulated loss self.update_running_loss() @@ -782,9 +783,7 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False): else: yield None - def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list: - opt_closure_result = self._curr_step_result - + def _process_closure_result(self, opt_closure_result: Optional[AttributeDict], batch_outputs: list, opt_idx: int) -> list: if opt_closure_result is not None: # cache metrics @@ -802,8 +801,6 @@ def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list: # track total loss for logging (avoid mem leaks) self.accumulated_loss.append(opt_closure_result.loss) - self._curr_step_result = None - return batch_outputs def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): @@ -811,7 +808,6 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, with self.trainer.profiler.profile("training_step_and_backward"): # lightning module hook result = self.training_step(split_batch, batch_idx, opt_idx, hiddens) - self._curr_step_result = result if not self._skip_backward and self.trainer.lightning_module.automatic_optimization: is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0 @@ -866,7 +862,7 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs): if not self.should_accumulate(): # track gradients - self.track_and_norm_grad(optimizer=optimizer) + result.grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer) def update_train_loop_lr_schedulers(self, monitor_metrics=None): num_accumulated_batches_reached = self._accumulated_batches_reached() From 0bd991bc267eaea705c7049105735cb0f2689aaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 12 May 2021 17:30:19 +0200 Subject: [PATCH 022/455] rename dic -> dict --- .../logger_connector/logger_connector.py | 16 ++++++++-------- pytorch_lightning/trainer/training_loop.py | 12 ++++++------ tests/trainer/loops/test_evaluation_loop_flow.py | 4 ++-- .../loops/test_training_loop_flow_scalar.py | 4 ++-- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 8c09de075147a..1c8298557662b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -195,14 +195,14 @@ def cache_training_step_metrics(self, opt_closure_result): self._callback_metrics.update(callback_metrics_tmp) self._logged_metrics.update(logged_metrics_tmp) - def log_metrics(self, metrics, grad_norm_dic, step=None): + def log_metrics(self, metrics, grad_norm_dict, step=None): """Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, uses metrics["step"] as a step Args: metrics (dict): Metric values - grad_norm_dic (dict): Gradient norms + grad_norm_dict (dict): Gradient norms step (int): Step for which metrics should be logged. Default value is `self.global_step` during training or the total validation / test log step count during validation and testing. """ @@ -212,7 +212,7 @@ def log_metrics(self, metrics, grad_norm_dic, step=None): metrics.update(mem_map) # add norms - metrics.update(grad_norm_dic) + metrics.update(grad_norm_dict) # turn all tensors to scalars scalar_metrics = metrics_to_scalars(metrics) @@ -368,11 +368,11 @@ def log_train_step_metrics(self, batch_output): # when metrics should be logged if self.should_update_logs or self.trainer.fast_dev_run is True: # logs user requested information to logger - grad_norm_dic = batch_output.grad_norm_dic - if grad_norm_dic is None: - grad_norm_dic = {} - if len(batch_log_metrics) > 0 or len(grad_norm_dic) > 0: - self.log_metrics(batch_log_metrics, grad_norm_dic) + grad_norm_dict = batch_output.grad_norm_dict + if grad_norm_dict is None: + grad_norm_dict = {} + if len(batch_log_metrics) > 0 or len(grad_norm_dict) > 0: + self.log_metrics(batch_log_metrics, grad_norm_dict) self._callback_metrics.update(batch_log_metrics) @property diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a4e6c5d1d2a6f..0a7534ffafa43 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -652,7 +652,7 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: def run_training_batch(self, batch, batch_idx, dataloader_idx): # track grad norms - grad_norm_dic = {} + grad_norm_dict = {} # bookkeeping self.trainer.hiddens = None @@ -666,19 +666,19 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): self.warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") return AttributeDict( signal=0, - grad_norm_dic=grad_norm_dic, + grad_norm_dict=grad_norm_dict, training_step_output_for_epoch_end=batch_outputs, ) # hook response = self.trainer.call_hook("on_batch_start") if response == -1: - return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) + return AttributeDict(signal=-1, grad_norm_dict=grad_norm_dict) # hook response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx) if response == -1: - return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) + return AttributeDict(signal=-1, grad_norm_dict=grad_norm_dict) # lightning module hook splits = self._tbptt_split_batch(batch) @@ -747,14 +747,14 @@ def train_step_and_backward_closure(): ) # todo: Properly aggregate grad_norm accros opt_idx and split_idx - grad_norm_dic = result.get("grad_norm_dict", {}) + grad_norm_dict = result.get("grad_norm_dict", {}) # update running loss + reset accumulated loss self.update_running_loss() result = AttributeDict( signal=0, - grad_norm_dic=grad_norm_dic, + grad_norm_dict=grad_norm_dict, training_step_output_for_epoch_end=batch_outputs, ) return result diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py index 8fdb321b6f230..3177a3aa09156 100644 --- a/tests/trainer/loops/test_evaluation_loop_flow.py +++ b/tests/trainer/loops/test_evaluation_loop_flow.py @@ -71,7 +71,7 @@ def backward(self, loss, optimizer, optimizer_idx): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) + assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) train_step_out = out.training_step_output_for_epoch_end assert len(train_step_out) == 1 @@ -140,7 +140,7 @@ def backward(self, loss, optimizer, optimizer_idx): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) + assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) train_step_out = out.training_step_output_for_epoch_end assert len(train_step_out) == 1 diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index 816134ee52941..f14f7d339d83f 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -155,7 +155,7 @@ def backward(self, loss, optimizer, optimizer_idx): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) + assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) train_step_out = out.training_step_output_for_epoch_end assert len(train_step_out) == 1 @@ -231,7 +231,7 @@ def backward(self, loss, optimizer, optimizer_idx): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) + assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) train_step_out = out.training_step_output_for_epoch_end assert len(train_step_out) == 1 From 5102c2c4a6cad62a0dc1d62e5bd730a91885fed8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 12 May 2021 17:38:07 +0200 Subject: [PATCH 023/455] simplify --- pytorch_lightning/trainer/training_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 0a7534ffafa43..1cdc40032472b 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -666,19 +666,19 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): self.warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") return AttributeDict( signal=0, - grad_norm_dict=grad_norm_dict, + grad_norm_dict={}, training_step_output_for_epoch_end=batch_outputs, ) # hook response = self.trainer.call_hook("on_batch_start") if response == -1: - return AttributeDict(signal=-1, grad_norm_dict=grad_norm_dict) + return AttributeDict(signal=-1, grad_norm_dict={}) # hook response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx) if response == -1: - return AttributeDict(signal=-1, grad_norm_dict=grad_norm_dict) + return AttributeDict(signal=-1, grad_norm_dict={}) # lightning module hook splits = self._tbptt_split_batch(batch) From 0b101173b615485a91146554082b3025533d654e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 May 2021 15:42:54 +0000 Subject: [PATCH 024/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/training_loop.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1cdc40032472b..0ecefcf38e71a 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -731,9 +731,7 @@ def train_step_and_backward_closure(): self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) else: - result = self.training_step( - split_batch, batch_idx, opt_idx, self.trainer.hiddens - ) + result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens) if result is None: # user decided to skip optimization @@ -783,7 +781,9 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False): else: yield None - def _process_closure_result(self, opt_closure_result: Optional[AttributeDict], batch_outputs: list, opt_idx: int) -> list: + def _process_closure_result( + self, opt_closure_result: Optional[AttributeDict], batch_outputs: list, opt_idx: int + ) -> list: if opt_closure_result is not None: # cache metrics From 3ee429cb9ba89ff34157950f8aab5250e59752ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 12 May 2021 17:45:55 +0200 Subject: [PATCH 025/455] changelog --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index adf5113e9eba6..a15320cc29986 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,7 +25,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Refactored Loops * Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025)) - + * Refactor result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506)) + - `DataModule`s now avoid duplicate `{setup,teardown,prepare_data}` calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238)) From 26f89823936832b5d661d1f86e14e81eb7a90066 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 May 2021 15:46:59 +0000 Subject: [PATCH 026/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a15320cc29986..94c2a5b4a451f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,7 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Refactored Loops * Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025)) * Refactor result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506)) - + - `DataModule`s now avoid duplicate `{setup,teardown,prepare_data}` calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238)) From 0b34749a6e5e1771e003ac05ddb51c9891f0bf56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 12 May 2021 21:04:43 +0200 Subject: [PATCH 027/455] fix None check --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 0ecefcf38e71a..8edda9253c51a 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -733,7 +733,7 @@ def train_step_and_backward_closure(): else: result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens) - if result is None: + if not result: # user decided to skip optimization # make sure to zero grad. continue From 0bdca77ca5ae3968c19283ad61e3d7382a8054d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 12 May 2021 21:07:38 +0200 Subject: [PATCH 028/455] chlog wording --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 94c2a5b4a451f..606e6eaec0579 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,7 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Refactored Loops * Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025)) - * Refactor result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506)) + * Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506)) - `DataModule`s now avoid duplicate `{setup,teardown,prepare_data}` calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238)) From b8ff2e06cecbdc144b608d8cb4400d720baf6c1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 13 May 2021 00:17:56 +0200 Subject: [PATCH 029/455] move process_closure_result to the end --- pytorch_lightning/trainer/training_loop.py | 34 +++++++++------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 8edda9253c51a..3f269a1cbc146 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -706,17 +706,10 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens ) - batch_outputs = self._process_closure_result( - opt_closure_result=result, - batch_outputs=batch_outputs, - opt_idx=opt_idx, - ) - # ------------------------------ # BACKWARD PASS # ------------------------------ # gradient update with accumulated gradients - else: if self.trainer.lightning_module.automatic_optimization: @@ -738,17 +731,17 @@ def train_step_and_backward_closure(): # make sure to zero grad. continue - batch_outputs = self._process_closure_result( - opt_closure_result=result, - batch_outputs=batch_outputs, - opt_idx=opt_idx, - ) - # todo: Properly aggregate grad_norm accros opt_idx and split_idx grad_norm_dict = result.get("grad_norm_dict", {}) # update running loss + reset accumulated loss - self.update_running_loss() + self.update_running_loss(result.loss) + + batch_outputs = self._process_closure_result( + opt_closure_result=result, + batch_outputs=batch_outputs, + opt_idx=opt_idx, + ) result = AttributeDict( signal=0, @@ -784,8 +777,7 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False): def _process_closure_result( self, opt_closure_result: Optional[AttributeDict], batch_outputs: list, opt_idx: int ) -> list: - if opt_closure_result is not None: - + if opt_closure_result: # cache metrics self.trainer.logger_connector.cache_training_step_metrics(opt_closure_result) @@ -797,10 +789,6 @@ def _process_closure_result( batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0 batch_outputs[batch_opt_idx].append(opt_closure_result.training_step_output_for_epoch_end) - if self.trainer.lightning_module.automatic_optimization: - # track total loss for logging (avoid mem leaks) - self.accumulated_loss.append(opt_closure_result.loss) - return batch_outputs def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): @@ -991,7 +979,11 @@ def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): # use to track metrics internally self.trainer.logger_connector.on_train_split_start(split_idx, opt_idx, split_batch) - def update_running_loss(self): + def update_running_loss(self, current_loss: torch.Tensor) -> None: + if self.trainer.lightning_module.automatic_optimization: + # track total loss for logging (avoid mem leaks) + self.accumulated_loss.append(current_loss) + accumulated_loss = self.accumulated_loss.mean() if accumulated_loss is not None: From 7e66b89bc843850a6437fab60f84e3ff9bb299a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 13 May 2021 02:34:16 +0200 Subject: [PATCH 030/455] extract method --- pytorch_lightning/trainer/training_loop.py | 100 +++++++++++---------- 1 file changed, 52 insertions(+), 48 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 3f269a1cbc146..a84b9452d6602 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -687,68 +687,72 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): # create an iterable for optimizers and loop over them for opt_idx, optimizer in optimizers: + self.run_batch_split(batch_outputs, batch_idx, split_idx, split_batch, opt_idx, optimizer) - # toggle model params + set info to logger_connector - self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) - result = AttributeDict() - if self.should_accumulate(): - # For gradient accumulation + result = AttributeDict( + signal=0, + grad_norm_dict=grad_norm_dict, + training_step_output_for_epoch_end=batch_outputs, + ) + return result - # ------------------- - # calculate loss (train step + train step end) - # ------------------- + def run_batch_split(self, batch_outputs, batch_idx, split_idx, split_batch, opt_idx, optimizer): + # toggle model params + set info to logger_connector + self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) - # automatic_optimization=True: perform dpp sync only when performing optimizer_step - # automatic_optimization=False: don't block synchronization here - with self.block_ddp_sync_behaviour(): - result = self.training_step_and_backward( - split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens - ) + result = AttributeDict() + if self.should_accumulate(): + # For gradient accumulation - # ------------------------------ - # BACKWARD PASS - # ------------------------------ - # gradient update with accumulated gradients - else: - if self.trainer.lightning_module.automatic_optimization: + # ------------------- + # calculate loss (train step + train step end) + # ------------------- + + # automatic_optimization=True: perform dpp sync only when performing optimizer_step + # automatic_optimization=False: don't block synchronization here + with self.block_ddp_sync_behaviour(): + result = self.training_step_and_backward( + split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens + ) - def train_step_and_backward_closure(): - nonlocal result - result = self.training_step_and_backward( - split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens - ) - return None if result is None else result.loss + # ------------------------------ + # BACKWARD PASS + # ------------------------------ + # gradient update with accumulated gradients + else: + if self.trainer.lightning_module.automatic_optimization: - # optimizer step - self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) + def train_step_and_backward_closure(): + nonlocal result + result = self.training_step_and_backward( + split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens + ) + return None if result is None else result.loss - else: - result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens) + # optimizer step + self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) - if not result: - # user decided to skip optimization - # make sure to zero grad. - continue + else: + result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens) - # todo: Properly aggregate grad_norm accros opt_idx and split_idx - grad_norm_dict = result.get("grad_norm_dict", {}) + if not result: + # user decided to skip optimization + # make sure to zero grad. + return batch_outputs - # update running loss + reset accumulated loss - self.update_running_loss(result.loss) + # todo: Properly aggregate grad_norm accros opt_idx and split_idx + grad_norm_dict = result.get("grad_norm_dict", {}) - batch_outputs = self._process_closure_result( - opt_closure_result=result, - batch_outputs=batch_outputs, - opt_idx=opt_idx, - ) + # update running loss + reset accumulated loss + self.update_running_loss(result.loss) - result = AttributeDict( - signal=0, - grad_norm_dict=grad_norm_dict, - training_step_output_for_epoch_end=batch_outputs, + batch_outputs = self._process_closure_result( + opt_closure_result=result, + batch_outputs=batch_outputs, + opt_idx=opt_idx, ) - return result + return batch_outputs @contextmanager def block_ddp_sync_behaviour(self, should_block_sync: bool = False): From 4d995d5b6891d0b3bd182369e948c89f67a566af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 13 May 2021 02:57:12 +0200 Subject: [PATCH 031/455] split automatic optimization --- pytorch_lightning/trainer/training_loop.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a84b9452d6602..8684883f1ae40 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -685,10 +685,11 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): for split_idx, split_batch in enumerate(splits): - # create an iterable for optimizers and loop over them - for opt_idx, optimizer in optimizers: - self.run_batch_split(batch_outputs, batch_idx, split_idx, split_batch, opt_idx, optimizer) - + if self.trainer.lightning_module.automatic_optimization: + for opt_idx, optimizer in optimizers: + self.run_batch_split(batch_outputs, batch_idx, split_idx, split_batch, opt_idx, optimizer) + else: + self.run_batch_split(batch_outputs, batch_idx, split_idx, split_batch) result = AttributeDict( signal=0, @@ -697,7 +698,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): ) return result - def run_batch_split(self, batch_outputs, batch_idx, split_idx, split_batch, opt_idx, optimizer): + def run_batch_split(self, batch_outputs, batch_idx, split_idx, split_batch, opt_idx=None, optimizer=None): # toggle model params + set info to logger_connector self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) From cfa8b8c48366517ee5b40e3d394d93959c628958 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 13 May 2021 10:52:06 +0200 Subject: [PATCH 032/455] pull out skip condition --- pytorch_lightning/trainer/training_loop.py | 27 ++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 8684883f1ae40..bb8c173b07f6d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -230,6 +230,22 @@ def _should_add_batch_output_to_epoch_output(self) -> bool: return False + # TODO: find a better way to compute this + def _should_skip_optimizer(self, opt_idx: int, batch_idx: Optional[int] = None) -> bool: + """ Determine if the optimizer should be skipped based on desired frequencies. """ + if not self.trainer.optimizer_frequencies: + return False + + if batch_idx is None: + batch_idx = self.total_batch_idx + + optimizers_loop_length = self.optimizer_freq_cumsum[-1] + current_place_in_loop = batch_idx % optimizers_loop_length + + # find optimzier index by looking for the first {item > current_place} in the cumsum list + return opt_idx != np.argmax(self.optimizer_freq_cumsum > current_place_in_loop) + + # TODO: get rid of this method def get_optimizers_iterable(self, batch_idx=None): """ Generates an iterable with (idx, optimizer) for each optimizer. @@ -657,7 +673,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): # bookkeeping self.trainer.hiddens = None - optimizers = self.prepare_optimizers() + optimizers = list(enumerate(self.trainer.optimizers)) # track all outputs across time and num of optimizers batch_outputs = [[] for _ in range(len(optimizers))] @@ -687,6 +703,8 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): if self.trainer.lightning_module.automatic_optimization: for opt_idx, optimizer in optimizers: + if self._should_skip_optimizer(opt_idx, batch_idx): + continue self.run_batch_split(batch_outputs, batch_idx, split_idx, split_batch, opt_idx, optimizer) else: self.run_batch_split(batch_outputs, batch_idx, split_idx, split_batch) @@ -780,7 +798,7 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False): yield None def _process_closure_result( - self, opt_closure_result: Optional[AttributeDict], batch_outputs: list, opt_idx: int + self, opt_closure_result: Optional[AttributeDict], batch_outputs: list, opt_idx: Optional[int] ) -> list: if opt_closure_result: # cache metrics @@ -791,8 +809,9 @@ def _process_closure_result( self._check_finite(opt_closure_result.loss) # track all the outputs across all steps - batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0 - batch_outputs[batch_opt_idx].append(opt_closure_result.training_step_output_for_epoch_end) + # batch_opt_idx = opt_idx if self.trainer.lightning_module.automatic_optimization else 0 + opt_idx = 0 if opt_idx is None else opt_idx + batch_outputs[opt_idx].append(opt_closure_result.training_step_output_for_epoch_end) return batch_outputs From 53097e097859218fc2ead2e39bda28db09756e19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 13 May 2021 11:09:36 +0200 Subject: [PATCH 033/455] pull out batch_outputs from process_output function --- pytorch_lightning/trainer/training_loop.py | 26 +++++++++------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index bb8c173b07f6d..fdfb6399d0958 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -766,11 +766,15 @@ def train_step_and_backward_closure(): # update running loss + reset accumulated loss self.update_running_loss(result.loss) - batch_outputs = self._process_closure_result( - opt_closure_result=result, - batch_outputs=batch_outputs, - opt_idx=opt_idx, - ) + self._process_closure_result(opt_closure_result=result) + # track all the outputs across all steps + + if result is not None: + # this if check is required for tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_and_accumulated_grad + # TODO: make grad accumulation + manual optimization incompatible to simplify this logic here! + opt_idx = 0 if opt_idx is None else opt_idx + batch_outputs[opt_idx].append(result.training_step_output_for_epoch_end) + return batch_outputs @contextmanager @@ -797,9 +801,8 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False): else: yield None - def _process_closure_result( - self, opt_closure_result: Optional[AttributeDict], batch_outputs: list, opt_idx: Optional[int] - ) -> list: + def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) -> None: + """ For manual_optimization, opt_idx is None. """ if opt_closure_result: # cache metrics self.trainer.logger_connector.cache_training_step_metrics(opt_closure_result) @@ -808,13 +811,6 @@ def _process_closure_result( if self.trainer.terminate_on_nan: self._check_finite(opt_closure_result.loss) - # track all the outputs across all steps - # batch_opt_idx = opt_idx if self.trainer.lightning_module.automatic_optimization else 0 - opt_idx = 0 if opt_idx is None else opt_idx - batch_outputs[opt_idx].append(opt_closure_result.training_step_output_for_epoch_end) - - return batch_outputs - def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): """Wrap forward, zero_grad and backward in a closure so second order methods work""" with self.trainer.profiler.profile("training_step_and_backward"): From fe8c18040a25d0b9acb500b7e9cf4cdfa1d28b53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 13 May 2021 11:33:57 +0200 Subject: [PATCH 034/455] move result handling outside --- pytorch_lightning/trainer/training_loop.py | 25 +++++++++++----------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index fdfb6399d0958..3367482b6d174 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -705,9 +705,16 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): for opt_idx, optimizer in optimizers: if self._should_skip_optimizer(opt_idx, batch_idx): continue - self.run_batch_split(batch_outputs, batch_idx, split_idx, split_batch, opt_idx, optimizer) + result = self.run_batch_split(batch_idx, split_idx, split_batch, opt_idx, optimizer) + if result: + batch_outputs[opt_idx].append(result.training_step_output_for_epoch_end) else: - self.run_batch_split(batch_outputs, batch_idx, split_idx, split_batch) + result = self.run_batch_split(batch_idx, split_idx, split_batch) + if result: + # this if check is required for + # tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_and_accumulated_grad + # TODO: make grad accumulation + manual optimization incompatible to simplify this logic here! + batch_outputs[0].append(result.training_step_output_for_epoch_end) result = AttributeDict( signal=0, @@ -716,7 +723,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): ) return result - def run_batch_split(self, batch_outputs, batch_idx, split_idx, split_batch, opt_idx=None, optimizer=None): + def run_batch_split(self, batch_idx, split_idx, split_batch, opt_idx=None, optimizer=None): # toggle model params + set info to logger_connector self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) @@ -758,7 +765,7 @@ def train_step_and_backward_closure(): if not result: # user decided to skip optimization # make sure to zero grad. - return batch_outputs + return result # todo: Properly aggregate grad_norm accros opt_idx and split_idx grad_norm_dict = result.get("grad_norm_dict", {}) @@ -767,15 +774,7 @@ def train_step_and_backward_closure(): self.update_running_loss(result.loss) self._process_closure_result(opt_closure_result=result) - # track all the outputs across all steps - - if result is not None: - # this if check is required for tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_and_accumulated_grad - # TODO: make grad accumulation + manual optimization incompatible to simplify this logic here! - opt_idx = 0 if opt_idx is None else opt_idx - batch_outputs[opt_idx].append(result.training_step_output_for_epoch_end) - - return batch_outputs + return result @contextmanager def block_ddp_sync_behaviour(self, should_block_sync: bool = False): From de8fedf502af3c7d820e16281939333a25b5a339 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 13 May 2021 12:43:07 +0200 Subject: [PATCH 035/455] return early from process_result --- pytorch_lightning/trainer/training_loop.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 3367482b6d174..37302a47111cc 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -802,13 +802,15 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False): def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) -> None: """ For manual_optimization, opt_idx is None. """ - if opt_closure_result: - # cache metrics - self.trainer.logger_connector.cache_training_step_metrics(opt_closure_result) + if not opt_closure_result: + return + + # cache metrics + self.trainer.logger_connector.cache_training_step_metrics(opt_closure_result) - # check if loss or model weights are nan - if self.trainer.terminate_on_nan: - self._check_finite(opt_closure_result.loss) + # check if loss or model weights are nan + if self.trainer.terminate_on_nan: + self._check_finite(opt_closure_result.loss) def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): """Wrap forward, zero_grad and backward in a closure so second order methods work""" From 626d58c8bf759fd768e4e6fe44881a2cd90c67ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 13 May 2021 12:43:28 +0200 Subject: [PATCH 036/455] move grad_norm_dict out --- pytorch_lightning/trainer/training_loop.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 37302a47111cc..f06d4224a0e4f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -704,10 +704,12 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): if self.trainer.lightning_module.automatic_optimization: for opt_idx, optimizer in optimizers: if self._should_skip_optimizer(opt_idx, batch_idx): + # frequency of this optimizer doesnt align with current batch index, skip it continue result = self.run_batch_split(batch_idx, split_idx, split_batch, opt_idx, optimizer) if result: batch_outputs[opt_idx].append(result.training_step_output_for_epoch_end) + grad_norm_dict = result.get("grad_norm_dict", {}) else: result = self.run_batch_split(batch_idx, split_idx, split_batch) if result: @@ -716,12 +718,13 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): # TODO: make grad accumulation + manual optimization incompatible to simplify this logic here! batch_outputs[0].append(result.training_step_output_for_epoch_end) - result = AttributeDict( + output = AttributeDict( signal=0, + # todo: Properly aggregate grad_norm accros opt_idx and split_idx grad_norm_dict=grad_norm_dict, training_step_output_for_epoch_end=batch_outputs, ) - return result + return output def run_batch_split(self, batch_idx, split_idx, split_batch, opt_idx=None, optimizer=None): # toggle model params + set info to logger_connector @@ -767,9 +770,6 @@ def train_step_and_backward_closure(): # make sure to zero grad. return result - # todo: Properly aggregate grad_norm accros opt_idx and split_idx - grad_norm_dict = result.get("grad_norm_dict", {}) - # update running loss + reset accumulated loss self.update_running_loss(result.loss) From 8ae812ee6827f7a903303055671cb74563ab6964 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 13 May 2021 12:59:55 +0200 Subject: [PATCH 037/455] spelling --- pytorch_lightning/trainer/training_loop.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f06d4224a0e4f..ba7e03c5aa8ac 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -704,19 +704,14 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): if self.trainer.lightning_module.automatic_optimization: for opt_idx, optimizer in optimizers: if self._should_skip_optimizer(opt_idx, batch_idx): - # frequency of this optimizer doesnt align with current batch index, skip it + # frequency of this optimizer does not align with current batch index, skip it continue result = self.run_batch_split(batch_idx, split_idx, split_batch, opt_idx, optimizer) - if result: - batch_outputs[opt_idx].append(result.training_step_output_for_epoch_end) - grad_norm_dict = result.get("grad_norm_dict", {}) + batch_outputs[opt_idx].append(result.training_step_output_for_epoch_end) + grad_norm_dict = result.get("grad_norm_dict", {}) else: result = self.run_batch_split(batch_idx, split_idx, split_batch) - if result: - # this if check is required for - # tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_and_accumulated_grad - # TODO: make grad accumulation + manual optimization incompatible to simplify this logic here! - batch_outputs[0].append(result.training_step_output_for_epoch_end) + batch_outputs[0].append(result.training_step_output_for_epoch_end) output = AttributeDict( signal=0, From 8fa5df07ee79102b86d0d7e68afeee166648d66d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 13 May 2021 12:59:55 +0200 Subject: [PATCH 038/455] Revert "spelling" This reverts commit 8ae812ee6827f7a903303055671cb74563ab6964. --- pytorch_lightning/trainer/training_loop.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ba7e03c5aa8ac..f06d4224a0e4f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -704,14 +704,19 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): if self.trainer.lightning_module.automatic_optimization: for opt_idx, optimizer in optimizers: if self._should_skip_optimizer(opt_idx, batch_idx): - # frequency of this optimizer does not align with current batch index, skip it + # frequency of this optimizer doesnt align with current batch index, skip it continue result = self.run_batch_split(batch_idx, split_idx, split_batch, opt_idx, optimizer) - batch_outputs[opt_idx].append(result.training_step_output_for_epoch_end) - grad_norm_dict = result.get("grad_norm_dict", {}) + if result: + batch_outputs[opt_idx].append(result.training_step_output_for_epoch_end) + grad_norm_dict = result.get("grad_norm_dict", {}) else: result = self.run_batch_split(batch_idx, split_idx, split_batch) - batch_outputs[0].append(result.training_step_output_for_epoch_end) + if result: + # this if check is required for + # tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_and_accumulated_grad + # TODO: make grad accumulation + manual optimization incompatible to simplify this logic here! + batch_outputs[0].append(result.training_step_output_for_epoch_end) output = AttributeDict( signal=0, From a734aad271a6a8b349d2c0eb546c4eb5689e242c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 13 May 2021 13:04:27 +0200 Subject: [PATCH 039/455] spelling --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f06d4224a0e4f..311412edc679d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -704,7 +704,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): if self.trainer.lightning_module.automatic_optimization: for opt_idx, optimizer in optimizers: if self._should_skip_optimizer(opt_idx, batch_idx): - # frequency of this optimizer doesnt align with current batch index, skip it + # frequency of this optimizer does not align with current batch index, skip it continue result = self.run_batch_split(batch_idx, split_idx, split_batch, opt_idx, optimizer) if result: From 456fa8c3b51e76e94c32491ce8f5a6546bd51e1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 13 May 2021 13:23:23 +0200 Subject: [PATCH 040/455] change test for old signature --- tests/deprecated_api/test_remove_1-5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 91b93a88f0055..9c3eb0f9f35ed 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -196,7 +196,7 @@ def __init__(self): self.automatic_optimization = False def training_step(self, batch, batch_idx, optimizer_idx): - assert optimizer_idx is not None + assert optimizer_idx is None return super().training_step(batch, batch_idx) def configure_optimizers(self): From cbea408515cf44942fc772fa496ab72b15b1243e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 13 May 2021 13:28:46 +0200 Subject: [PATCH 041/455] rename get_optimizer_iterable and remove prepare_optimizers functions --- pytorch_lightning/callbacks/finetuning.py | 2 +- pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/trainer/training_loop.py | 36 +++++++++------------- 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index d3f52b4ba9a15..a6c13d1b0c0db 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -282,7 +282,7 @@ def _store( def on_train_epoch_start(self, trainer, pl_module): """Called when the epoch begins.""" - for opt_idx, optimizer in trainer.train_loop.prepare_optimizers(): + for opt_idx, optimizer in trainer.train_loop.get_active_optimizers(): num_param_groups = len(optimizer.param_groups) self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx) current_param_groups = optimizer.param_groups diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a9a431ddbba5e..07bc93ad3f1eb 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1000,7 +1000,7 @@ def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: self.optimizer_connector.update_learning_rates( interval='epoch', opt_indices=[ - opt_idx for opt_idx, _ in self.train_loop.get_optimizers_iterable( + opt_idx for opt_idx, _ in self.train_loop.get_active_optimizers( batch_idx=(self.train_loop.total_batch_idx - 1) ) # Select the optimizers which were used in the last batch of the epoch ], diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 311412edc679d..4682816ac3b4e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -14,10 +14,11 @@ from contextlib import contextmanager, suppress from copy import copy, deepcopy -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, Tuple import numpy as np import torch +from torch.optim import Optimizer from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.step_result import Result @@ -78,9 +79,8 @@ def __init__( self.trainer.num_sanity_val_steps = num_sanity_val_steps @property - def num_optimizers(self): - num_optimizers = len(self.get_optimizers_iterable()) - return num_optimizers + def num_active_optimizers(self): + return len(self.get_active_optimizers()) @property def optimizer_freq_cumsum(self): @@ -245,24 +245,25 @@ def _should_skip_optimizer(self, opt_idx: int, batch_idx: Optional[int] = None) # find optimzier index by looking for the first {item > current_place} in the cumsum list return opt_idx != np.argmax(self.optimizer_freq_cumsum > current_place_in_loop) - # TODO: get rid of this method - def get_optimizers_iterable(self, batch_idx=None): + def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, Optimizer]]: """ - Generates an iterable with (idx, optimizer) for each optimizer. + Returns the currently active optimizers. When multiple optimizers are used with different frequencies, + only one of the optimizers is active at a time. + + Returns: + A list of tuples (opt_idx, optimizer) of currently active optimizers. """ if not self.trainer.optimizer_frequencies: # call training_step once per optimizer return list(enumerate(self.trainer.optimizers)) - if batch_idx is None: - batch_idx = self.total_batch_idx - + batch_idx = self.total_batch_idx if batch_idx is None else batch_idx optimizers_loop_length = self.optimizer_freq_cumsum[-1] current_place_in_loop = batch_idx % optimizers_loop_length # find optimzier index by looking for the first {item > current_place} in the cumsum list - opt_idx = np.argmax(self.optimizer_freq_cumsum > current_place_in_loop) - return [[opt_idx, self.trainer.optimizers[opt_idx]]] + opt_idx = int(np.argmax(self.optimizer_freq_cumsum > current_place_in_loop)) + return [(opt_idx, self.trainer.optimizers[opt_idx])] def on_after_backward(self, training_step_output, batch_idx, untouched_loss): training_step_output.detach() @@ -483,7 +484,7 @@ def run_training_epoch(self): train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) # track epoch output - epoch_output = [[] for _ in range(self.num_optimizers)] + epoch_output = [[] for _ in range(self.num_active_optimizers)] train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) dataloader_idx = 0 @@ -882,7 +883,7 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None): self.trainer.optimizer_connector.update_learning_rates( interval="step", monitor_metrics=monitor_metrics, - opt_indices=[opt_idx for opt_idx, _ in self.get_optimizers_iterable()], + opt_indices=[opt_idx for opt_idx, _ in self.get_active_optimizers()], ) def increment_accumulated_grad_global_step(self): @@ -980,13 +981,6 @@ def save_loggers_on_train_batch_end(self): if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() - def prepare_optimizers(self): - # in manual optimization we loop over all optimizers at once - optimizers = self.get_optimizers_iterable() - if not self.trainer.lightning_module.automatic_optimization: - optimizers = [optimizers[0]] - return optimizers - def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): # set split_idx to trainer for tracking self.trainer.split_idx = split_idx From 3ccfe98ccb8aa46895478976b8c7480aa1b46fa1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 May 2021 11:29:36 +0000 Subject: [PATCH 042/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 4682816ac3b4e..1aa0560f67c00 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -14,7 +14,7 @@ from contextlib import contextmanager, suppress from copy import copy, deepcopy -from typing import Any, Dict, List, Optional, Union, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch From f6ca31b43ef5a39f30331b8a07ea78e163707c22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 13 May 2021 13:53:51 +0200 Subject: [PATCH 043/455] add changelog --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c04dd2b673481..059fc0ecdb421 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,7 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Refactored Loops * Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025)) * Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506)) - + * Refactored the logic around manual and automatic optimization inside the optimizer loop ([#7526](https://github.com/PyTorchLightning/pytorch-lightning/pull/7526)) + - `DataModule`s now avoid duplicate `{setup,teardown,prepare_data}` calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238)) From ea1bbbeca96a89eb1805d00d79c70877abef2dae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 13 May 2021 14:16:22 +0200 Subject: [PATCH 044/455] simplify --- pytorch_lightning/trainer/training_loop.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1aa0560f67c00..b58ff64be1946 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -230,20 +230,10 @@ def _should_add_batch_output_to_epoch_output(self) -> bool: return False - # TODO: find a better way to compute this def _should_skip_optimizer(self, opt_idx: int, batch_idx: Optional[int] = None) -> bool: """ Determine if the optimizer should be skipped based on desired frequencies. """ - if not self.trainer.optimizer_frequencies: - return False - - if batch_idx is None: - batch_idx = self.total_batch_idx - - optimizers_loop_length = self.optimizer_freq_cumsum[-1] - current_place_in_loop = batch_idx % optimizers_loop_length - - # find optimzier index by looking for the first {item > current_place} in the cumsum list - return opt_idx != np.argmax(self.optimizer_freq_cumsum > current_place_in_loop) + active_indices = [idx for (idx, _) in self.get_active_optimizers(batch_idx)] + return opt_idx not in active_indices def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, Optimizer]]: """ From 338a0c14fb26c25561ba6c0b4040ca3cb0ce2020 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 May 2021 12:17:22 +0000 Subject: [PATCH 045/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 059fc0ecdb421..3cb297a9606f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,7 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025)) * Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506)) * Refactored the logic around manual and automatic optimization inside the optimizer loop ([#7526](https://github.com/PyTorchLightning/pytorch-lightning/pull/7526)) - + - `DataModule`s now avoid duplicate `{setup,teardown,prepare_data}` calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238)) From 52f1244554de1ff72c93434e54f4bdf675ba77ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 13 May 2021 15:58:16 +0200 Subject: [PATCH 046/455] clean up --- pytorch_lightning/trainer/training_loop.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b58ff64be1946..c86ac832a1d8b 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -764,7 +764,7 @@ def train_step_and_backward_closure(): # update running loss + reset accumulated loss self.update_running_loss(result.loss) - self._process_closure_result(opt_closure_result=result) + self._process_closure_result(result) return result @contextmanager @@ -792,7 +792,6 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False): yield None def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) -> None: - """ For manual_optimization, opt_idx is None. """ if not opt_closure_result: return From 5e33ec09d71a9f6ec54798735c838d36af6caa9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 13 May 2021 16:01:39 +0200 Subject: [PATCH 047/455] simplify --- pytorch_lightning/trainer/training_loop.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c86ac832a1d8b..ff2c2168ab4c9 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -702,11 +702,9 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): batch_outputs[opt_idx].append(result.training_step_output_for_epoch_end) grad_norm_dict = result.get("grad_norm_dict", {}) else: + # in manual optimization, there is no looping over optimizers result = self.run_batch_split(batch_idx, split_idx, split_batch) if result: - # this if check is required for - # tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_and_accumulated_grad - # TODO: make grad accumulation + manual optimization incompatible to simplify this logic here! batch_outputs[0].append(result.training_step_output_for_epoch_end) output = AttributeDict( From 262bf49ebc9d5568498b110127edd11fe4cec6b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 13 May 2021 16:34:41 +0200 Subject: [PATCH 048/455] simplify --- pytorch_lightning/trainer/training_loop.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ff2c2168ab4c9..9232b88901568 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -693,10 +693,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): for split_idx, split_batch in enumerate(splits): if self.trainer.lightning_module.automatic_optimization: - for opt_idx, optimizer in optimizers: - if self._should_skip_optimizer(opt_idx, batch_idx): - # frequency of this optimizer does not align with current batch index, skip it - continue + for opt_idx, optimizer in self.get_active_optimizers(batch_idx): result = self.run_batch_split(batch_idx, split_idx, split_batch, opt_idx, optimizer) if result: batch_outputs[opt_idx].append(result.training_step_output_for_epoch_end) From 338cfbca9f00e432cf1d70aa7958f05d3e52ed4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 13 May 2021 16:50:13 +0200 Subject: [PATCH 049/455] better name for method --- pytorch_lightning/trainer/training_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 9232b88901568..ddff048355e5c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -694,13 +694,13 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): if self.trainer.lightning_module.automatic_optimization: for opt_idx, optimizer in self.get_active_optimizers(batch_idx): - result = self.run_batch_split(batch_idx, split_idx, split_batch, opt_idx, optimizer) + result = self.run_optimization(batch_idx, split_idx, split_batch, opt_idx, optimizer) if result: batch_outputs[opt_idx].append(result.training_step_output_for_epoch_end) grad_norm_dict = result.get("grad_norm_dict", {}) else: # in manual optimization, there is no looping over optimizers - result = self.run_batch_split(batch_idx, split_idx, split_batch) + result = self.run_optimization(batch_idx, split_idx, split_batch) if result: batch_outputs[0].append(result.training_step_output_for_epoch_end) @@ -712,7 +712,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): ) return output - def run_batch_split(self, batch_idx, split_idx, split_batch, opt_idx=None, optimizer=None): + def run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=None, optimizer=None): # toggle model params + set info to logger_connector self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) From 771b6891bdb52ac5329f15b96833a017b7b2a98f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 14 May 2021 14:54:09 +0200 Subject: [PATCH 050/455] update with closure --- pytorch_lightning/core/optimizer.py | 2 +- pytorch_lightning/trainer/training_loop.py | 41 ++++++++++++++-------- tests/core/test_lightning_optimizer.py | 2 +- 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 162e17ca47bf5..fbe4cf612d91e 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -207,7 +207,7 @@ def closure_dis(): profiler_name = "closure_{self._optimizer_idx}" closure = do_nothing_closure else: - if not isinstance(closure, types.FunctionType): + if not isinstance(closure, Callable): raise MisconfigurationException("When closure is provided, it should be a function") profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ddff048355e5c..ec4e2f58538b2 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -14,7 +14,8 @@ from contextlib import contextmanager, suppress from copy import copy, deepcopy -from typing import Any, Dict, List, Optional, Tuple, Union +from functools import partial, update_wrapper +from typing import Any, Dict, List, Optional, Tuple, Union, Callable import numpy as np import torch @@ -717,6 +718,8 @@ def run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=None, opti self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) result = AttributeDict() + closure = self.make_closure(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens, result) + if self.should_accumulate(): # For gradient accumulation @@ -727,9 +730,7 @@ def run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=None, opti # automatic_optimization=True: perform dpp sync only when performing optimizer_step # automatic_optimization=False: don't block synchronization here with self.block_ddp_sync_behaviour(): - result = self.training_step_and_backward( - split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens - ) + closure() # ------------------------------ # BACKWARD PASS @@ -737,17 +738,7 @@ def run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=None, opti # gradient update with accumulated gradients else: if self.trainer.lightning_module.automatic_optimization: - - def train_step_and_backward_closure(): - nonlocal result - result = self.training_step_and_backward( - split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens - ) - return None if result is None else result.loss - - # optimizer step - self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) - + self.optimizer_step(optimizer, opt_idx, batch_idx, closure) else: result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens) @@ -762,6 +753,26 @@ def train_step_and_backward_closure(): self._process_closure_result(result) return result + def training_step_and_backward_closure( + self, + split_batch: Any, + batch_idx: int, + opt_idx: int, + optimizer: Optimizer, + hiddens, + return_result: AttributeDict, + ) -> Optional[torch.Tensor]: + + step_result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) + if step_result is not None: + return_result.update(step_result) + return return_result.loss + + def make_closure(self, *closure_args, **closure_kwargs: Any) -> Callable: + """ Wraps the training step closure into a partial object which will be called within ``optimizer.step``. """ + partial_func = partial(self.training_step_and_backward_closure, *closure_args, **closure_kwargs) + return update_wrapper(partial_func, self.training_step_and_backward_closure) + @contextmanager def block_ddp_sync_behaviour(self, should_block_sync: bool = False): """ diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index 8858129b221f9..d79cae75956a2 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -243,7 +243,7 @@ def training_epoch_end(self, outputs): ... def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **_): - assert optimizer_closure.__name__ == "train_step_and_backward_closure" + assert optimizer_closure.__name__ == "training_step_and_backward_closure" # not passing the closure to the optimizer because step is mocked # zero_grad is called inside the closure if isinstance(optimizer, SGD) and batch_idx % 2 == 0: From 534863ae29db74ad2127a46e74358518b5e9c27b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 May 2021 12:54:55 +0000 Subject: [PATCH 051/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ec4e2f58538b2..57738a5485434 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -15,7 +15,7 @@ from contextlib import contextmanager, suppress from copy import copy, deepcopy from functools import partial, update_wrapper -from typing import Any, Dict, List, Optional, Tuple, Union, Callable +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch From 2f12edb88aa921062adb8ef1dd422f0f5b504a05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 14 May 2021 15:29:04 +0200 Subject: [PATCH 052/455] Update pytorch_lightning/trainer/training_loop.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- pytorch_lightning/trainer/training_loop.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 57738a5485434..7f4e30e2bcd75 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -744,7 +744,6 @@ def run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=None, opti if not result: # user decided to skip optimization - # make sure to zero grad. return result # update running loss + reset accumulated loss From 6cfa75978e22094bd97faba2d84adce1d1a3d35f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 15 May 2021 00:21:47 +0200 Subject: [PATCH 053/455] Update pytorch_lightning/trainer/training_loop.py Co-authored-by: ananthsub --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 7f4e30e2bcd75..f9f2fdbaa62ba 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -727,7 +727,7 @@ def run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=None, opti # calculate loss (train step + train step end) # ------------------- - # automatic_optimization=True: perform dpp sync only when performing optimizer_step + # automatic_optimization=True: perform ddp sync only when performing optimizer_step # automatic_optimization=False: don't block synchronization here with self.block_ddp_sync_behaviour(): closure() From d8e464ba7892fca843a3f669ee6a52c3e32bc2df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 15 May 2021 03:24:36 +0200 Subject: [PATCH 054/455] remove unused method --- pytorch_lightning/trainer/training_loop.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f9f2fdbaa62ba..2c179c8fa47c3 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -231,11 +231,6 @@ def _should_add_batch_output_to_epoch_output(self) -> bool: return False - def _should_skip_optimizer(self, opt_idx: int, batch_idx: Optional[int] = None) -> bool: - """ Determine if the optimizer should be skipped based on desired frequencies. """ - active_indices = [idx for (idx, _) in self.get_active_optimizers(batch_idx)] - return opt_idx not in active_indices - def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, Optimizer]]: """ Returns the currently active optimizers. When multiple optimizers are used with different frequencies, From 90fb8d369fad19001d357f18182864b4146cdb06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 15 May 2021 03:26:19 +0200 Subject: [PATCH 055/455] protect eyes of user --- pytorch_lightning/trainer/training_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 2c179c8fa47c3..d28eee6d76b04 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -690,13 +690,13 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): if self.trainer.lightning_module.automatic_optimization: for opt_idx, optimizer in self.get_active_optimizers(batch_idx): - result = self.run_optimization(batch_idx, split_idx, split_batch, opt_idx, optimizer) + result = self._run_optimization(batch_idx, split_idx, split_batch, opt_idx, optimizer) if result: batch_outputs[opt_idx].append(result.training_step_output_for_epoch_end) grad_norm_dict = result.get("grad_norm_dict", {}) else: # in manual optimization, there is no looping over optimizers - result = self.run_optimization(batch_idx, split_idx, split_batch) + result = self._run_optimization(batch_idx, split_idx, split_batch) if result: batch_outputs[0].append(result.training_step_output_for_epoch_end) @@ -708,7 +708,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): ) return output - def run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=None, optimizer=None): + def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=None, optimizer=None): # toggle model params + set info to logger_connector self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) From d7a6367e756e58630d73dcf91660d3917a7c1caf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 15 May 2021 03:30:29 +0200 Subject: [PATCH 056/455] add a nice type hint --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d28eee6d76b04..5cc4c1f292367 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -80,7 +80,7 @@ def __init__( self.trainer.num_sanity_val_steps = num_sanity_val_steps @property - def num_active_optimizers(self): + def num_active_optimizers(self) -> int: return len(self.get_active_optimizers()) @property From fa3155875776d15846691da7e70e723cb60b2aea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 17 May 2021 10:36:09 +0200 Subject: [PATCH 057/455] callable check Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- pytorch_lightning/core/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index fbe4cf612d91e..07e64a79ec8a9 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -207,7 +207,7 @@ def closure_dis(): profiler_name = "closure_{self._optimizer_idx}" closure = do_nothing_closure else: - if not isinstance(closure, Callable): + if not callable(closure): raise MisconfigurationException("When closure is provided, it should be a function") profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}" From a8a0ca793e406273ef40056df1629a3e53ba1cf1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 17 May 2021 11:00:07 +0200 Subject: [PATCH 058/455] change opt_idx back to 0 in manual opt --- pytorch_lightning/trainer/training_loop.py | 5 ++++- tests/deprecated_api/test_remove_1-5.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e023921e583c3..e19adb57756bb 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -712,7 +712,10 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): ) return output - def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=None, optimizer=None): + def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimizer=None): + # TODO: In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change + # opt_idx=0 to opt_idx=None in the signature here + # toggle model params + set info to logger_connector self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 9c3eb0f9f35ed..d6c9b6d8f8f31 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -196,7 +196,7 @@ def __init__(self): self.automatic_optimization = False def training_step(self, batch, batch_idx, optimizer_idx): - assert optimizer_idx is None + assert optimizer_idx == 0 return super().training_step(batch, batch_idx) def configure_optimizers(self): From e8d315554017e45f5b7b759a6bd26ab374536e7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 17 May 2021 11:00:45 +0200 Subject: [PATCH 059/455] resolve conflict --- tests/accelerators/test_accelerator_connector.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 92dd5c21ac420..e60b86513e5ff 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -295,7 +295,6 @@ def test_accelerator_choice_ddp_kubeflow(device_count_mock, setup_distributed_mo class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPPlugin) assert isinstance(trainer.training_type_plugin.cluster_environment, KubeflowEnvironment) @@ -331,7 +330,6 @@ def test_accelerator_choice_ddp_cpu_kubeflow(device_count_mock, setup_distribute class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp assert isinstance(trainer.accelerator, CPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPPlugin) assert isinstance(trainer.training_type_plugin.cluster_environment, KubeflowEnvironment) From 172031d67046e7a7974f65c29a83504bc02c5921 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 17 May 2021 12:57:36 +0200 Subject: [PATCH 060/455] fix pepe --- pytorch_lightning/core/optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 07e64a79ec8a9..174631ae73e8b 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import types from contextlib import contextmanager from typing import Callable, Optional from weakref import proxy From 677d7e13a03d7c9ce778cd20f1e2ae7dcc2b4848 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 17 May 2021 14:31:09 +0200 Subject: [PATCH 061/455] incorporate changes from PR #7526 and master --- pytorch_lightning/loops/batch_loop.py | 281 +++++++++++++++-------- pytorch_lightning/loops/epoch_loop.py | 57 +++-- pytorch_lightning/loops/training_loop.py | 25 +- pytorch_lightning/trainer/trainer.py | 33 +-- 4 files changed, 249 insertions(+), 147 deletions(-) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 0bd3b484e6d73..0188c74c671ef 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -1,9 +1,12 @@ +from collections import OrderedDict from contextlib import contextmanager from copy import copy -from typing import List +from functools import partial, update_wrapper +from typing import List, Any, Optional, Callable, Tuple import numpy as np import torch +from torch.optim import Optimizer from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.step_result import Result @@ -14,6 +17,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.imports import _TPU_AVAILABLE +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.warnings import WarningCache @@ -26,11 +30,14 @@ def __init__(self): self.warning_cache = WarningCache() # self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20) - self.automatic_optimization = True + self.accumulated_loss = None + self._skip_backward = False + self._hiddens = None + + self.split_idx = None def connect(self, trainer, *args, **kwargs): self.trainer = trainer - self._optimizers = self.prepare_optimizers() @property def done(self): @@ -38,51 +45,29 @@ def done(self): def on_run_start(self, batch, batch_idx, dataloader_idx): self._grad_norm_dic = {} - self.trainer.hiddens = None - # self._optimizers = self.prepare_optimizers() + self._hiddens = None + self._optimizers = self.get_active_optimizers() # lightning module hook self._remaining_splits = list(enumerate(self.tbptt_split_batch(batch))) - self._curr_step_result = None - self._cur_grad_norm_dict = None - # self._multiple_trainloader_mode = multiple_trainloader_mode - self._skip_backward = False - # self.trainer._multiple_trainloader_mode = multiple_trainloader_mode - def advance(self, batch, batch_idx, dataloader_idx): split_idx, split_batch = self._remaining_splits.pop(0) + self.split_idx = split_idx - batch_outputs = [[] for _ in range(len(self._optimizers))] - - for opt_idx, optimizer in self._optimizers: - # toggle model params + set info to logger_connector - self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) - - def train_step_and_backward_closure(): - result = self.training_step_and_backward( - split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens - ) - return None if result is None else result.loss - - # optimizer step - self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) - - if self._curr_step_result is None: - # user decided to skip optimization - # make sure to zero grad. - continue - - batch_outputs = self._process_closure_result( - batch_outputs=batch_outputs, - opt_idx=opt_idx, - ) # 1 optimizer case: batch_outputs[0][0] = Result object - - # todo: Properly aggregate grad_norm accros opt_idx and split_idx - grad_norm_dic = self._cur_grad_norm_dict - self._cur_grad_norm_dict = None + # TODO: this list needs to go outside this loop + batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] - # update running loss + reset accumulated loss - # self.update_running_loss() + if self.trainer.lightning_module.automatic_optimization: + for opt_idx, optimizer in self.get_active_optimizers(batch_idx): + result = self._run_optimization(batch_idx, split_idx, split_batch, opt_idx, optimizer) + if result: + batch_outputs[opt_idx].append(result.training_step_output_for_epoch_end) + grad_norm_dict = result.get("grad_norm_dict", {}) + else: + # in manual optimization, there is no looping over optimizers + result = self._run_optimization(batch_idx, split_idx, split_batch) + if result: + batch_outputs[0].append(result.training_step_output_for_epoch_end) return batch_outputs @@ -104,12 +89,14 @@ def run(self, batch, batch_idx, dataloader_idx): batch_outputs = batch_outputs[0] # TODO: hack for poc - result = AttributeDict( + output = AttributeDict( signal=0, - grad_norm_dic=self._cur_grad_norm_dict, + # todo: Properly aggregate grad_norm accros opt_idx and split_idx + # grad_norm_dict=grad_norm_dict, + grad_norm_dict={}, training_step_output_for_epoch_end=batch_outputs, ) - return result + return output def on_run_end(self, outputs: List) -> List: return outputs @@ -125,29 +112,78 @@ def tbptt_split_batch(self, batch): # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ + def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimizer=None): + # TODO: In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change + # opt_idx=0 to opt_idx=None in the signature here - def prepare_optimizers(self): - # in manual optimization we loop over all optimizers at once - optimizers = self.get_optimizers_iterable() - if not self.automatic_optimization: - optimizers = [optimizers[0]] - return optimizers + # toggle model params + set info to logger_connector + self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) - def get_optimizers_iterable(self): - """ - Generates an iterable with (idx, optimizer) for each optimizer. - """ - if not self.trainer.optimizer_frequencies: - # call training_step once per optimizer - return list(enumerate(self.trainer.optimizers)) + result = AttributeDict() + closure = self.make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens, result) - optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) - optimizers_loop_length = optimizer_freq_cumsum[-1] - current_place_in_loop = self.trainer.total_batch_idx % optimizers_loop_length + if self.should_accumulate(): + # For gradient accumulation - # find optimzier index by looking for the first {item > current_place} in the cumsum list - opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) - return [[opt_idx, self.trainer.optimizers[opt_idx]]] + # ------------------- + # calculate loss (train step + train step end) + # ------------------- + + # automatic_optimization=True: perform ddp sync only when performing optimizer_step + # automatic_optimization=False: don't block synchronization here + with self.block_ddp_sync_behaviour(): + closure() + + # ------------------------------ + # BACKWARD PASS + # ------------------------------ + # gradient update with accumulated gradients + else: + if self.trainer.lightning_module.automatic_optimization: + self.optimizer_step(optimizer, opt_idx, batch_idx, closure) + else: + result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens) + + if not result: + # user decided to skip optimization + return result + + # update running loss + reset accumulated loss + self.update_running_loss(result.loss) + + self._process_closure_result(result) + return result + + def training_step_and_backward_closure( + self, + split_batch: Any, + batch_idx: int, + opt_idx: int, + optimizer: Optimizer, + hiddens, + return_result: AttributeDict, + ) -> Optional[torch.Tensor]: + + step_result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) + if step_result is not None: + return_result.update(step_result) + return return_result.loss + + def make_closure(self, *closure_args, **closure_kwargs: Any) -> Callable: + """ Wraps the training step closure into a partial object which will be called within ``optimizer.step``. """ + partial_func = partial(self.training_step_and_backward_closure, *closure_args, **closure_kwargs) + return update_wrapper(partial_func, self.training_step_and_backward_closure) + + def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) -> None: + if not opt_closure_result: + return + + # cache metrics + self.trainer.logger_connector.cache_training_step_metrics(opt_closure_result) + + # check if loss or model weights are nan + if self.trainer.terminate_on_nan: + self._check_finite(opt_closure_result.loss) def on_after_backward(self, training_step_output, batch_idx, untouched_loss): training_step_output.detach() @@ -159,7 +195,7 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss): self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach()) def _check_training_step_output(self, training_step_output): - if isinstance(training_step_output, torch.Tensor) and not self.automatic_optimization: + if isinstance(training_step_output, torch.Tensor) and not self.trainer.lightning_module.automatic_optimization: if training_step_output.grad_fn is None: # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor") @@ -169,13 +205,13 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): model_ref = self.trainer.lightning_module with self.trainer.profiler.profile("model_forward"): - args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) + step_kwargs = self._build_kwargs(split_batch, batch_idx, opt_idx, hiddens) # manually capture logged metrics model_ref._current_fx_name = 'training_step' model_ref._results = Result() with self.trainer.profiler.profile("training_step"): - training_step_output = self.trainer.accelerator.training_step(args) + training_step_output = self.trainer.accelerator.training_step(step_kwargs) self.trainer.accelerator.post_training_step() self.trainer.logger_connector.cache_logged_metrics() @@ -194,7 +230,7 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): closure_loss = None untouched_loss = None - if self.automatic_optimization: + if self.trainer.lightning_module.automatic_optimization: # accumulate loss. if accumulate_grad_batches==1, no effect closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches @@ -205,8 +241,8 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): result = AttributeDict( closure_loss=closure_loss, loss=untouched_loss, - training_step_output=training_step_output, - training_step_output_for_epoch_end=training_step_output_for_epoch_end, + training_step_output=training_step_output, # Result object + training_step_output_for_epoch_end=training_step_output_for_epoch_end, # Result object ) return result @@ -320,14 +356,14 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens): if len(self.trainer.optimizers) > 1: if self.trainer.has_arg("training_step", "optimizer_idx"): - if not self.automatic_optimization: + if not self.trainer.lightning_module.automatic_optimization: self.warning_cache.warn( "`training_step` hook signature has changed in v1.3." " `optimizer_idx` argument has been removed in case of manual optimization. Support for" " the old signature will be removed in v1.5", DeprecationWarning ) args.append(opt_idx) - elif not self.trainer.has_arg("training_step", "optimizer_idx") and self.automatic_optimization: + elif not self.trainer.has_arg("training_step", "optimizer_idx") and self.trainer.lightning_module.automatic_optimization: raise ValueError( f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but" ' `training_step` is missing the `optimizer_idx` argument.' @@ -345,7 +381,7 @@ def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. - if self.automatic_optimization and len(self.trainer.optimizers) > 1: + if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1: model = self.trainer.lightning_module model.toggle_optimizer(optimizer, opt_idx) @@ -369,46 +405,20 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False): """ if ( isinstance(self.trainer.training_type_plugin, ParallelPlugin) - and (self.automatic_optimization or should_block_sync) + and (self.trainer.lightning_module.automatic_optimization or should_block_sync) ): with self.trainer.training_type_plugin.block_backward_sync(): yield None else: yield None - def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list: - opt_closure_result = self._curr_step_result - - if opt_closure_result is not None: - - # cache metrics - self.trainer.logger_connector.cache_training_step_metrics(opt_closure_result) - - # check if loss or model weights are nan - if self.trainer.terminate_on_nan: - self._check_finite(opt_closure_result.loss) - - # track all the outputs across all steps - batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0 - batch_outputs[batch_opt_idx].append(opt_closure_result.training_step_output_for_epoch_end) - - if self.automatic_optimization: - # track total loss for logging (avoid mem leaks) - # self.accumulated_loss.append(opt_closure_result.loss) - pass - - self._curr_step_result = None - - return batch_outputs - def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): """Wrap forward, zero_grad and backward in a closure so second order methods work""" with self.trainer.profiler.profile("training_step_and_backward"): # lightning module hook result = self.training_step(split_batch, batch_idx, opt_idx, hiddens) - self._curr_step_result = result - if not self._skip_backward and self.automatic_optimization: + if not self._skip_backward and self.trainer.lightning_module.automatic_optimization: is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0 if is_first_batch_to_accumulate: @@ -430,7 +440,9 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, self._check_finite(result.loss) else: - self.warning_cache.warn("training_step returned None if it was on purpose, ignore this warning...") + self.warning_cache.warn( + "training_step returned None. If this was on purpose, ignore this warning..." + ) if len(self.trainer.optimizers) > 1: # revert back to previous state @@ -461,7 +473,11 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs): # track gradients self.track_and_norm_grad(optimizer=optimizer) - def update_running_loss(self): + def update_running_loss(self, current_loss: torch.Tensor) -> None: + if self.trainer.lightning_module.automatic_optimization: + # track total loss for logging (avoid mem leaks) + self.accumulated_loss.append(current_loss) + accumulated_loss = self.accumulated_loss.mean() if accumulated_loss is not None: @@ -470,3 +486,64 @@ def update_running_loss(self): # reset for next set of accumulated grads self.accumulated_loss.reset() + + + def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, Optimizer]]: + """ + Returns the currently active optimizers. When multiple optimizers are used with different frequencies, + only one of the optimizers is active at a time. + + Returns: + A list of tuples (opt_idx, optimizer) of currently active optimizers. + """ + if not self.trainer.optimizer_frequencies: + # call training_step once per optimizer + return list(enumerate(self.trainer.optimizers)) + + batch_idx = self.total_batch_idx if batch_idx is None else batch_idx + optimizers_loop_length = self.optimizer_freq_cumsum[-1] + current_place_in_loop = batch_idx % optimizers_loop_length + + # find optimzier index by looking for the first {item > current_place} in the cumsum list + opt_idx = int(np.argmax(self.optimizer_freq_cumsum > current_place_in_loop)) + return [(opt_idx, self.trainer.optimizers[opt_idx])] + + def _build_kwargs(self, batch, batch_idx, opt_idx, hiddens): + # enable not needing to add opt_idx to training_step + step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)]) + + lightning_module = self.trainer.lightning_module + + if len(self.trainer.optimizers) > 1: + training_step_fx = getattr(lightning_module, "training_step") + has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx") + if has_opt_idx_in_train_step: + if not lightning_module.automatic_optimization: + self.warning_cache.warn( + "`training_step` hook signature has changed in v1.3." + " `optimizer_idx` argument has been removed in case of manual optimization. Support for" + " the old signature will be removed in v1.5", DeprecationWarning + ) + step_kwargs['optimizer_idx'] = opt_idx + elif not has_opt_idx_in_train_step and self.trainer.lightning_module.automatic_optimization: + raise ValueError( + f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but" + ' `training_step` is missing the `optimizer_idx` argument.' + ) + + # pass hiddens if using tbptt + if self._truncated_bptt_enabled(): + step_kwargs['hiddens'] = hiddens + + return step_kwargs + + def _truncated_bptt_enabled(self) -> bool: + """ Temporary tbptt utilities until this flag is fully migrated to the lightning module. """ + return self._truncated_bptt_steps() > 0 + + def _truncated_bptt_steps(self) -> int: + lightning_module = self.trainer.lightning_module + # Give precedence to the LightningModule as the Trainer flag will be removed in v1.5 + if lightning_module.truncated_bptt_steps > 0: + return lightning_module.truncated_bptt_steps + return self.trainer.truncated_bptt_steps or 0 diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index d9bbd1309f326..667aa78fe5087 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -21,19 +21,46 @@ class EpochLoop(Loop): - def __init__(self): + def __init__(self, min_epochs, max_epochs, min_steps, max_steps): super().__init__() self.running_loss = torch.tensor(0.0) # dummy TODO: self._teardown_already_run = False + # If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000 + self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs + # If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1 + self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs + + self.current_epoch = 0 + + self.training_loop = TrainingLoop(min_steps, max_steps) + + @property + def global_step(self): + return self.training_loop.global_step + + @property + def total_batch_idx(self): + return self.training_loop.total_batch_idx + + @property + def batch_idx(self): + return self.training_loop.batch_idx + + @property + def split_idx(self): + return self.training_loop.split_idx + + @property + def min_steps(self): + return self.training_loop.min_steps + + @property + def max_steps(self): + return self.training_loop.max_steps + def connect(self, trainer: 'pl.Trainer', *args, **kwargs): - self.num_epochs = trainer.max_epochs - self.min_epochs = trainer.min_epochs - # TODO: let inner loop track the steps - self.max_steps = trainer.max_steps - self.min_steps = trainer.min_steps self.trainer = trainer - self.training_loop = TrainingLoop() self.training_loop.connect(trainer) def should_accumulate(self): @@ -43,13 +70,13 @@ def should_accumulate(self): @property def done(self) -> bool: # TODO: Move track steps inside training loop and move part of these condition inside training loop - stop_steps = self.trainer.max_steps and self.trainer.max_steps <= self.trainer.global_step + stop_steps = self.max_steps and self.max_steps <= self.global_step should_stop = False if self.trainer.should_stop: # early stopping - met_min_epochs = (self.iteration_count >= self.trainer.min_epochs - 1) if self.trainer.min_epochs else True - met_min_steps = self.trainer.global_step >= self.trainer.min_steps if self.trainer.min_steps else True + met_min_epochs = (self.iteration_count >= self.min_epochs - 1) if self.min_epochs else True + met_min_steps = self.global_step >= self.min_steps if self.min_steps else True if met_min_epochs and met_min_steps: self.training_loop.on_train_end() should_stop = True @@ -61,7 +88,7 @@ def done(self) -> bool: ) self.trainer.should_stop = False - stop_epochs = self.iteration_count >= self.num_epochs + stop_epochs = self.iteration_count >= self.max_epochs return stop_steps or should_stop or stop_epochs def on_run_start(self): @@ -75,10 +102,10 @@ def on_run_end(self, outputs): # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates # when a checkpoint was saved at the last step - self.trainer.global_step -= 1 + self.training_loop.global_step -= 1 # TODO: see discussion/rework https://github.com/PyTorchLightning/pytorch-lightning/issues/7406 # self.check_checkpoint_callback(should_update=True, is_last=True) - self.trainer.global_step += 1 + self.training_loop.global_step += 1 # hook self.trainer.call_hook("on_train_end") @@ -105,7 +132,7 @@ def on_advance_start(self): # equal to on train epoch start epoch = self.iteration_count + 1 # update training progress in trainer - self.trainer.current_epoch = epoch + self.current_epoch = epoch model = self.trainer.lightning_module @@ -122,7 +149,7 @@ def on_advance_start(self): # equal to on train epoch start self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) # stores accumulated grad fractions per batch - self.accumulated_loss = TensorRunningAccum(window_length=self.trainer.accumulate_grad_batches) + self.training_loop.batch_loop.accumulated_loss = TensorRunningAccum(window_length=self.trainer.accumulate_grad_batches) # hook self.trainer.call_hook("on_epoch_start") diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index ea49c777f1021..1b9ea6895d3e9 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -7,15 +7,27 @@ from pytorch_lightning.loops.batch_loop import BatchLoop from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature class TrainingLoop(Loop): """ Runs over all batches in a dataloader (one epoch). """ - def __init__(self): + def __init__(self, min_steps, max_steps): super().__init__() # cache of all outputs in a single training run / epoch # self.epoch_output = [[]] + self.min_steps = min_steps + self.max_steps = max_steps + + self.global_step = 0 + + # the total batch index across all epochs + self.total_batch_idx = 0 + # the current batch index in the loop that runs over the dataloader(s) + self.batch_idx = 0 + # the current split index when the batch gets split into chunks in truncated backprop through time + self.split_idx = None def connect(self, trainer: 'pl.Trainer', *args, **kwargs): self.trainer = trainer @@ -29,12 +41,13 @@ def on_run_start(self): self._train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) self._dataloader_idx = 0 - self.trainer.batch_idx = 0 + self.batch_idx = 0 self.is_last_batch = False def advance(self): # TODO: profiling is gone batch_idx, (batch, is_last) = next(self._train_dataloader) + self.batch_idx = batch_idx self.trainer.batch_idx = batch_idx self.is_last_batch = is_last @@ -106,7 +119,7 @@ def done(self): if self.trainer.should_stop: return True - self.trainer.total_batch_idx += 1 + self.total_batch_idx += 1 # stop epoch if we limited the number of training batches if self._num_training_batches_reached(self.is_last_batch): @@ -196,7 +209,7 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: self.trainer._cache_logged_metrics() def _num_training_batches_reached(self, is_last_batch=False): - return self.trainer.batch_idx == self.trainer.num_training_batches or is_last_batch + return self.batch_idx == self.trainer.num_training_batches or is_last_batch # TODO move to on_advance_end() def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx): @@ -307,8 +320,8 @@ def increment_accumulated_grad_global_step(self): # progress global step according to grads progress if num_accumulated_batches_reached or num_training_batches_reached: - self.trainer.global_step = self.trainer.accelerator.update_global_step( - self.trainer.total_batch_idx, self.trainer.global_step + self.global_step = self.trainer.accelerator.update_global_step( + self.total_batch_idx, self.trainer.global_step ) def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 310aa02bdeb5a..2244debf5c805 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -30,6 +30,7 @@ from pytorch_lightning.core.step_result import Result from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loops.epoch_loop import EpochLoop +from pytorch_lightning.loops.training_loop import TrainingLoop from pytorch_lightning.plugins import Plugin from pytorch_lightning.plugins.environments import ClusterEnvironment from pytorch_lightning.profiler import BaseProfiler @@ -330,6 +331,11 @@ def __init__( self.checkpoint_connector = CheckpointConnector(self) self.slurm_connector = SLURMConnector(self) self.tuner = Tuner(self) + + self.new_epoch_loop = EpochLoop(min_epochs, max_epochs, min_steps, max_steps) + self.new_epoch_loop.connect(self) + + # old loops: self.train_loop = TrainLoop(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps) self.evaluation_loop = EvaluationLoop(self) self.predict_loop = PredictLoop(self) @@ -377,10 +383,6 @@ def __init__( terminate_on_nan, ) self._setup_on_init( - max_epochs, - min_epochs, - max_steps, - min_steps, num_sanity_val_steps, ) self.evaluation_loop.on_trainer_init() @@ -416,29 +418,13 @@ def __init__( def _setup_on_init( self, - max_epochs: Optional[int], - min_epochs: Optional[int], - max_steps: Optional[int], - min_steps: Optional[int], num_sanity_val_steps: int, ): - self.global_step = 0 - self.current_epoch = 0 self.should_stop = False self.state = TrainerState() - - self.total_batch_idx = 0 - self.batch_idx = 0 self.num_training_batches = 0 self.train_dataloader = None - # If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000 - self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs - # If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1 - self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs - self.max_steps = max_steps - self.min_steps = min_steps - if num_sanity_val_steps == -1: self.num_sanity_val_steps = float("inf") else: @@ -901,13 +887,12 @@ def reset_train_val_dataloaders(self, model) -> None: if self.val_dataloaders is None: self.reset_val_dataloader(model) - def run_train(self) -> None: + def _run_train(self) -> None: new_loop = True if new_loop: - self.train_loop = EpochLoop() - self.train_loop.connect(self) + self.train_loop = self.new_epoch_loop self._run_train_new_loop() else: self._run_train_old_loop() @@ -922,7 +907,7 @@ def _run_train_new_loop(self) -> None: if not self.is_global_zero and self.progress_bar_callback is not None: self.progress_bar_callback.disable() - self.run_sanity_check(self.lightning_module) + self._run_sanity_check(self.lightning_module) self.checkpoint_connector.has_trained = False From 3c6c9bd5b959b99dfde8c42f4c51effcb0211303 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 18 May 2021 00:49:01 +0200 Subject: [PATCH 062/455] merge fixes merge fixes --- pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/loops/batch_loop.py | 4 ++-- pytorch_lightning/loops/epoch_loop.py | 7 +++++-- pytorch_lightning/loops/training_loop.py | 19 ++++++++++++------- 4 files changed, 20 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f1269e4ef1982..5206f91084bd4 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1317,7 +1317,7 @@ def training_step(...): # backward self._running_manual_backward = True - self.trainer.train_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs) + self.trainer.train_loop.training_loop.batch_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs) self._running_manual_backward = False def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None: diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 0188c74c671ef..1f701e73cb2df 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -339,10 +339,10 @@ def _track_gradient_norm(self): return grad_norm_dict def _accumulated_batches_reached(self): - return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 + return (self.iteration_count + 1) % self.trainer.accumulate_grad_batches == 0 def _num_training_batches_reached(self, is_last_batch=False): - return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch + return (self.iteration_count + 1) == self.trainer.num_training_batches or is_last_batch def should_accumulate(self): # checks if backward or backward + optimizer step (via closure) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 667aa78fe5087..06e2d36f0dab5 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -63,9 +63,12 @@ def connect(self, trainer: 'pl.Trainer', *args, **kwargs): self.trainer = trainer self.training_loop.connect(trainer) + # TODO: is it used anywhere? def should_accumulate(self): - # TODO - return False + return self.training_loop.batch_loop.should_accumulate() + + def _accumulated_batches_reached(self): + return (self.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 @property def done(self) -> bool: diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 1b9ea6895d3e9..1bcf0733ca07a 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -29,6 +29,11 @@ def __init__(self, min_steps, max_steps): # the current split index when the batch gets split into chunks in truncated backprop through time self.split_idx = None + self.batch_loop = None + self._train_dataloader = None + self._dataloader_idx = None + self.is_last_batch = None + def connect(self, trainer: 'pl.Trainer', *args, **kwargs): self.trainer = trainer # self.epoch_output = [[] for _ in range(len(trainer.optimizers))] @@ -39,6 +44,7 @@ def on_run_start(self): # modify dataloader if needed (ddp, etc...) train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) + # reset self._train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) self._dataloader_idx = 0 self.batch_idx = 0 @@ -48,8 +54,6 @@ def advance(self): # TODO: profiling is gone batch_idx, (batch, is_last) = next(self._train_dataloader) self.batch_idx = batch_idx - - self.trainer.batch_idx = batch_idx self.is_last_batch = is_last # ------------------------------------ @@ -73,21 +77,22 @@ def advance(self): batch_idx, self._dataloader_idx, ) - return epoch_output - def on_advance_end(self, output): # ----------------------------------------- # SAVE METRICS TO LOGGERS # ----------------------------------------- - self.trainer.logger_connector.log_train_step_metrics(output) + self.trainer.logger_connector.log_train_step_metrics(epoch_output) + return epoch_output + + def on_advance_end(self, output): # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- - should_check_val = self.should_check_val_fx(self.trainer.batch_idx, self.is_last_batch) + should_check_val = self.should_check_val_fx(self.batch_idx, self.is_last_batch) if should_check_val: self.trainer.validating = True - self.trainer.run_evaluation() + self.trainer._run_evaluation() self.trainer.training = True # ----------------------------------------- From 6a9b2a58335a5e5c59c3c812d453a29a6f9e9f93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 19 May 2021 23:03:05 +0200 Subject: [PATCH 063/455] update grad norm tracking --- pytorch_lightning/loops/batch_loop.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 1f701e73cb2df..489d692df7292 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -13,7 +13,7 @@ from pytorch_lightning.loops.base import Loop from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.supporters import TensorRunningAccum, prefetch_iterator -from pytorch_lightning.utilities import AttributeDict, DeviceType, AMPType +from pytorch_lightning.utilities import AttributeDict, DeviceType, AMPType, grad_norm from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.imports import _TPU_AVAILABLE @@ -44,7 +44,6 @@ def done(self): return len(self._remaining_splits) == 0 def on_run_start(self, batch, batch_idx, dataloader_idx): - self._grad_norm_dic = {} self._hiddens = None self._optimizers = self.get_active_optimizers() # lightning module hook @@ -320,22 +319,22 @@ def on_before_zero_grad(self, optimizer): def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) - def track_and_norm_grad(self, optimizer): + def track_and_norm_grad(self, optimizer) -> dict: # track gradient norms - grad_norm_dic = self._track_gradient_norm() + grad_norm_dict = self._track_gradient_norm() # clip gradients self.trainer.accelerator.clip_gradients( optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm ) - self._cur_grad_norm_dict = grad_norm_dic + return grad_norm_dict def _track_gradient_norm(self): grad_norm_dict = {} if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0: if float(self.trainer.track_grad_norm) > 0: model = self.trainer.lightning_module - grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm) + grad_norm_dict = grad_norm(model, self.trainer.track_grad_norm) return grad_norm_dict def _accumulated_batches_reached(self): @@ -471,7 +470,7 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs): if not self.should_accumulate(): # track gradients - self.track_and_norm_grad(optimizer=optimizer) + result.grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer) def update_running_loss(self, current_loss: torch.Tensor) -> None: if self.trainer.lightning_module.automatic_optimization: From cd81aef17cdfdf1f959cf5347492d6a44aa7e79c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 19 May 2021 23:31:23 +0200 Subject: [PATCH 064/455] handle outputs globally --- pytorch_lightning/loops/batch_loop.py | 42 ++++++++++++++---------- pytorch_lightning/loops/training_loop.py | 18 ++++++---- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 489d692df7292..dcf428d325456 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -45,30 +45,32 @@ def done(self): def on_run_start(self, batch, batch_idx, dataloader_idx): self._hiddens = None - self._optimizers = self.get_active_optimizers() - # lightning module hook self._remaining_splits = list(enumerate(self.tbptt_split_batch(batch))) + # TODO: let loops track individual outputs + self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] + def advance(self, batch, batch_idx, dataloader_idx): split_idx, split_batch = self._remaining_splits.pop(0) self.split_idx = split_idx # TODO: this list needs to go outside this loop - batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] + # batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] if self.trainer.lightning_module.automatic_optimization: for opt_idx, optimizer in self.get_active_optimizers(batch_idx): result = self._run_optimization(batch_idx, split_idx, split_batch, opt_idx, optimizer) if result: - batch_outputs[opt_idx].append(result.training_step_output_for_epoch_end) + self.batch_outputs[opt_idx].append(result.training_step_output_for_epoch_end) grad_norm_dict = result.get("grad_norm_dict", {}) else: # in manual optimization, there is no looping over optimizers result = self._run_optimization(batch_idx, split_idx, split_batch) if result: - batch_outputs[0].append(result.training_step_output_for_epoch_end) + self.batch_outputs[0].append(result.training_step_output_for_epoch_end) - return batch_outputs + # TODO: return and accumulate batch outputs + return None def run(self, batch, batch_idx, dataloader_idx): if batch is None: @@ -84,33 +86,30 @@ def run(self, batch, batch_idx, dataloader_idx): if response == -1: return AttributeDict(signal=-1, grad_norm_dic={}) - batch_outputs = super().run(batch, batch_idx, dataloader_idx) + super().run(batch, batch_idx, dataloader_idx) - batch_outputs = batch_outputs[0] # TODO: hack for poc + # batch_outputs = batch_outputs[0] # TODO: hack for poc output = AttributeDict( signal=0, # todo: Properly aggregate grad_norm accros opt_idx and split_idx # grad_norm_dict=grad_norm_dict, grad_norm_dict={}, - training_step_output_for_epoch_end=batch_outputs, + training_step_output_for_epoch_end=self.batch_outputs, ) return output def on_run_end(self, outputs: List) -> List: return outputs - def tbptt_split_batch(self, batch): - splits = [batch] - if self.trainer.truncated_bptt_steps is not None: - model_ref = self.trainer.lightning_module - with self.trainer.profiler.profile("tbptt_split_batch"): - splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps) - return splits - # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ + + @property + def num_active_optimizers(self) -> int: + return len(self.get_active_optimizers()) + def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimizer=None): # TODO: In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change # opt_idx=0 to opt_idx=None in the signature here @@ -349,6 +348,14 @@ def should_accumulate(self): is_final_batch = self._num_training_batches_reached() return not (accumulation_done or is_final_batch) + def tbptt_split_batch(self, batch): + splits = [batch] + if self.trainer.truncated_bptt_steps is not None: + model_ref = self.trainer.lightning_module + with self.trainer.profiler.profile("tbptt_split_batch"): + splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps) + return splits + def build_train_args(self, batch, batch_idx, opt_idx, hiddens): # enable not needing to add opt_idx to training_step args = [batch, batch_idx] @@ -486,7 +493,6 @@ def update_running_loss(self, current_loss: torch.Tensor) -> None: # reset for next set of accumulated grads self.accumulated_loss.reset() - def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, Optimizer]]: """ Returns the currently active optimizers. When multiple optimizers are used with different frequencies, diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 1bcf0733ca07a..dac44bf0da813 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -50,6 +50,9 @@ def on_run_start(self): self.batch_idx = 0 self.is_last_batch = False + # track epoch output + self.epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers)] + def advance(self): # TODO: profiling is gone batch_idx, (batch, is_last) = next(self._train_dataloader) @@ -69,9 +72,9 @@ def advance(self): return # hook - epoch_output = [[]] # TODO: track and return output, let loop base concatenate all outputs into a list etc. + # epoch_output = [[]] # TODO: track and return output, let loop base concatenate all outputs into a list etc. self.on_train_batch_end( - epoch_output, + self.epoch_output, batch_output.training_step_output_for_epoch_end, batch, batch_idx, @@ -81,9 +84,10 @@ def advance(self): # ----------------------------------------- # SAVE METRICS TO LOGGERS # ----------------------------------------- - self.trainer.logger_connector.log_train_step_metrics(epoch_output) + self.trainer.logger_connector.log_train_step_metrics(batch_output) - return epoch_output + # TODO + return None def on_advance_end(self, output): # ----------------------------------------- @@ -107,7 +111,7 @@ def on_advance_end(self, output): # progress global step according to grads progress self.increment_accumulated_grad_global_step() - return output + return None @property def done(self): @@ -134,13 +138,13 @@ def done(self): def on_run_end(self, outputs): # hack for poc - outputs = outputs[0] + # outputs = outputs[0] # inform logger the batch loop has finished self.trainer.logger_connector.on_train_epoch_end() # prepare epoch output - processed_outputs = self._prepare_outputs(outputs, batch_mode=False) + processed_outputs = self._prepare_outputs(self.epoch_output, batch_mode=False) # get the model and call model.training_epoch_end model = self.trainer.lightning_module From c0421643fb425688a040ccc8174574046eabc3c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 20 May 2021 00:24:01 +0200 Subject: [PATCH 065/455] update accumulate, total_batch_idx, optimizer freq --- pytorch_lightning/loops/batch_loop.py | 11 ++++++++--- pytorch_lightning/loops/training_loop.py | 10 +++++++--- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index dcf428d325456..f6c1cfd8fc547 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -33,6 +33,7 @@ def __init__(self): self.accumulated_loss = None self._skip_backward = False self._hiddens = None + self._optimizer_freq_cumsum = None self.split_idx = None @@ -106,9 +107,14 @@ def on_run_end(self, outputs: List) -> List: # HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ + def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int: + return len(self.get_active_optimizers(batch_idx)) + @property - def num_active_optimizers(self) -> int: - return len(self.get_active_optimizers()) + def optimizer_freq_cumsum(self): + if self._optimizer_freq_cumsum is None: + self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) + return self._optimizer_freq_cumsum def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimizer=None): # TODO: In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change @@ -505,7 +511,6 @@ def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[i # call training_step once per optimizer return list(enumerate(self.trainer.optimizers)) - batch_idx = self.total_batch_idx if batch_idx is None else batch_idx optimizers_loop_length = self.optimizer_freq_cumsum[-1] current_place_in_loop = batch_idx % optimizers_loop_length diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index dac44bf0da813..c1364b24dc106 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -51,7 +51,7 @@ def on_run_start(self): self.is_last_batch = False # track epoch output - self.epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers)] + self.epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] def advance(self): # TODO: profiling is gone @@ -118,7 +118,7 @@ def done(self): # max steps reached, end training if ( self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step + 1 - and self._accumulated_batches_reached() + and self.batch_loop._accumulated_batches_reached() ): return True @@ -321,7 +321,11 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None): if num_accumulated_batches_reached or num_training_batches_reached: # update lr - self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics) + self.trainer.optimizer_connector.update_learning_rates( + interval="step", + monitor_metrics=monitor_metrics, + opt_indices=[opt_idx for opt_idx, _ in self.batch_loop.get_active_optimizers(self.total_batch_idx)], + ) def increment_accumulated_grad_global_step(self): num_accumulated_batches_reached = self.batch_loop._accumulated_batches_reached() From ed6103329e7ff7c96607e452f0604a3819ce9133 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 20 May 2021 00:38:08 +0200 Subject: [PATCH 066/455] epoch output --- pytorch_lightning/loops/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index c1364b24dc106..43368396b2e9d 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -169,7 +169,7 @@ def on_run_end(self, outputs): # call train epoch end hooks self._on_train_epoch_end_hook(processed_outputs) self.trainer.call_hook('on_epoch_end') - return processed_outputs + return self.epoch_output # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP From e4395891c3ef5ad85866d9005c9e66e37c75784c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 20 May 2021 01:18:17 +0200 Subject: [PATCH 067/455] val loop calling and checkpoints --- pytorch_lightning/loops/epoch_loop.py | 64 +++++++++++++++++++----- pytorch_lightning/loops/training_loop.py | 2 + 2 files changed, 54 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 06e2d36f0dab5..c0a9b0937803d 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -11,6 +11,7 @@ from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.training_loop import TrainingLoop from pytorch_lightning.trainer.supporters import TensorRunningAccum +from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.parsing import AttributeDict @@ -107,7 +108,7 @@ def on_run_end(self, outputs): # when a checkpoint was saved at the last step self.training_loop.global_step -= 1 # TODO: see discussion/rework https://github.com/PyTorchLightning/pytorch-lightning/issues/7406 - # self.check_checkpoint_callback(should_update=True, is_last=True) + self.check_checkpoint_callback(should_update=True, is_last=True) self.training_loop.global_step += 1 # hook @@ -166,23 +167,21 @@ def on_advance_end(self, outputs): # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics(outputs) - # should_check_val = self.should_check_val_fx(self.trainer.batch_idx, self.trainer.is_last_batch, on_epoch=True) + should_check_val = self.training_loop.should_check_val_fx(self.batch_idx, self.training_loop.is_last_batch, on_epoch=True) should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) should_train_only = self.trainer.disable_validation or should_skip_eval # update epoch level lr_schedulers if no val loop outside train loop is triggered - # if (val_loop_called and not should_check_val) or should_train_only: - self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + if (self.training_loop.val_loop_called and not should_check_val) or should_train_only: + self.trainer.optimizer_connector.update_learning_rates(interval='epoch') - # if should_train_only: - # self.check_checkpoint_callback(True) - # self.check_early_stopping_callback(True) + if should_train_only: + self.check_checkpoint_callback(True) - # TODO: see discussion/rework https://github.com/PyTorchLightning/pytorch-lightning/issues/7406 - # if should_check_val: - # self.trainer.validating = True - # self.trainer.run_evaluation(on_epoch=True) - # self.trainer.training = True + if should_check_val: + self.trainer.validating = True + self.trainer._run_evaluation(on_epoch=True) + self.trainer.training = True # increment the global step once # progress global step according to grads progress @@ -197,3 +196,44 @@ def advance(self): return output + def check_checkpoint_callback(self, should_update, is_last=False): + # TODO bake this logic into the ModelCheckpoint callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = self.trainer.checkpoint_callbacks + + if is_last and any(cb.save_last and cb.verbose for cb in callbacks): + rank_zero_info("Saving latest checkpoint...") + + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + + # def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool: + # """ Decide if we should run validation. """ + # + # if not self.trainer.enable_validation: + # return False + # + # # check if this epoch is eligible to run validation + # if (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0: + # return False + # + # # val_check_batch is inf for iterable datasets with no length defined + # # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch + # is_val_check_batch = False + # if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'): + # is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 + # elif self.trainer.val_check_batch != float('inf'): + # is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 + # + # # Note: num_training_batches is also inf for iterable datasets with no length defined + # epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 + # is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") + # + # if on_epoch: + # return ( + # is_val_check_batch and epoch_end_val_check + # ) or self.trainer.should_stop or is_last_batch_for_infinite_dataset + # else: + # return is_val_check_batch and not epoch_end_val_check diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 43368396b2e9d..e7d4032e87d36 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -49,6 +49,7 @@ def on_run_start(self): self._dataloader_idx = 0 self.batch_idx = 0 self.is_last_batch = False + self.val_loop_called = False # track epoch output self.epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] @@ -98,6 +99,7 @@ def on_advance_end(self, output): self.trainer.validating = True self.trainer._run_evaluation() self.trainer.training = True + self.val_loop_called = True # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) From 7ecc81c1be9d07775ce2d509dc3998873f7a0544 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 20 May 2021 01:32:17 +0200 Subject: [PATCH 068/455] fix merge with master and add missing attributes/methods --- pytorch_lightning/loops/epoch_loop.py | 6 ++-- pytorch_lightning/loops/training_loop.py | 41 ++++++++++++++++-------- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index c0a9b0937803d..d8f52893b0a3f 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -68,8 +68,8 @@ def connect(self, trainer: 'pl.Trainer', *args, **kwargs): def should_accumulate(self): return self.training_loop.batch_loop.should_accumulate() - def _accumulated_batches_reached(self): - return (self.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 + def get_active_optimizers(self, batch_idx): + return self.training_loop.batch_loop.get_active_optimizers(batch_idx) @property def done(self) -> bool: @@ -172,7 +172,7 @@ def on_advance_end(self, outputs): should_train_only = self.trainer.disable_validation or should_skip_eval # update epoch level lr_schedulers if no val loop outside train loop is triggered - if (self.training_loop.val_loop_called and not should_check_val) or should_train_only: + if not should_check_val or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') if should_train_only: diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index e7d4032e87d36..63abda7f5d7b2 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -49,7 +49,6 @@ def on_run_start(self): self._dataloader_idx = 0 self.batch_idx = 0 self.is_last_batch = False - self.val_loop_called = False # track epoch output self.epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] @@ -99,7 +98,6 @@ def on_advance_end(self, output): self.trainer.validating = True self.trainer._run_evaluation() self.trainer.training = True - self.val_loop_called = True # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) @@ -186,7 +184,7 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: hook_name = "on_train_epoch_end" # set hook_name to model + reset Result obj - skip = self.trainer._reset_result_and_set_hook_fx_name(hook_name) + skip = self.trainer._reset_result_and_set_fx_name(hook_name) # always profile hooks with self.trainer.profiler.profile(hook_name): @@ -339,19 +337,34 @@ def increment_accumulated_grad_global_step(self): self.total_batch_idx, self.trainer.global_step ) - def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False): - # decide if we should run validation - is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 - is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 - can_check_val = self.trainer.enable_validation and is_val_check_epoch - is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") - epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 + def should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool: + """ Decide if we should run validation. """ + + if not self.trainer.enable_validation: + return False + + # check if this epoch is eligible to run validation + if (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0: + return False - should_check_val = ((is_val_check_batch and epoch_end_val_check) or self.trainer.should_stop - or is_last_batch_for_infinite_dataset - ) if on_epoch else (is_val_check_batch and not epoch_end_val_check) + # val_check_batch is inf for iterable datasets with no length defined + # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch + is_val_check_batch = False + if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'): + is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 + elif self.trainer.val_check_batch != float('inf'): + is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 + + # Note: num_training_batches is also inf for iterable datasets with no length defined + epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 + is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") - return should_check_val and can_check_val + if on_epoch: + return ( + is_val_check_batch and epoch_end_val_check + ) or self.trainer.should_stop or is_last_batch_for_infinite_dataset + else: + return is_val_check_batch and not epoch_end_val_check def save_loggers_on_train_batch_end(self): # when loggers should save to disk From 1132288e04de4db8fdf01aa7587a357a3047ca3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 20 May 2021 09:22:00 +0200 Subject: [PATCH 069/455] fox NoneType >= error --- pytorch_lightning/loops/epoch_loop.py | 4 ++-- pytorch_lightning/loops/training_loop.py | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index d8f52893b0a3f..e74b6332b284c 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -79,7 +79,7 @@ def done(self) -> bool: should_stop = False if self.trainer.should_stop: # early stopping - met_min_epochs = (self.iteration_count >= self.min_epochs - 1) if self.min_epochs else True + met_min_epochs = (self.current_epoch >= self.min_epochs - 1) if self.min_epochs else True met_min_steps = self.global_step >= self.min_steps if self.min_steps else True if met_min_epochs and met_min_steps: self.training_loop.on_train_end() @@ -92,7 +92,7 @@ def done(self) -> bool: ) self.trainer.should_stop = False - stop_epochs = self.iteration_count >= self.max_epochs + stop_epochs = self.current_epoch >= self.max_epochs if self.max_epochs is not None else False return stop_steps or should_stop or stop_epochs def on_run_start(self): diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 63abda7f5d7b2..6df3430a92185 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -90,6 +90,10 @@ def advance(self): return None def on_advance_end(self, output): + + # TODO: where is the right place update this !!!!????? + self.total_batch_idx += 1 + # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- @@ -117,7 +121,7 @@ def on_advance_end(self, output): def done(self): # max steps reached, end training if ( - self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step + 1 + self.max_steps is not None and self.max_steps <= self.global_step + 1 and self.batch_loop._accumulated_batches_reached() ): return True @@ -128,7 +132,8 @@ def done(self): if self.trainer.should_stop: return True - self.total_batch_idx += 1 + # TODO: moved to on_advance_end, check if correct? + # self.total_batch_idx += 1 # stop epoch if we limited the number of training batches if self._num_training_batches_reached(self.is_last_batch): From 8bbb4c51ec67c83c65a8cd73a31ac720cda7007b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 20 May 2021 10:16:07 +0200 Subject: [PATCH 070/455] add missing warning cache --- pytorch_lightning/loops/training_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 6df3430a92185..b2e5a3d5c87cd 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -33,6 +33,7 @@ def __init__(self, min_steps, max_steps): self._train_dataloader = None self._dataloader_idx = None self.is_last_batch = None + self.warning_cache = WarningCache() def connect(self, trainer: 'pl.Trainer', *args, **kwargs): self.trainer = trainer From 789603afd6d42eabe7d3668048fabda809b60412 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 20 May 2021 10:16:36 +0200 Subject: [PATCH 071/455] add missing warning cache --- pytorch_lightning/loops/training_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index b2e5a3d5c87cd..1662d719d1190 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -8,6 +8,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature +from pytorch_lightning.utilities.warnings import WarningCache class TrainingLoop(Loop): From cbdcf82c4f2af682ad4749e57a533d5481c17bb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 20 May 2021 10:17:06 +0200 Subject: [PATCH 072/455] change epoch end and trainnig end conditions --- pytorch_lightning/loops/epoch_loop.py | 3 ++- pytorch_lightning/loops/training_loop.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index e74b6332b284c..ff2be83ee20a0 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -82,7 +82,8 @@ def done(self) -> bool: met_min_epochs = (self.current_epoch >= self.min_epochs - 1) if self.min_epochs else True met_min_steps = self.global_step >= self.min_steps if self.min_steps else True if met_min_epochs and met_min_steps: - self.training_loop.on_train_end() + # TODO: THIS is now in on_run_end, always run? + # self.training_loop.on_train_end() should_stop = True else: log.info( diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 1662d719d1190..d727ed1973e3c 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -225,7 +225,7 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: self.trainer._cache_logged_metrics() def _num_training_batches_reached(self, is_last_batch=False): - return self.batch_idx == self.trainer.num_training_batches or is_last_batch + return self.batch_idx == self.trainer.num_training_batches - 1 or is_last_batch # TODO move to on_advance_end() def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx): From 7164f10b43f5e51d72834d9912371b04806c12ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 20 May 2021 12:37:15 +0200 Subject: [PATCH 073/455] update handling of outputs in loops --- pytorch_lightning/loops/base.py | 24 +++++-------- pytorch_lightning/loops/batch_loop.py | 8 ----- pytorch_lightning/loops/epoch_loop.py | 44 +++--------------------- pytorch_lightning/loops/training_loop.py | 15 ++------ 4 files changed, 17 insertions(+), 74 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 106e15f12d3cb..6505796218335 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -25,35 +25,29 @@ def done(self): def advance(self, *args: Any, **kwargs: Any): """What to do within a single step""" - def on_run_start(self, *args: Any, **kwargs: Any): + def on_run_start(self, *args: Any, **kwargs: Any) -> None: pass - def on_run_end(self, outputs: List) -> List: - return outputs + def on_run_end(self) -> Any: + pass - def on_advance_start(self, *args: Any, **kwargs: Any): + def on_advance_start(self, *args: Any, **kwargs: Any) -> None: pass - def on_advance_end(self, curr_output: Any) -> Any: - return curr_output + def on_advance_end(self) -> None: + pass def run(self, *args: Any, **kwargs: Any): self.on_run_start(*args, **kwargs) - outputs = [] - while not self.done: self.on_advance_start(*args, **kwargs) - curr_output = self.advance(*args, **kwargs) - curr_output = self.on_advance_end(curr_output) - - outputs.append(curr_output) - + self.advance(*args, **kwargs) + self.on_advance_end() self.iteration_count += 1 - outputs = self.on_run_end(outputs) - return outputs + return self.on_run_end() def state_dict(self): return dict() diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index f6c1cfd8fc547..e33cb95092f83 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -70,9 +70,6 @@ def advance(self, batch, batch_idx, dataloader_idx): if result: self.batch_outputs[0].append(result.training_step_output_for_epoch_end) - # TODO: return and accumulate batch outputs - return None - def run(self, batch, batch_idx, dataloader_idx): if batch is None: return AttributeDict(signal=0, grad_norm_dic={}) @@ -89,8 +86,6 @@ def run(self, batch, batch_idx, dataloader_idx): super().run(batch, batch_idx, dataloader_idx) - # batch_outputs = batch_outputs[0] # TODO: hack for poc - output = AttributeDict( signal=0, # todo: Properly aggregate grad_norm accros opt_idx and split_idx @@ -100,9 +95,6 @@ def run(self, batch, batch_idx, dataloader_idx): ) return output - def on_run_end(self, outputs: List) -> List: - return outputs - # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index ff2be83ee20a0..b8c9d75c15c90 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -100,7 +100,7 @@ def on_run_start(self): # hook self.trainer.call_hook("on_train_start") - def on_run_end(self, outputs): + def on_run_end(self): if self._teardown_already_run: return self._teardown_already_run = True @@ -130,8 +130,6 @@ def on_run_end(self, outputs): # reset bookkeeping self.trainer._running_stage = None - return outputs - def on_advance_start(self): # equal to on train epoch start # implemented here since this code has to be run always no matter the actual epoch implementation epoch = self.iteration_count + 1 @@ -161,13 +159,10 @@ def on_advance_start(self): # equal to on train epoch start self.trainer.call_hook("on_train_epoch_start") # why is this not the same as the old on_train_epoch_end? - def on_advance_end(self, outputs): + def on_advance_end(self): # # handle epoch_output on epoch end # self.on_train_epoch_end(outputs) # Handled in on_run_end of training_loop now - # log epoch metrics - self.trainer.logger_connector.log_train_epoch_end_metrics(outputs) - should_check_val = self.training_loop.should_check_val_fx(self.batch_idx, self.training_loop.is_last_batch, on_epoch=True) should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) should_train_only = self.trainer.disable_validation or should_skip_eval @@ -193,9 +188,9 @@ def advance(self): with self.trainer.profiler.profile("run_training_epoch"): # run train epoch - output = self.training_loop.run() - - return output + epoch_output = self.training_loop.run() + # log epoch metrics + self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output) def check_checkpoint_callback(self, should_update, is_last=False): # TODO bake this logic into the ModelCheckpoint callback @@ -209,32 +204,3 @@ def check_checkpoint_callback(self, should_update, is_last=False): for cb in callbacks: cb.on_validation_end(self.trainer, model) - - # def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool: - # """ Decide if we should run validation. """ - # - # if not self.trainer.enable_validation: - # return False - # - # # check if this epoch is eligible to run validation - # if (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0: - # return False - # - # # val_check_batch is inf for iterable datasets with no length defined - # # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch - # is_val_check_batch = False - # if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'): - # is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 - # elif self.trainer.val_check_batch != float('inf'): - # is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 - # - # # Note: num_training_batches is also inf for iterable datasets with no length defined - # epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 - # is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") - # - # if on_epoch: - # return ( - # is_val_check_batch and epoch_end_val_check - # ) or self.trainer.should_stop or is_last_batch_for_infinite_dataset - # else: - # return is_val_check_batch and not epoch_end_val_check diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index d727ed1973e3c..2375a3a588b58 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -88,11 +88,7 @@ def advance(self): # ----------------------------------------- self.trainer.logger_connector.log_train_step_metrics(batch_output) - # TODO - return None - - def on_advance_end(self, output): - + def on_advance_end(self): # TODO: where is the right place update this !!!!????? self.total_batch_idx += 1 @@ -117,7 +113,6 @@ def on_advance_end(self, output): # progress global step according to grads progress self.increment_accumulated_grad_global_step() - return None @property def done(self): @@ -142,11 +137,7 @@ def done(self): return True # this is the old on train_epoch_end? - def on_run_end(self, outputs): - - # hack for poc - # outputs = outputs[0] - + def on_run_end(self): # inform logger the batch loop has finished self.trainer.logger_connector.on_train_epoch_end() @@ -227,7 +218,7 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: def _num_training_batches_reached(self, is_last_batch=False): return self.batch_idx == self.trainer.num_training_batches - 1 or is_last_batch - # TODO move to on_advance_end() + # TODO move to on_advance_end() ?? def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx): # epoch output : [[] ... ] From 5677e95834372e8b0a7af2a0f7f17d53604458bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 20 May 2021 12:41:39 +0200 Subject: [PATCH 074/455] cache draft notes --- pytorch_lightning/loops/cache.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 pytorch_lightning/loops/cache.py diff --git a/pytorch_lightning/loops/cache.py b/pytorch_lightning/loops/cache.py new file mode 100644 index 0000000000000..440d2b27d73aa --- /dev/null +++ b/pytorch_lightning/loops/cache.py @@ -0,0 +1,26 @@ +from typing import Tuple + + +class Cache: + + def __init__(self): + self._store = ... + + def add(self, obj: object, **tags): + pass + + def merge(self, cache: "Cache"): + pass + + def filter_by(self, tags: Tuple[str]): + pass + + + +self.cache = Cache() +self.cache.add("abc", result, batch_idx=, opt_idx=..) +self.cache.add("abc", result, batch_idx=) + +self.cache.group_by("abc", ("batch_idx", "opt_idx")) + + From 18b89816af9afa31e5474fbae3d1df3778dcfd52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 20 May 2021 12:52:45 +0200 Subject: [PATCH 075/455] loop stop condition based on batches seen instead of current batch_idx --- pytorch_lightning/loops/training_loop.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 2375a3a588b58..ab322d3777f53 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -34,6 +34,7 @@ def __init__(self, min_steps, max_steps): self._train_dataloader = None self._dataloader_idx = None self.is_last_batch = None + self.batches_seen = 0 self.warning_cache = WarningCache() def connect(self, trainer: 'pl.Trainer', *args, **kwargs): @@ -50,6 +51,7 @@ def on_run_start(self): self._train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) self._dataloader_idx = 0 self.batch_idx = 0 + self.batches_seen = 0 self.is_last_batch = False # track epoch output @@ -67,6 +69,7 @@ def advance(self): with self.trainer.profiler.profile("run_training_batch"): # batch_output = self.run_training_batch(batch, batch_idx, self._dataloader_idx) batch_output = self.batch_loop.run(batch, batch_idx, self._dataloader_idx) + self.batches_seen += 1 # when returning -1 from train_step, we end epoch early if batch_output.signal == -1: @@ -216,7 +219,7 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: self.trainer._cache_logged_metrics() def _num_training_batches_reached(self, is_last_batch=False): - return self.batch_idx == self.trainer.num_training_batches - 1 or is_last_batch + return self.batches_seen == self.trainer.num_training_batches or is_last_batch # TODO move to on_advance_end() ?? def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx): From 41cb7640d03182af535819d6ff0fa7ed4d67d147 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 20 May 2021 13:03:22 +0200 Subject: [PATCH 076/455] fix attribute errors in tests --- pytorch_lightning/loops/epoch_loop.py | 4 ++++ pytorch_lightning/trainer/trainer.py | 7 ++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index b8c9d75c15c90..58a35be3506fd 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -40,6 +40,10 @@ def __init__(self, min_epochs, max_epochs, min_steps, max_steps): def global_step(self): return self.training_loop.global_step + @global_step.setter + def global_step(self, value): + self.training_loop.global_step = value + @property def total_batch_idx(self): return self.training_loop.total_batch_idx diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f52b3cb7ed251..bb8272d336e37 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -929,16 +929,17 @@ def _run_train_new_loop(self) -> None: rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...') # user could press Ctrl+c many times... only shutdown once if not self.interrupted: - self.state = TrainerState.INTERRUPTED + self.state.status = TrainerStatus.INTERRUPTED self.on_keyboard_interrupt() # same treatment as below self.accelerator.on_train_end() - self._running_stage = None + self.state.stage = None except BaseException: + self.state.status = TrainerStatus.INTERRUPTED # give accelerators a chance to finish self.accelerator.on_train_end() # reset bookkeeping - self._running_stage = None + self.state.stage = None raise def _run_train_old_loop(self) -> None: From cb5f1a0a0253746bdbc6920930438daa592269a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 20 May 2021 14:17:48 +0200 Subject: [PATCH 077/455] move new loop creation earler (fast_dev_run setup) --- pytorch_lightning/loops/epoch_loop.py | 5 +++++ pytorch_lightning/trainer/trainer.py | 18 +++++++++--------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 58a35be3506fd..95c7f1d2c7b88 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -64,6 +64,11 @@ def min_steps(self): def max_steps(self): return self.training_loop.max_steps + @max_steps.setter + def max_steps(self, value): + # TODO: This setter is required by debugging connector (fast dev run) + self.training_loop.max_steps = value + def connect(self, trainer: 'pl.Trainer', *args, **kwargs): self.trainer = trainer self.training_loop.connect(trainer) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index bb8272d336e37..3e2f94ec18235 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -76,6 +76,8 @@ 'please use torch.distributed.ReduceOp instead' ) +NEW_LOOP = True + class Trainer( TrainerProperties, @@ -332,11 +334,13 @@ def __init__( self.slurm_connector = SLURMConnector(self) self.tuner = Tuner(self) - self.new_epoch_loop = EpochLoop(min_epochs, max_epochs, min_steps, max_steps) - self.new_epoch_loop.connect(self) + if NEW_LOOP: + self.train_loop = EpochLoop(min_epochs, max_epochs, min_steps, max_steps) + self.train_loop.connect(self) + else: + # old loops: + self.train_loop = TrainLoop(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps) - # old loops: - self.train_loop = TrainLoop(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps) self.evaluation_loop = EvaluationLoop(self) self.predict_loop = PredictLoop(self) @@ -888,11 +892,7 @@ def reset_train_val_dataloaders(self, model) -> None: self.reset_val_dataloader(model) def _run_train(self) -> None: - - new_loop = True - - if new_loop: - self.train_loop = self.new_epoch_loop + if NEW_LOOP: self._run_train_new_loop() else: self._run_train_old_loop() From efbf6f196a9484d8ad3679b8887a2696d1fd5efd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 20 May 2021 16:50:14 +0200 Subject: [PATCH 078/455] change stopping condition for train loop --- pytorch_lightning/loops/training_loop.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index ab322d3777f53..fd23147b558c2 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -30,16 +30,18 @@ def __init__(self, min_steps, max_steps): # the current split index when the batch gets split into chunks in truncated backprop through time self.split_idx = None - self.batch_loop = None self._train_dataloader = None self._dataloader_idx = None + self._should_stop = False + self.is_last_batch = None self.batches_seen = 0 self.warning_cache = WarningCache() + self.batch_loop = None + def connect(self, trainer: 'pl.Trainer', *args, **kwargs): self.trainer = trainer - # self.epoch_output = [[] for _ in range(len(trainer.optimizers))] self.batch_loop = BatchLoop() self.batch_loop.connect(trainer) @@ -50,6 +52,7 @@ def on_run_start(self): # reset self._train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) self._dataloader_idx = 0 + self._should_stop = False self.batch_idx = 0 self.batches_seen = 0 self.is_last_batch = False @@ -114,11 +117,16 @@ def on_advance_end(self): self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) self.trainer.checkpoint_connector.has_trained = True + self._should_stop = self.stopping_condition() + # progress global step according to grads progress self.increment_accumulated_grad_global_step() @property def done(self): + return self._should_stop + + def stopping_condition(self): # max steps reached, end training if ( self.max_steps is not None and self.max_steps <= self.global_step + 1 From 57b4a32cbe7953192411785788b5570ba44f3dae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 20 May 2021 17:31:52 +0200 Subject: [PATCH 079/455] allow stopping training loop mid epoch --- pytorch_lightning/loops/training_loop.py | 25 +++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index fd23147b558c2..fbf3be51879cc 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -76,11 +76,9 @@ def advance(self): # when returning -1 from train_step, we end epoch early if batch_output.signal == -1: - self._skip_remaining_steps = True - return + raise StopIteration # hook - # epoch_output = [[]] # TODO: track and return output, let loop base concatenate all outputs into a list etc. self.on_train_batch_end( self.epoch_output, batch_output.training_step_output_for_epoch_end, @@ -117,16 +115,14 @@ def on_advance_end(self): self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) self.trainer.checkpoint_connector.has_trained = True - self._should_stop = self.stopping_condition() + if self.done: + raise StopIteration # progress global step according to grads progress self.increment_accumulated_grad_global_step() @property def done(self): - return self._should_stop - - def stopping_condition(self): # max steps reached, end training if ( self.max_steps is not None and self.max_steps <= self.global_step + 1 @@ -180,6 +176,21 @@ def on_run_end(self): self.trainer.call_hook('on_epoch_end') return self.epoch_output + def run(self, *args, **kwargs): + self.on_run_start() + + while True: + try: + self.on_advance_start() + self.advance() + self.on_advance_end() + except StopIteration: + break + + self.iteration_count += 1 + + return self.on_run_end() + # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ From b090e4f232a71ed7784633a99681725aea5215c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 20 May 2021 17:53:24 +0200 Subject: [PATCH 080/455] standardize order of loop hooks --- pytorch_lightning/loops/base.py | 38 +++++----- pytorch_lightning/loops/batch_loop.py | 50 +++++++------- pytorch_lightning/loops/epoch_loop.py | 88 ++++++++++++------------ pytorch_lightning/loops/training_loop.py | 74 ++++++++++---------- 4 files changed, 125 insertions(+), 125 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 6505796218335..5346109a767d9 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -1,11 +1,11 @@ from _weakref import proxy -from abc import ABCMeta, abstractmethod -from typing import Any, Counter, List, Optional +from abc import abstractmethod, ABC +from typing import Any, Optional import pytorch_lightning as pl -class Loop(metaclass=ABCMeta): +class Loop(ABC): def __init__(self): self.iteration_count: int = 0 @@ -21,22 +21,6 @@ def connect(self, trainer, *args, **kwargs): def done(self): """Property indicating when loop is finished""" - @abstractmethod - def advance(self, *args: Any, **kwargs: Any): - """What to do within a single step""" - - def on_run_start(self, *args: Any, **kwargs: Any) -> None: - pass - - def on_run_end(self) -> Any: - pass - - def on_advance_start(self, *args: Any, **kwargs: Any) -> None: - pass - - def on_advance_end(self) -> None: - pass - def run(self, *args: Any, **kwargs: Any): self.on_run_start(*args, **kwargs) @@ -49,5 +33,21 @@ def run(self, *args: Any, **kwargs: Any): return self.on_run_end() + def on_run_start(self, *args: Any, **kwargs: Any) -> None: + pass + + def on_advance_start(self, *args: Any, **kwargs: Any) -> None: + pass + + @abstractmethod + def advance(self, *args: Any, **kwargs: Any): + """What to do within a single step""" + + def on_advance_end(self) -> None: + pass + + def on_run_end(self) -> Any: + pass + def state_dict(self): return dict() diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index e33cb95092f83..ed35c12d711fa 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -44,6 +44,31 @@ def connect(self, trainer, *args, **kwargs): def done(self): return len(self._remaining_splits) == 0 + def run(self, batch, batch_idx, dataloader_idx): + if batch is None: + return AttributeDict(signal=0, grad_norm_dic={}) + + # hook + response = self.trainer.call_hook("on_batch_start") + if response == -1: + return AttributeDict(signal=-1, grad_norm_dic={}) + + # hook + response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx) + if response == -1: + return AttributeDict(signal=-1, grad_norm_dic={}) + + super().run(batch, batch_idx, dataloader_idx) + + output = AttributeDict( + signal=0, + # todo: Properly aggregate grad_norm accros opt_idx and split_idx + # grad_norm_dict=grad_norm_dict, + grad_norm_dict={}, + training_step_output_for_epoch_end=self.batch_outputs, + ) + return output + def on_run_start(self, batch, batch_idx, dataloader_idx): self._hiddens = None self._remaining_splits = list(enumerate(self.tbptt_split_batch(batch))) @@ -70,31 +95,6 @@ def advance(self, batch, batch_idx, dataloader_idx): if result: self.batch_outputs[0].append(result.training_step_output_for_epoch_end) - def run(self, batch, batch_idx, dataloader_idx): - if batch is None: - return AttributeDict(signal=0, grad_norm_dic={}) - - # hook - response = self.trainer.call_hook("on_batch_start") - if response == -1: - return AttributeDict(signal=-1, grad_norm_dic={}) - - # hook - response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx) - if response == -1: - return AttributeDict(signal=-1, grad_norm_dic={}) - - super().run(batch, batch_idx, dataloader_idx) - - output = AttributeDict( - signal=0, - # todo: Properly aggregate grad_norm accros opt_idx and split_idx - # grad_norm_dict=grad_norm_dict, - grad_norm_dict={}, - training_step_output_for_epoch_end=self.batch_outputs, - ) - return output - # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 95c7f1d2c7b88..9a366c85705d8 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -73,13 +73,6 @@ def connect(self, trainer: 'pl.Trainer', *args, **kwargs): self.trainer = trainer self.training_loop.connect(trainer) - # TODO: is it used anywhere? - def should_accumulate(self): - return self.training_loop.batch_loop.should_accumulate() - - def get_active_optimizers(self, batch_idx): - return self.training_loop.batch_loop.get_active_optimizers(batch_idx) - @property def done(self) -> bool: # TODO: Move track steps inside training loop and move part of these condition inside training loop @@ -109,36 +102,6 @@ def on_run_start(self): # hook self.trainer.call_hook("on_train_start") - def on_run_end(self): - if self._teardown_already_run: - return - self._teardown_already_run = True - - # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - self.training_loop.global_step -= 1 - # TODO: see discussion/rework https://github.com/PyTorchLightning/pytorch-lightning/issues/7406 - self.check_checkpoint_callback(should_update=True, is_last=True) - self.training_loop.global_step += 1 - - # hook - self.trainer.call_hook("on_train_end") - - # todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers. - # It might be related to xla tensors blocked when moving the cpu - # kill loggers - if self.trainer.logger is not None: - self.trainer.logger.finalize("success") - - # summarize profile results - self.trainer.profiler.describe() - - # give accelerators a chance to finish - self.trainer.accelerator.on_train_end() - - # reset bookkeeping - self.trainer._running_stage = None - def on_advance_start(self): # equal to on train epoch start # implemented here since this code has to be run always no matter the actual epoch implementation epoch = self.iteration_count + 1 @@ -167,7 +130,14 @@ def on_advance_start(self): # equal to on train epoch start self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") - # why is this not the same as the old on_train_epoch_end? + def advance(self): + + with self.trainer.profiler.profile("run_training_epoch"): + # run train epoch + epoch_output = self.training_loop.run() + # log epoch metrics + self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output) + def on_advance_end(self): # # handle epoch_output on epoch end # self.on_train_epoch_end(outputs) # Handled in on_run_end of training_loop now @@ -193,13 +163,42 @@ def on_advance_end(self): # TODO: move inside training_loop.on_run_end? equivalent? order? self.training_loop.increment_accumulated_grad_global_step() - def advance(self): + # why is this not the same as the old on_train_epoch_end? + def on_run_end(self): + if self._teardown_already_run: + return + self._teardown_already_run = True - with self.trainer.profiler.profile("run_training_epoch"): - # run train epoch - epoch_output = self.training_loop.run() - # log epoch metrics - self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output) + # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates + # when a checkpoint was saved at the last step + self.training_loop.global_step -= 1 + # TODO: see discussion/rework https://github.com/PyTorchLightning/pytorch-lightning/issues/7406 + self.check_checkpoint_callback(should_update=True, is_last=True) + self.training_loop.global_step += 1 + + # hook + self.trainer.call_hook("on_train_end") + + # todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers. + # It might be related to xla tensors blocked when moving the cpu + # kill loggers + if self.trainer.logger is not None: + self.trainer.logger.finalize("success") + + # summarize profile results + self.trainer.profiler.describe() + + # give accelerators a chance to finish + self.trainer.accelerator.on_train_end() + + # reset bookkeeping + self.trainer._running_stage = None + + def should_accumulate(self): + return self.training_loop.batch_loop.should_accumulate() + + def get_active_optimizers(self, batch_idx): + return self.training_loop.batch_loop.get_active_optimizers(batch_idx) def check_checkpoint_callback(self, should_update, is_last=False): # TODO bake this logic into the ModelCheckpoint callback @@ -213,3 +212,4 @@ def check_checkpoint_callback(self, should_update, is_last=False): for cb in callbacks: cb.on_validation_end(self.trainer, model) + diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index fbf3be51879cc..3674971bb96d0 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -45,6 +45,43 @@ def connect(self, trainer: 'pl.Trainer', *args, **kwargs): self.batch_loop = BatchLoop() self.batch_loop.connect(trainer) + @property + def done(self): + # max steps reached, end training + if ( + self.max_steps is not None and self.max_steps <= self.global_step + 1 + and self.batch_loop._accumulated_batches_reached() + ): + return True + + # end epoch early + # stop when the flag is changed or we've gone past the amount + # requested in the batches + if self.trainer.should_stop: + return True + + # TODO: moved to on_advance_end, check if correct? + # self.total_batch_idx += 1 + + # stop epoch if we limited the number of training batches + if self._num_training_batches_reached(self.is_last_batch): + return True + + def run(self, *args, **kwargs): + self.on_run_start() + + while True: + try: + self.on_advance_start() + self.advance() + self.on_advance_end() + except StopIteration: + break + + self.iteration_count += 1 + + return self.on_run_end() + def on_run_start(self): # modify dataloader if needed (ddp, etc...) train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) @@ -121,28 +158,6 @@ def on_advance_end(self): # progress global step according to grads progress self.increment_accumulated_grad_global_step() - @property - def done(self): - # max steps reached, end training - if ( - self.max_steps is not None and self.max_steps <= self.global_step + 1 - and self.batch_loop._accumulated_batches_reached() - ): - return True - - # end epoch early - # stop when the flag is changed or we've gone past the amount - # requested in the batches - if self.trainer.should_stop: - return True - - # TODO: moved to on_advance_end, check if correct? - # self.total_batch_idx += 1 - - # stop epoch if we limited the number of training batches - if self._num_training_batches_reached(self.is_last_batch): - return True - # this is the old on train_epoch_end? def on_run_end(self): # inform logger the batch loop has finished @@ -176,21 +191,6 @@ def on_run_end(self): self.trainer.call_hook('on_epoch_end') return self.epoch_output - def run(self, *args, **kwargs): - self.on_run_start() - - while True: - try: - self.on_advance_start() - self.advance() - self.on_advance_end() - except StopIteration: - break - - self.iteration_count += 1 - - return self.on_run_end() - # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ From f2eceab856da7467c852d3504a2c6b2597e152df Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 20 May 2021 19:07:18 +0200 Subject: [PATCH 081/455] WIP --- pytorch_lightning/core/step_result.py | 656 ++++++++++---------------- 1 file changed, 248 insertions(+), 408 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 3c54b19f99d4e..396dc943d47df 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Result class for easier logging and epoch-wise reduction.""" - import numbers -from copy import copy -from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Sequence, Tuple, Union +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Tuple, Union import torch from torch import Tensor @@ -24,61 +22,132 @@ from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed -class Result(Dict): +@dataclass +class Metadata: + prog_bar: bool = False + logger: bool = True + on_step: bool = False + on_epoch: bool = True + reduce_fx: Callable = torch.mean + tbptt_reduce_fx: Callable = torch.mean + tbptt_pad_token: int = 0 + dataloader_idx: Optional[int] = None - def __init__(self, minimize: Optional[Tensor] = None): - super().__init__() + @property + def forked(self) -> bool: + return self.on_step and self.on_epoch - if minimize is not None: - err = 'Minimize can only be used in training_step, training_step_end, training_epoch_end' - self._assert_grad_tensor_metric('minimize', minimize, err) - self.minimize = minimize + def names(self, name: str) -> List[str]: + names = [name] + # check both + if self.on_step: + names += name + '_step' + if self.on_epoch: + names += name + '_epoch' + return names - self['meta'] = {'_internal': {'_reduce_on_epoch': False, 'batch_sizes': []}} - def __getitem__(self, key: Union[str, Any]) -> Any: - try: - return super().__getitem__(key) - except KeyError: - return super().__getitem__(f'{key}_step') +@dataclass +class Result: + data: Any # TODO: Union[Tensor, Metric]? + meta: Metadata = field(repr=False) + batch_sizes: List[int] = field(default_factory=list, init=False) - def __getattr__(self, key: str) -> Any: + @staticmethod + def extract_batch_size(batch: Any) -> int: try: - if key == 'batch_log_metrics': - return self.get_batch_log_metrics() - elif key == 'batch_pbar_metrics': - return self.get_batch_pbar_metrics() - elif key == 'epoch_log_metrics': - return self.get_epoch_log_metrics() - elif key == 'epoch_pbar_metrics': - return self.get_epoch_pbar_metrics() - else: - return self[key] - except KeyError: - return None + return Result._extract_batch_size(batch) + except RecursionError: + return 1 - def __setattr__(self, key: str, val: Union[Tensor, Any]): - # ensure tensors are detached - if isinstance(val, torch.Tensor) and key != 'minimize': - val = val.detach() - self[key] = val + @staticmethod + def _extract_batch_size(batch: Any) -> int: + """ + Recursively unpack a batch to find a torch.Tensor. - def __getstate__(self): - return self + Returns: + ``len(tensor)`` when found, or ``1`` when it hits an empty or non iterable. + """ + if isinstance(batch, torch.Tensor): + size = batch.size(0) + elif isinstance(batch, str): + return len(batch) + elif isinstance(batch, dict): + sample = next(iter(batch.values()), 1) + size = Result._extract_batch_size(sample) + elif isinstance(batch, Iterable): + sample = next(iter(batch), 1) + size = Result._extract_batch_size(sample) + else: + size = 1 + return size - def __setstate__(self, d): - self.update(d) - def _assert_grad_tensor_metric(self, name: str, x: Union[torch.Tensor, Any], additional_err: str = ''): - if x is not None: - if not isinstance(x, Tensor): - raise TypeError(f'{name} must be a torch.Tensor') +class ResultCollection(dict): - m = f'{name} must have a computational graph.' + def __init__(self) -> None: + super().__init__() + self.minimize: Optional[Tensor] = None + self.should_reduce_on_epoch_end = False + + #@staticmethod + #def removesuffix(s: str, suffix: str) -> str: + # # available from Python 3.9 + # if suffix and s.endswith(suffix): + # return s[:-len(suffix)] + # return s + + #@staticmethod + #def _parse_key(key: str) -> str: + # key = ResultCollection.removesuffix(key, '_epoch') + # key = ResultCollection.removesuffix(key, '_step') + # return key + + #def __getitem__(self, key: str) -> Result: + # if not isinstance(key, str): + # raise ValueError(f'`Result` keys must be `str`, found: {key}') + # if key in self: + # return super().__getitem__(key) + # # try removing `_epoch` and `_step` suffixes + # key = self._parse_key(key) + # return super().__getitem__(key) + + def get_callback_metrics(self): + return self.items() + + def get_logger_metrics(self): + pass + #ret = {} + #for item in self.items(): + # for name in item.names(): # names knows whether it is forked + # ret[name] = item.data + # checks whether is forked and returns all + # return self.items_prefixes() - if additional_err: - m += f' {additional_err}' - assert x.grad_fn is not None, m + @staticmethod + def _sync( + value, + sync_fn: Optional[Callable] = None, + sync_dist: bool = False, + sync_dist_op: Union[Any, str] = 'mean', + sync_dist_group: Optional[Any] = None, + device: torch.device = None, + ): + """Sync across workers when using distributed training""" + if not isinstance(value, (torch.Tensor, numbers.Number)): + return value + + sync_fn = sync_fn or sync_ddp_if_available + dist_available = torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed() + if not sync_dist or not dist_available: + return value + + # TODO: Find a way to make the reduction only once, so we don't need to clone. + if isinstance(value, torch.Tensor): + value = value.clone() + else: + value = torch.tensor(value, device=device, dtype=torch.float) + return sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op) def log( self, @@ -98,335 +167,139 @@ def log( sync_fn: Callable = None, dataloader_idx: Optional[int] = None, device: torch.device = None, + batch_size: Optional[int] = None, ): + """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" # no metrics should be logged with graphs if not enable_graph and isinstance(value, torch.Tensor): value = value.detach() - # sync across workers when using distributed training - sync_fn = sync_fn or sync_ddp_if_available - - if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)): - is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized() - # TODO: Find a way to make the reduction only once, so we don't need to clone. - if (is_dist_initialized or tpu_distributed()) and isinstance(value, torch.Tensor): - value = value.clone() - else: - value = torch.tensor(value, device=device, dtype=torch.float) - value = sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op) + # TODO: should this be in the caller? + value = self._sync( + value, + sync_fn=sync_fn, + sync_dist=sync_dist, + sync_dist_op=sync_dist_op, + sync_dist_group=sync_dist_group, + device=device, + ) if isinstance(value, torch.Tensor) and value.device.type == "xla": value = value.cpu() - if 'meta' not in self: - self.__setitem__('meta', {}) - - # if user requests both step and epoch, then we split the metric in two automatically - # one will be logged per step. the other per epoch - was_forked = False - if on_step and on_epoch: - was_forked = True - - # set step version - step_name = f'{name}_step' - - self.__set_meta( - step_name, - value, - prog_bar, - logger, - on_step=True, - on_epoch=False, - reduce_fx=reduce_fx, - tbptt_reduce_fx=tbptt_reduce_fx, - tbptt_pad_token=tbptt_pad_token, - forked=False, - dataloader_idx=dataloader_idx, - ) - - self.__setitem__(step_name, value) - - # set epoch version - epoch_name = f'{name}_epoch' - - self.__set_meta( - epoch_name, - value, - prog_bar, - logger, - on_step=False, - on_epoch=True, + result = Result( + value, + Metadata( + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token, - forked=False, dataloader_idx=dataloader_idx, - ) - self.__setitem__(epoch_name, value) - - # always log the original metric - self.__set_meta( - name, - value, - prog_bar, - logger, - on_step, - on_epoch, - reduce_fx, - tbptt_reduce_fx=tbptt_reduce_fx, - tbptt_pad_token=tbptt_pad_token, - forked=was_forked, - dataloader_idx=dataloader_idx, - ) - - # set the value - self.__setitem__(name, value) - - def __set_meta( - self, - name: str, - value: Any, - prog_bar: bool, - logger: bool, - on_step: bool, - on_epoch: bool, - reduce_fx: Callable, - tbptt_pad_token: int, - tbptt_reduce_fx: Callable, - forked: bool, - dataloader_idx: Union[int, None], - ): - # set the meta for the item - meta_value = value - meta = dict( - prog_bar=prog_bar, - logger=logger, - on_step=on_step, - on_epoch=on_epoch, - reduce_fx=reduce_fx, - value=meta_value, - tbptt_reduce_fx=tbptt_reduce_fx, - tbptt_pad_token=tbptt_pad_token, - forked=forked, - dataloader_idx=dataloader_idx, + ), ) - - self['meta'][name] = meta - - # track whether any input requires reduction on epoch end - _internal = self['meta']['_internal'] - _internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch) - - def track_batch_size(self, batch): - batch_size = Result.extract_batch_size(batch) - Result.attach_batch_size(batch_size, self) + if batch_size is None: + batch_size = Result.extract_batch_size(value) + result.batch_sizes.append(batch_size) + self[name] = result + self.should_reduce_on_epoch_end |= on_epoch @staticmethod - def extract_batch_size(batch): - try: - batch_size = Result.unpack_batch_size(batch) - except RecursionError: - batch_size = 1 - return batch_size + def _add_dl_idx(key: str, dl_idx: Union[int, None]) -> str: + if dl_idx is not None: + return f"{key}/dataloader_idx_{dl_idx}" + return key @staticmethod - def attach_batch_size(batch_size: Union[int, None], result: 'Result') -> None: - if batch_size is not None: - meta = result['meta'] - meta['_internal']['batch_sizes'].append(batch_size) - - def get_batch_sizes(self): - meta = self['meta'] - return torch.tensor(meta['_internal']['batch_sizes']) - - def _add_dataloader_idx(self, k: str, dataloader_idx: Union[int, None], add_dataloader_idx: bool) -> str: - if dataloader_idx is not None and add_dataloader_idx: - return f"{k}/dataloader_idx_{dataloader_idx}" - return k - - def get_batch_log_metrics(self, include_forked_originals=True, add_dataloader_idx=False) -> dict: - """ - Gets the metrics to log at the end of the batch step - - """ + def _filter(self: 'Result', fields: List[str], add_dataloader_idx: bool = False) -> Dict[str, '_METRIC']: # TODO result = {} - - meta = self['meta'] - for k, options in meta.items(): - if k == '_internal': - continue - - if options['forked'] and not include_forked_originals: - continue - - dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) - - if options['logger'] and options['on_step']: - if isinstance(self[k], Metric) and self[k]._forward_cache is not None: - result[dl_key] = self[k]._forward_cache.detach() - else: - result[dl_key] = self[k] - + for k, item in self.items(): + # check if we need to add the suffix + if 'on_step' in fields and 'on_epoch' not in fields: + k += '_step' + elif 'on_step' not in fields and 'on_epoch' in fields: + k += '_epoch' + + if all(getattr(item.meta, f, False) for f in fields): + k = Result._add_dl_idx(k, item.meta.dataloader_idx) + result[k] = item.data return result - def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict: - """ - Gets the metrics to log at the end of epoch - """ - result = {} - meta = self['meta'] - for k, options in meta.items(): - if k == '_internal': - continue - - if options['forked']: - continue - - dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) - - if options['logger'] and options['on_epoch']: - if isinstance(self[k], Metric): - result[dl_key] = self[k].compute().detach() - else: - result[dl_key] = self[k] - - if k in self and not options['on_epoch'] and isinstance(self[k], Metric): - # compute for reuse later - self[k].compute() - - return result - - def get_epoch_pbar_metrics(self, add_dataloader_idx=False): - """ - Gets the metrics to log at the end of epoch - """ - result = {} - - meta = self['meta'] - for k, options in meta.items(): - if k == '_internal': - continue - - if options['forked']: - continue - - dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) - - if options['prog_bar'] and options['on_epoch']: - if isinstance(self[k], Metric): - result[dl_key] = self[k].compute().detach() - else: - result[dl_key] = self[k] - - if k in self and not options['on_epoch'] and isinstance(self[k], Metric): - # compute for reuse later - self[k].compute() - - return result - - def get_forked_metrics(self, add_dataloader_idx=False): - """ - Gets the metrics to log at the end of epoch - """ - result = {} - - meta = self['meta'] - for k, options in meta.items(): - if k == '_internal': - continue - - dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) - - if options['forked']: - if isinstance(self[k], Metric): - result[dl_key] = self[k].compute().detach() - else: - result[dl_key] = self[k] - - return result - - def get_batch_pbar_metrics(self, include_forked_originals=True, add_dataloader_idx=False): - """ - Gets the metrics to log at the end of the batch step - """ - result = {} - - meta = self['meta'] - for k, options in meta.items(): - if k == '_internal': - continue - - if options['forked'] and not include_forked_originals: - continue - - dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) - - if options['prog_bar'] and options['on_step']: - if isinstance(self[k], Metric) and self[k]._forward_cache is not None: - result[dl_key] = self[k]._forward_cache - else: - result[dl_key] = self[k] - - return result - - def detach(self) -> 'Result': - for k, v in self.items(): - if isinstance(v, torch.Tensor): - self.__setitem__(k, v.detach()) + def get_batch_log_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, '_METRIC']: + """Gets the metrics to log at the end of the batch""" + # TODO: remove dl idx + results = self._filter(self, ['logger', 'on_step'], add_dataloader_idx=add_dataloader_idx) + for k, v in results: + if isinstance(v, Metric) and v._forward_cache is not None: + results[k] = v._foward_cache.detach() + return results + + def get_epoch_log_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, '_METRIC']: + """Gets the metrics to log at the end of epoch""" + results = self._filter(self, ['logger', 'on_epoch'], add_dataloader_idx=add_dataloader_idx) + for k, v in results: + if isinstance(v, Metric) and v._forward_cache is not None: + results[k] = v._foward_cache.compute().detach() + # TODO: this? + # if k in self and not options['on_epoch'] and isinstance(self[k], Metric): + # # compute for reuse later + # self[k].compute() + return results + + def get_batch_pbar_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, '_METRIC']: + """Gets the metrics to include in the progress_bar at the end of epoch""" + results = self._filter(self, ['prog_bar', 'on_step'], add_dataloader_idx=add_dataloader_idx) + for k, v in results: + if isinstance(v, Metric) and v._forward_cache is not None: + results[k] = v._foward_cache.detach() + return results + + def get_epoch_pbar_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, '_METRIC']: + """Gets the metrics to include in the progress_bar at the end of epoch""" + results = self._filter(self, ['prog_bar', 'on_epoch'], add_dataloader_idx=add_dataloader_idx) + for k, v in results: + if isinstance(v, Metric) and v._forward_cache is not None: + results[k] = v._foward_cache.compute().detach() + # TODO: this? + # if k in self and not options['on_epoch'] and isinstance(self[k], Metric): + # # compute for reuse later + # self[k].compute() + return results + + def get_forked_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, '_METRIC']: + results = self._filter(self, [], add_dataloader_idx=add_dataloader_idx) + for k, v in results: + if isinstance(v, Metric) and v._forward_cache is not None: + results[k] = v._foward_cache.compute().detach() + return results + + def detach(self) -> 'ResultCollection': + for k, item in self.items(): + if isinstance(item.data, torch.Tensor): + item.data = item.data.detach() return self - def to(self, *args, **kwargs) -> 'Result': - """Move all self attributes to the given device.""" - for k, v in self.items(): - if isinstance(v, torch.Tensor): - self.__setitem__(k, v.to(*args, **kwargs)) + def to(self, *args, **kwargs) -> 'ResultCollection': + """Move all data to the given device.""" + for k, item in self.items(): + if isinstance(item.data, torch.Tensor): + item.data = item.data.to(*args, **kwargs) return self def cpu(self) -> 'Result': - """Move all self attributes to CPU.""" - return self.to(torch.device("cpu")) - - def __repr__(self): - self_copy = self.copy() - - if 'meta' in self_copy: - del self_copy['meta'] - - return str(self_copy) - - def __str__(self): - copy = self.copy() - del copy['meta'] - - return str(copy) - - def __copy__(self): - newone = type(self)() - for k, v in self.items(): - if isinstance(v, torch.Tensor): - v = v.detach() - newone[k] = copy(v) - return newone - - @staticmethod - def unpack_batch_size(sample): - """ - Recursively unpack sample to find a torch.Tensor. - returns len(tensor) when found, or 1 when it hits an empty or non iterable. - """ - if isinstance(sample, torch.Tensor): - size = sample.size(0) - elif isinstance(sample, str): - return len(sample) - elif isinstance(sample, dict): - sample = next(iter(sample.values()), 1) - size = Result.unpack_batch_size(sample) - elif isinstance(sample, Iterable): - sample = next(iter(sample), 1) - size = Result.unpack_batch_size(sample) - else: - size = 1 - return size + """Move all data to CPU.""" + return self.to(device="cpu") + + # TODO: need this with detach? + #def __copy__(self): + # newone = type(self)() + # for k, v in self.items(): + # if isinstance(v, torch.Tensor): + # v = v.detach() + # newone[k] = copy(v) + # return newone @classmethod def gather(cls, outputs): @@ -448,16 +321,13 @@ def padded_gather(cls, outputs): # find the padding used for other values default_padding_idx = 0 for name, value in result.items(): - if ( - name != 'minimize' and isinstance(value, list) and len(value) > 0 - and isinstance(value[0], torch.Tensor) - ): + if name != 'minimize' and isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor): default_padding_idx = meta[name]['tbptt_pad_token'] break # pad across each key individually for name, value in result.items(): - if (isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor)): + if isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor): padding_key = default_padding_idx if name == 'minimize' else meta[name]['tbptt_pad_token'] padded = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=padding_key) result[name] = padded @@ -514,9 +384,9 @@ def reduce_on_epoch_end(cls, outputs): return result @classmethod - def reduce_across_time(cls, time_outputs): + def reduce_across_time(cls, time_outputs: List['Result']) -> 'Result': # auto-reduce across time for tbptt - meta = time_outputs[0]['meta'] + meta = time_outputs[0].meta result = cls() result = recursive_gather(time_outputs, result) @@ -551,34 +421,15 @@ def dp_reduce(self): self[k] = value.mean(dim=-1) - @property - def should_reduce_on_epoch_end(self) -> bool: - return self['meta']['_internal']['_reduce_on_epoch'] - - def rename_keys(self, map_dict: dict): - """ - Maps key values to the target values. Useful when renaming variables in mass. - - Args: - map_dict: - """ - meta = self.meta - for source, dest in map_dict.items(): - # map the main keys - self[dest] = self[source] - del self[source] - - # map meta - meta[dest] = meta[source] - del meta[source] + def get_non_metrics_keys(self) -> List[str]: + """This function is used to filter metric keys for which the value isn't a Metric""" + return [k for k, v in self.items() if not isinstance(v.data, Metric)] def reset(self) -> None: - """ - Call at the end of epoch to reset all metric objects - """ - for k, value in self.items(): - if isinstance(value, Metric): - value.reset() + """Call at the end of epoch to reset all metric objects""" + for item in self.values(): + if isinstance(item.data, Metric): + item.data.reset() def choose_last(x): @@ -589,37 +440,26 @@ def choose_last(x): x[k] = x[k][-1] -def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]: +def recursive_gather(outputs: List[Result], result: Result) -> Result: for out in outputs: - if 'meta' in out: - del out['meta'] - - for k, v in out.items(): - # support manual opt where the user does not return a minimize key - if k == 'minimize' and v is None: - continue - - if isinstance(v, dict): - in_d = result.get(k, {}) - v = recursive_gather([v], in_d) - result[k] = v + for k, item in out.items(): + if isinstance(item.data, dict): + in_d = result.get(k, Result()) + result[k] = recursive_gather([item], in_d) + elif isinstance(item.data, Metric): + # if v is a metric, just keep one of them, + # don't keep on adding a list of them + result[k] = item else: - if isinstance(v, Metric): - # if v is a metric, just keep one of them, - # don't keep on adding a list of them - result[k] = v - else: - if k not in result: - result[k] = [] - result[k].append(v) - + result.setdefault(k, []) + result[k].append(item) return result def recursive_stack(result: MutableMapping): - for k, v in result.items(): - if isinstance(v, dict): - recursive_stack(v) + for k, item in result.items(): + if isinstance(item.data, dict): + recursive_stack(item.data) result[k] = collate_tensors(v) From 347e034255bf5dd49e4e4b7c5ec2722741731a93 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 May 2021 14:20:50 +0000 Subject: [PATCH 082/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/core/step_result.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index c40708f3951e5..600f472e4563e 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -46,6 +46,7 @@ def names(self, name: str) -> List[str]: names += name + '_epoch' return names + @dataclass class Result: data: Any # TODO: Union[Tensor, Metric]? From ca08d66b3b49a0d627e40a9bab1212dbad261f68 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 21 May 2021 16:31:29 +0100 Subject: [PATCH 083/455] update --- pytorch_lightning/core/step_result.py | 544 ++++++------------ .../connectors/test_logger_connectors.py | 132 +++++ 2 files changed, 296 insertions(+), 380 deletions(-) create mode 100644 tests/trainer/connectors/test_logger_connectors.py diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 396dc943d47df..55ae28bc3dcf3 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -12,51 +12,111 @@ # See the License for the specific language governing permissions and # limitations under the License. import numbers +from copy import deepcopy from dataclasses import dataclass, field from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Tuple, Union import torch -from torch import Tensor +from torch import Tensor, tensor from torchmetrics import Metric from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed +from pytorch_lightning.utilities.enums import LightningEnum +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class DefaultMetricsKeys(LightningEnum): + + CALLBACK_METRICS = "callback_metrics" + PBAR_METRICS = "pbar_metrics" + LOG_METRICS = "log_metrics" + + +class Result: + pass @dataclass class Metadata: - prog_bar: bool = False - logger: bool = True + + hook_name: str + name: str + on_prog_bar: bool + on_logger: bool = True on_step: bool = False on_epoch: bool = True reduce_fx: Callable = torch.mean tbptt_reduce_fx: Callable = torch.mean tbptt_pad_token: int = 0 dataloader_idx: Optional[int] = None + is_tensor: bool = True @property def forked(self) -> bool: return self.on_step and self.on_epoch - def names(self, name: str) -> List[str]: - names = [name] - # check both - if self.on_step: - names += name + '_step' - if self.on_epoch: - names += name + '_epoch' - return names + @property + def forked_step_name(self) -> str: + if self.forked: + return self.name + "_step" + return self.name + @property + def forked_epoch_name(self) -> str: + if self.forked: + return self.name + "_epoch" + return self.name -@dataclass -class Result: - data: Any # TODO: Union[Tensor, Metric]? - meta: Metadata = field(repr=False) - batch_sizes: List[int] = field(default_factory=list, init=False) + +class ResultMetric(Metric): + + def __init__(self, metatata: Metadata): + super().__init__() + + self.meta = metatata + if self.meta.is_tensor: + self.add_state("values", []) + self.add_state("batch_sizes", []) + + def update(self, value: Union[torch.Tensor, Metric], batch_size: int) -> None: + if self.meta.is_tensor: + self.values.append(value) + self.batch_sizes.append(batch_size) + else: + self.value = value + + def compute(self) -> torch.Tensor: + if self.meta.is_tensor: + if self.reduce_fx == torch.mean: + return (tensor(self.values) * tensor(self.batch_sizes)).sum() / sum(self.batch_sizes) + elif self.reduce_fx == torch.max: + return max(self.values) + else: + raise MisconfigurationException("Only mean, max are supported.") + else: + return self.value.compute() + + +class ResultCollection(dict): + + def __init__(self) -> None: + super().__init__() + self.minimize: Optional[Tensor] = None + self.on_epoch_end_reached: bool = False + self.default_metrics = { + DefaultMetricsKeys.CALLBACK_METRICS: {}, + DefaultMetricsKeys.PBAR_METRICS: {}, + DefaultMetricsKeys.LOG_METRICS: {}, + } + + @property + def metrics_fn(self): + return self.get_epoch_metrics if self.on_epoch_end_reached else self.get_batch_metrics @staticmethod def extract_batch_size(batch: Any) -> int: try: - return Result._extract_batch_size(batch) + return ResultCollection._extract_batch_size(batch) except RecursionError: return 1 @@ -74,56 +134,14 @@ def _extract_batch_size(batch: Any) -> int: return len(batch) elif isinstance(batch, dict): sample = next(iter(batch.values()), 1) - size = Result._extract_batch_size(sample) + size = ResultCollection._extract_batch_size(sample) elif isinstance(batch, Iterable): sample = next(iter(batch), 1) - size = Result._extract_batch_size(sample) + size = ResultCollection._extract_batch_size(sample) else: size = 1 return size - -class ResultCollection(dict): - - def __init__(self) -> None: - super().__init__() - self.minimize: Optional[Tensor] = None - self.should_reduce_on_epoch_end = False - - #@staticmethod - #def removesuffix(s: str, suffix: str) -> str: - # # available from Python 3.9 - # if suffix and s.endswith(suffix): - # return s[:-len(suffix)] - # return s - - #@staticmethod - #def _parse_key(key: str) -> str: - # key = ResultCollection.removesuffix(key, '_epoch') - # key = ResultCollection.removesuffix(key, '_step') - # return key - - #def __getitem__(self, key: str) -> Result: - # if not isinstance(key, str): - # raise ValueError(f'`Result` keys must be `str`, found: {key}') - # if key in self: - # return super().__getitem__(key) - # # try removing `_epoch` and `_step` suffixes - # key = self._parse_key(key) - # return super().__getitem__(key) - - def get_callback_metrics(self): - return self.items() - - def get_logger_metrics(self): - pass - #ret = {} - #for item in self.items(): - # for name in item.names(): # names knows whether it is forked - # ret[name] = item.data - # checks whether is forked and returns all - # return self.items_prefixes() - @staticmethod def _sync( value, @@ -151,6 +169,7 @@ def _sync( def log( self, + hook_name: str, name: str, value: Any, prog_bar: bool = False, @@ -187,347 +206,112 @@ def log( if isinstance(value, torch.Tensor) and value.device.type == "xla": value = value.cpu() - result = Result( - value, - Metadata( - prog_bar=prog_bar, - logger=logger, - on_step=on_step, - on_epoch=on_epoch, - reduce_fx=reduce_fx, - tbptt_reduce_fx=tbptt_reduce_fx, - tbptt_pad_token=tbptt_pad_token, - dataloader_idx=dataloader_idx, - ), - ) if batch_size is None: - batch_size = Result.extract_batch_size(value) - result.batch_sizes.append(batch_size) - self[name] = result - self.should_reduce_on_epoch_end |= on_epoch - - @staticmethod - def _add_dl_idx(key: str, dl_idx: Union[int, None]) -> str: - if dl_idx is not None: - return f"{key}/dataloader_idx_{dl_idx}" - return key - - @staticmethod - def _filter(self: 'Result', fields: List[str], add_dataloader_idx: bool = False) -> Dict[str, '_METRIC']: # TODO - result = {} - for k, item in self.items(): - # check if we need to add the suffix - if 'on_step' in fields and 'on_epoch' not in fields: - k += '_step' - elif 'on_step' not in fields and 'on_epoch' in fields: - k += '_epoch' - - if all(getattr(item.meta, f, False) for f in fields): - k = Result._add_dl_idx(k, item.meta.dataloader_idx) - result[k] = item.data - return result - - def get_batch_log_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, '_METRIC']: - """Gets the metrics to log at the end of the batch""" - # TODO: remove dl idx - results = self._filter(self, ['logger', 'on_step'], add_dataloader_idx=add_dataloader_idx) - for k, v in results: - if isinstance(v, Metric) and v._forward_cache is not None: - results[k] = v._foward_cache.detach() - return results - - def get_epoch_log_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, '_METRIC']: - """Gets the metrics to log at the end of epoch""" - results = self._filter(self, ['logger', 'on_epoch'], add_dataloader_idx=add_dataloader_idx) - for k, v in results: - if isinstance(v, Metric) and v._forward_cache is not None: - results[k] = v._foward_cache.compute().detach() - # TODO: this? - # if k in self and not options['on_epoch'] and isinstance(self[k], Metric): - # # compute for reuse later - # self[k].compute() - return results - - def get_batch_pbar_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, '_METRIC']: - """Gets the metrics to include in the progress_bar at the end of epoch""" - results = self._filter(self, ['prog_bar', 'on_step'], add_dataloader_idx=add_dataloader_idx) - for k, v in results: - if isinstance(v, Metric) and v._forward_cache is not None: - results[k] = v._foward_cache.detach() - return results - - def get_epoch_pbar_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, '_METRIC']: - """Gets the metrics to include in the progress_bar at the end of epoch""" - results = self._filter(self, ['prog_bar', 'on_epoch'], add_dataloader_idx=add_dataloader_idx) - for k, v in results: - if isinstance(v, Metric) and v._forward_cache is not None: - results[k] = v._foward_cache.compute().detach() - # TODO: this? - # if k in self and not options['on_epoch'] and isinstance(self[k], Metric): - # # compute for reuse later - # self[k].compute() - return results - - def get_forked_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, '_METRIC']: - results = self._filter(self, [], add_dataloader_idx=add_dataloader_idx) - for k, v in results: - if isinstance(v, Metric) and v._forward_cache is not None: - results[k] = v._foward_cache.compute().detach() - return results - - def detach(self) -> 'ResultCollection': - for k, item in self.items(): - if isinstance(item.data, torch.Tensor): - item.data = item.data.detach() - return self - - def to(self, *args, **kwargs) -> 'ResultCollection': - """Move all data to the given device.""" - for k, item in self.items(): - if isinstance(item.data, torch.Tensor): - item.data = item.data.to(*args, **kwargs) - return self - - def cpu(self) -> 'Result': - """Move all data to CPU.""" - return self.to(device="cpu") - - # TODO: need this with detach? - #def __copy__(self): - # newone = type(self)() - # for k, v in self.items(): - # if isinstance(v, torch.Tensor): - # v = v.detach() - # newone[k] = copy(v) - # return newone - - @classmethod - def gather(cls, outputs): - meta = outputs[0].get('meta') - result = cls() - result = recursive_gather(outputs, result) - recursive_stack(result) - - if meta: - result['meta'] = meta - return result - - @classmethod - def padded_gather(cls, outputs): - meta = outputs[0].get('meta') - result = cls() - result = recursive_gather(outputs, result) - - # find the padding used for other values - default_padding_idx = 0 - for name, value in result.items(): - if name != 'minimize' and isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor): - default_padding_idx = meta[name]['tbptt_pad_token'] - break - - # pad across each key individually - for name, value in result.items(): - if isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor): - padding_key = default_padding_idx if name == 'minimize' else meta[name]['tbptt_pad_token'] - padded = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=padding_key) - result[name] = padded - - # also update the result - if meta and name != "minimize": - meta[name]['value'] = padded - if meta: - result['meta'] = meta - return result - - @classmethod - def reduce_on_epoch_end(cls, outputs): - # get the batch sizes for all outputs - batch_sizes = [] - meta = {} - for x in outputs: - batch_sizes.append(x.get_batch_sizes()) - meta.update(x['meta']) - - batch_sizes = torch.stack(batch_sizes).view(-1) - - result = cls() - result = recursive_gather(outputs, result) - recursive_stack(result) - - for k, option in meta.items(): - if k == '_internal' or isinstance(result[k], Metric): + raise MisconfigurationException("`batch_size` should be provided.") + + storage_key = f"{hook_name}.{name}" + + if storage_key not in self: + result = ResultMetric( + Metadata( + hook_name=hook_name, + name=name, + on_prog_bar=prog_bar, + on_logger=logger, + on_step=on_step, + on_epoch=on_epoch, + reduce_fx=reduce_fx, + tbptt_reduce_fx=tbptt_reduce_fx, + tbptt_pad_token=tbptt_pad_token, + dataloader_idx=dataloader_idx, + ) + ) + self[storage_key] = result + + self[storage_key](value, batch_size) + + def get_batch_metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: + metrics = deepcopy(self.default_metrics) + + for result_metric in self.values(): + if not result_metric.meta.on_step: continue - # for forked metrics don't reduce, just take the last val - if option['forked']: - result[k] = choose_last(result[k]) - continue - - if option['on_epoch']: - fx = option['reduce_fx'] - if fx == torch.mean: - if isinstance(result[k], list): - result[k] = torch.tensor(result[k]).float() - try: - reduced_val = weighted_mean(result[k], batch_sizes) - # todo: specify the expected Exceptions to come - except Exception: - reduced_val = torch.mean(result[k]) - else: - reduced_val = fx(result[k]) - - result[k] = reduced_val - else: - del result[k] + foward_cache: torch.Tensor = result_metric._forward_cache.detach() + name: str = result_metric.meta.name + name_forked: str = result_metric.meta.forked_step_name - result['meta'] = meta - return result + if result_metric.meta.on_prog_bar: + metrics[DefaultMetricsKeys.PBAR_METRICS][name_forked] = foward_cache - @classmethod - def reduce_across_time(cls, time_outputs: List['Result']) -> 'Result': - # auto-reduce across time for tbptt - meta = time_outputs[0].meta + if result_metric.meta.on_logger: + metrics[DefaultMetricsKeys.LOG_METRICS][name_forked] = foward_cache - result = cls() - result = recursive_gather(time_outputs, result) - recursive_stack(result) + metrics[DefaultMetricsKeys.CALLBACK_METRICS][name] = foward_cache - for k, value in result.items(): - if k in ['meta', 'extra'] or isinstance(value, Metric): - continue + return metrics - # pick the reduce fx - tbptt_reduce_fx = torch.mean if k == "minimize" else meta[k]['tbptt_reduce_fx'] + def get_batch_log_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, torch.Tensor]: + """Gets the metrics to log at the end of the batch""" + return self.get_batch_metrics()[DefaultMetricsKeys.LOG_METRICS] - if isinstance(value, list): - value = torch.tensor(value) + def get_batch_pbar_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, torch.Tensor]: + """Gets the metrics to include in the progress_bar at the end of the batch""" + return self.get_batch_metrics()[DefaultMetricsKeys.PBAR_METRICS] - if isinstance(value, dict): - # TODO: recursive reduce: - _recursive_fx_apply(value, tbptt_reduce_fx) - else: - result[k] = tbptt_reduce_fx(value.float()) + def get_batch_callback_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, torch.Tensor]: + """Gets the metrics for the callbacks at the end of the batch""" + return self.get_batch_metrics()[DefaultMetricsKeys.CALLBACK_METRICS] - result['meta'] = meta - return result + def get_epoch_metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: + metrics = deepcopy(self.default_metrics) - def dp_reduce(self): - for k, value in self.items(): - if k == 'meta' or isinstance(value, Metric): + for result_metric in self.values(): + if not result_metric.meta.on_epoch: continue - if isinstance(value, list): - value = torch.tensor(value) - - self[k] = value.mean(dim=-1) - - def get_non_metrics_keys(self) -> List[str]: - """This function is used to filter metric keys for which the value isn't a Metric""" - return [k for k, v in self.items() if not isinstance(v.data, Metric)] - - def reset(self) -> None: - """Call at the end of epoch to reset all metric objects""" - for item in self.values(): - if isinstance(item.data, Metric): - item.data.reset() - - -def choose_last(x): - if isinstance(x, (torch.Tensor, list)): - return x[-1] - if isinstance(x, dict): - for k, v in x.items(): - x[k] = x[k][-1] - - -def recursive_gather(outputs: List[Result], result: Result) -> Result: - for out in outputs: - for k, item in out.items(): - if isinstance(item.data, dict): - in_d = result.get(k, Result()) - result[k] = recursive_gather([item], in_d) - elif isinstance(item.data, Metric): - # if v is a metric, just keep one of them, - # don't keep on adding a list of them - result[k] = item - else: - result.setdefault(k, []) - result[k].append(item) - return result - + if not result_metric._computed: + result_metric.compute() -def recursive_stack(result: MutableMapping): - for k, item in result.items(): - if isinstance(item.data, dict): - recursive_stack(item.data) + computed: torch.Tensor = result_metric._computed.detach() + name: str = result_metric.meta.name + name_forked: str = result_metric.meta.forked_epoch_name - result[k] = collate_tensors(v) + if result_metric.meta.on_prog_bar: + metrics[DefaultMetricsKeys.PBAR_METRICS][name_forked] = computed + if result_metric.meta.on_logger: + metrics[DefaultMetricsKeys.LOG_METRICS][name_forked] = computed -def _recursive_fx_apply(input: dict, fx): - for k, v in input.items(): - if isinstance(v, list): - v = torch.tensor(v) - - if isinstance(v, torch.Tensor): - v = fx(v.float()) - input[k] = v - else: - _recursive_fx_apply(v, fx) - - -def collate_tensors(items: Union[List, Tuple]) -> Union[Tensor, List, Tuple]: - if not items or not isinstance(items, (list, tuple)) or any(not isinstance(item, Tensor) for item in items): - # items is not a sequence, empty, or contains non-tensors - return items - - if all(item.ndim == 0 for item in items): - # all tensors are scalars, we need to stack - return torch.stack(items) - - if all(item.ndim >= 1 and item.shape[1:] == items[0].shape[1:] for item in items): - # we can concatenate along the first dimension - return torch.cat(items) - - return items - - -def weighted_mean(result, weights): - - if isinstance(result, dict): - _process_dataloader_aggregated_steps(result, weights) - else: - if isinstance(result, list): - result = torch.tensor(result) - - weights = weights.to(result.device)[:result.size(0)] - numerator = torch.dot(result.float(), weights.transpose(-1, 0).float()) - result = numerator / weights.sum().float() - return result + metrics[DefaultMetricsKeys.CALLBACK_METRICS][name] = computed + return metrics -def _process_dataloader_aggregated_steps(result, weights): - internal_keys = {'meta'} + def get_epoch_log_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, torch.Tensor]: + """Gets the metrics to log at the end of the epoch""" + return self.get_epoch_metrics()[DefaultMetricsKeys.LOG_METRICS] - moved = False + def get_epoch_pbar_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, torch.Tensor]: + """Gets the metrics to include in the progress_bar at the end of the epoch""" + return self.get_epoch_metrics()[DefaultMetricsKeys.PBAR_METRICS] - for k, v in result.items(): - if k in internal_keys: - continue + def get_epoch_callback_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, torch.Tensor]: + """Gets the metrics for the callbacks at the end of the epoch""" + return self.get_epoch_metrics()[DefaultMetricsKeys.CALLBACK_METRICS] - # make sure v is a tensor - if not isinstance(v, torch.Tensor): - v = torch.tensor(v) - - # move to memory only once - if not moved: - weights = weights.to(v.device) - moved = True + def to(self, *args, **kwargs) -> 'ResultCollection': + """Move all data to the given device.""" + for item in self.values(): + if isinstance(item, ResultMetric): + item.to(*args, **kwargs) + return self - # move weights to same device as value to reduce - weights_t = weights[:v.size(0)] + def cpu(self) -> 'Result': + """Move all data to CPU.""" + return self.to(device="cpu") - # weighted mean - numerator = torch.dot(v.float(), weights_t.transpose(-1, 0).float()) - v = numerator / weights.sum().float() - result[k] = v + def reset(self) -> None: + """Call at the end of epoch to reset all metric objects""" + for item in self.values(): + if isinstance(item, Metric): + item.reset() diff --git a/tests/trainer/connectors/test_logger_connectors.py b/tests/trainer/connectors/test_logger_connectors.py new file mode 100644 index 0000000000000..fceee0da4872e --- /dev/null +++ b/tests/trainer/connectors/test_logger_connectors.py @@ -0,0 +1,132 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import tensor + +from pytorch_lightning import seed_everything +from pytorch_lightning.core.step_result import DefaultMetricsKeys, ResultCollection + + +def test_result_collection(): + + seed_everything(42) + + result_collection = ResultCollection() + + for i in range(1, 10): + for prob_bar in [False, True]: + for logger in [False, True]: + result_collection.log( + "training_step", + f"loss_1_{int(prob_bar)}_{int(logger)}", + torch.tensor(i), + on_step=True, + on_epoch=True, + batch_size=i**2, + prog_bar=prob_bar, + logger=logger + ) + result_collection.log( + "training_step", + f"loss_2_{int(prob_bar)}_{int(logger)}", + torch.tensor(i), + on_step=False, + on_epoch=True, + batch_size=i**2, + prog_bar=prob_bar, + logger=logger + ) + result_collection.log( + "training_step", + f"loss_3_{int(prob_bar)}_{int(logger)}", + torch.tensor(i), + on_step=True, + on_epoch=False, + batch_size=i**2, + prog_bar=prob_bar, + logger=logger + ) + result_collection.log( + "training_step", + f"loss_4_{int(prob_bar)}_{int(logger)}", + torch.tensor(i), + on_step=False, + on_epoch=False, + batch_size=i**2, + prog_bar=prob_bar, + logger=logger + ) + + excepted_values = [ + tensor(1), tensor(2), + tensor(3), tensor(4), + tensor(5), tensor(6), + tensor(7), tensor(8), + tensor(9) + ] + assert result_collection["training_step.loss_1_0_0"].values == excepted_values + excepted_batches = [1, 4, 9, 16, 25, 36, 49, 64, 81] + assert result_collection["training_step.loss_1_0_0"].batch_sizes == excepted_batches + + batch_metrics = result_collection.get_batch_metrics() + + expected = { + 'loss_1_1_0_step': tensor([9.]), + 'loss_3_1_0': tensor([9.]), + 'loss_1_1_1_step': tensor([9.]), + 'loss_3_1_1': tensor([9.]) + } + assert batch_metrics[DefaultMetricsKeys.PBAR_METRICS] == expected + + excepted = { + 'loss_1_0_1_step': tensor([9.]), + 'loss_3_0_1': tensor([9.]), + 'loss_1_1_1_step': tensor([9.]), + 'loss_3_1_1': tensor([9.]) + } + assert batch_metrics[DefaultMetricsKeys.LOG_METRICS] == excepted + + excepted = { + 'loss_1_0_0': tensor([9.]), + 'loss_3_0_0': tensor([9.]), + 'loss_1_0_1': tensor([9.]), + 'loss_3_0_1': tensor([9.]), + 'loss_1_1_0': tensor([9.]), + 'loss_3_1_0': tensor([9.]), + 'loss_1_1_1': tensor([9.]), + 'loss_3_1_1': tensor([9.]) + } + assert batch_metrics[DefaultMetricsKeys.CALLBACK_METRICS] == excepted + + epoch_metrics = result_collection.get_epoch_metrics() + + mean = (tensor(excepted_values) * tensor(excepted_batches)).sum() / sum(excepted_batches) + + expected = {'loss_1_1_0_epoch': mean, 'loss_2_1_0': mean, 'loss_1_1_1_epoch': mean, 'loss_2_1_1': mean} + assert epoch_metrics[DefaultMetricsKeys.PBAR_METRICS] == expected + + excepted = {'loss_1_0_1_epoch': mean, 'loss_2_0_1': mean, 'loss_1_1_1_epoch': mean, 'loss_2_1_1': mean} + assert epoch_metrics[DefaultMetricsKeys.LOG_METRICS] == excepted + + excepted = { + 'loss_1_0_0': mean, + 'loss_2_0_0': mean, + 'loss_1_0_1': mean, + 'loss_2_0_1': mean, + 'loss_1_1_0': mean, + 'loss_2_1_0': mean, + 'loss_1_1_1': mean, + 'loss_2_1_1': mean + } + assert epoch_metrics[DefaultMetricsKeys.CALLBACK_METRICS] == excepted From f166edc7058ecda32a7182bb716a146c5bbf2250 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 22 May 2021 02:31:12 +0200 Subject: [PATCH 084/455] WIP --- pytorch_lightning/core/step_result.py | 145 ++++++++++++-------------- 1 file changed, 65 insertions(+), 80 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index a689fce405a4a..c0086b12b399e 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -13,41 +13,39 @@ # limitations under the License. import numbers from copy import deepcopy -from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Tuple, Union +from dataclasses import dataclass +from typing import Any, Callable, Dict, Iterable, Optional, Union import torch -from torch import Tensor, tensor +from torch import Tensor from torchmetrics import Metric from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import _METRIC class DefaultMetricsKeys(LightningEnum): - - CALLBACK_METRICS = "callback_metrics" - PBAR_METRICS = "pbar_metrics" - LOG_METRICS = "log_metrics" + CALLBACK = "callback" + PBAR = "pbar" + LOG = "log" +# TODO: remove class Result: pass @dataclass class Metadata: - - hook_name: str + fx: str # TODO: distinction? name: str - on_prog_bar: bool - on_logger: bool = True + prog_bar: bool = False + logger: bool = True on_step: bool = False on_epoch: bool = True reduce_fx: Callable = torch.mean - tbptt_reduce_fx: Callable = torch.mean - tbptt_pad_token: int = 0 dataloader_idx: Optional[int] = None is_tensor: bool = True @@ -70,17 +68,19 @@ def forked_epoch_name(self) -> str: class ResultMetric(Metric): - def __init__(self, metatata: Metadata): + def __init__(self, metadata: Metadata) -> None: super().__init__() - - self.meta = metatata + self.meta = metadata if self.meta.is_tensor: + # TODO: dist_reduce_fx? self.add_state("values", []) self.add_state("batch_sizes", []) - def update(self, value: Union[torch.Tensor, Metric], batch_size: int) -> None: + def update(self, value: _METRIC, batch_size: Optional[int] = None) -> None: if self.meta.is_tensor: self.values.append(value) + if batch_size is None: + batch_size = self.extract_batch_size(value) self.batch_sizes.append(batch_size) else: self.value = value @@ -88,7 +88,7 @@ def update(self, value: Union[torch.Tensor, Metric], batch_size: int) -> None: def compute(self) -> torch.Tensor: if self.meta.is_tensor: if self.reduce_fx == torch.mean: - return (tensor(self.values) * tensor(self.batch_sizes)).sum() / sum(self.batch_sizes) + return (torch.tensor(self.values) * torch.tensor(self.batch_sizes)).sum() / sum(self.batch_sizes) elif self.reduce_fx == torch.max: return max(self.values) else: @@ -96,27 +96,10 @@ def compute(self) -> torch.Tensor: else: return self.value.compute() - -class ResultCollection(dict): - - def __init__(self) -> None: - super().__init__() - self.minimize: Optional[Tensor] = None - self.on_epoch_end_reached: bool = False - self.default_metrics = { - DefaultMetricsKeys.CALLBACK_METRICS: {}, - DefaultMetricsKeys.PBAR_METRICS: {}, - DefaultMetricsKeys.LOG_METRICS: {}, - } - - @property - def metrics_fn(self): - return self.get_epoch_metrics if self.on_epoch_end_reached else self.get_batch_metrics - @staticmethod def extract_batch_size(batch: Any) -> int: try: - return ResultCollection._extract_batch_size(batch) + return ResultMetric._extract_batch_size(batch) except RecursionError: return 1 @@ -134,14 +117,27 @@ def _extract_batch_size(batch: Any) -> int: return len(batch) elif isinstance(batch, dict): sample = next(iter(batch.values()), 1) - size = ResultCollection._extract_batch_size(sample) + size = ResultMetric._extract_batch_size(sample) elif isinstance(batch, Iterable): sample = next(iter(batch), 1) - size = ResultCollection._extract_batch_size(sample) + size = ResultMetric._extract_batch_size(sample) else: size = 1 return size + +class ResultCollection(dict): + + def __init__(self) -> None: + super().__init__() + self.minimize: Optional[Tensor] = None + self.on_epoch_end_reached: bool = False + self.default_metrics = {k: {} for k in DefaultMetricsKeys} + + @property + def metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: + return self.get_epoch_metrics() if self.on_epoch_end_reached else self.get_batch_metrics() + @staticmethod def _sync( value, @@ -190,8 +186,6 @@ def log( on_step: bool = False, on_epoch: bool = True, reduce_fx: Callable = torch.mean, - tbptt_reduce_fx: Callable = torch.mean, - tbptt_pad_token: int = 0, enable_graph: bool = False, sync_dist: bool = False, sync_dist_op: Union[Any, str] = 'mean', @@ -219,31 +213,27 @@ def log( if isinstance(value, torch.Tensor) and value.device.type == "xla": value = value.cpu() - if batch_size is None: - raise MisconfigurationException("`batch_size` should be provided.") - - storage_key = f"{hook_name}.{name}" + key = f"{hook_name}.{name}" - if storage_key not in self: + if key not in self: result = ResultMetric( Metadata( - hook_name=hook_name, + fx=hook_name, name=name, - on_prog_bar=prog_bar, - on_logger=logger, + prog_bar=prog_bar, + logger=logger, on_step=on_step, on_epoch=on_epoch, reduce_fx=reduce_fx, - tbptt_reduce_fx=tbptt_reduce_fx, - tbptt_pad_token=tbptt_pad_token, dataloader_idx=dataloader_idx, ) ) - self[storage_key] = result + self[key] = result - self[storage_key](value, batch_size) + self[key](value, batch_size) def get_batch_metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: + # TODO: do we need deepcopy? metrics = deepcopy(self.default_metrics) for result_metric in self.values(): @@ -251,30 +241,28 @@ def get_batch_metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: continue foward_cache: torch.Tensor = result_metric._forward_cache.detach() - name: str = result_metric.meta.name - name_forked: str = result_metric.meta.forked_step_name - if result_metric.meta.on_prog_bar: - metrics[DefaultMetricsKeys.PBAR_METRICS][name_forked] = foward_cache - - if result_metric.meta.on_logger: - metrics[DefaultMetricsKeys.LOG_METRICS][name_forked] = foward_cache - - metrics[DefaultMetricsKeys.CALLBACK_METRICS][name] = foward_cache + name_forked = result_metric.meta.forked_step_name + if result_metric.meta.prog_bar: + metrics[DefaultMetricsKeys.PBAR][name_forked] = foward_cache + if result_metric.meta.logger: + metrics[DefaultMetricsKeys.LOG][name_forked] = foward_cache + metrics[DefaultMetricsKeys.CALLBACK][result_metric.meta.name] = foward_cache return metrics + # TODO: add_dataloader_idx? def get_batch_log_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, torch.Tensor]: """Gets the metrics to log at the end of the batch""" - return self.get_batch_metrics()[DefaultMetricsKeys.LOG_METRICS] + return self.get_batch_metrics()[DefaultMetricsKeys.LOG] def get_batch_pbar_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, torch.Tensor]: """Gets the metrics to include in the progress_bar at the end of the batch""" - return self.get_batch_metrics()[DefaultMetricsKeys.PBAR_METRICS] + return self.get_batch_metrics()[DefaultMetricsKeys.PBAR] def get_batch_callback_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, torch.Tensor]: """Gets the metrics for the callbacks at the end of the batch""" - return self.get_batch_metrics()[DefaultMetricsKeys.CALLBACK_METRICS] + return self.get_batch_metrics()[DefaultMetricsKeys.CALLBACK] def get_epoch_metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: metrics = deepcopy(self.default_metrics) @@ -287,39 +275,36 @@ def get_epoch_metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: result_metric.compute() computed: torch.Tensor = result_metric._computed.detach() - name: str = result_metric.meta.name - name_forked: str = result_metric.meta.forked_epoch_name - - if result_metric.meta.on_prog_bar: - metrics[DefaultMetricsKeys.PBAR_METRICS][name_forked] = computed - if result_metric.meta.on_logger: - metrics[DefaultMetricsKeys.LOG_METRICS][name_forked] = computed - - metrics[DefaultMetricsKeys.CALLBACK_METRICS][name] = computed + name_forked: str = result_metric.meta.forked_epoch_name + if result_metric.meta.prog_bar: + metrics[DefaultMetricsKeys.PBAR][name_forked] = computed + if result_metric.meta.logger: + metrics[DefaultMetricsKeys.LOG][name_forked] = computed + metrics[DefaultMetricsKeys.CALLBACK][result_metric.meta.name] = computed return metrics def get_epoch_log_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, torch.Tensor]: """Gets the metrics to log at the end of the epoch""" - return self.get_epoch_metrics()[DefaultMetricsKeys.LOG_METRICS] + return self.get_epoch_metrics()[DefaultMetricsKeys.LOG] def get_epoch_pbar_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, torch.Tensor]: """Gets the metrics to include in the progress_bar at the end of the epoch""" - return self.get_epoch_metrics()[DefaultMetricsKeys.PBAR_METRICS] + return self.get_epoch_metrics()[DefaultMetricsKeys.PBAR] def get_epoch_callback_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, torch.Tensor]: """Gets the metrics for the callbacks at the end of the epoch""" - return self.get_epoch_metrics()[DefaultMetricsKeys.CALLBACK_METRICS] + return self.get_epoch_metrics()[DefaultMetricsKeys.CALLBACK] def to(self, *args, **kwargs) -> 'ResultCollection': """Move all data to the given device.""" - for item in self.values(): - if isinstance(item, ResultMetric): - item.to(*args, **kwargs) + for k, v in self.items(): + if isinstance(v, (torch.Tensor, Metric)): + self[k] = v.to(*args, **kwargs) return self - def cpu(self) -> 'Result': + def cpu(self) -> 'ResultCollection': """Move all data to CPU.""" return self.to(device="cpu") From 1e77d02390eb0a6953dfc036a65742bdffb208c8 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 22 May 2021 02:51:00 +0200 Subject: [PATCH 085/455] Move sync code from step result to lightning module --- pytorch_lightning/core/lightning.py | 43 +++++++++++++++++++++++---- pytorch_lightning/core/step_result.py | 20 ------------- pytorch_lightning/utilities/types.py | 3 +- 3 files changed, 39 insertions(+), 27 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index b974f57741ad2..80f902e422504 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -17,6 +17,7 @@ import copy import inspect import logging +import numbers import os import tempfile import types @@ -42,10 +43,11 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin +from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT +from pytorch_lightning.utilities.types import _METRIC, EPOCH_OUTPUT, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache warning_cache = WarningCache() @@ -325,6 +327,15 @@ def log( f"Logged key: {name} should not contain information about dataloader_idx." ) + value = self._sync( + value, + sync_fn=self.trainer.training_type_plugin.reduce, + sync_dist=sync_dist, + sync_dist_op=sync_dist_op, + sync_dist_group=sync_dist_group, + device=self.device, + ) + self._results.log( name, value, @@ -336,12 +347,7 @@ def log( tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token, enable_graph=enable_graph, - sync_dist=sync_dist, - sync_dist_op=sync_dist_op, - sync_dist_group=sync_dist_group, - sync_fn=self.trainer.training_type_plugin.reduce, dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), - device=self.device, ) def log_dict( @@ -403,6 +409,31 @@ def log_dict( add_dataloader_idx=add_dataloader_idx ) + @staticmethod + def __sync( + value: _METRIC, + sync_fn: Optional[Callable] = None, + sync_dist: bool = False, + sync_dist_op: Union[Any, str] = 'mean', + sync_dist_group: Optional[Any] = None, + device: torch.device = None, + ) -> _METRIC: + """Sync across workers when using distributed training""" + if not isinstance(value, (torch.Tensor, numbers.Number)): + return value + + sync_fn = sync_fn or sync_ddp_if_available + dist_available = torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed() + if not sync_dist or not dist_available: + return value + + # TODO: Find a way to make the reduction only once, so we don't need to clone. + if isinstance(value, torch.Tensor): + value = value.clone() + else: + value = torch.tensor(value, device=device, dtype=torch.float) + return sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op) + def write_prediction( self, name: str, value: Union[torch.Tensor, List[torch.Tensor]], filename: str = 'predictions.pt' ): diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index acf06cc858eb9..e77558a5abb4e 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -13,7 +13,6 @@ # limitations under the License. """Result class for easier logging and epoch-wise reduction.""" -import numbers from copy import copy from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Sequence, Tuple, Union @@ -21,8 +20,6 @@ from torch import Tensor from torchmetrics import Metric -from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed - class Result(Dict): @@ -88,29 +85,12 @@ def log( tbptt_reduce_fx: Callable = torch.mean, tbptt_pad_token: int = 0, enable_graph: bool = False, - sync_dist: bool = False, - sync_dist_op: Union[Any, str] = 'mean', - sync_dist_group: Optional[Any] = None, - sync_fn: Callable = None, dataloader_idx: Optional[int] = None, - device: torch.device = None, ): # no metrics should be logged with graphs if not enable_graph and isinstance(value, torch.Tensor): value = value.detach() - # sync across workers when using distributed training - sync_fn = sync_fn or sync_ddp_if_available - - if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)): - is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized() - # TODO: Find a way to make the reduction only once, so we don't need to clone. - if (is_dist_initialized or tpu_distributed()) and isinstance(value, torch.Tensor): - value = value.clone() - else: - value = torch.tensor(value, device=device, dtype=torch.float) - value = sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op) - if isinstance(value, torch.Tensor) and value.device.type == "xla": value = value.cpu() diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index fdfdb95b08692..8a81040af07db 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -16,12 +16,13 @@ - Do not include any `_TYPE` suffix - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no trailing `_`) """ +from numbers import Number from typing import Any, Dict, Iterator, List, Union import torch from torchmetrics import Metric -_METRIC = Union[Metric, torch.Tensor, int, float] +_METRIC = Union[Metric, torch.Tensor, Number] STEP_OUTPUT = Union[torch.Tensor, Dict[str, Any]] EPOCH_OUTPUT = List[STEP_OUTPUT] _EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader From 2f7ab755d6d605f11ec0385130aa8770ee37a95c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 22 May 2021 02:55:10 +0200 Subject: [PATCH 086/455] Typo --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 80f902e422504..fca96a64ea6fa 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -327,7 +327,7 @@ def log( f"Logged key: {name} should not contain information about dataloader_idx." ) - value = self._sync( + value = self.__sync( value, sync_fn=self.trainer.training_type_plugin.reduce, sync_dist=sync_dist, From ec9da3b66e859e0aea9b15cd9ab362bcb09d2309 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 24 May 2021 10:28:56 +0200 Subject: [PATCH 087/455] integrate #7563 --- pytorch_lightning/loops/batch_loop.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index ed35c12d711fa..f014a1c04d45a 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -124,7 +124,6 @@ def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimi # ------------------- # calculate loss (train step + train step end) # ------------------- - # automatic_optimization=True: perform ddp sync only when performing optimizer_step # automatic_optimization=False: don't block synchronization here with self.block_ddp_sync_behaviour(): @@ -137,6 +136,9 @@ def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimi else: if self.trainer.lightning_module.automatic_optimization: self.optimizer_step(optimizer, opt_idx, batch_idx, closure) + if len(self.trainer.optimizers) > 1: + # revert back to previous state + self.trainer.lightning_module.untoggle_optimizer(opt_idx) else: result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens) @@ -448,10 +450,6 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, "training_step returned None. If this was on purpose, ignore this warning..." ) - if len(self.trainer.optimizers) > 1: - # revert back to previous state - self.trainer.lightning_module.untoggle_optimizer(opt_idx) - return result def _check_finite(self, loss: torch.Tensor) -> None: From 82dbcd3f2ecbc8deb9efa513c413d31d2805a798 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 24 May 2021 14:24:39 +0100 Subject: [PATCH 088/455] update poc --- pytorch_lightning/core/lightning.py | 35 +- pytorch_lightning/core/step_result.py | 342 +++++++++++------ .../logger_connector/logger_connector.py | 349 +++++++----------- pytorch_lightning/trainer/evaluation_loop.py | 20 +- pytorch_lightning/trainer/properties.py | 12 + pytorch_lightning/trainer/trainer.py | 28 +- pytorch_lightning/trainer/training_loop.py | 44 +-- pytorch_lightning/utilities/apply_func.py | 55 +++ .../connectors/test_logger_connectors.py | 19 +- .../logging_/test_train_loop_logging.py | 3 +- 10 files changed, 488 insertions(+), 419 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 74c1ef442f993..0e1dc7d224efc 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -38,7 +38,7 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES -from pytorch_lightning.core.step_result import Result +from pytorch_lightning.core.step_result import Result, ResultCollection from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -80,6 +80,7 @@ class LightningModule( "model_size", "automatic_optimization", "truncated_bptt_steps", + "_results", ] + DeviceDtypeModuleMixin.__jit_unused_properties__ def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -106,7 +107,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # optionally can be set by user self._example_input_array = None self._datamodule = None - self._results: Optional[Result] = None self._current_fx_name: Optional[str] = None self._running_manual_backward: bool = False self._current_dataloader_idx: Optional[int] = None @@ -216,6 +216,11 @@ def logger(self): """ Reference to the logger object in the Trainer. """ return self.trainer.logger if self.trainer else None + @property + def _results(self) -> 'Optional[ResultCollection]': + if hasattr(self, "trainer"): + return self.trainer.result_collections + def _apply_batch_transfer_handler( self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: Optional[int] = None ) -> Any: @@ -272,6 +277,7 @@ def log( sync_dist_op: Union[Any, str] = 'mean', sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, + batch_size: Optional[int] = None, ) -> None: """ Log a key, value @@ -308,6 +314,8 @@ def log( add_dataloader_idx: if True, appends the index of the current dataloader to the name (when using multiple). If False, user needs to give unique names for each dataloader to not mix values + batch_size: Current batch_size. This will be directly inferred from the loaded batch, + but some esoteric data type such as graph might need to explicitly provide the batch_size. """ if tbptt_reduce_fx is not None: rank_zero_deprecation( @@ -338,8 +346,8 @@ def log( f"Logged key: {name} should not contain information about dataloader_idx." ) - value = self.__sync( - value, + sync_fn = partial( + self.__sync, sync_fn=self.trainer.training_type_plugin.reduce, sync_dist=sync_dist, sync_dist_op=sync_dist_op, @@ -347,7 +355,14 @@ def log( device=self.device, ) + value = apply_to_collection(value, ( + torch.Tensor, + float, + int, + ), sync_fn) + self._results.log( + self._current_fx_name, name, value, prog_bar=prog_bar, @@ -357,6 +372,7 @@ def log( reduce_fx=reduce_fx, enable_graph=enable_graph, dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), + batch_size=batch_size ) def log_dict( @@ -429,16 +445,17 @@ def __sync( if not isinstance(value, (torch.Tensor, numbers.Number)): return value + # TODO: Find a way to make the reduction only once, so we don't need to clone. + if isinstance(value, torch.Tensor): + value = value.clone() + else: + return torch.tensor(value, device=device, dtype=torch.float) + sync_fn = sync_fn or sync_ddp_if_available dist_available = torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed() if not sync_dist or not dist_available: return value - # TODO: Find a way to make the reduction only once, so we don't need to clone. - if isinstance(value, torch.Tensor): - value = value.clone() - else: - value = torch.tensor(value, device=device, dtype=torch.float) return sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op) def write_prediction( diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 041ba17349ad1..fca5f37dd35b6 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -11,19 +11,78 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy +from collections.abc import Mapping, Sequence from dataclasses import dataclass -from typing import Any, Callable, Dict, Iterable, Optional +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union import torch from torch import Tensor from torchmetrics import Metric +from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import _METRIC +def apply_to_metrics_collection( + data: Any, + dtype: Union[type, tuple], + function: Callable, + *args, + wrong_dtype: Optional[Union[type, tuple]] = None, + **kwargs +) -> Any: + """ + Recursively applies a function to all elements of a certain dtype. + + Args: + data: the collection to apply the function to + dtype: the given function will be applied to all elements of this dtype + function: the function to apply + *args: positional arguments (will be forwarded to calls of ``function``) + wrong_dtype: the given function won't be applied if this type is specified and the given collections is of + the :attr:`wrong_type` even if it is of type :attr`dtype` + **kwargs: keyword arguments (will be forwarded to calls of ``function``) + + Returns: + the resulting collection + """ + elem_type = type(data) + + # Breaking condition + if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): + return function(data, *args, **kwargs) + + # Recursively apply to collection items + if isinstance(data, Mapping): + _out = {} + for k, v in data.items(): + v = apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + if v is not None: + _out[k] = v + return elem_type(_out) + + if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple + _out = [] + for d in data: + v = apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + if v is not None: + _out.append(v) + return elem_type(*_out) + + if isinstance(data, Sequence) and not isinstance(data, str): + _out = [] + for d in data: + v = apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + if v is not None: + _out.append(v) + return elem_type(_out) + + # data is neither of dtype, nor a collection + return data + + class DefaultMetricsKeys(LightningEnum): CALLBACK = "callback" PBAR = "pbar" @@ -63,6 +122,18 @@ def forked_epoch_name(self) -> str: return self.name + "_epoch" return self.name + @property + def is_tensor_and_mean_reduction(self) -> bool: + return self.is_tensor and self.reduce_fx == torch.mean + + @property + def is_tensor_and_max_reduction(self) -> bool: + return self.is_tensor and (self.reduce_fx in (torch.max, max)) + + @property + def is_tensor_and_min_reduction(self) -> bool: + return self.is_tensor and (self.reduce_fx in (torch.min, min)) + class ResultMetric(Metric): @@ -70,67 +141,49 @@ def __init__(self, metadata: Metadata) -> None: super().__init__() self.meta = metadata if self.meta.is_tensor: - # TODO: dist_reduce_fx? - self.add_state("values", []) - self.add_state("batch_sizes", []) + self.add_state("value", torch.tensor(.0)) + if self.meta.is_tensor_and_mean_reduction: + self.add_state("cumulated_batch_size", torch.tensor(.0)) def update(self, value: _METRIC, batch_size: Optional[int] = None) -> None: - if self.meta.is_tensor: - self.values.append(value) - if batch_size is None: - batch_size = self.extract_batch_size(value) - self.batch_sizes.append(batch_size) + if self.meta.is_tensor_and_mean_reduction: + self.value += value.float().mean() * batch_size + self.cumulated_batch_size += batch_size + + elif self.meta.is_tensor_and_max_reduction: + self.value = max(self.value, value.float().mean()) + + elif self.meta.is_tensor_and_min_reduction: + self.value = min(self.value, value.float().mean()) + else: self.value = value def compute(self) -> torch.Tensor: if self.meta.is_tensor: - if self.reduce_fx == torch.mean: - return (torch.tensor(self.values) * torch.tensor(self.batch_sizes)).sum() / sum(self.batch_sizes) - elif self.reduce_fx == torch.max: - return max(self.values) + if self.meta.is_tensor_and_mean_reduction: + return self.value / self.cumulated_batch_size + elif self.meta.is_tensor_and_max_reduction or self.meta.is_tensor_and_min_reduction: + return self.value else: raise MisconfigurationException("Only mean, max are supported.") else: return self.value.compute() - @staticmethod - def extract_batch_size(batch: Any) -> int: - try: - return ResultMetric._extract_batch_size(batch) - except RecursionError: - return 1 - - @staticmethod - def _extract_batch_size(batch: Any) -> int: - """ - Recursively unpack a batch to find a torch.Tensor. - - Returns: - ``len(tensor)`` when found, or ``1`` when it hits an empty or non iterable. - """ - if isinstance(batch, torch.Tensor): - size = batch.size(0) - elif isinstance(batch, str): - return len(batch) - elif isinstance(batch, dict): - sample = next(iter(batch.values()), 1) - size = ResultMetric._extract_batch_size(sample) - elif isinstance(batch, Iterable): - sample = next(iter(batch), 1) - size = ResultMetric._extract_batch_size(sample) - else: - size = 1 - return size - class ResultCollection(dict): def __init__(self) -> None: super().__init__() - self.minimize: Optional[Tensor] = None - self.on_epoch_end_reached: bool = False - self.default_metrics = {k: {} for k in DefaultMetricsKeys} + self.reset() + + @property + def batch_size(self) -> int: + return self._batch_size + + @batch_size.setter + def batch_size(self, batch_size: int) -> None: + self._batch_size = batch_size @property def metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: @@ -138,16 +191,24 @@ def metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: @property def minimize(self) -> Optional[Tensor]: - return self.get('minimize', None) + return self._minimize @minimize.setter - def minimize(self, val: Optional[torch.Tensor]) -> None: - if val is not None: - if not isinstance(val, Tensor): - raise ValueError(f"`Result.minimize` must be a `torch.Tensor`, found: {val}") - if val.grad_fn is None: + def minimize(self, loss: Optional[torch.Tensor]) -> None: + if loss is not None: + if not isinstance(loss, Tensor): + raise ValueError(f"`Result.minimize` must be a `torch.Tensor`, found: {loss}") + if loss.grad_fn is None: raise RuntimeError("`Result.minimize` must have a `grad_fn`") - self['minimize'] = val + self._minimize = loss + + @property + def extra(self) -> Dict: + return self.get('extra', {}) + + @extra.setter + def extra(self, extra: Dict) -> None: + self['extra'] = extra def log( self, @@ -165,6 +226,12 @@ def log( ): """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" # no metrics should be logged with graphs + + batch_size = batch_size or self._batch_size + + if not batch_size: + raise MisconfigurationException("batch_size should be provided to ResultCollection.log function.") + if not enable_graph and isinstance(value, torch.Tensor): value = value.detach() @@ -174,86 +241,105 @@ def log( key = f"{hook_name}.{name}" if key not in self: - result = ResultMetric( - Metadata( - fx=hook_name, - name=name, - prog_bar=prog_bar, - logger=logger, - on_step=on_step, - on_epoch=on_epoch, - reduce_fx=reduce_fx, - dataloader_idx=dataloader_idx, - ) + meta = Metadata( + fx=hook_name, + name=name, + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, + reduce_fx=reduce_fx, + dataloader_idx=dataloader_idx, ) - self[key] = result + self.instance_result_metric(key, meta, value) - self[key](value, batch_size) + self.update_metrics(key, value, batch_size) - def get_batch_metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: - # TODO: do we need deepcopy? - metrics = deepcopy(self.default_metrics) + def instance_result_metric(self, key: str, meta: Metadata, value: Union[Dict, torch.Tensor]) -> None: - for result_metric in self.values(): - if not result_metric.meta.on_step: - continue + def fn(*_): + return ResultMetric(meta) - foward_cache: torch.Tensor = result_metric._forward_cache.detach() + self[key] = apply_to_collection(value, torch.Tensor, fn) + self[key + '.forked'] = meta.forked + self[key + '.logger'] = meta.logger + self[key + '.prog_bar'] = meta.prog_bar - name_forked = result_metric.meta.forked_step_name - if result_metric.meta.prog_bar: - metrics[DefaultMetricsKeys.PBAR][name_forked] = foward_cache - if result_metric.meta.logger: - metrics[DefaultMetricsKeys.LOG][name_forked] = foward_cache - metrics[DefaultMetricsKeys.CALLBACK][result_metric.meta.name] = foward_cache + def update_metrics(self, key: str, value: Union[Dict, torch.Tensor], batch_size) -> None: - return metrics + def fn(result_metric, v): + assert torch.is_tensor(v) + result_metric(v, batch_size) - # TODO: add_dataloader_idx? - def get_batch_log_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, torch.Tensor]: - """Gets the metrics to log at the end of the batch""" - return self.get_batch_metrics()[DefaultMetricsKeys.LOG] + apply_to_collections(self[key], value, ResultMetric, fn) - def get_batch_pbar_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, torch.Tensor]: - """Gets the metrics to include in the progress_bar at the end of the batch""" - return self.get_batch_metrics()[DefaultMetricsKeys.PBAR] + @staticmethod + def _get_forward_cache(result_metric: ResultMetric) -> Optional[torch.Tensor]: + if not result_metric.meta.on_step: + return - def get_batch_callback_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, torch.Tensor]: - """Gets the metrics for the callbacks at the end of the batch""" - return self.get_batch_metrics()[DefaultMetricsKeys.CALLBACK] + return result_metric._forward_cache.detach() - def get_epoch_metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: - metrics = deepcopy(self.default_metrics) + @staticmethod + def _to_item(forward_cache: torch.Tensor) -> float: + return forward_cache.item() - for result_metric in self.values(): - if not result_metric.meta.on_epoch: + def valid_metrics(self) -> Tuple[str, Any]: + for key, result_metric in self.items(): + if isinstance(result_metric, bool) or key == "extra": + continue + yield (key, result_metric) + + def _extract_metadata(self, key: str, result_metric, on_step: bool, suffix: str) -> Tuple: + if isinstance(result_metric, ResultMetric): + name = result_metric.meta.name + name_forked = result_metric.meta.forked_step_name if on_step else result_metric.meta.forked_epoch_name + logger = result_metric.meta.logger + prog_bar = result_metric.meta.prog_bar + else: + name = key.split('.')[-1] + name_forked = name + suffix if self[key + '.forked'] else name + logger = self[key + '.logger'] + prog_bar = self[key + '.prog_bar'] + return name, name_forked, logger, prog_bar + + def get_metrics(self, on_step: bool) -> Dict[str, Dict[str, torch.Tensor]]: + metrics = {k: {} for k in DefaultMetricsKeys} + fn = self._get_forward_cache if on_step else self._get_computed_cache + suffix = "_step" if on_step else "_epoch" + + for key, result_metric in self.valid_metrics(): + value = apply_to_metrics_collection(result_metric, ResultMetric, fn) + if value is None: continue - if not result_metric._computed: - result_metric.compute() - - computed: torch.Tensor = result_metric._computed.detach() + name, name_forked, logger, prog_bar = self._extract_metadata(key, result_metric, on_step, suffix) - name_forked: str = result_metric.meta.forked_epoch_name - if result_metric.meta.prog_bar: - metrics[DefaultMetricsKeys.PBAR][name_forked] = computed - if result_metric.meta.logger: - metrics[DefaultMetricsKeys.LOG][name_forked] = computed - metrics[DefaultMetricsKeys.CALLBACK][result_metric.meta.name] = computed + if logger: + metrics[DefaultMetricsKeys.LOG][name_forked] = value + metrics[DefaultMetricsKeys.CALLBACK][name] = value + metrics[DefaultMetricsKeys.CALLBACK][name_forked] = value + if prog_bar: + value = apply_to_metrics_collection(result_metric, torch.Tensor, self._to_item) + metrics[DefaultMetricsKeys.PBAR][name_forked] = value return metrics - def get_epoch_log_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, torch.Tensor]: - """Gets the metrics to log at the end of the epoch""" - return self.get_epoch_metrics()[DefaultMetricsKeys.LOG] + def get_batch_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, Dict[str, torch.Tensor]]: + return self.get_metrics(on_step=True) + + @staticmethod + def _get_computed_cache(result_metric: ResultMetric) -> Optional[torch.Tensor]: + if not result_metric.meta.on_epoch: + return - def get_epoch_pbar_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, torch.Tensor]: - """Gets the metrics to include in the progress_bar at the end of the epoch""" - return self.get_epoch_metrics()[DefaultMetricsKeys.PBAR] + if not result_metric._computed: + result_metric.compute() - def get_epoch_callback_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, torch.Tensor]: - """Gets the metrics for the callbacks at the end of the epoch""" - return self.get_epoch_metrics()[DefaultMetricsKeys.CALLBACK] + return result_metric._computed.detach() + + def get_epoch_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, Dict[str, torch.Tensor]]: + return self.get_metrics(on_step=False) def to(self, *args, **kwargs) -> 'ResultCollection': """Move all data to the given device.""" @@ -271,3 +357,33 @@ def reset(self) -> None: for item in self.values(): if isinstance(item, Metric): item.reset() + self._batch_size: int = 1 + self.on_epoch_end_reached: bool = False + self._minimize: Optional[Tensor] = None + + def extract_batch_size(self, batch: Any) -> None: + try: + self._batch_size = self._extract_batch_size(batch) + except RecursionError: + self._batch_size = 1 + + def _extract_batch_size(self, batch: Any) -> int: + """ + Recursively unpack a batch to find a torch.Tensor. + + Returns: + ``len(tensor)`` when found, or ``1`` when it hits an empty or non iterable. + """ + if isinstance(batch, torch.Tensor): + size = batch.size(0) + elif isinstance(batch, str): + return len(batch) + elif isinstance(batch, dict): + sample = next(iter(batch.values()), 1) + size = self._extract_batch_size(sample) + elif isinstance(batch, Iterable): + sample = next(iter(batch), 1) + size = self._extract_batch_size(sample) + else: + size = 1 + return size diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index a16f5119abff2..6a60e33650476 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -14,18 +14,19 @@ import os from copy import deepcopy from pprint import pprint -from typing import Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union import torch from pytorch_lightning.core import memory -from pytorch_lightning.core.step_result import Result +from pytorch_lightning.core.step_result import DefaultMetricsKeys, Result, ResultCollection from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import DeviceType +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.metrics import metrics_to_scalars from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT @@ -35,88 +36,13 @@ class LoggerConnector: def __init__(self, trainer, log_gpu_memory: Optional[str] = None): self.trainer = trainer self.log_gpu_memory = log_gpu_memory - self._callback_metrics = MetricsHolder() - self._evaluation_callback_metrics = MetricsHolder(to_float=True) - self._logged_metrics = MetricsHolder() - self._progress_bar_metrics = MetricsHolder(to_float=True) self.eval_loop_results = [] - self._cached_results = {stage: EpochResultStore(trainer) for stage in RunningStage} - self._cached_results[None] = EpochResultStore(trainer) self._fx_validator = FxValidator() self._val_log_step: int = 0 self._test_log_step: int = 0 - - @property - def callback_metrics(self) -> Dict: - return self.get_metrics("callback_metrics") - - @callback_metrics.setter - def callback_metrics(self, callback_metrics: Dict) -> None: - self.set_metrics("callback_metrics", callback_metrics) - - @property - def evaluation_callback_metrics(self) -> Dict: - return self.get_metrics("evaluation_callback_metrics") - - @evaluation_callback_metrics.setter - def evaluation_callback_metrics(self, evaluation_callback_metrics: Dict) -> None: - self.set_metrics("evaluation_callback_metrics", evaluation_callback_metrics) - - @property - def logged_metrics(self) -> Dict: - return self.get_metrics("logged_metrics") - - @logged_metrics.setter - def logged_metrics(self, logged_metrics: Dict) -> None: - self.set_metrics("logged_metrics", logged_metrics) - - @property - def progress_bar_metrics(self) -> Dict: - return self.get_metrics("progress_bar_metrics") - - @progress_bar_metrics.setter - def progress_bar_metrics(self, progress_bar_metrics: Dict) -> None: - self.set_metrics("progress_bar_metrics", progress_bar_metrics) - - @property - def cached_results(self) -> Union[EpochResultStore, None]: - return self._cached_results.get(self.trainer.state.stage) - - def get_metrics(self, key: str) -> Dict: - metrics_holder: MetricsHolder = getattr(self, f"_{key}") - model = self.trainer.lightning_module - metrics_holder.convert(model.device if model is not None else None) - return metrics_holder.metrics - - def set_metrics(self, key: str, val: Dict) -> None: - metrics_holder: MetricsHolder = getattr(self, f"_{key}") - metrics_holder.reset(val) - - def reset(self) -> None: - self.cached_results.reset() - - def check_logging(self, fx_name: str, on_step: bool, on_epoch: bool) -> None: - self._fx_validator.check_logging(fx_name=fx_name, on_step=on_step, on_epoch=on_epoch) - - def on_evaluation_batch_start(self, batch, dataloader_idx, num_dataloaders): - model = self.trainer.lightning_module - # set dataloader_idx only if multiple ones - model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None - # track batch_size - self.cached_results._batch_size = Result.extract_batch_size(batch) - - def on_train_split_start(self, split_idx: int, opt_idx: int, split_batch) -> None: - self.cached_results._split_idx = split_idx - self.cached_results._opt_idx = opt_idx - self.cached_results._batch_size = Result.extract_batch_size(split_batch) - - def on_train_batch_end(self) -> None: - self.cached_results._split_idx = None - self.cached_results._opt_idx = None - self.cached_results._batch_size = None - - def cache_logged_metrics(self): - self._cached_results[self.trainer.state.stage].cache_result() + self._progress_bar_metrics: Dict[str, float] = {} + self._logged_metrics: Dict[str, float] = {} + self._callback_metrics: Dict[str, float] = {} def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool): # logging @@ -125,16 +51,6 @@ def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_ste self.trainer.log_every_n_steps = log_every_n_steps self.trainer.move_metrics_to_cpu = move_metrics_to_cpu - @property - def should_flush_logs(self): - should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0 - return should_flush or self.trainer.should_stop - - @property - def should_update_logs(self): - should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 - return should_log_every_n_steps or self.trainer.should_stop - def configure_logger(self, logger): if logger is True: version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id) @@ -151,46 +67,26 @@ def configure_logger(self, logger): else: self.trainer.logger = logger - def cache_training_step_metrics(self, opt_closure_result): - """ - This function is responsible to update - logger_connector internals metrics holder based for depreceated logging - """ - using_results_obj = isinstance(opt_closure_result.training_step_output, Result) - - # temporary dict to collect metrics - logged_metrics_tmp = {} - pbar_metrics_tmp = {} - callback_metrics_tmp = {} - - if using_results_obj: - batch_log_metrics = opt_closure_result.training_step_output.get_batch_log_metrics( - include_forked_originals=False - ) - logged_metrics_tmp.update(batch_log_metrics) - - batch_pbar_metrics = opt_closure_result.training_step_output.get_batch_pbar_metrics( - include_forked_originals=False - ) - pbar_metrics_tmp.update(batch_pbar_metrics) - - forked_metrics = opt_closure_result.training_step_output.get_forked_metrics() - callback_metrics_tmp.update(forked_metrics) - callback_metrics_tmp.update(logged_metrics_tmp) + def on_evaluation_batch_start(self, batch, dataloader_idx, num_dataloaders): + if self.trainer.sanity_checking: + return - else: - batch_log_metrics = opt_closure_result.training_step_output.log_metrics - logged_metrics_tmp.update(batch_log_metrics) + model = self.trainer.lightning_module + # set dataloader_idx only if multiple ones + model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None - batch_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end - pbar_metrics_tmp.update(batch_pbar_metrics) + # track batch_size + self.trainer.result_collections.extract_batch_size(batch) - # track progress bar metrics - if len(pbar_metrics_tmp) > 0: - self.add_progress_bar_metrics(pbar_metrics_tmp) + @property + def should_flush_logs(self): + should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0 + return should_flush or self.trainer.should_stop - self._callback_metrics.update(callback_metrics_tmp) - self._logged_metrics.update(logged_metrics_tmp) + @property + def should_update_logs(self): + should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 + return should_log_every_n_steps or self.trainer.should_stop def log_metrics(self, metrics, grad_norm_dict, step=None): """Logs the metric dict passed in. @@ -228,30 +124,23 @@ def log_metrics(self, metrics, grad_norm_dict, step=None): self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step) self.trainer.logger.save() - # track the logged metrics - self.logged_metrics.update(scalar_metrics) - self.trainer.dev_debugger.track_logged_metrics_history(scalar_metrics) - - def add_progress_bar_metrics(self, metrics): - for k, v in metrics.items(): - if isinstance(v, torch.Tensor): - v = v.item() - - self._progress_bar_metrics.metrics[k] = v - - self.trainer.dev_debugger.track_pbar_metrics_history(metrics) + self.add_logged_metrics(scalar_metrics) def evaluation_epoch_end(self): + if self.trainer.sanity_checking: + return + # reset dataloader idx model_ref = self.trainer.lightning_module model_ref._current_dataloader_idx = None - - # setting `has_batch_loop_finished` to True - # will perform Results reduction accross entire epoch. - self.cached_results.has_batch_loop_finished = True + self.trainer.result_collections.on_epoch_end_reached = True def add_to_eval_loop_results(self, dl_idx, has_been_initialized): - callback_metrics = deepcopy(self.evaluation_callback_metrics) + if self.trainer.sanity_checking: + return + + callback_metrics = self.trainer.result_collections.metrics[DefaultMetricsKeys.CALLBACK] + callback_metrics = deepcopy(callback_metrics) for key in list(callback_metrics.keys()): if "dataloader_idx" in key: if f"dataloader_idx_{dl_idx}" not in key: @@ -271,7 +160,7 @@ def prepare_eval_loop_results(self): def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: if not self.trainer.sanity_checking: # log all the metrics as a single dict - metrics_to_log = self.cached_results.get_epoch_log_metrics() + metrics_to_log = self.trainer.result_collections.metrics[DefaultMetricsKeys.LOG] if len(metrics_to_log) > 0: self.log_metrics(metrics_to_log, {}) @@ -297,81 +186,6 @@ def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: self.eval_loop_results = [] return results - def on_train_epoch_end(self): - # inform cached logger connector epoch finished - self.cached_results.has_batch_loop_finished = True - - def log_train_epoch_end_metrics(self, epoch_output: List[List[List[Result]]]) -> None: - # epoch output is a list. Each item in that list has all the outputs per optimizer - # epoch_output[optimizer_idx][training_step_idx][tbptt_index] - # remember that not using truncated backprop is equivalent with truncated back prop of len(1) - - # log/aggregate metrics automatically - epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output) - - # it will perform reduction over epoch and return log metrics - cached_epoch_log_metrics = self.cached_results.get_epoch_log_metrics() - cached_epoch_pbar_metrics = self.cached_results.get_epoch_pbar_metrics() - - # update - epoch_log_metrics.update(cached_epoch_log_metrics) - epoch_progress_bar_metrics.update(cached_epoch_pbar_metrics) - - # -------------------------- - # track results - # -------------------------- - # add the metrics to the loggers and callbacks - if epoch_log_metrics and len(epoch_log_metrics) > 0: - self.log_metrics(epoch_log_metrics, {}) - self._callback_metrics.update(epoch_log_metrics) - - # add metrics to progress_bar and callbacks - if len(epoch_progress_bar_metrics) > 0: - self.add_progress_bar_metrics(epoch_progress_bar_metrics) - self._callback_metrics.update(epoch_progress_bar_metrics) - - # reset epoch loop result for next epoch - self.cached_results.reset() - - def __auto_reduce_results_on_epoch_end(self, epoch_output): - epoch_log_metrics = {} - epoch_progress_bar_metrics = {} - for opt_outputs in epoch_output: - # reduce across time first - time_reduced_outputs = [] - for tbptt_outs in opt_outputs: - tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs) - if len(tbptt_outs) > 1: - time_reduced_outputs.append(tbptt_outs) - - if len(time_reduced_outputs) == 0: - continue - - # reduce across training steps - opt_outputs = time_reduced_outputs[0].__class__.reduce_on_epoch_end(time_reduced_outputs) - - # with manual opt need 1 + metrics because meta is always there - if opt_outputs.minimize is not None: - opt_outputs.minimize = opt_outputs.minimize.mean() - epoch_log_metrics.update(opt_outputs.epoch_log_metrics) - epoch_progress_bar_metrics.update(opt_outputs.epoch_pbar_metrics) - - return epoch_log_metrics, epoch_progress_bar_metrics - - def log_train_step_metrics(self, batch_output): - if self.trainer.train_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization: - return - _, batch_log_metrics = self.cached_results.update_logger_connector() - # when metrics should be logged - if self.should_update_logs or self.trainer.fast_dev_run is True: - # logs user requested information to logger - grad_norm_dict = batch_output.grad_norm_dict - if grad_norm_dict is None: - grad_norm_dict = {} - if len(batch_log_metrics) > 0 or len(grad_norm_dict) > 0: - self.log_metrics(batch_log_metrics, grad_norm_dict) - self._callback_metrics.update(batch_log_metrics) - @property def evaluation_log_step(self) -> Optional[int]: if self.trainer.state.stage is RunningStage.VALIDATING: @@ -390,7 +204,9 @@ def increment_evaluation_log_step(self) -> None: def log_evaluation_step_metrics(self) -> None: if self.trainer.sanity_checking: return - _, batch_log_metrics = self.cached_results.update_logger_connector() + + metrics = self.trainer.result_collections.metrics + batch_log_metrics = metrics[DefaultMetricsKeys.LOG] # logs user requested information to logger if len(batch_log_metrics) > 0: @@ -399,3 +215,96 @@ def log_evaluation_step_metrics(self) -> None: # increment the step even if nothing was logged self.increment_evaluation_log_step() + + ############## TRAIN METRICS UPDATES START ############## + + def on_train_split_start(self, split_batch: Any) -> None: + self.trainer.result_collections.extract_batch_size(split_batch) + + def on_train_batch_end(self) -> None: + self.trainer.result_collections.batch_size = 1 + + def update_train_step_metrics(self, batch_output): + metrics = self.trainer.result_collections.metrics + + # update metrics + self.add_progress_bar_metrics(metrics[DefaultMetricsKeys.PBAR]) + self.add_callback_metrics(metrics[DefaultMetricsKeys.CALLBACK]) + + if self.trainer.train_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization: + return + + batch_log_metrics = metrics[DefaultMetricsKeys.LOG] + + # when metrics should be logged + if self.should_update_logs or self.trainer.fast_dev_run is True: + # logs user requested information to logger + grad_norm_dict = batch_output.grad_norm_dict + if grad_norm_dict is None: + grad_norm_dict = {} + if len(batch_log_metrics) > 0 or len(grad_norm_dict) > 0: + self.log_metrics(batch_log_metrics, grad_norm_dict) + + def on_train_epoch_end(self): + # inform cached logger connector epoch finished + self.trainer.result_collections.on_epoch_end_reached = True + + def update_train_epoch_metrics(self) -> None: + + metrics = self.trainer.result_collections.metrics + + # update metrics + self.add_progress_bar_metrics(metrics[DefaultMetricsKeys.PBAR]) + + callback_metrics = metrics[DefaultMetricsKeys.CALLBACK] + if os.getenv("PL_DEV_DEBUG", '0') == '1': + callback_metrics["debug_epoch"] = self.trainer.current_epoch + + self._callback_metrics.update(callback_metrics) + + epoch_log_metrics = metrics[DefaultMetricsKeys.LOG] + epoch_log_metrics["epoch"] = self.trainer.current_epoch + self._logged_metrics.update(epoch_log_metrics) + + # add the metrics to the loggers + if epoch_log_metrics and len(epoch_log_metrics) > 0: + self.log_metrics(epoch_log_metrics, {}) + + # reset result collection for next epoch + self.trainer.result_collections.reset() + + ############## TRAIN METRICS UPDATES END ############## + + ############## UTILS START ############## + + @property + def callback_metrics(self) -> Dict: + return self._callback_metrics + + @property + def logged_metrics(self) -> Dict: + return self._logged_metrics + + @property + def progress_bar_metrics(self) -> Dict: + return self._progress_bar_metrics + + def add_progress_bar_metrics(self, metrics): + self._progress_bar_metrics.update(metrics) + self.trainer.dev_debugger.track_pbar_metrics_history(metrics) + + def add_logged_metrics(self, metrics): + self._logged_metrics.update(metrics) + self.trainer.dev_debugger.track_logged_metrics_history(metrics) + + def add_callback_metrics(self, metrics): + self._callback_metrics.update(metrics) + + def check_logging(self, fx_name: str, on_step: bool, on_epoch: bool) -> None: + self._fx_validator.check_logging(fx_name=fx_name, on_step=on_step, on_epoch=on_epoch) + + def reset(self): + if self.trainer.result_collections: + self.trainer.result_collections.reset() + + ############## UTILS END ############## diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index f048297892533..45957948d77b7 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -17,7 +17,7 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning.core.step_result import Result +from pytorch_lightning.core.step_result import Result, ResultCollection from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.model_helpers import is_overridden @@ -34,6 +34,8 @@ def __init__(self, trainer: 'pl.Trainer'): self.max_batches: Optional[List[Union[int, float]]] = None self.warning_cache = WarningCache() self.num_dataloaders: Optional[int] = None + self.validation_results = ResultCollection() + self.test_results = ResultCollection() def on_trainer_init(self) -> None: self.trainer.num_sanity_val_batches = [] @@ -162,24 +164,15 @@ def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Op # configure step_kwargs step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) - model_ref = self.trainer.lightning_module - model_ref._results = Result() - if self.trainer.testing: - model_ref._current_fx_name = "test_step" + self.trainer.lightning_module._current_fx_name = "test_step" with self.trainer.profiler.profile("test_step"): output = self.trainer.accelerator.test_step(step_kwargs) else: - model_ref._current_fx_name = "validation_step" + self.trainer.lightning_module._current_fx_name = "validation_step" with self.trainer.profiler.profile("validation_step"): output = self.trainer.accelerator.validation_step(step_kwargs) - # capture any logged information - self.trainer.logger_connector.cache_logged_metrics() - # track batch size for weighted average - if isinstance(output, Result): - output.track_batch_size(batch) - return output def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: @@ -213,9 +206,6 @@ def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: model._current_fx_name = 'validation_epoch_end' model.validation_epoch_end(outputs) - # capture logging - self.trainer.logger_connector.cache_logged_metrics() - def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: # set dataloader_idx to model and track batch_size self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self.num_dataloaders) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index ff12e5c6e9053..3c98866c799cb 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -26,6 +26,7 @@ from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.core.step_result import ResultCollection from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin @@ -42,6 +43,7 @@ parse_env_variables, ) from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -511,6 +513,16 @@ def max_steps(self) -> Optional[int]: def min_steps(self) -> Optional[int]: return self.train_loop.min_steps + @property + def result_collections(self) -> Optional[ResultCollection]: + if self.training: + return self.train_loop.train_results + elif self.validating: + return self.evaluation_loop.validation_results + elif self.testing: + return self.evaluation_loop.test_results + return None + # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6a20625978e39..3fe9d48c8d0ad 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1201,32 +1201,13 @@ def _call_teardown_hook(self, model: LightningModule) -> None: model._current_fx_name = None model._current_dataloader_idx = None - def _reset_result_and_set_fx_name(self, hook_name: str) -> bool: - # on_before_zero_grad is called within training_step - # TODO(@carmocca): Result should handle this logic - if "batch_start" in hook_name or hook_name in ("on_before_zero_grad", "on_after_backward"): - return True - model_ref = self.lightning_module - if model_ref is not None: - # used to track current hook name called - model_ref._results = Result() - model_ref._current_fx_name = hook_name - return False - - def _cache_logged_metrics(self): - model_ref = self.lightning_module - if model_ref is not None: - # capture logging for this hook - self.logger_connector.cache_logged_metrics() - def call_hook(self, hook_name: str, *args, **kwargs) -> Any: # Note this implementation is copy/pasted into the TrainLoop class in TrainLoop._on_train_epoch_end_hook # This was done to manage the deprecation of the `outputs` argument to on_train_epoch_end # If making changes to this function, ensure that those changes are also made to # TrainLoop._on_train_epoch_end_hook - - # set hook_name to model + reset Result obj - skip = self._reset_result_and_set_fx_name(hook_name) + if self.lightning_module: + self.lightning_module._current_fx_name = hook_name # always profile hooks with self.profiler.profile(hook_name): @@ -1249,8 +1230,9 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: accelerator_hook = getattr(self.accelerator, hook_name) output = accelerator_hook(*args, **kwargs) - if not skip: - self._cache_logged_metrics() + if self.lightning_module: + self.lightning_module._current_fx_name = None + return output @staticmethod diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 84d69765c7c36..70ac7e236dd83 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -23,7 +23,7 @@ from torch.optim import Optimizer from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.core.step_result import Result +from pytorch_lightning.core.step_result import Result, ResultCollection from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType @@ -83,6 +83,8 @@ def __init__( else: self.trainer.num_sanity_val_steps = num_sanity_val_steps + self.train_results = ResultCollection() + @property def num_active_optimizers(self) -> int: return len(self.get_active_optimizers()) @@ -256,8 +258,6 @@ def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[i return [(opt_idx, self.trainer.optimizers[opt_idx])] def on_after_backward(self, training_step_output, batch_idx, untouched_loss): - training_step_output.detach() - # insert after step hook self.trainer.call_hook("on_after_backward") @@ -279,13 +279,10 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): # manually capture logged metrics model_ref._current_fx_name = 'training_step' - model_ref._results = Result() with self.trainer.profiler.profile("training_step"): training_step_output = self.trainer.accelerator.training_step(step_kwargs) self.trainer.accelerator.post_training_step() - self.trainer.logger_connector.cache_logged_metrics() - self._check_training_step_output(training_step_output) training_step_output = self.trainer.call_hook("training_step_end", training_step_output) @@ -345,12 +342,7 @@ def _process_training_step_output(self, training_step_output, split_batch): result.minimize = loss self._hiddens = hiddens - # track batch for manual reduction with result - result.track_batch_size(len(split_batch)) - - # track metrics without grads for epoch reduction training_step_output_for_epoch_end = copy(result) - training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() if self.trainer.move_metrics_to_cpu: training_step_output_for_epoch_end = training_step_output_for_epoch_end.cpu() @@ -388,6 +380,9 @@ def _prepare_outputs( for batch_outputs in opt_outputs: processed_tbptt_outputs = [] + if isinstance(batch_outputs, ResultCollection): + batch_outputs = [batch_outputs] + for tbptt_output in batch_outputs: out = tbptt_output.extra out['loss'] = tbptt_output.minimize @@ -495,6 +490,11 @@ def run_training_epoch(self): if batch_output.signal == -1: break + # ----------------------------------------- + # SAVE METRICS TO LOGGERS AND PROGRESS_BAR + # ----------------------------------------- + self.trainer.logger_connector.update_train_step_metrics(batch_output) + # hook # TODO: add outputs to batches self.on_train_batch_end( @@ -505,11 +505,6 @@ def run_training_epoch(self): dataloader_idx, ) - # ----------------------------------------- - # SAVE METRICS TO LOGGERS - # ----------------------------------------- - self.trainer.logger_connector.log_train_step_metrics(batch_output) - # ----------------------------------------- # VALIDATE IF NEEDED # ----------------------------------------- @@ -559,7 +554,7 @@ def run_training_epoch(self): self.on_train_epoch_end(epoch_output) # log epoch metrics - self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output) + self.trainer.logger_connector.update_train_epoch_metrics() should_check_val = self._should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) @@ -603,9 +598,6 @@ def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None: 'HINT: remove the return statement in training_epoch_end' ) - # capture logging - self.trainer.logger_connector.cache_logged_metrics() - # call train epoch end hooks self._on_train_epoch_end_hook(processed_epoch_output) self.trainer.call_hook('on_epoch_end') @@ -617,9 +609,7 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: # This implementation is copied from Trainer.call_hook hook_name = "on_train_epoch_end" - - # set hook_name to model + reset Result obj - skip = self.trainer._reset_result_and_set_fx_name(hook_name) + self.trainer.lightning_module._current_fx_name = hook_name # always profile hooks with self.trainer.profiler.profile(hook_name): @@ -649,8 +639,7 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: accelerator_hook = getattr(self.trainer.accelerator, hook_name) accelerator_hook() - if not skip: - self.trainer._cache_logged_metrics() + self.trainer.lightning_module._current_fx_name = None def run_training_batch(self, batch, batch_idx, dataloader_idx): # track grad norms @@ -798,9 +787,6 @@ def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) - if not opt_closure_result: return - # cache metrics - self.trainer.logger_connector.cache_training_step_metrics(opt_closure_result) - # check if loss or model weights are nan if self.trainer.terminate_on_nan: self._check_finite(opt_closure_result.loss) @@ -981,7 +967,7 @@ def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): model.toggle_optimizer(optimizer, opt_idx) # use to track metrics internally - self.trainer.logger_connector.on_train_split_start(split_idx, opt_idx, split_batch) + self.trainer.logger_connector.on_train_split_start(split_batch) def update_running_loss(self, current_loss: torch.Tensor) -> None: if self.trainer.lightning_module.automatic_optimization: diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 1cbab2fb8dee9..fc88eb24ec442 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -104,6 +104,61 @@ def apply_to_collection( return data +def apply_to_collections( + data1: Any, + data2: Any, + dtype: Union[type, tuple], + function: Callable, + *args, + wrong_dtype: Optional[Union[type, tuple]] = None, + **kwargs +) -> Any: + """ + Recursively applies a function to all elements of a certain dtype. + + Args: + data: the collection to apply the function to + dtype: the given function will be applied to all elements of this dtype + function: the function to apply + *args: positional arguments (will be forwarded to calls of ``function``) + wrong_dtype: the given function won't be applied if this type is specified and the given collections is of + the :attr:`wrong_type` even if it is of type :attr`dtype` + **kwargs: keyword arguments (will be forwarded to calls of ``function``) + + Returns: + the resulting collection + """ + elem_type_1 = type(data1) + + # Breaking condition + if isinstance(data1, dtype) and (wrong_dtype is None or not isinstance(data1, wrong_dtype)): + return function(data1, data2, *args, **kwargs) + + # Recursively apply to collection items + if isinstance(data1, Mapping): + return elem_type_1({ + k1: apply_to_collections(v1, v2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + for (k1, v1), (k1, v2) in zip(data1.items(), data2.items()) + }) + + if isinstance(data1, tuple) and hasattr(data1, '_fields'): # named tuple + return elem_type_1( + *( + apply_to_collections(d1, d2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + for d1, d2 in zip(data1, data2) + ) + ) + + if isinstance(data1, Sequence) and not isinstance(data1, str): + return elem_type_1([ + apply_to_collections(d1, d2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + for d1, d2 in zip(data1, data2) + ]) + + # data is neither of dtype, nor a collection + return data1 + + class TransferableDataType(ABC): """ A custom type for data that can be moved to a torch device via `.to(...)`. diff --git a/tests/trainer/connectors/test_logger_connectors.py b/tests/trainer/connectors/test_logger_connectors.py index fceee0da4872e..a0a01df0ff858 100644 --- a/tests/trainer/connectors/test_logger_connectors.py +++ b/tests/trainer/connectors/test_logger_connectors.py @@ -18,7 +18,7 @@ from pytorch_lightning.core.step_result import DefaultMetricsKeys, ResultCollection -def test_result_collection(): +def test_result_collection_on_tensor_with_mean_reduction(): seed_everything(42) @@ -75,9 +75,10 @@ def test_result_collection(): tensor(7), tensor(8), tensor(9) ] - assert result_collection["training_step.loss_1_0_0"].values == excepted_values excepted_batches = [1, 4, 9, 16, 25, 36, 49, 64, 81] - assert result_collection["training_step.loss_1_0_0"].batch_sizes == excepted_batches + total_value = tensor(excepted_values) * tensor(excepted_batches) + assert result_collection["training_step.loss_1_0_0"].value == sum(total_value) + assert result_collection["training_step.loss_1_0_0"].cumulated_batch_size == sum(excepted_batches) batch_metrics = result_collection.get_batch_metrics() @@ -87,7 +88,7 @@ def test_result_collection(): 'loss_1_1_1_step': tensor([9.]), 'loss_3_1_1': tensor([9.]) } - assert batch_metrics[DefaultMetricsKeys.PBAR_METRICS] == expected + assert batch_metrics[DefaultMetricsKeys.PBAR] == expected excepted = { 'loss_1_0_1_step': tensor([9.]), @@ -95,7 +96,7 @@ def test_result_collection(): 'loss_1_1_1_step': tensor([9.]), 'loss_3_1_1': tensor([9.]) } - assert batch_metrics[DefaultMetricsKeys.LOG_METRICS] == excepted + assert batch_metrics[DefaultMetricsKeys.LOG] == excepted excepted = { 'loss_1_0_0': tensor([9.]), @@ -107,17 +108,17 @@ def test_result_collection(): 'loss_1_1_1': tensor([9.]), 'loss_3_1_1': tensor([9.]) } - assert batch_metrics[DefaultMetricsKeys.CALLBACK_METRICS] == excepted + assert batch_metrics[DefaultMetricsKeys.CALLBACK] == excepted epoch_metrics = result_collection.get_epoch_metrics() mean = (tensor(excepted_values) * tensor(excepted_batches)).sum() / sum(excepted_batches) expected = {'loss_1_1_0_epoch': mean, 'loss_2_1_0': mean, 'loss_1_1_1_epoch': mean, 'loss_2_1_1': mean} - assert epoch_metrics[DefaultMetricsKeys.PBAR_METRICS] == expected + assert epoch_metrics[DefaultMetricsKeys.PBAR] == expected excepted = {'loss_1_0_1_epoch': mean, 'loss_2_0_1': mean, 'loss_1_1_1_epoch': mean, 'loss_2_1_1': mean} - assert epoch_metrics[DefaultMetricsKeys.LOG_METRICS] == excepted + assert epoch_metrics[DefaultMetricsKeys.LOG] == excepted excepted = { 'loss_1_0_0': mean, @@ -129,4 +130,4 @@ def test_result_collection(): 'loss_1_1_1': mean, 'loss_2_1_1': mean } - assert epoch_metrics[DefaultMetricsKeys.CALLBACK_METRICS] == excepted + assert epoch_metrics[DefaultMetricsKeys.CALLBACK] == excepted diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 546fb9ff8fdac..081b8cc7237fc 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -175,7 +175,7 @@ def backward(self, loss, optimizer, optimizer_idx): # make sure all the metrics are available for callbacks logged_metrics = set(trainer.logged_metrics.keys()) - expected_logged_metrics = {'epoch', 'a_step', 'a_epoch', 'b', 'b1', 'a1', 'a2'} + expected_logged_metrics = {'a1', 'b', 'epoch', 'a_step', 'b1', 'a_epoch', 'a2'} assert logged_metrics == expected_logged_metrics pbar_metrics = set(trainer.progress_bar_metrics.keys()) @@ -397,6 +397,7 @@ def val_dataloader(self): limit_val_batches=2, max_epochs=1, weights_summary=None, + fast_dev_run=True, ) trainer.fit(model) From c1dc5a54e3502aa44905d8505bc36e6978c067cf Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 24 May 2021 15:55:27 +0100 Subject: [PATCH 089/455] update --- pytorch_lightning/core/step_result.py | 111 ++++++------------ .../logger_connector/logger_connector.py | 9 ++ pytorch_lightning/utilities/apply_func.py | 33 ++++-- .../logging_/test_eval_loop_logging.py | 19 +-- .../logging_/test_train_loop_logging.py | 4 +- 5 files changed, 84 insertions(+), 92 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index fca5f37dd35b6..5fa208ce455ef 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -25,64 +25,6 @@ from pytorch_lightning.utilities.types import _METRIC -def apply_to_metrics_collection( - data: Any, - dtype: Union[type, tuple], - function: Callable, - *args, - wrong_dtype: Optional[Union[type, tuple]] = None, - **kwargs -) -> Any: - """ - Recursively applies a function to all elements of a certain dtype. - - Args: - data: the collection to apply the function to - dtype: the given function will be applied to all elements of this dtype - function: the function to apply - *args: positional arguments (will be forwarded to calls of ``function``) - wrong_dtype: the given function won't be applied if this type is specified and the given collections is of - the :attr:`wrong_type` even if it is of type :attr`dtype` - **kwargs: keyword arguments (will be forwarded to calls of ``function``) - - Returns: - the resulting collection - """ - elem_type = type(data) - - # Breaking condition - if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): - return function(data, *args, **kwargs) - - # Recursively apply to collection items - if isinstance(data, Mapping): - _out = {} - for k, v in data.items(): - v = apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) - if v is not None: - _out[k] = v - return elem_type(_out) - - if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple - _out = [] - for d in data: - v = apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) - if v is not None: - _out.append(v) - return elem_type(*_out) - - if isinstance(data, Sequence) and not isinstance(data, str): - _out = [] - for d in data: - v = apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) - if v is not None: - _out.append(v) - return elem_type(_out) - - # data is neither of dtype, nor a collection - return data - - class DefaultMetricsKeys(LightningEnum): CALLBACK = "callback" PBAR = "pbar" @@ -105,6 +47,7 @@ class Metadata: reduce_fx: Callable = torch.mean dataloader_idx: Optional[int] = None is_tensor: bool = True + should_reset: bool = True @property def forked(self) -> bool: @@ -170,6 +113,13 @@ def compute(self) -> torch.Tensor: else: return self.value.compute() + def __repr__(self) -> str: + if self.meta.is_tensor_and_mean_reduction: + attr = f"value={self.value}, cumulated_batch_size={self.cumulated_batch_size}" + else: + attr = f"value={self.value}" + return f"{self.__class__.__name__}({attr})" + class ResultCollection(dict): @@ -185,6 +135,15 @@ def batch_size(self) -> int: def batch_size(self, batch_size: int) -> None: self._batch_size = batch_size + @property + def on_epoch_end_reached(self) -> bool: + return self._on_epoch_end_reached + + @on_epoch_end_reached.setter + def on_epoch_end_reached(self, on_epoch_end_reached): + self._on_epoch_end_reached = on_epoch_end_reached + self._batch_size = None + @property def metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: return self.get_epoch_metrics() if self.on_epoch_end_reached else self.get_batch_metrics() @@ -210,6 +169,9 @@ def extra(self) -> Dict: def extra(self, extra: Dict) -> None: self['extra'] = extra + def should_reset(self, hook_name): + return hook_name not in ("on_train_start") + def log( self, hook_name: str, @@ -227,11 +189,6 @@ def log( """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" # no metrics should be logged with graphs - batch_size = batch_size or self._batch_size - - if not batch_size: - raise MisconfigurationException("batch_size should be provided to ResultCollection.log function.") - if not enable_graph and isinstance(value, torch.Tensor): value = value.detach() @@ -250,10 +207,11 @@ def log( on_epoch=on_epoch, reduce_fx=reduce_fx, dataloader_idx=dataloader_idx, + should_reset=self.should_reset(hook_name), ) self.instance_result_metric(key, meta, value) - self.update_metrics(key, value, batch_size) + self.update_metrics(key, value, batch_size or torch.tensor(1.)) def instance_result_metric(self, key: str, meta: Metadata, value: Union[Dict, torch.Tensor]) -> None: @@ -261,9 +219,11 @@ def fn(*_): return ResultMetric(meta) self[key] = apply_to_collection(value, torch.Tensor, fn) - self[key + '.forked'] = meta.forked - self[key + '.logger'] = meta.logger - self[key + '.prog_bar'] = meta.prog_bar + # cache the meta for reduction + if not isinstance(self[key], ResultMetric): + self[key + '.forked'] = meta.forked + self[key + '.logger'] = meta.logger + self[key + '.prog_bar'] = meta.prog_bar def update_metrics(self, key: str, value: Union[Dict, torch.Tensor], batch_size) -> None: @@ -309,7 +269,7 @@ def get_metrics(self, on_step: bool) -> Dict[str, Dict[str, torch.Tensor]]: suffix = "_step" if on_step else "_epoch" for key, result_metric in self.valid_metrics(): - value = apply_to_metrics_collection(result_metric, ResultMetric, fn) + value = apply_to_collection(result_metric, ResultMetric, fn, remove_none=True) if value is None: continue @@ -321,7 +281,7 @@ def get_metrics(self, on_step: bool) -> Dict[str, Dict[str, torch.Tensor]]: metrics[DefaultMetricsKeys.CALLBACK][name_forked] = value if prog_bar: - value = apply_to_metrics_collection(result_metric, torch.Tensor, self._to_item) + value = apply_to_collection(value, torch.Tensor, self._to_item, remove_none=True) metrics[DefaultMetricsKeys.PBAR][name_forked] = value return metrics @@ -355,10 +315,10 @@ def cpu(self) -> 'ResultCollection': def reset(self) -> None: """Call at the end of epoch to reset all metric objects""" for item in self.values(): - if isinstance(item, Metric): + if isinstance(item, ResultMetric) and item.meta.should_reset: item.reset() - self._batch_size: int = 1 - self.on_epoch_end_reached: bool = False + self._batch_size: Optional[int] = None + self._on_epoch_end_reached: bool = False self._minimize: Optional[Tensor] = None def extract_batch_size(self, batch: Any) -> None: @@ -387,3 +347,10 @@ def _extract_batch_size(self, batch: Any) -> int: else: size = 1 return size + + def __repr__(self) -> str: + repr = f'{self.__class__.__name__}' + '{\n' + for k in sorted(self.keys()): + v = self[k] + repr += f" {k}: {v},\n" + return repr[:-1] + '\n}' diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 6a60e33650476..227ce08d9a2ef 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -279,14 +279,23 @@ def update_train_epoch_metrics(self) -> None: @property def callback_metrics(self) -> Dict: + if self.trainer.result_collections: + metrics = self.trainer.result_collections.metrics[DefaultMetricsKeys.CALLBACK] + self._callback_metrics.update(metrics) return self._callback_metrics @property def logged_metrics(self) -> Dict: + if self.trainer.result_collections: + metrics = self.trainer.result_collections.metrics[DefaultMetricsKeys.LOG] + self._logged_metrics.update(metrics) return self._logged_metrics @property def progress_bar_metrics(self) -> Dict: + if self.trainer.result_collections: + metrics = self.trainer.result_collections.metrics[DefaultMetricsKeys.PBAR] + self._progress_bar_metrics.update(metrics) return self._progress_bar_metrics def add_progress_bar_metrics(self, metrics): diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index fc88eb24ec442..c46d0ad525d33 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -60,6 +60,7 @@ def apply_to_collection( function: Callable, *args, wrong_dtype: Optional[Union[type, tuple]] = None, + remove_none: bool = False, **kwargs ) -> Any: """ @@ -72,6 +73,7 @@ def apply_to_collection( *args: positional arguments (will be forwarded to calls of ``function``) wrong_dtype: the given function won't be applied if this type is specified and the given collections is of the :attr:`wrong_type` even if it is of type :attr`dtype` + remove_none: Whether to skip an element if the output of function is ``None`` while applying onto the collection. **kwargs: keyword arguments (will be forwarded to calls of ``function``) Returns: @@ -85,20 +87,31 @@ def apply_to_collection( # Recursively apply to collection items if isinstance(data, Mapping): - return elem_type({ - k: apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) - for k, v in data.items() - }) + _out = {} + for k, v in data.items(): + v = apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + if remove_none and v is None: + continue + _out[k] = v + return elem_type(_out) if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple - return elem_type( - *(apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data) - ) + _out = [] + for d in data: + v = apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + if remove_none and v is None: + continue + _out.append(v) + return elem_type(*_out) if isinstance(data, Sequence) and not isinstance(data, str): - return elem_type([ - apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data - ]) + _out = [] + for d in data: + v = apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + if remove_none and v is None: + continue + _out.append(v) + return elem_type(_out) # data is neither of dtype, nor a collection return data diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 331734aa9b412..c583a47c26848 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -207,7 +207,7 @@ def validation_epoch_end(self, outputs): assert callback_metrics == expected_callback_metrics # assert the loggers received the expected number - assert len(trainer.dev_debugger.logged_metrics) == max_epochs + assert len(trainer.dev_debugger.logged_metrics) == max_epochs * 2 def test_eval_float_logging(tmpdir): @@ -264,7 +264,7 @@ def validation_step(self, batch, batch_idx): output = self.layer(batch) loss = self.loss(batch, output) self.seen_vals.append(loss) - self.log('val_loss', loss, on_epoch=True, on_step=True, prog_bar=True) + self.log('val_loss', loss, on_epoch=True, on_step=True, prog_bar=True, logger=True) return {"x": loss} def validation_epoch_end(self, outputs) -> None: @@ -292,21 +292,24 @@ def validation_epoch_end(self, outputs) -> None: # make sure values are correct assert trainer.logged_metrics['val_loss_epoch'] == manual_mean - assert trainer.callback_metrics['val_loss'] == trainer.logged_metrics['val_loss_step'] + assert trainer.callback_metrics['val_loss_epoch'] == manual_mean + assert trainer.callback_metrics['val_loss'] == manual_mean + assert trainer.logged_metrics["val_loss_step"] == model.seen_vals[-1] # make sure correct values were logged logged_val = trainer.dev_debugger.logged_metrics + assert trainer.logged_metrics["val_loss_step"] == model.seen_vals[-1] # 3 val batches - assert logged_val[0]['val_loss_step'] == model.seen_vals[0] - assert logged_val[1]['val_loss_step'] == model.seen_vals[1] - assert logged_val[2]['val_loss_step'] == model.seen_vals[2] + assert logged_val[1]['val_loss_step'] == model.seen_vals[0] + assert logged_val[2]['val_loss_step'] == model.seen_vals[1] + assert logged_val[3]['val_loss_step'] == model.seen_vals[2] # epoch mean - assert logged_val[3]['val_loss_epoch'] == model.manual_epoch_end_mean + assert logged_val[4]['val_loss_epoch'] == model.manual_epoch_end_mean # only those logged - assert len(logged_val) == 4 + assert len(logged_val) == 5 @pytest.mark.parametrize(['batches', 'log_interval', 'max_epochs'], [(1, 1, 1), (64, 32, 2)]) diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 081b8cc7237fc..7f9acb4cac8a1 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -618,7 +618,7 @@ def training_step(self, batch, batch_idx): output = self.layer(batch) loss = self.loss(batch, output) self.manual_loss.append(loss) - self.log('train_loss', loss) + self.log('train_loss', loss, prog_bar=True) return {"loss": loss} max_epochs = 2 @@ -669,7 +669,7 @@ def get_expected_output(func_attr, original_values): # Make sure the func_name output equals the average from all logged values when on_epoch true # pop extra keys trainer.callback_metrics.pop("debug_epoch") - assert trainer.logged_metrics["train_loss"] == model.manual_loss[-1] + assert trainer.progress_bar_dict["train_loss"] == model.manual_loss[-1] assert trainer.callback_metrics["train_loss"] == model.manual_loss[-1] trainer.callback_metrics.pop("train_loss") From 9280a06bff055fd355bf287e657ae9650a788118 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 24 May 2021 17:29:36 +0100 Subject: [PATCH 090/455] update --- pytorch_lightning/core/step_result.py | 7 +- .../logger_connector/epoch_result_store.py | 493 ------------------ .../logger_connector/logger_connector.py | 16 +- .../logger_connector/metrics_holder.py | 82 --- pytorch_lightning/trainer/evaluation_loop.py | 4 +- pytorch_lightning/trainer/training_loop.py | 2 +- .../logging_/test_eval_loop_logging.py | 45 +- 7 files changed, 45 insertions(+), 604 deletions(-) delete mode 100644 pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py delete mode 100644 pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 5fa208ce455ef..a170678d46040 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -123,8 +123,9 @@ def __repr__(self) -> str: class ResultCollection(dict): - def __init__(self) -> None: + def __init__(self, is_train: bool) -> None: super().__init__() + self.is_train = is_train self.reset() @property @@ -278,7 +279,9 @@ def get_metrics(self, on_step: bool) -> Dict[str, Dict[str, torch.Tensor]]: if logger: metrics[DefaultMetricsKeys.LOG][name_forked] = value metrics[DefaultMetricsKeys.CALLBACK][name] = value - metrics[DefaultMetricsKeys.CALLBACK][name_forked] = value + + if self.is_train or (not self.is_train and not on_step): + metrics[DefaultMetricsKeys.CALLBACK][name_forked] = value if prog_bar: value = apply_to_collection(value, torch.Tensor, self._to_item, remove_none=True) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py deleted file mode 100644 index 3d6370e3eb658..0000000000000 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ /dev/null @@ -1,493 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from collections import defaultdict -from typing import Any, Dict, List, Optional, Tuple -from weakref import proxy - -import torch - -import pytorch_lightning as pl -from pytorch_lightning.core.step_result import Result -from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import DistributedType, LightningEnum - - -class ResultStoreType(LightningEnum): - INSIDE_BATCH_TRAIN_LOOP = "inside_batch_train_loop" - OUTSIDE_BATCH_TRAIN_LOOP = "outside_batch_train_loop" - - -class HookResultStore: - """ - This class is defined for internal usage. - It holds all metrics logged using the self.log function - in the scope of ModelHooks or Callback functions. - - We need to differentiate 3 different scenarios: - - (1): We are outside of a batch loop - * It means no dataloader_idx, no optimizer idx, etc.. - - (2): We are inside the training batch loop - * We have an optimizer idx and split idx to track - - (3): We are inside the evaluation loop - * We have a dataloader_idx to track - - The data store `Result` objects for those 3 scenarios in `self._internals`. - - (1): self._internals = {dataloader_idx: [Result(), ..., Result()]} - * dataloader_idx not being defined, it is set to 0 b default - (2): self._internals = {dataloader_idx: {optimizer_idx: {batch_idx: [Result(), ..., Result()]}}} - (3): Same as (1) for simplicity - - Those data structures enables us to reduce properly Result object when batch loop is finished. - """ - - def __init__(self, fx_name: str) -> None: - self._fx_name = fx_name - self._internals = {} - self._internals_reduced = {} - self._internal_type: Optional[ResultStoreType] = None - self.has_reduced = False - self._latest_ref = {} - - @property - def num_dataloaders(self) -> int: - return len(self._internals_reduced if self.has_reduced else self._internals) - - def check_dataloader_idx(self, result: Result) -> bool: - random_key = list(result.keys())[-1] - return result["meta"][random_key]["dataloader_idx"] is not None - - def get_latest_from_func_name(self, latest_result_opt, func_name: str, *args, **kwargs) -> Dict: - results = {} - for opt_idx in latest_result_opt: - latest_result = latest_result_opt[opt_idx] - add_dataloader_idx = self.check_dataloader_idx(latest_result) - func = getattr(latest_result, func_name) - results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs)) - return results - - def run_latest_batch_metrics_with_func_name(self, func_name, *args, **kwargs) -> List[Dict]: - """ - This function used cache_ref and cache_result to optimize loading metrics - - Context: As we update the logger_connector metrics on every `self.log` call, - and it can be pretty time consuming, especially when logging outside batch loop. - - HookResultStore keeps track of its latest added result object, - and cache its pbar and log metrics if already called on, - """ - return [ - self.get_latest_from_func_name(self._latest_ref[dl_idx], func_name, *args, **kwargs) - for dl_idx in range(self.num_dataloaders) - ] - - def get_batch_pbar_metrics(self, *args, **kwargs): - return self.run_latest_batch_metrics_with_func_name("get_batch_pbar_metrics", *args, **kwargs) - - def get_batch_log_metrics(self, *args, **kwargs): - return self.run_latest_batch_metrics_with_func_name("get_batch_log_metrics", *args, **kwargs) - - def run_epoch_func(self, results, opt_metric, func_name, *args, **kwargs) -> None: - if not isinstance(opt_metric, Result): - raise Exception("The provided opt_metric should be a Result Object. Something is wrong") - - func = getattr(opt_metric, func_name) - metrics_to_log = func(*args, add_dataloader_idx=self.num_dataloaders > 1, **kwargs) - - results.append(metrics_to_log) - - def get_epoch_from_func_name(self, func_name, *args, **kwargs) -> List[Dict]: - results = [] - for dl_idx in range(self.num_dataloaders): - opt_metrics = self._internals_reduced[dl_idx] - if isinstance(opt_metrics, defaultdict): - for opt_metric in opt_metrics.values(): - self.run_epoch_func(results, opt_metric, func_name, *args, **kwargs) - else: - self.run_epoch_func(results, opt_metrics, func_name, *args, **kwargs) - return results - - def get_epoch_pbar_metrics(self, *_, **__) -> List[Dict]: - return self.get_epoch_from_func_name("get_epoch_pbar_metrics") - - def get_epoch_log_metrics(self, *_, **__) -> List[Dict]: - return self.get_epoch_from_func_name("get_epoch_log_metrics") - - def get_forked_metrics(self, *_, **__) -> List[Dict]: - return self.get_epoch_from_func_name("get_forked_metrics") - - def append(self, result: Result, info: Dict) -> None: - dataloader_idx = info["dataloader_idx"] - self._internal_type = info["type"] - opt_idx = info["opt_idx"] - - if self._internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP: - if dataloader_idx not in self._internals: - self._internals_reduced[dataloader_idx] = defaultdict(dict) - self._latest_ref[dataloader_idx] = {} - self._internals.setdefault(dataloader_idx, {}) - - batch_idx = info["batch_idx"] - self._internals[dataloader_idx].setdefault(opt_idx, {}) - self._internals[dataloader_idx][opt_idx].setdefault(batch_idx, []) - self._internals[dataloader_idx][opt_idx][batch_idx].append(result) - else: - self._internals.setdefault(dataloader_idx, []) - self._internals[dataloader_idx].append(result) - self._latest_ref.setdefault(dataloader_idx, {}) - - self._latest_ref[dataloader_idx].setdefault(opt_idx, {}) - self._latest_ref[dataloader_idx][opt_idx] = result - - def auto_reduce_results_on_epoch_end(self) -> None: - """ - This function is called to reduce `self._internals` Result object. - The reduced Result object will be saved into `self._internals_reduced` - The `self._internals` stored Result objects will be deleted to save memory. - """ - if self.has_reduced: - return - for dl_idx in range(self.num_dataloaders): - epoch_metrics = self._internals[dl_idx] - - if self._internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP: - for opt_idx in list(epoch_metrics): - # TODO: Figure out to reduce memory - # TODO: How to start training in middle of epoch - outputs = epoch_metrics[opt_idx] - # reduce across time first - time_reduced_outputs = [] - for tbptt_outputs in outputs.values(): - tbptt_outputs = type(tbptt_outputs[0]).reduce_across_time(tbptt_outputs) - if len(tbptt_outputs) > 1: - time_reduced_outputs.append(tbptt_outputs) - - if len(time_reduced_outputs) == 0: - continue - - # reduce across training steps - outputs = type(time_reduced_outputs[0]).reduce_on_epoch_end(time_reduced_outputs) - - # with manual opt need 1 + metrics because meta is always there - if outputs.minimize is not None: - outputs.minimize = outputs.minimize.mean() - - self._internals_reduced[dl_idx][opt_idx] = outputs - - # free memory - del self._internals[dl_idx][opt_idx] - else: - reduced_epoch_metrics = epoch_metrics[0] - if len(epoch_metrics) != 1: - reduced_epoch_metrics = type(reduced_epoch_metrics).reduce_on_epoch_end(epoch_metrics) - - self._internals_reduced[dl_idx] = reduced_epoch_metrics - - # free memory - del self._internals[dl_idx] - - self.has_reduced = True - - def reset(self) -> None: - """ - Call at the end of epoch to reset Result objects - """ - for dl_idx in range(self.num_dataloaders): - epoch_metrics = self._internals[dl_idx] if not self.has_reduced else self._internals_reduced[dl_idx] - if self._internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP: - for opt_idx in list(epoch_metrics): - epoch_metrics[opt_idx].reset() - else: - epoch_metrics.reset() - - def __getitem__(self, key: str) -> Any: - return self._internals.get(key, None) - - def __repr__(self): - return self._internals.__repr__() - - -class EpochResultStore: - """ - This class is defined for internal usage. - It holds all metrics logged using the self.log function inside `HookResultStore` objects. - - The internal data-structure is as follow: - self._internals = {"fx_name_0": HookResultStore(), ..., "fx_name_n": HookResultStore()} - - ..example:: - - model._results = Result() - model._current_fx_name = 'something' - model.log('a', ...) - epoch_result_store.cache_result() - """ - - def __init__(self, trainer: 'pl.Trainer') -> None: - self.trainer = proxy(trainer) - self._internals = {} - self.reset() - - def __getitem__(self, key: str) -> Any: - return self._internals.get(key, None) - - @property - def info(self): - """ - This function provides necessary parameters to properly configure HookResultStore obj - """ - model_ref = self.trainer.lightning_module - return { - "batch_idx": self.trainer.train_loop.batch_idx, - "fx_name": model_ref._current_fx_name, - "dataloader_idx": model_ref._current_dataloader_idx or 0, - "opt_idx": self._opt_idx or 0, - "split_idx": self._split_idx or 0, - "type": ( - ResultStoreType.INSIDE_BATCH_TRAIN_LOOP if self._opt_idx is not None and self._split_idx is not None - else ResultStoreType.OUTSIDE_BATCH_TRAIN_LOOP - ) - } - - def reset_model(self): - """ - This function is used to reset model state at the end of the capture - """ - model_ref = self.trainer.lightning_module - model_ref._results = Result() - model_ref._current_fx_name = None - - def cache_result(self) -> None: - """ - This function is called after every hook and stores the result object - """ - with self.trainer.profiler.profile("cache_result"): - model_ref = self.trainer.lightning_module - - # extract hook results - hook_result = model_ref._results - - if len(hook_result) == 1: - model_ref._current_fx_name = None - return - - info = self.info - fx_name = info["fx_name"] - - self._internals.setdefault(fx_name, HookResultStore(fx_name)) - - # attach capture batch_size - Result.attach_batch_size(self._batch_size, hook_result) - - hook_result = hook_result.detach() - if self.trainer.move_metrics_to_cpu: - hook_result = hook_result.cpu() - elif self.trainer._distrib_type == DistributedType.DP: - hook_result = hook_result.to(torch.device("cuda", self.trainer.root_gpu)) - - self._internals[fx_name].append(hook_result, info) - - # update logged_metrics, progress_bar_metrics, callback_metrics - if "epoch_end" in fx_name: - self.update_logger_connector() - - self.reset_model() - - def update_logger_connector(self) -> Tuple[Dict, Dict]: - """ - This function is called every time we capture a hook - It automatically updates the logger_connector followings: - - progress_bar_metrics with pbar_metrics - - logged_metrics with log_metrics - - callback_metrics with progress_bar_metrics + logged_metrics - """ - - logger_connector = self.trainer.logger_connector - - callback_metrics = {} - batch_pbar_metrics = {} - batch_log_metrics = {} - - if not self._has_batch_loop_finished: - # get pbar - batch_pbar_metrics = self.get_latest_batch_pbar_metrics() - logger_connector.add_progress_bar_metrics(batch_pbar_metrics) - batch_log_metrics = self.get_latest_batch_log_metrics() - - if self.trainer.training: - logger_connector._logged_metrics.update(batch_log_metrics) - callback_metrics.update(batch_pbar_metrics) - callback_metrics.update(batch_log_metrics) - else: - # get pbar - epoch_pbar_metrics = self.get_epoch_pbar_metrics() - logger_connector.add_progress_bar_metrics(epoch_pbar_metrics) - - # get logged_metrics - epoch_log_metrics = self.get_epoch_log_metrics() - logger_connector._logged_metrics.update(epoch_log_metrics) - logger_connector._logged_metrics.update({"epoch": self.trainer.current_epoch}) - - # get forked_metrics - forked_metrics = self.get_forked_metrics() - - callback_metrics.update(epoch_pbar_metrics) - callback_metrics.update(epoch_log_metrics) - callback_metrics.update(forked_metrics) - - # TODO(carmocca): when we implement flushing the logger connector metrics after - # the trainer.state changes, this should check trainer.evaluating instead - if self.trainer.state.fn in (TrainerFn.TESTING, TrainerFn.VALIDATING): - logger_connector.evaluation_callback_metrics.update(callback_metrics) - - # update callback_metrics - logger_connector._callback_metrics.update(callback_metrics) - - batch_pbar_metrics.pop("debug_epoch", None) - return batch_pbar_metrics, batch_log_metrics - - def run_batch_from_func_name(self, func_name) -> Dict: - results = [getattr(hook_result, func_name) for hook_result in self._internals.values()] - results = [func(include_forked_originals=False) for func in results] - return {k: v for d in sum(results, []) for k, v in d.items()} # List[List[dict]] -> dict - - def get_latest_batch_log_metrics(self) -> Dict: - batch_log_metrics = self.run_batch_from_func_name("get_batch_log_metrics") - return batch_log_metrics - - def get_latest_batch_pbar_metrics(self) -> Dict: - batch_pbar_metrics = self.run_batch_from_func_name("get_batch_pbar_metrics") - return batch_pbar_metrics - - @property - def has_reduced(self) -> bool: - hook_results = self._internals.values() - return len(hook_results) == sum(h.has_reduced for h in hook_results) - - def auto_reduce_results_on_epoch_end(self) -> None: - if not self.has_reduced: - for hook_result in self._internals.values(): - hook_result.auto_reduce_results_on_epoch_end() - - @property - def has_batch_loop_finished(self) -> bool: - return self._has_batch_loop_finished - - @has_batch_loop_finished.setter - def has_batch_loop_finished(self, has_batch_loop_finished): - if has_batch_loop_finished: - # If batch loop has finished, reduce metrics - self.auto_reduce_results_on_epoch_end() - - # batch_size should be none as we finished batch loop - self._batch_size = None - - self._has_batch_loop_finished = has_batch_loop_finished - self.update_logger_connector() - - def run_epoch_by_func_name(self, func_name) -> Dict: - if not self.has_reduced: - self.auto_reduce_results_on_epoch_end() - results = [getattr(hook_result, func_name) for hook_result in self._internals.values()] - results = [func() for func in results] - return {k: v for d in sum(results, []) for k, v in d.items()} # List[List[dict]] -> dict - - def get_epoch_pbar_metrics(self) -> Dict: - return self.run_epoch_by_func_name("get_epoch_pbar_metrics") - - def get_epoch_log_metrics(self) -> Dict: - return self.run_epoch_by_func_name("get_epoch_log_metrics") - - def get_forked_metrics(self) -> Dict: - return self.run_epoch_by_func_name("get_forked_metrics") - - def reset(self) -> None: - for value in self._internals.values(): - value.reset() - self._internals = {} - self._dataloader_idx: Optional[int] = None - self._split_idx: Optional[int] = None - self._opt_idx: Optional[int] = None - self._batch_size: Optional[int] = None - self._has_batch_loop_finished = False - - def __call__( - self, - fx_name: str, - dl_idx: Optional[int] = None, - opt_idx: Optional[int] = None, - batch_idx: Optional[int] = None, - split_idx: Optional[int] = None, - reduced: bool = False, - ): - """ - This function is a helper to access stored data - - It access data from the HookResultStore. Please, - check its data structure for better understanding - - Data can be accessed with the following chains: - - IF REDUCED: - * IF accessing a fx_name defined in batch training loop: - fx_name -> dl_idx -> opt_idx -> batch_idx -> split_idx - * ELSE fx_name -> dl_idx -> batch_idx - ELSE: - * IF accessing a fx_name defined in batch training loop: - fx_name -> dl_idx -> opt_idx - * ELSE fx_name -> dl_idx - - Note: - As soon as a param is None, it breaks the chain and returns associated stored data. - - Example:: - - result: Result = self(fx_name="training_step", dl_idx=0, opt_idx=0, reduced=True) - result['train_loss_epoch'] # aggregated train_loss over one epoch. - - Args: - - fx_name: Hook name from ModelHooks or Callback. Example: ``"training_step"`` - - dl_idx: Dataloader index in short. From ``0`` to ``num_dataloaders - 1`` - - opt_idx: Optimizer index in short. From ``0`` to ``num_optimizers - 1`` - - batch_idx: Batch index seen during batch training or evaluation. - Works only with ``reduced=False`` - - split_idx: Index of split idx in training loop when tbptt is used. - - reduced: Data are being aggregated on on_epoch_end. - Indicates if we want to access the aggregated Result or not. - """ - hook_result = self[fx_name] - internal_type = hook_result._internal_type - result = hook_result._internals_reduced if reduced else hook_result._internals - - if dl_idx is not None: - result = result[dl_idx] - if internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP: - if opt_idx is not None: - result = result[opt_idx] - if not reduced and batch_idx is not None: - result = result[batch_idx] - if split_idx is not None: - result = result[split_idx] - elif not reduced and batch_idx is not None: - result = result[batch_idx] - return result - - def __repr__(self): - return f"{self.__class__.__name__}(internals={self._internals})" diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 227ce08d9a2ef..35e6b64c50ef8 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -14,19 +14,16 @@ import os from copy import deepcopy from pprint import pprint -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, Optional import torch from pytorch_lightning.core import memory -from pytorch_lightning.core.step_result import DefaultMetricsKeys, Result, ResultCollection +from pytorch_lightning.core.step_result import DefaultMetricsKeys from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger -from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator -from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import DeviceType -from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.metrics import metrics_to_scalars from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT @@ -206,6 +203,11 @@ def log_evaluation_step_metrics(self) -> None: return metrics = self.trainer.result_collections.metrics + + # update metrics + self.add_progress_bar_metrics(metrics[DefaultMetricsKeys.PBAR]) + self.add_callback_metrics(metrics[DefaultMetricsKeys.CALLBACK]) + batch_log_metrics = metrics[DefaultMetricsKeys.LOG] # logs user requested information to logger @@ -263,11 +265,11 @@ def update_train_epoch_metrics(self) -> None: self._callback_metrics.update(callback_metrics) epoch_log_metrics = metrics[DefaultMetricsKeys.LOG] - epoch_log_metrics["epoch"] = self.trainer.current_epoch - self._logged_metrics.update(epoch_log_metrics) # add the metrics to the loggers if epoch_log_metrics and len(epoch_log_metrics) > 0: + epoch_log_metrics["epoch"] = self.trainer.current_epoch + self._logged_metrics.update(epoch_log_metrics) self.log_metrics(epoch_log_metrics, {}) # reset result collection for next epoch diff --git a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py deleted file mode 100644 index 8f12f57c640b0..0000000000000 --- a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import numbers -from typing import Dict, Optional - -import torch -from torchmetrics import Metric - -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import _METRIC - - -class MetricsHolder: - """ - This class acts as a dictionary holder. - It holds metrics and implements conversion functions. - Those functions will be triggered within LoggerConnector - when the property is being requested from the user. - """ - - def __init__(self, to_float: bool = False) -> None: - self.metrics: Dict[str, _METRIC] = {} - self._to_float = to_float - - def update(self, metrics: dict) -> None: - self.metrics.update(metrics) - - def pop(self, key: str, default: _METRIC) -> _METRIC: - return self.metrics.pop(key, default) - - def reset(self, metrics: Dict[str, _METRIC]) -> None: - self.metrics = metrics - - def convert(self, device: Optional[torch.device]) -> None: - for key, value in self.metrics.items(): - if self._to_float: - if isinstance(value, torch.Tensor) and value.numel() != 1: - raise MisconfigurationException( - f"The metric `{key}` does not contain a single element" - f" thus it cannot be converted to float. Found `{value}`" - ) - converted = self._convert_to_float(value) - else: - converted = self._convert_to_tensor(value, device) - self.metrics[key] = converted - - @staticmethod - def _convert_to_float(current: _METRIC) -> float: - if isinstance(current, Metric): - current = current.compute().detach() - - if isinstance(current, torch.Tensor): - current = float(current.item()) - - elif isinstance(current, int): - current = float(current) - - return current - - @staticmethod - def _convert_to_tensor(current: _METRIC, device: Optional[torch.device]) -> torch.Tensor: - if isinstance(current, Metric): - current = current.compute().detach() - - elif isinstance(current, numbers.Number): - current = torch.tensor(current, device=device, dtype=torch.float) - - if isinstance(current, torch.Tensor) and current.device.type == "xla": - current = current.cpu() - - return current diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 45957948d77b7..a73cb341ff4a0 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -34,8 +34,8 @@ def __init__(self, trainer: 'pl.Trainer'): self.max_batches: Optional[List[Union[int, float]]] = None self.warning_cache = WarningCache() self.num_dataloaders: Optional[int] = None - self.validation_results = ResultCollection() - self.test_results = ResultCollection() + self.validation_results = ResultCollection(False) + self.test_results = ResultCollection(False) def on_trainer_init(self) -> None: self.trainer.num_sanity_val_batches = [] diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 70ac7e236dd83..618faeff227c9 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -83,7 +83,7 @@ def __init__( else: self.trainer.num_sanity_val_steps = num_sanity_val_steps - self.train_results = ResultCollection() + self.train_results = ResultCollection(True) @property def num_active_optimizers(self) -> int: diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index c583a47c26848..c21ebfc94e21b 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -207,7 +207,7 @@ def validation_epoch_end(self, outputs): assert callback_metrics == expected_callback_metrics # assert the loggers received the expected number - assert len(trainer.dev_debugger.logged_metrics) == max_epochs * 2 + assert len(trainer.dev_debugger.logged_metrics) == max_epochs def test_eval_float_logging(tmpdir): @@ -301,15 +301,15 @@ def validation_epoch_end(self, outputs) -> None: assert trainer.logged_metrics["val_loss_step"] == model.seen_vals[-1] # 3 val batches - assert logged_val[1]['val_loss_step'] == model.seen_vals[0] - assert logged_val[2]['val_loss_step'] == model.seen_vals[1] - assert logged_val[3]['val_loss_step'] == model.seen_vals[2] + assert logged_val[0]['val_loss_step'] == model.seen_vals[0] + assert logged_val[1]['val_loss_step'] == model.seen_vals[1] + assert logged_val[2]['val_loss_step'] == model.seen_vals[2] # epoch mean - assert logged_val[4]['val_loss_epoch'] == model.manual_epoch_end_mean + assert logged_val[3]['val_loss_epoch'] == model.manual_epoch_end_mean # only those logged - assert len(logged_val) == 5 + assert len(logged_val) == 4 @pytest.mark.parametrize(['batches', 'log_interval', 'max_epochs'], [(1, 1, 1), (64, 32, 2)]) @@ -449,7 +449,8 @@ def make_logging(self, pl_module, func_name, func_idx, on_steps=[], on_epochs=[] "on_epoch": on_epoch, "prog_bar": prog_bar, "forked": on_step and on_epoch, - "func_name": func_name + "func_name": func_name, + "training": self.log.__self__.trainer.training } if on_step and on_epoch: @@ -458,7 +459,8 @@ def make_logging(self, pl_module, func_name, func_idx, on_steps=[], on_epochs=[] "on_epoch": False, "prog_bar": prog_bar, "forked": False, - "func_name": func_name + "func_name": func_name, + "training": self.log.__self__.trainer.training } self.funcs_attr[f"{custom_func_name}_epoch"] = { @@ -466,7 +468,8 @@ def make_logging(self, pl_module, func_name, func_idx, on_steps=[], on_epochs=[] "on_epoch": True, "prog_bar": prog_bar, "forked": False, - "func_name": func_name + "func_name": func_name, + "training": self.log.__self__.trainer.training } def on_validation_start(self, trainer, pl_module): @@ -542,6 +545,9 @@ def validation_step(self, batch, batch_idx): loss = self.loss(batch, output) self.log('val_loss', loss) + def on_validation_end(self) -> None: + print(self.trainer.result_collections) + max_epochs = 1 model = TestModel() model.validation_epoch_end = None @@ -580,13 +586,12 @@ def validation_step(self, batch, batch_idx): # function used to describe expected return logic def get_expected_output(func_attr, original_values): - - if func_attr["on_epoch"] and not func_attr["on_step"]: - # Apply mean on values - expected_output = np.mean(original_values) - else: + if func_attr["on_step"] and not func_attr["on_epoch"]: # Keep the latest value expected_output = np.max(original_values) + else: + # Apply mean on values + expected_output = np.mean(original_values) return expected_output # Make sure the func_name output equals the average from all logged values when on_epoch true @@ -612,10 +617,16 @@ def get_expected_output(func_attr, original_values): assert float(output_value) == float(expected_output) for func_name, func_attr in test_callback.funcs_attr.items(): - if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]) and not func_attr["forked"]: - assert func_name in trainer.logger_connector.progress_bar_metrics + if "on_batch_end" in func_name: + continue + if func_attr["prog_bar"] and (func_attr["on_epoch"] or func_attr["on_step"]): + try: + assert func_name in trainer.progress_bar_metrics + except: + import pdb + pdb.set_trace() else: - assert func_name not in trainer.logger_connector.progress_bar_metrics + assert func_name not in trainer.progress_bar_metrics @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) From 1ac2e7f73e398dff11d10508ba651127c297f2b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 24 May 2021 18:55:09 +0200 Subject: [PATCH 091/455] add note about #7677 --- pytorch_lightning/loops/epoch_loop.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 9a366c85705d8..51d9273378ce3 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -158,9 +158,8 @@ def on_advance_end(self): self.trainer._run_evaluation(on_epoch=True) self.trainer.training = True - # increment the global step once - # progress global step according to grads progress # TODO: move inside training_loop.on_run_end? equivalent? order? + # Needs to check batch_output signal -1, see #7677 self.training_loop.increment_accumulated_grad_global_step() # why is this not the same as the old on_train_epoch_end? From 19f61e0fdd27620bf8b15d51b9a7ddebef9d4618 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 24 May 2021 20:24:39 +0100 Subject: [PATCH 092/455] update --- pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/core/step_result.py | 97 +++++- .../plugins/training_type/ddp2.py | 1 - pytorch_lightning/plugins/training_type/dp.py | 1 - .../logger_connector/logger_connector.py | 20 +- pytorch_lightning/trainer/evaluation_loop.py | 7 +- pytorch_lightning/trainer/trainer.py | 18 +- pytorch_lightning/trainer/training_loop.py | 15 +- tests/core/test_metric_result_integration.py | 1 - tests/core/test_results.py | 1 - tests/models/test_tpu.py | 1 - .../logging_/test_eval_loop_logging.py | 34 +- .../trainer/logging_/test_logger_connector.py | 317 ++---------------- 13 files changed, 145 insertions(+), 370 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 02d2cdb43a572..d96f69a2e78fb 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -38,7 +38,7 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES -from pytorch_lightning.core.step_result import Result, ResultCollection +from pytorch_lightning.core.step_result import ResultCollection from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index a170678d46040..7213fbc4affac 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Mapping, Sequence +from copy import deepcopy from dataclasses import dataclass from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union @@ -31,11 +32,6 @@ class DefaultMetricsKeys(LightningEnum): LOG = "log" -# TODO: remove -class Result: - pass - - @dataclass class Metadata: fx: str # TODO: distinction? @@ -81,7 +77,7 @@ def is_tensor_and_min_reduction(self) -> bool: class ResultMetric(Metric): def __init__(self, metadata: Metadata) -> None: - super().__init__() + super().__init__(compute_on_step=metadata.is_tensor) self.meta = metadata if self.meta.is_tensor: self.add_state("value", torch.tensor(.0)) @@ -101,6 +97,7 @@ def update(self, value: _METRIC, batch_size: Optional[int] = None) -> None: else: self.value = value + self._forward_cache = value._forward_cache def compute(self) -> torch.Tensor: if self.meta.is_tensor: @@ -111,7 +108,10 @@ def compute(self) -> torch.Tensor: else: raise MisconfigurationException("Only mean, max are supported.") else: - return self.value.compute() + try: + return self.value.compute() + except RuntimeError: + return torch.tensor(0.) def __repr__(self) -> str: if self.meta.is_tensor_and_mean_reduction: @@ -120,9 +120,47 @@ def __repr__(self) -> str: attr = f"value={self.value}" return f"{self.__class__.__name__}({attr})" + def reset(self): + if self.meta.is_tensor: + super().reset() + else: + print(self.meta.fx, self.meta.name) + self.value.reset() + + def forward(self, *args, **kwargs): + """ + Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True. + """ + # add current step + with torch.no_grad(): + self.update(*args, **kwargs) + + if self.compute_on_step: + self._to_sync = self.dist_sync_on_step + + # save context before switch + cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} + + # call reset, update, compute, on single batch + self.reset() + self.update(*args, **kwargs) + self._forward_cache = self.compute() + + # restore context + for attr, val in cache.items(): + setattr(self, attr, val) + self._to_sync = True + self._computed = None + + return self._forward_cache + class ResultCollection(dict): + STEP_SUFFIX = "_step" + EPOCH_SUFFIX = "_epoch" + DATALOADER_SUFFIX = "/dataloader_idx_{}" + def __init__(self, is_train: bool) -> None: super().__init__() self.is_train = is_train @@ -198,6 +236,9 @@ def log( key = f"{hook_name}.{name}" + if dataloader_idx: + key += f'.{dataloader_idx}' + if key not in self: meta = Metadata( fx=hook_name, @@ -208,7 +249,7 @@ def log( on_epoch=on_epoch, reduce_fx=reduce_fx, dataloader_idx=dataloader_idx, - should_reset=self.should_reset(hook_name), + should_reset=self.should_reset(hook_name) ) self.instance_result_metric(key, meta, value) @@ -216,20 +257,25 @@ def log( def instance_result_metric(self, key: str, meta: Metadata, value: Union[Dict, torch.Tensor]) -> None: - def fn(*_): + def fn(v): + nonlocal meta + meta = deepcopy(meta) + meta.is_tensor = torch.is_tensor(v) return ResultMetric(meta) - self[key] = apply_to_collection(value, torch.Tensor, fn) + self[key] = apply_to_collection(value, (torch.Tensor, Metric), fn) # cache the meta for reduction if not isinstance(self[key], ResultMetric): self[key + '.forked'] = meta.forked self[key + '.logger'] = meta.logger self[key + '.prog_bar'] = meta.prog_bar + self[key + '.on_epoch'] = meta.on_epoch + self[key + '.dataloader_idx'] = meta.dataloader_idx def update_metrics(self, key: str, value: Union[Dict, torch.Tensor], batch_size) -> None: def fn(result_metric, v): - assert torch.is_tensor(v) + assert isinstance(v, (torch.Tensor, Metric)) result_metric(v, batch_size) apply_to_collections(self[key], value, ResultMetric, fn) @@ -257,38 +303,53 @@ def _extract_metadata(self, key: str, result_metric, on_step: bool, suffix: str) name_forked = result_metric.meta.forked_step_name if on_step else result_metric.meta.forked_epoch_name logger = result_metric.meta.logger prog_bar = result_metric.meta.prog_bar + metric_on_epoch = result_metric.meta.on_epoch + dataloader_idx = result_metric.meta.dataloader_idx else: name = key.split('.')[-1] name_forked = name + suffix if self[key + '.forked'] else name logger = self[key + '.logger'] prog_bar = self[key + '.prog_bar'] - return name, name_forked, logger, prog_bar + metric_on_epoch = self[key + '.on_epoch'] + dataloader_idx = self[key + '.dataloader_idx'] + + if dataloader_idx is not None: + dataloader_suffix = self.DATALOADER_SUFFIX.format(dataloader_idx) + name += dataloader_suffix + name_forked += dataloader_suffix + + return name, name_forked, logger, prog_bar, metric_on_epoch def get_metrics(self, on_step: bool) -> Dict[str, Dict[str, torch.Tensor]]: metrics = {k: {} for k in DefaultMetricsKeys} fn = self._get_forward_cache if on_step else self._get_computed_cache - suffix = "_step" if on_step else "_epoch" + suffix = self.STEP_SUFFIX if on_step else self.EPOCH_SUFFIX for key, result_metric in self.valid_metrics(): value = apply_to_collection(result_metric, ResultMetric, fn, remove_none=True) if value is None: continue - name, name_forked, logger, prog_bar = self._extract_metadata(key, result_metric, on_step, suffix) + name, name_forked, logger, prog_bar, metric_on_epoch = self._extract_metadata( + key, result_metric, on_step, suffix + ) if logger: metrics[DefaultMetricsKeys.LOG][name_forked] = value - metrics[DefaultMetricsKeys.CALLBACK][name] = value - if self.is_train or (not self.is_train and not on_step): + if not self.is_train and (not metric_on_epoch or on_step): + pass + else: + metrics[DefaultMetricsKeys.CALLBACK][name] = value metrics[DefaultMetricsKeys.CALLBACK][name_forked] = value if prog_bar: value = apply_to_collection(value, torch.Tensor, self._to_item, remove_none=True) metrics[DefaultMetricsKeys.PBAR][name_forked] = value + return metrics - def get_batch_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, Dict[str, torch.Tensor]]: + def get_batch_metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: return self.get_metrics(on_step=True) @staticmethod @@ -301,7 +362,7 @@ def _get_computed_cache(result_metric: ResultMetric) -> Optional[torch.Tensor]: return result_metric._computed.detach() - def get_epoch_metrics(self, add_dataloader_idx: bool = False) -> Dict[str, Dict[str, torch.Tensor]]: + def get_epoch_metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: return self.get_metrics(on_step=False) def to(self, *args, **kwargs) -> 'ResultCollection': diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index b6d21904d1933..e66c7a5243d6a 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -13,7 +13,6 @@ # limitations under the License. import torch -from pytorch_lightning.core.step_result import Result from pytorch_lightning.plugins.training_type.ddp import DDPPlugin diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 08caa7398ab8c..01820c7f6b614 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -16,7 +16,6 @@ import torch from torch.nn import DataParallel -from pytorch_lightning.core.step_result import Result from pytorch_lightning.overrides.data_parallel import LightningParallelModule from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 35e6b64c50ef8..d235e6c47d518 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -137,6 +137,8 @@ def add_to_eval_loop_results(self, dl_idx, has_been_initialized): return callback_metrics = self.trainer.result_collections.metrics[DefaultMetricsKeys.CALLBACK] + if os.getenv("PL_DEV_DEBUG", '0') == '1': + callback_metrics["debug_epoch"] = self.trainer.current_epoch callback_metrics = deepcopy(callback_metrics) for key in list(callback_metrics.keys()): if "dataloader_idx" in key: @@ -156,8 +158,14 @@ def prepare_eval_loop_results(self): def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: if not self.trainer.sanity_checking: + + metrics = self.trainer.result_collections.metrics + # update metrics + self.add_progress_bar_metrics(metrics[DefaultMetricsKeys.PBAR]) + self.add_callback_metrics(metrics[DefaultMetricsKeys.CALLBACK]) + # log all the metrics as a single dict - metrics_to_log = self.trainer.result_collections.metrics[DefaultMetricsKeys.LOG] + metrics_to_log = metrics[DefaultMetricsKeys.LOG] if len(metrics_to_log) > 0: self.log_metrics(metrics_to_log, {}) @@ -198,7 +206,7 @@ def increment_evaluation_log_step(self) -> None: elif self.trainer.state.stage is RunningStage.TESTING: self._test_log_step += 1 - def log_evaluation_step_metrics(self) -> None: + def update_evaluation_step_metrics(self) -> None: if self.trainer.sanity_checking: return @@ -259,8 +267,6 @@ def update_train_epoch_metrics(self) -> None: self.add_progress_bar_metrics(metrics[DefaultMetricsKeys.PBAR]) callback_metrics = metrics[DefaultMetricsKeys.CALLBACK] - if os.getenv("PL_DEV_DEBUG", '0') == '1': - callback_metrics["debug_epoch"] = self.trainer.current_epoch self._callback_metrics.update(callback_metrics) @@ -284,6 +290,8 @@ def callback_metrics(self) -> Dict: if self.trainer.result_collections: metrics = self.trainer.result_collections.metrics[DefaultMetricsKeys.CALLBACK] self._callback_metrics.update(metrics) + if os.getenv("PL_DEV_DEBUG", '0') == '1': + self._callback_metrics["debug_epoch"] = self.trainer.current_epoch return self._callback_metrics @property @@ -314,8 +322,4 @@ def add_callback_metrics(self, metrics): def check_logging(self, fx_name: str, on_step: bool, on_epoch: bool) -> None: self._fx_validator.check_logging(fx_name=fx_name, on_step=on_step, on_epoch=on_epoch) - def reset(self): - if self.trainer.result_collections: - self.trainer.result_collections.reset() - ############## UTILS END ############## diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index a73cb341ff4a0..fb80cc927893b 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -17,7 +17,7 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning.core.step_result import Result, ResultCollection +from pytorch_lightning.core.step_result import ResultCollection from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.model_helpers import is_overridden @@ -99,6 +99,9 @@ def on_evaluation_model_train(self) -> None: model_ref.on_validation_model_train() def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: + if self.trainer.result_collections: + self.trainer.result_collections.reset() + if self.trainer.testing: self.trainer.call_hook('on_test_end', *args, **kwargs) else: @@ -233,7 +236,7 @@ def on_evaluation_batch_end( def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, dataloader_idx: int) -> None: # Add step predictions to prediction collection to write later if output is not None and self.predictions is not None: - if isinstance(output, Result) and self.trainer.testing: + if isinstance(output, ResultCollection) and self.trainer.testing: self.predictions.add(output.pop('predictions', None)) # track debug metrics diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5310c8e724ad4..973dc6444263e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -22,13 +22,14 @@ import torch from torch.utils.data import DataLoader +from torchmetrics.metric import Metric from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.memory import ModelSummary -from pytorch_lightning.core.step_result import Result +from pytorch_lightning.core.step_result import ResultCollection from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.plugins import Plugin from pytorch_lightning.plugins.environments import ClusterEnvironment @@ -986,7 +987,7 @@ def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: self.evaluation_loop.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx) # log batch metrics - self.logger_connector.log_evaluation_step_metrics() + self.logger_connector.update_evaluation_step_metrics() # track epoch level outputs dl_outputs = self._track_output_for_epoch_end(dl_outputs, output) @@ -1033,16 +1034,13 @@ def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: # enable train mode again self.evaluation_loop.on_evaluation_model_train() - # reset cached results - self.logger_connector.reset() - torch.set_grad_enabled(True) return eval_loop_results def _track_output_for_epoch_end(self, outputs, output): if output is not None: - if isinstance(output, Result): + if isinstance(output, ResultCollection): output = output.detach() if self.move_metrics_to_cpu: output = output.cpu() @@ -1129,10 +1127,18 @@ def _run_sanity_check(self, ref_model): self.state.stage = stage + # reset metrics + self._reset_metrics(ref_model) + # reset the seed to what it was before sanity check # prevents sanity check to affect random sampling in training reset_seed() + def _reset_metrics(self, ref_model): + for module in ref_model.modules(): + if isinstance(module, Metric): + module.reset() + def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]: if ckpt_path is None: return diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 58b877f2d471c..b4d6bb70ce9b9 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -23,7 +23,7 @@ from torch.optim import Optimizer from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.core.step_result import Result, ResultCollection +from pytorch_lightning.core.step_result import ResultCollection from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType @@ -206,17 +206,14 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): # track the outputs to reduce at the end of the epoch for opt_idx, opt_outputs in enumerate(batch_end_outputs): - sample_output = opt_outputs[-1] - - # decide if we need to reduce at the end of the epoch automatically - auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end - if not (hook_overridden or auto_reduce_tng_result): + if not hook_overridden: continue # with 1 step (no tbptt) don't use a sequence at epoch end - if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result): + if isinstance(opt_outputs, + list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], ResultCollection): opt_outputs = opt_outputs[0] epoch_output[opt_idx].append(opt_outputs) @@ -350,7 +347,7 @@ def _process_training_step_output(self, training_step_output, split_batch): @staticmethod def _prepare_outputs( - outputs: List[List[List[Result]]], + outputs: List[List[List['ResultCollection']]], batch_mode: bool, ) -> Union[List[List[List[Dict]]], List[List[Dict]], List[Dict], Dict]: """ @@ -576,7 +573,7 @@ def run_training_epoch(self): # progress global step according to grads progress self.increment_accumulated_grad_global_step() - def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None: + def on_train_epoch_end(self, epoch_output: List[List[List['ResultCollection']]]) -> None: # inform logger the batch loop has finished self.trainer.logger_connector.on_train_epoch_end() diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 734b9e7f56152..d54c4cfdbb043 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -18,7 +18,6 @@ from torchmetrics import Metric import tests.helpers.utils as tutils -from pytorch_lightning.core.step_result import Result from tests.helpers.runif import RunIf diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 02d30d9f79ee3..2ce74fe8598c5 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -22,7 +22,6 @@ import tests.helpers.utils as tutils from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.core.step_result import Result from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index f7d0aea829ced..f66229b959554 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -24,7 +24,6 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.core.step_result import Result from pytorch_lightning.plugins import TPUSpawnPlugin from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.distributed import ReduceOp diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index c21ebfc94e21b..575eb91db8c03 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -28,6 +28,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.utilities import exceptions from tests.helpers import BoringModel, RandomDataset from tests.helpers.deterministic_model import DeterministicModel @@ -503,11 +504,6 @@ def on_validation_epoch_start(self, trainer, pl_module): prob_bars=self.choices ) - def on_batch_end(self, trainer, pl_module): - self.make_logging( - pl_module, 'on_batch_end', 6, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices - ) - def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self.make_logging( pl_module, @@ -567,7 +563,6 @@ def on_validation_end(self) -> None: assert test_callback.funcs_called_count["on_epoch_start"] == 1 # assert test_callback.funcs_called_count["on_batch_start"] == 1 - assert test_callback.funcs_called_count["on_batch_end"] == 1 assert test_callback.funcs_called_count["on_validation_start"] == 1 assert test_callback.funcs_called_count["on_validation_epoch_start"] == 1 # assert test_callback.funcs_called_count["on_validation_batch_start"] == 4 @@ -617,14 +612,8 @@ def get_expected_output(func_attr, original_values): assert float(output_value) == float(expected_output) for func_name, func_attr in test_callback.funcs_attr.items(): - if "on_batch_end" in func_name: - continue - if func_attr["prog_bar"] and (func_attr["on_epoch"] or func_attr["on_step"]): - try: - assert func_name in trainer.progress_bar_metrics - except: - import pdb - pdb.set_trace() + if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]) and not func_attr["forked"]: + assert func_name in trainer.progress_bar_metrics else: assert func_name not in trainer.progress_bar_metrics @@ -776,17 +765,16 @@ def test_dataloader(self): # function used to describe expected return logic def get_expected_output(func_attr, original_values): - # Apply mean on values - if func_attr["on_epoch"] and not func_attr["on_step"]: - expected_output = np.mean(original_values) - else: + if func_attr["on_step"] and not func_attr["on_epoch"]: expected_output = np.max(original_values) + else: + expected_output = np.mean(original_values) return expected_output # Make sure the func_name output equals the average from all logged values when on_epoch true # pop extra keys - assert "debug_epoch" in trainer.callback_metrics - trainer.callback_metrics.pop("debug_epoch") + #assert "debug_epoch" in trainer.callback_metrics + #trainer.callback_metrics.pop("debug_epoch") for dl_idx in range(num_dataloaders): key = f"test_loss/dataloader_idx_{dl_idx}" @@ -915,13 +903,13 @@ def get_metrics_at_idx(idx): assert get_metrics_at_idx(6)["valid_loss_1"] == expected results = trainer.test(model) - expected_callback_metrics = { + expected_callback_metrics = set({ 'train_loss', 'valid_loss_0_epoch', 'valid_loss_0', 'debug_epoch', 'valid_loss_1', 'test_loss', - } - assert set(trainer.callback_metrics) == expected_callback_metrics + }) + assert sorted(trainer.callback_metrics) == sorted(expected_callback_metrics) assert set(results[0]) == {'test_loss', 'debug_epoch'} diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index e0e1c3cdf42ec..337423a06e5ca 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -26,251 +26,13 @@ from pytorch_lightning import LightningModule from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.core.step_result import Result from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator -from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf -def decorator_with_arguments(fx_name: str = '', hook_fx_name: str = None) -> Callable: - - def decorator(func: Callable) -> Callable: - - def wrapper(self, *args, **kwargs) -> Any: - # Set information - self._current_fx_name = fx_name - self._current_hook_fx_name = hook_fx_name - self._results = Result() - - result = func(self, *args, **kwargs) - - # cache metrics - self.trainer.logger_connector.cache_logged_metrics() - return result - - return wrapper - - return decorator - - -def test__logger_connector__epoch_result_store__train(tmpdir): - """ - Tests that LoggerConnector will properly capture logged information - and reduce them - """ - - class TestModel(BoringModel): - - train_losses = [] - - @decorator_with_arguments(fx_name="training_step") - def training_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - - self.train_losses.append(loss) - - self.log("train_loss", loss, on_step=True, on_epoch=True) - - return {"loss": loss} - - def training_step_end(self, *_): - self.train_results = deepcopy(self.trainer.logger_connector.cached_results) - - model = TestModel() - model.training_epoch_end = None - model.val_dataloader = None - - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=4, - max_epochs=1, - log_every_n_steps=1, - weights_summary=None, - ) - trainer.fit(model) - - train_results = model.train_results - - assert len(train_results(fx_name="training_step", dl_idx=0, opt_idx=0)) == 2 - generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, batch_idx=0, split_idx=0)["train_loss"] - assert generated == model.train_losses[0] - generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, batch_idx=1, split_idx=0)["train_loss"] - assert generated == model.train_losses[1] - - assert train_results.has_reduced is not True - - train_results.has_batch_loop_finished = True - - assert train_results.has_reduced is True - - generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, reduced=True)['train_loss_epoch'].item() - excepted = torch.stack(model.train_losses).mean().item() - assert generated == excepted - - -def test__logger_connector__epoch_result_store__train__tbptt(tmpdir): - """ - Tests that LoggerConnector will properly capture logged information with ttbt - and reduce them - """ - truncated_bptt_steps = 2 - sequence_size = 30 - batch_size = 30 - - x_seq = torch.rand(batch_size, sequence_size, 1) - y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist() - - class MockSeq2SeqDataset(torch.utils.data.Dataset): - - def __getitem__(self, i): - return x_seq, y_seq_list - - def __len__(self): - return 1 - - class TestModel(BoringModel): - - train_losses = [] - - def __init__(self): - super().__init__() - self.test_hidden = None - self.layer = torch.nn.Linear(2, 2) - - @decorator_with_arguments(fx_name="training_step") - def training_step(self, batch, batch_idx, hiddens): - assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" - self.test_hidden = torch.rand(1) - - x_tensor, y_list = batch - assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed" - - y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype) - assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" - - pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) - loss = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps)) - - self.train_losses.append(loss) - - self.log('a', loss, on_epoch=True) - - return {'loss': loss, 'hiddens': self.test_hidden} - - def on_train_epoch_start(self) -> None: - self.test_hidden = None - - def train_dataloader(self): - return torch.utils.data.DataLoader( - dataset=MockSeq2SeqDataset(), - batch_size=batch_size, - shuffle=False, - sampler=None, - ) - - def training_step_end(self, training_step_output): - self.train_results = deepcopy(self.trainer.logger_connector.cached_results) - # must return - return training_step_output - - model = TestModel() - model.training_epoch_end = None - model.example_input_array = torch.randn(5, truncated_bptt_steps) - - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=10, - limit_val_batches=0, - truncated_bptt_steps=truncated_bptt_steps, - max_epochs=1, - log_every_n_steps=1, - weights_summary=None, - ) - trainer.fit(model) - - train_results = model.train_results - - generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, batch_idx=0) - assert len(generated) == len(model.train_losses) - - # assert reduction didn't happen yet - assert train_results.has_reduced is False - - # Launch reduction - train_results.has_batch_loop_finished = True - - # assert reduction did happen - assert train_results.has_reduced is True - - generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, reduced=True)['a_epoch'].item() - assert generated == torch.stack(model.train_losses).mean().item() - - -@pytest.mark.parametrize('num_dataloaders', [1, 2]) -def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, num_dataloaders): - """ - Tests that LoggerConnector will properly capture logged information in multi dataloaders scenario - """ - - class TestModel(BoringModel): - test_losses = {dl_idx: [] for dl_idx in range(num_dataloaders)} - - @decorator_with_arguments(fx_name="test_step") - def test_step(self, batch, batch_idx, dl_idx=0): - output = self.layer(batch) - loss = self.loss(batch, output) - self.test_losses[dl_idx].append(loss) - self.log("test_loss", loss, on_step=True, on_epoch=True) - return {"test_loss": loss} - - def on_test_batch_end(self, *args, **kwargs): - # save objects as it will be reset at the end of epoch. - self.batch_results = deepcopy(self.trainer.logger_connector.cached_results) - - def on_test_epoch_end(self): - # save objects as it will be reset at the end of epoch. - self.reduce_results = deepcopy(self.trainer.logger_connector.cached_results) - - def test_dataloader(self): - return [super().test_dataloader()] * num_dataloaders - - model = TestModel() - model.test_epoch_end = None - limit_test_batches = 4 - - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=0, - limit_val_batches=0, - limit_test_batches=limit_test_batches, - max_epochs=1, - log_every_n_steps=1, - weights_summary=None, - ) - trainer.test(model) - - test_results = model.batch_results - - generated = test_results(fx_name="test_step") - assert len(generated) == num_dataloaders - - for dl_idx in range(num_dataloaders): - generated = test_results(fx_name="test_step", dl_idx=dl_idx) - assert len(generated) == limit_test_batches - - test_results = model.reduce_results - - for dl_idx in range(num_dataloaders): - expected = torch.stack(model.test_losses[dl_idx]).mean() - generated = test_results(fx_name="test_step", dl_idx=dl_idx, reduced=True)["test_loss_epoch"] - torch.testing.assert_allclose(generated, expected) - - def test_fx_validator(tmpdir): funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_')]) @@ -444,56 +206,6 @@ def test_dataloader(self): trainer.test(model, ckpt_path=None) -@pytest.mark.parametrize('to_float', [False, True]) -def test_metrics_holder(to_float, tmpdir): - - device = "cuda" if torch.cuda.is_available() else "cpu" - preds = torch.tensor([[0.9, 0.1]], device=device) - - def is_float(value: Any) -> bool: - return isinstance(value, float) - - excepted_function = is_float if to_float else torch.is_tensor - targets = torch.tensor([1], device=device) - acc = Accuracy().to(device) - metric_holder = MetricsHolder(to_float=to_float) - metric_holder.update({ - "x": 1, - "y": torch.tensor(2), - "z": acc(preds, targets), - }) - metric_holder.convert(device) - metrics = metric_holder.metrics - assert excepted_function(metrics["x"]) - assert excepted_function(metrics["y"]) - assert excepted_function(metrics["z"]) - - -def test_metric_holder_raises(tmpdir): - """Check that an error is raised when trying to convert non-scalar tensors""" - - class TestModel(BoringModel): - - def validation_step(self, batch, *args, **kwargs): - output = self(batch) - self.log('test', output) - - def test_step(self, *args, **kwargs): - return self.validation_step(*args, **kwargs) - - model = TestModel() - model.validation_epoch_end = None - model.test_epoch_end = None - - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - - match = "The metric `test` does not contain a single element" - with pytest.raises(MisconfigurationException, match=match): - trainer.validate(model) - with pytest.raises(MisconfigurationException, match=match): - trainer.test(model) - - def test_can_return_tensor_with_more_than_one_element(tmpdir): """Ensure {validation,test}_step return values are not included as callback metrics. #6623""" @@ -562,7 +274,13 @@ def validation_step(self, *args, **kwargs): model = TestModel() model.validation_epoch_end = None - trainer = Trainer(default_root_dir=tmpdir, max_steps=5) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=2, + limit_val_batches=2, + num_sanity_val_steps=0, + ) trainer.fit(model) logged = trainer.logged_metrics @@ -672,26 +390,28 @@ def _assert_epoch_end(self, stage): acc = self._modules[f"acc_{stage}"] ap = self._modules[f"ap_{stage}"] - acc.reset.asset_not_called() - ap.reset.assert_not_called() + acc.reset.assert_called_once() + ap.reset.assert_called_once() - def on_train_epoch_end(self): + def on_train_end(self): self._assert_epoch_end('train') - def on_validation_epoch_end(self): - self._assert_epoch_end('val') + def on_validation_end(self): + if not self.trainer.sanity_checking: + self._assert_epoch_end('val') - def on_test_epoch_end(self): - self._assert_epoch_end('test') + def on_test_end(self): + if not self.trainer.sanity_checking: + self._assert_epoch_end('test') def _assert_called(model, stage): acc = model._modules[f"acc_{stage}"] ap = model._modules[f"ap_{stage}"] - acc.reset.assert_called_once() + assert acc.reset.call_count == 1 acc.reset.reset_mock() - ap.reset.assert_called_once() + assert ap.reset.call_count == 1 ap.reset.reset_mock() model = TestModel() @@ -702,6 +422,7 @@ def _assert_called(model, stage): limit_test_batches=2, max_epochs=1, progress_bar_refresh_rate=0, + num_sanity_val_steps=0, ) trainer.fit(model) From 35576918947b530fee49b95a7b099c14186c2385 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 24 May 2021 20:43:25 +0100 Subject: [PATCH 093/455] update --- pytorch_lightning/core/step_result.py | 1 - pytorch_lightning/plugins/training_type/ddp2.py | 10 ++++++---- pytorch_lightning/plugins/training_type/dp.py | 12 ++++-------- .../connectors/logger_connector/logger_connector.py | 8 ++++---- pytorch_lightning/trainer/properties.py | 1 - pytorch_lightning/utilities/apply_func.py | 7 ++++--- tests/core/test_metric_result_integration.py | 11 ++++++----- tests/core/test_results.py | 8 -------- tests/models/test_tpu.py | 3 ++- tests/trainer/logging_/test_eval_loop_logging.py | 3 --- tests/trainer/logging_/test_logger_connector.py | 4 +--- 11 files changed, 27 insertions(+), 41 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 7213fbc4affac..991ad63a70027 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Mapping, Sequence from copy import deepcopy from dataclasses import dataclass from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index e66c7a5243d6a..13e6530270d4e 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -14,6 +14,7 @@ import torch from pytorch_lightning.plugins.training_type.ddp import DDPPlugin +from pytorch_lightning.utilities.apply_func import apply_to_collection class DDP2Plugin(DDPPlugin): @@ -46,11 +47,12 @@ def reduce(self, tensor, *args, **kwargs): Return: reduced value, except when the input was not a tensor the output remains is unchanged """ - if isinstance(tensor, Result): - tensor.dp_reduce() - elif isinstance(tensor, torch.Tensor): - tensor = tensor.mean() + def _reduce(t: torch.Tensor): + dtype_tensor = t.dtype + return t.float().mean().type(dtype_tensor) + + tensor = apply_to_collection(tensor, torch.Tensor, _reduce) return tensor diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 01820c7f6b614..9420409b6d10a 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -63,16 +63,12 @@ def reduce(self, tensor, *args, **kwargs): Return: reduced value, except when the input was not a tensor the output remains is unchanged """ - if isinstance(tensor, Result): - tensor.dp_reduce() - else: + def _reduce(t: torch.Tensor): + dtype_tensor = t.dtype + return t.float().mean().type(dtype_tensor) - def _reduce(t: torch.Tensor): - dtype_tensor = t.dtype - return t.float().mean().type(dtype_tensor) - - tensor = apply_to_collection(tensor, torch.Tensor, _reduce) + tensor = apply_to_collection(tensor, torch.Tensor, _reduce) return tensor diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index d235e6c47d518..15419fd8b8ec7 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -226,7 +226,7 @@ def update_evaluation_step_metrics(self) -> None: # increment the step even if nothing was logged self.increment_evaluation_log_step() - ############## TRAIN METRICS UPDATES START ############## + ############## TRAIN METRICS UPDATES START ############## # noqa E266 def on_train_split_start(self, split_batch: Any) -> None: self.trainer.result_collections.extract_batch_size(split_batch) @@ -281,9 +281,9 @@ def update_train_epoch_metrics(self) -> None: # reset result collection for next epoch self.trainer.result_collections.reset() - ############## TRAIN METRICS UPDATES END ############## + ############## TRAIN METRICS UPDATES END ############## # noqa E266 - ############## UTILS START ############## + ############## UTILS START ############## # noqa E266 @property def callback_metrics(self) -> Dict: @@ -322,4 +322,4 @@ def add_callback_metrics(self, metrics): def check_logging(self, fx_name: str, on_step: bool, on_epoch: bool) -> None: self._fx_validator.check_logging(fx_name=fx_name, on_step=on_step, on_epoch=on_epoch) - ############## UTILS END ############## + ############## UTILS END ############## # noqa E266 diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 3c98866c799cb..57d06d99dd773 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -43,7 +43,6 @@ parse_env_variables, ) from pytorch_lightning.utilities.cloud_io import get_filesystem -from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index c46d0ad525d33..af125e5ee5a95 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -71,9 +71,10 @@ def apply_to_collection( dtype: the given function will be applied to all elements of this dtype function: the function to apply *args: positional arguments (will be forwarded to calls of ``function``) - wrong_dtype: the given function won't be applied if this type is specified and the given collections is of - the :attr:`wrong_type` even if it is of type :attr`dtype` - remove_none: Whether to skip an element if the output of function is ``None`` while applying onto the collection. + wrong_dtype: the given function won't be applied if this type is specified and the given collections + is of the :attr:`wrong_type` even if it is of type :attr`dtype` + remove_none: Whether to skip an element if the output of function is ``None`` + while applying onto the collection. **kwargs: keyword arguments (will be forwarded to calls of ``function``) Returns: diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index d54c4cfdbb043..ac6fd852f521f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -18,6 +18,7 @@ from torchmetrics import Metric import tests.helpers.utils as tutils +from pytorch_lightning.core.step_result import DefaultMetricsKeys, ResultCollection from tests.helpers.runif import RunIf @@ -52,7 +53,7 @@ def _ddp_test_fn(rank, worldsize): metric_c = DummyMetric() # dist_sync_on_step is False by default - result = Result() + result = ResultCollection() for epoch in range(3): cumulative_sum = 0 @@ -74,7 +75,7 @@ def _ddp_test_fn(rank, worldsize): for k in batch_expected.keys(): assert batch_expected[k] == batch_log[k] - epoch_log = result.get_epoch_log_metrics() + epoch_log = result.get_epoch_metrics()[DefaultMetricsKeys.LOG] result.reset() # assert metric state reset to default values @@ -103,7 +104,7 @@ def test_result_metric_integration(): metric_b = DummyMetric() metric_c = DummyMetric() - result = Result() + result = ResultCollection() for epoch in range(3): cumulative_sum = 0 @@ -119,13 +120,13 @@ def test_result_metric_integration(): result.log('b', metric_b, on_step=False, on_epoch=True) result.log('c', metric_c, on_step=True, on_epoch=False) - batch_log = result.get_batch_log_metrics() + batch_log = result.get_batch_metrics()[DefaultMetricsKeys.LOG] batch_expected = {"a_step": i, "a": i, "c": i} assert set(batch_log.keys()) == set(batch_expected.keys()) for k in batch_expected.keys(): assert batch_expected[k] == batch_log[k] - epoch_log = result.get_epoch_log_metrics() + epoch_log = result.get_epoch_metrics()[DefaultMetricsKeys.LOG] result.reset() # assert metric state reset to default values diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 2ce74fe8598c5..08ab51ad94b86 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -174,11 +174,3 @@ def test_dataloader(self): assert prediction_file.exists() predictions = torch.load(prediction_file) assert len(predictions) == len(dm.random_test) - - -def test_result_retrieve_last_logged_item(): - result = Result() - result.log('a', 5., on_step=True, on_epoch=True) - assert result['a_epoch'] == 5. - assert result['a_step'] == 5. - assert result['a'] == 5. diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index f66229b959554..9b978e66512e7 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -24,6 +24,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.core.step_result import ResultCollection from pytorch_lightning.plugins import TPUSpawnPlugin from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.distributed import ReduceOp @@ -427,7 +428,7 @@ def test_sync_dist(rank): tensor = torch.tensor([1.0]) training_type_plugin = TPUSpawnPlugin() - res = Result() + res = ResultCollection() res.log( "test_tensor", tensor, diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 575eb91db8c03..83523addff99a 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -28,7 +28,6 @@ from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.utilities import exceptions from tests.helpers import BoringModel, RandomDataset from tests.helpers.deterministic_model import DeterministicModel @@ -773,8 +772,6 @@ def get_expected_output(func_attr, original_values): # Make sure the func_name output equals the average from all logged values when on_epoch true # pop extra keys - #assert "debug_epoch" in trainer.callback_metrics - #trainer.callback_metrics.pop("debug_epoch") for dl_idx in range(num_dataloaders): key = f"test_loss/dataloader_idx_{dl_idx}" diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 337423a06e5ca..4c2185d38295e 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -15,8 +15,6 @@ Tests to ensure that the training loop works with a dict (1.0) """ import os -from copy import deepcopy -from typing import Any, Callable from unittest import mock import pytest @@ -422,7 +420,7 @@ def _assert_called(model, stage): limit_test_batches=2, max_epochs=1, progress_bar_refresh_rate=0, - num_sanity_val_steps=0, + num_sanity_val_steps=2, ) trainer.fit(model) From eb0b34c44ad53d76814351ce257b5e44139f163c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 01:41:16 +0200 Subject: [PATCH 094/455] fix current_epoch counting --- pytorch_lightning/loops/epoch_loop.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 51d9273378ce3..1496363b205d4 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -32,10 +32,16 @@ def __init__(self, min_epochs, max_epochs, min_steps, max_steps): # If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1 self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs - self.current_epoch = 0 - self.training_loop = TrainingLoop(min_steps, max_steps) + @property + def current_epoch(self) -> int: + return self.iteration_count + + @current_epoch.setter + def current_epoch(self, value: int): + self.iteration_count = value + @property def global_step(self): return self.training_loop.global_step @@ -102,23 +108,17 @@ def on_run_start(self): # hook self.trainer.call_hook("on_train_start") - def on_advance_start(self): # equal to on train epoch start - # implemented here since this code has to be run always no matter the actual epoch implementation - epoch = self.iteration_count + 1 - - # update training progress in trainer - self.current_epoch = epoch - + def on_advance_start(self): # equal to old on_train_epoch_start model = self.trainer.lightning_module # reset train dataloader - if epoch != 0 and self.trainer.reload_dataloaders_every_epoch: + if self.current_epoch != 0 and self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_train_dataloader(model) # todo: specify the possible exception with suppress(Exception): # set seed for distributed sampler (enables shuffling for each epoch) - self.trainer.train_dataloader.sampler.set_epoch(epoch) + self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch) # changing gradient according accumulation_scheduler self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) @@ -211,4 +211,3 @@ def check_checkpoint_callback(self, should_update, is_last=False): for cb in callbacks: cb.on_validation_end(self.trainer, model) - From f283371d5b8bc0273524244187b5c7fb5968eee4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 01:46:56 +0200 Subject: [PATCH 095/455] fix hiddens update --- pytorch_lightning/loops/batch_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index f014a1c04d45a..1066d7554342c 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -271,7 +271,7 @@ def _process_training_step_output(self, training_step_output, split_batch): # map to results under the hood result.minimize = loss - self.trainer.hiddens = hiddens + self._hiddens = hiddens # track batch for manual reduction with result result.track_batch_size(len(split_batch)) From c3f25e902abbca770ca384c7e17adf0d56ac6bc4 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 25 May 2021 04:13:29 -0400 Subject: [PATCH 096/455] update --- pytorch_lightning/core/step_result.py | 12 ++++++------ .../connectors/logger_connector/logger_connector.py | 1 + tests/accelerators/test_common.py | 4 ++-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 991ad63a70027..a287faa520aed 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -79,9 +79,9 @@ def __init__(self, metadata: Metadata) -> None: super().__init__(compute_on_step=metadata.is_tensor) self.meta = metadata if self.meta.is_tensor: - self.add_state("value", torch.tensor(.0)) + self.add_state("value", torch.tensor(.0, dtype=torch.float64)) if self.meta.is_tensor_and_mean_reduction: - self.add_state("cumulated_batch_size", torch.tensor(.0)) + self.add_state("cumulated_batch_size", torch.tensor(.0, dtype=torch.float64)) def update(self, value: _METRIC, batch_size: Optional[int] = None) -> None: if self.meta.is_tensor_and_mean_reduction: @@ -101,9 +101,9 @@ def update(self, value: _METRIC, batch_size: Optional[int] = None) -> None: def compute(self) -> torch.Tensor: if self.meta.is_tensor: if self.meta.is_tensor_and_mean_reduction: - return self.value / self.cumulated_batch_size + return torch.sum(self.value) / torch.sum(self.cumulated_batch_size) elif self.meta.is_tensor_and_max_reduction or self.meta.is_tensor_and_min_reduction: - return self.value + return self.meta.fx(self.value) else: raise MisconfigurationException("Only mean, max are supported.") else: @@ -123,7 +123,6 @@ def reset(self): if self.meta.is_tensor: super().reset() else: - print(self.meta.fx, self.meta.name) self.value.reset() def forward(self, *args, **kwargs): @@ -260,7 +259,8 @@ def fn(v): nonlocal meta meta = deepcopy(meta) meta.is_tensor = torch.is_tensor(v) - return ResultMetric(meta) + metric = ResultMetric(meta) + return metric.to(v.device) self[key] = apply_to_collection(value, (torch.Tensor, Metric), fn) # cache the meta for reduction diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 15419fd8b8ec7..0b83aae77e4df 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -160,6 +160,7 @@ def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: if not self.trainer.sanity_checking: metrics = self.trainer.result_collections.metrics + # update metrics self.add_progress_bar_metrics(metrics[DefaultMetricsKeys.PBAR]) self.add_callback_metrics(metrics[DefaultMetricsKeys.CALLBACK]) diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index a538838150381..f675a82e29558 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -25,8 +25,8 @@ @pytest.mark.parametrize( "trainer_kwargs", ( - pytest.param(dict(gpus=1), marks=RunIf(min_gpus=1)), - pytest.param(dict(accelerator="dp", gpus=2), marks=RunIf(min_gpus=2)), + #pytest.param(dict(gpus=1), marks=RunIf(min_gpus=1)), + #pytest.param(dict(accelerator="dp", gpus=2), marks=RunIf(min_gpus=2)), pytest.param(dict(accelerator="ddp_spawn", gpus=2), marks=RunIf(min_gpus=2)), ) ) From e343b3460ebd65a11afe195c647402cce30b9deb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 May 2021 08:14:20 +0000 Subject: [PATCH 097/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/accelerators/test_common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index f675a82e29558..7fff90fa7a918 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -24,7 +24,8 @@ @pytest.mark.parametrize( - "trainer_kwargs", ( + "trainer_kwargs", + ( #pytest.param(dict(gpus=1), marks=RunIf(min_gpus=1)), #pytest.param(dict(accelerator="dp", gpus=2), marks=RunIf(min_gpus=2)), pytest.param(dict(accelerator="ddp_spawn", gpus=2), marks=RunIf(min_gpus=2)), From 3c67b01cbe13462e43e29d5585d8506a25dd1adf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 11:21:49 +0200 Subject: [PATCH 098/455] move total batch idx to old place --- pytorch_lightning/loops/training_loop.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 3674971bb96d0..f5a542dc851f0 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -130,9 +130,6 @@ def advance(self): self.trainer.logger_connector.log_train_step_metrics(batch_output) def on_advance_end(self): - # TODO: where is the right place update this !!!!????? - self.total_batch_idx += 1 - # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- @@ -152,6 +149,8 @@ def on_advance_end(self): self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) self.trainer.checkpoint_connector.has_trained = True + self.total_batch_idx += 1 + if self.done: raise StopIteration From 5acd9fb67566b13220904d5831edad2998fb0de7 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 25 May 2021 05:44:02 -0400 Subject: [PATCH 099/455] update --- pytorch_lightning/core/step_result.py | 54 +++++++++++++++----- pytorch_lightning/trainer/training_loop.py | 2 +- tests/accelerators/test_common.py | 4 +- tests/callbacks/test_progress_bar.py | 5 +- tests/core/test_metric_result_integration.py | 32 +++++++----- 5 files changed, 67 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index a287faa520aed..502358a93991c 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -13,7 +13,8 @@ # limitations under the License. from copy import deepcopy from dataclasses import dataclass -from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union +from collections.abc import Mapping, Sequence, Generator +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union, NamedTuple import torch from torch import Tensor @@ -75,13 +76,15 @@ def is_tensor_and_min_reduction(self) -> bool: class ResultMetric(Metric): + DTYPE = torch.float32 + def __init__(self, metadata: Metadata) -> None: super().__init__(compute_on_step=metadata.is_tensor) self.meta = metadata if self.meta.is_tensor: - self.add_state("value", torch.tensor(.0, dtype=torch.float64)) + self.add_state("value", torch.tensor(.0, dtype=self.DTYPE)) if self.meta.is_tensor_and_mean_reduction: - self.add_state("cumulated_batch_size", torch.tensor(.0, dtype=torch.float64)) + self.add_state("cumulated_batch_size", torch.tensor(.0, dtype=self.DTYPE)) def update(self, value: _METRIC, batch_size: Optional[int] = None) -> None: if self.meta.is_tensor_and_mean_reduction: @@ -204,9 +207,12 @@ def extra(self) -> Dict: @extra.setter def extra(self, extra: Dict) -> None: + def detach_fn(v): + return v.detach() + extra = apply_to_collection(extra, torch.Tensor, detach_fn) self['extra'] = extra - def should_reset(self, hook_name): + def should_reset(self, hook_name: str) -> bool: return hook_name not in ("on_train_start") def log( @@ -260,7 +266,8 @@ def fn(v): meta = deepcopy(meta) meta.is_tensor = torch.is_tensor(v) metric = ResultMetric(meta) - return metric.to(v.device) + device = getattr(v, "device", torch.device("cpu")) + return metric.to(device) self[key] = apply_to_collection(value, (torch.Tensor, Metric), fn) # cache the meta for reduction @@ -290,11 +297,11 @@ def _get_forward_cache(result_metric: ResultMetric) -> Optional[torch.Tensor]: def _to_item(forward_cache: torch.Tensor) -> float: return forward_cache.item() - def valid_metrics(self) -> Tuple[str, Any]: - for key, result_metric in self.items(): - if isinstance(result_metric, bool) or key == "extra": + def valid_metrics(self) -> Generator: + for key, item in self.items(): + if item is None or isinstance(item, bool) or key == "extra": continue - yield (key, result_metric) + yield (key, item) def _extract_metadata(self, key: str, result_metric, on_step: bool, suffix: str) -> Tuple: if isinstance(result_metric, ResultMetric): @@ -324,24 +331,46 @@ def get_metrics(self, on_step: bool) -> Dict[str, Dict[str, torch.Tensor]]: fn = self._get_forward_cache if on_step else self._get_computed_cache suffix = self.STEP_SUFFIX if on_step else self.EPOCH_SUFFIX + # iterate over all stored metrics. for key, result_metric in self.valid_metrics(): + + # extract forward_cache or computed from the ResultMetric + # ignore when the output of fn is None value = apply_to_collection(result_metric, ResultMetric, fn, remove_none=True) - if value is None: + + # detect if the value is None. This can be nested. + is_empty = True + + def is_empty_fn(v): + nonlocal is_empty + if v is not None: + is_empty = False + + # apply detection. + apply_to_collection(value, object, is_empty_fn, wrong_dtype=(Mapping, Sequence, NamedTuple)) + + # skip is the value was actually empty. + if is_empty: continue + # extract metadata name, name_forked, logger, prog_bar, metric_on_epoch = self._extract_metadata( key, result_metric, on_step, suffix ) + # populate logging metrics if logger: metrics[DefaultMetricsKeys.LOG][name_forked] = value + # populate callback metrics + # callback metrics don't take `_step` forked metrics. if not self.is_train and (not metric_on_epoch or on_step): pass else: metrics[DefaultMetricsKeys.CALLBACK][name] = value metrics[DefaultMetricsKeys.CALLBACK][name_forked] = value + # populate progress_bar metrics. By default, the value should be converted to a float. if prog_bar: value = apply_to_collection(value, torch.Tensor, self._to_item, remove_none=True) metrics[DefaultMetricsKeys.PBAR][name_forked] = value @@ -377,9 +406,10 @@ def cpu(self) -> 'ResultCollection': def reset(self) -> None: """Call at the end of epoch to reset all metric objects""" - for item in self.values(): - if isinstance(item, ResultMetric) and item.meta.should_reset: + def reset_fn(item: ResultMetric) -> None: + if item.meta.should_reset: item.reset() + apply_to_collection(dict(self.items()), ResultMetric, reset_fn) self._batch_size: Optional[int] = None self._on_epoch_end_reached: bool = False self._minimize: Optional[Tensor] = None diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b4d6bb70ce9b9..b17f80e6b6874 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -382,7 +382,7 @@ def _prepare_outputs( for tbptt_output in batch_outputs: out = tbptt_output.extra - out['loss'] = tbptt_output.minimize + out['loss'] = tbptt_output.minimize.detach() processed_tbptt_outputs.append(out) # if there was only one tbptt step then we can collapse that dimension diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index f675a82e29558..a538838150381 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -25,8 +25,8 @@ @pytest.mark.parametrize( "trainer_kwargs", ( - #pytest.param(dict(gpus=1), marks=RunIf(min_gpus=1)), - #pytest.param(dict(accelerator="dp", gpus=2), marks=RunIf(min_gpus=2)), + pytest.param(dict(gpus=1), marks=RunIf(min_gpus=1)), + pytest.param(dict(accelerator="dp", gpus=2), marks=RunIf(min_gpus=2)), pytest.param(dict(accelerator="ddp_spawn", gpus=2), marks=RunIf(min_gpus=2)), ) ) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index f4f8f34c1b4c1..f28178626b25e 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -388,6 +388,9 @@ def training_step(self, batch, batch_idx): self.log('bar', {"baz": torch.tensor([1])}, prog_bar=True) return super().training_step(batch, batch_idx) + def on_train_end(self) -> None: + print(self.trainer.result_collections) + trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, @@ -399,7 +402,7 @@ def training_step(self, batch, batch_idx): pbar = trainer.progress_bar_callback.main_progress_bar actual = str(pbar.postfix) - assert actual.endswith("foo=0.123, bar={'baz': tensor([1])}") + assert actual.endswith("foo=0.123, bar={'baz': 1.0}") @pytest.mark.parametrize( diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index ac6fd852f521f..1372855e2ae34 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -52,10 +52,14 @@ def _ddp_test_fn(rank, worldsize): metric_b = DummyMetric() metric_c = DummyMetric() + metric_a = metric_a.to(f"cuda:{rank}") + metric_b = metric_b.to(f"cuda:{rank}") + metric_c = metric_c.to(f"cuda:{rank}") + # dist_sync_on_step is False by default - result = ResultCollection() + result = ResultCollection(True) - for epoch in range(3): + for _ in range(3): cumulative_sum = 0 for i in range(5): @@ -65,12 +69,12 @@ def _ddp_test_fn(rank, worldsize): cumulative_sum += i - result.log('a', metric_a, on_step=True, on_epoch=True) - result.log('b', metric_b, on_step=False, on_epoch=True) - result.log('c', metric_c, on_step=True, on_epoch=False) + result.log('h', 'a', metric_a, on_step=True, on_epoch=True) + result.log('h', 'b', metric_b, on_step=False, on_epoch=True) + result.log('h', 'c', metric_c, on_step=True, on_epoch=False) - batch_log = result.get_batch_log_metrics() - batch_expected = {"a_step": i, "a": i, "c": i} + batch_log = result.get_batch_metrics()[DefaultMetricsKeys.LOG] + batch_expected = {"a_step": i, "c": i} assert set(batch_log.keys()) == set(batch_expected.keys()) for k in batch_expected.keys(): assert batch_expected[k] == batch_log[k] @@ -79,7 +83,7 @@ def _ddp_test_fn(rank, worldsize): result.reset() # assert metric state reset to default values - assert metric_a.x == metric_a._defaults['x'] + assert metric_a.x == metric_a._defaults['x'], (metric_a.x, metric_a._defaults['x']) assert metric_b.x == metric_b._defaults['x'] assert metric_c.x == metric_c._defaults['x'] @@ -104,9 +108,9 @@ def test_result_metric_integration(): metric_b = DummyMetric() metric_c = DummyMetric() - result = ResultCollection() + result = ResultCollection(True) - for epoch in range(3): + for _ in range(3): cumulative_sum = 0 for i in range(5): @@ -116,12 +120,12 @@ def test_result_metric_integration(): cumulative_sum += i - result.log('a', metric_a, on_step=True, on_epoch=True) - result.log('b', metric_b, on_step=False, on_epoch=True) - result.log('c', metric_c, on_step=True, on_epoch=False) + result.log('h', 'a', metric_a, on_step=True, on_epoch=True) + result.log('h', 'b', metric_b, on_step=False, on_epoch=True) + result.log('h', 'c', metric_c, on_step=True, on_epoch=False) batch_log = result.get_batch_metrics()[DefaultMetricsKeys.LOG] - batch_expected = {"a_step": i, "a": i, "c": i} + batch_expected = {"a_step": i, "c": i} assert set(batch_log.keys()) == set(batch_expected.keys()) for k in batch_expected.keys(): assert batch_expected[k] == batch_log[k] From 32971e61e84d1e8e8fb844316218198c12f5ab9b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 May 2021 09:45:16 +0000 Subject: [PATCH 100/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/core/step_result.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 502358a93991c..85e3d70178042 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator, Mapping, Sequence from copy import deepcopy from dataclasses import dataclass -from collections.abc import Mapping, Sequence, Generator -from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union, NamedTuple +from typing import Any, Callable, Dict, Iterable, NamedTuple, Optional, Tuple, Union import torch from torch import Tensor @@ -207,8 +207,10 @@ def extra(self) -> Dict: @extra.setter def extra(self, extra: Dict) -> None: + def detach_fn(v): return v.detach() + extra = apply_to_collection(extra, torch.Tensor, detach_fn) self['extra'] = extra @@ -333,20 +335,20 @@ def get_metrics(self, on_step: bool) -> Dict[str, Dict[str, torch.Tensor]]: # iterate over all stored metrics. for key, result_metric in self.valid_metrics(): - + # extract forward_cache or computed from the ResultMetric # ignore when the output of fn is None value = apply_to_collection(result_metric, ResultMetric, fn, remove_none=True) - + # detect if the value is None. This can be nested. is_empty = True - + def is_empty_fn(v): nonlocal is_empty if v is not None: is_empty = False - # apply detection. + # apply detection. apply_to_collection(value, object, is_empty_fn, wrong_dtype=(Mapping, Sequence, NamedTuple)) # skip is the value was actually empty. @@ -406,9 +408,11 @@ def cpu(self) -> 'ResultCollection': def reset(self) -> None: """Call at the end of epoch to reset all metric objects""" + def reset_fn(item: ResultMetric) -> None: if item.meta.should_reset: item.reset() + apply_to_collection(dict(self.items()), ResultMetric, reset_fn) self._batch_size: Optional[int] = None self._on_epoch_end_reached: bool = False From 17d6f2a81faa72be9d5e10ec3ac1d68294e33b7a Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 25 May 2021 05:59:12 -0400 Subject: [PATCH 101/455] update --- pytorch_lightning/core/step_result.py | 18 +++++++++++++++++- .../connectors/test_logger_connectors.py | 7 +++++++ .../optimization/test_multiple_optimizers.py | 12 ------------ 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 502358a93991c..af1227b9c9e0c 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -155,6 +155,12 @@ def forward(self, *args, **kwargs): return self._forward_cache + def state_dict(self): + return { + "meta": self.meta, + + } + class ResultCollection(dict): @@ -347,7 +353,7 @@ def is_empty_fn(v): is_empty = False # apply detection. - apply_to_collection(value, object, is_empty_fn, wrong_dtype=(Mapping, Sequence, NamedTuple)) + apply_to_collection(value, object, is_empty_fn, wrong_dtype=(Mapping, Sequence, NamedTuple,)) # skip is the value was actually empty. if is_empty: @@ -447,3 +453,13 @@ def __repr__(self) -> str: v = self[k] repr += f" {k}: {v},\n" return repr[:-1] + '\n}' + + def state_dict(self): + def get_state_dict(item: ResultMetric) -> Dict[str, Any]: + return item.state_dict() + + return { + k: apply_to_collection(v, ResultMetric, get_state_dict) + for k, v in self.items() + } + \ No newline at end of file diff --git a/tests/trainer/connectors/test_logger_connectors.py b/tests/trainer/connectors/test_logger_connectors.py index a0a01df0ff858..b6f487b6d6a0c 100644 --- a/tests/trainer/connectors/test_logger_connectors.py +++ b/tests/trainer/connectors/test_logger_connectors.py @@ -131,3 +131,10 @@ def test_result_collection_on_tensor_with_mean_reduction(): 'loss_2_1_1': mean } assert epoch_metrics[DefaultMetricsKeys.CALLBACK] == excepted + + +def test_result_collection_restoration(): + + result_collection = ResultCollection(True) + + result_collection \ No newline at end of file diff --git a/tests/trainer/optimization/test_multiple_optimizers.py b/tests/trainer/optimization/test_multiple_optimizers.py index aba3b53248a57..c795107a36371 100644 --- a/tests/trainer/optimization/test_multiple_optimizers.py +++ b/tests/trainer/optimization/test_multiple_optimizers.py @@ -49,22 +49,12 @@ def training_step(self, batch, batch_idx, optimizer_idx): model = TestModel() model.training_epoch_end = None - class TestCallback(pl.Callback): - - def on_train_batch_end(self, trainer, pl_module, output, batch, batch_idx, dl_idx): - # when this is called, the EpochResultStore state has not been reset yet because we are still - # "INSIDE_BATCH_TRAIN_LOOP" and the LoggerConnector runs its `on_train_batch_end` after the - # Callback (see `TrainLoop.on_train_batch_end`). For this reason, opt_idx here is the index - # of the last optimizer updated (the second, index 1). This produced a KeyError as reported in #5459 - pl_module.log("test_train_batch_end", trainer.logger_connector.cached_results._opt_idx) - # Initialize a trainer trainer = pl.Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=5, limit_val_batches=5, - callbacks=[TestCallback()], weights_summary=None, ) trainer.fit(model) @@ -74,8 +64,6 @@ def on_train_batch_end(self, trainer, pl_module, output, batch, batch_idx, dl_id # test loss is properly reduced torch.testing.assert_allclose(trainer.callback_metrics[f"loss_{k}_epoch"], torch.tensor(v).mean()) - assert trainer.callback_metrics["test_train_batch_end"] == len(model.optimizers()) - 1 - def test_multiple_optimizers(tmpdir): From 8f626690f7c810d776aed01f07e3b4fb9c69f2b2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 May 2021 10:00:51 +0000 Subject: [PATCH 102/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/core/step_result.py | 16 ++++++++-------- .../trainer/connectors/test_logger_connectors.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index ea0d094cac45b..00c88576b5824 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -158,7 +158,6 @@ def forward(self, *args, **kwargs): def state_dict(self): return { "meta": self.meta, - } @@ -354,8 +353,12 @@ def is_empty_fn(v): if v is not None: is_empty = False - # apply detection. - apply_to_collection(value, object, is_empty_fn, wrong_dtype=(Mapping, Sequence, NamedTuple,)) + # apply detection. + apply_to_collection(value, object, is_empty_fn, wrong_dtype=( + Mapping, + Sequence, + NamedTuple, + )) # skip is the value was actually empty. if is_empty: @@ -459,11 +462,8 @@ def __repr__(self) -> str: return repr[:-1] + '\n}' def state_dict(self): + def get_state_dict(item: ResultMetric) -> Dict[str, Any]: return item.state_dict() - return { - k: apply_to_collection(v, ResultMetric, get_state_dict) - for k, v in self.items() - } - \ No newline at end of file + return {k: apply_to_collection(v, ResultMetric, get_state_dict) for k, v in self.items()} diff --git a/tests/trainer/connectors/test_logger_connectors.py b/tests/trainer/connectors/test_logger_connectors.py index b6f487b6d6a0c..f46c03107dc8c 100644 --- a/tests/trainer/connectors/test_logger_connectors.py +++ b/tests/trainer/connectors/test_logger_connectors.py @@ -137,4 +137,4 @@ def test_result_collection_restoration(): result_collection = ResultCollection(True) - result_collection \ No newline at end of file + result_collection From 1f6ab04806331237a4c39c2a45d717293b1333f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 12:42:06 +0200 Subject: [PATCH 103/455] fix epoch end handling for empty training loop epoch --- pytorch_lightning/loops/epoch_loop.py | 7 +++++++ pytorch_lightning/loops/training_loop.py | 7 +++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 1496363b205d4..9e738333d6de4 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -136,12 +136,19 @@ def advance(self): # run train epoch epoch_output = self.training_loop.run() # log epoch metrics + + if epoch_output is None: + return + self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output) def on_advance_end(self): # # handle epoch_output on epoch end # self.on_train_epoch_end(outputs) # Handled in on_run_end of training_loop now + if self.training_loop.batch_idx is None: + return + should_check_val = self.training_loop.should_check_val_fx(self.batch_idx, self.training_loop.is_last_batch, on_epoch=True) should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) should_train_only = self.trainer.disable_validation or should_skip_eval diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index f5a542dc851f0..4975f07bde549 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -90,7 +90,7 @@ def on_run_start(self): self._train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) self._dataloader_idx = 0 self._should_stop = False - self.batch_idx = 0 + self.batch_idx = None self.batches_seen = 0 self.is_last_batch = False @@ -107,7 +107,6 @@ def advance(self): # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ with self.trainer.profiler.profile("run_training_batch"): - # batch_output = self.run_training_batch(batch, batch_idx, self._dataloader_idx) batch_output = self.batch_loop.run(batch, batch_idx, self._dataloader_idx) self.batches_seen += 1 @@ -159,6 +158,10 @@ def on_advance_end(self): # this is the old on train_epoch_end? def on_run_end(self): + if self.batch_idx is None: + # dataloader/iterator did not produce a batch + return + # inform logger the batch loop has finished self.trainer.logger_connector.on_train_epoch_end() From c62860fa0cdbedb28d29a72886900e48b0b9b2c2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 May 2021 10:53:25 +0000 Subject: [PATCH 104/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/base.py | 5 +++-- pytorch_lightning/loops/batch_loop.py | 11 +++++++---- pytorch_lightning/loops/cache.py | 2 -- pytorch_lightning/loops/epoch_loop.py | 8 ++++++-- pytorch_lightning/loops/training_loop.py | 9 +++++---- pytorch_lightning/trainer/trainer.py | 6 ++---- 6 files changed, 23 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 5346109a767d9..d4833505cfb5c 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -1,7 +1,8 @@ -from _weakref import proxy -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod from typing import Any, Optional +from _weakref import proxy + import pytorch_lightning as pl diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 1066d7554342c..0211f408c2af2 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -2,7 +2,7 @@ from contextlib import contextmanager from copy import copy from functools import partial, update_wrapper -from typing import List, Any, Optional, Callable, Tuple +from typing import Any, Callable, List, Optional, Tuple import numpy as np import torch @@ -12,8 +12,8 @@ from pytorch_lightning.core.step_result import Result from pytorch_lightning.loops.base import Loop from pytorch_lightning.plugins import ParallelPlugin -from pytorch_lightning.trainer.supporters import TensorRunningAccum, prefetch_iterator -from pytorch_lightning.utilities import AttributeDict, DeviceType, AMPType, grad_norm +from pytorch_lightning.trainer.supporters import prefetch_iterator, TensorRunningAccum +from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.imports import _TPU_AVAILABLE @@ -95,6 +95,7 @@ def advance(self, batch, batch_idx, dataloader_idx): if result: self.batch_outputs[0].append(result.training_step_output_for_epoch_end) + # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ @@ -369,7 +370,9 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens): " the old signature will be removed in v1.5", DeprecationWarning ) args.append(opt_idx) - elif not self.trainer.has_arg("training_step", "optimizer_idx") and self.trainer.lightning_module.automatic_optimization: + elif not self.trainer.has_arg( + "training_step", "optimizer_idx" + ) and self.trainer.lightning_module.automatic_optimization: raise ValueError( f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but" ' `training_step` is missing the `optimizer_idx` argument.' diff --git a/pytorch_lightning/loops/cache.py b/pytorch_lightning/loops/cache.py index 440d2b27d73aa..8a616a9f94779 100644 --- a/pytorch_lightning/loops/cache.py +++ b/pytorch_lightning/loops/cache.py @@ -22,5 +22,3 @@ def filter_by(self, tags: Tuple[str]): self.cache.add("abc", result, batch_idx=) self.cache.group_by("abc", ("batch_idx", "opt_idx")) - - diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 9e738333d6de4..6ea3468b9fde9 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -124,7 +124,9 @@ def on_advance_start(self): # equal to old on_train_epoch_start self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) # stores accumulated grad fractions per batch - self.training_loop.batch_loop.accumulated_loss = TensorRunningAccum(window_length=self.trainer.accumulate_grad_batches) + self.training_loop.batch_loop.accumulated_loss = TensorRunningAccum( + window_length=self.trainer.accumulate_grad_batches + ) # hook self.trainer.call_hook("on_epoch_start") @@ -149,7 +151,9 @@ def on_advance_end(self): if self.training_loop.batch_idx is None: return - should_check_val = self.training_loop.should_check_val_fx(self.batch_idx, self.training_loop.is_last_batch, on_epoch=True) + should_check_val = self.training_loop.should_check_val_fx( + self.batch_idx, self.training_loop.is_last_batch, on_epoch=True + ) should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) should_train_only = self.trainer.disable_validation or should_skip_eval diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 4975f07bde549..f4e81f053c438 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import List, Dict, Union +from typing import Dict, List, Union import pytorch_lightning as pl from pytorch_lightning.core.step_result import Result @@ -193,6 +193,7 @@ def on_run_end(self): self.trainer.call_hook('on_epoch_end') return self.epoch_output + # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ @@ -287,8 +288,8 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): @staticmethod def _prepare_outputs( - outputs: List[List[List[Result]]], - batch_mode: bool, + outputs: List[List[List[Result]]], + batch_mode: bool, ) -> Union[List[List[List[Dict]]], List[List[Dict]], List[Dict], Dict]: """ Extract required information from batch or epoch end results. @@ -392,4 +393,4 @@ def save_loggers_on_train_batch_end(self): # when loggers should save to disk should_flush_logs = self.trainer.logger_connector.should_flush_logs if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: - self.trainer.logger.save() \ No newline at end of file + self.trainer.logger.save() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 554f1639dee7a..9ca1b7626e4a7 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -62,7 +62,7 @@ from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin from pytorch_lightning.trainer.predict_loop import PredictLoop from pytorch_lightning.trainer.properties import TrainerProperties -from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus, TrainerState +from pytorch_lightning.trainer.states import TrainerFn, TrainerState, TrainerStatus from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.lr_finder import _LRFinder @@ -391,9 +391,7 @@ def __init__( truncated_bptt_steps, terminate_on_nan, ) - self._setup_on_init( - num_sanity_val_steps, - ) + self._setup_on_init(num_sanity_val_steps, ) self.evaluation_loop.on_trainer_init() self.predict_loop.on_trainer_init() From 6a150569b8c9f832c43cfbb49b0c5219cd857ba3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 13:41:54 +0200 Subject: [PATCH 105/455] typing for base loop --- pytorch_lightning/loops/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index d4833505cfb5c..b4e7096466f7a 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -19,10 +19,10 @@ def connect(self, trainer, *args, **kwargs): @property @abstractmethod - def done(self): + def done(self) -> bool: """Property indicating when loop is finished""" - def run(self, *args: Any, **kwargs: Any): + def run(self, *args: Any, **kwargs: Any) -> Any: self.on_run_start(*args, **kwargs) while not self.done: @@ -41,7 +41,7 @@ def on_advance_start(self, *args: Any, **kwargs: Any) -> None: pass @abstractmethod - def advance(self, *args: Any, **kwargs: Any): + def advance(self, *args: Any, **kwargs: Any) -> None: """What to do within a single step""" def on_advance_end(self) -> None: From ddc9108ccee0b8cc7b66f47263724f8db85afbdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 13:52:48 +0200 Subject: [PATCH 106/455] increment iteration method --- pytorch_lightning/loops/base.py | 7 +++++-- pytorch_lightning/loops/training_loop.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index b4e7096466f7a..f06620ada8c79 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -30,7 +30,7 @@ def run(self, *args: Any, **kwargs: Any) -> Any: self.on_advance_start(*args, **kwargs) self.advance(*args, **kwargs) self.on_advance_end() - self.iteration_count += 1 + self.iteration_count = self.increment_iteration(self.iteration_count) return self.on_run_end() @@ -50,5 +50,8 @@ def on_advance_end(self) -> None: def on_run_end(self) -> Any: pass - def state_dict(self): + def increment_iteration(self, iteration: int) -> int: + return iteration + 1 + + def state_dict(self) -> dict: return dict() diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index f4e81f053c438..8d8b197949bba 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -78,7 +78,7 @@ def run(self, *args, **kwargs): except StopIteration: break - self.iteration_count += 1 + self.iteration_count = self.increment_iteration(self.iteration_count) return self.on_run_end() From afc986df7ac79d031bc2f36fb249c1e644cb462c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 13:53:01 +0200 Subject: [PATCH 107/455] cache notes commented --- pytorch_lightning/loops/cache.py | 48 ++++++++++++++++---------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/pytorch_lightning/loops/cache.py b/pytorch_lightning/loops/cache.py index 8a616a9f94779..c8c35f2f0b940 100644 --- a/pytorch_lightning/loops/cache.py +++ b/pytorch_lightning/loops/cache.py @@ -1,24 +1,24 @@ -from typing import Tuple - - -class Cache: - - def __init__(self): - self._store = ... - - def add(self, obj: object, **tags): - pass - - def merge(self, cache: "Cache"): - pass - - def filter_by(self, tags: Tuple[str]): - pass - - - -self.cache = Cache() -self.cache.add("abc", result, batch_idx=, opt_idx=..) -self.cache.add("abc", result, batch_idx=) - -self.cache.group_by("abc", ("batch_idx", "opt_idx")) +# from typing import Tuple +# +# +# class Cache: +# +# def __init__(self): +# self._store = ... +# +# def add(self, obj: object, **tags): +# pass +# +# def merge(self, cache: "Cache"): +# pass +# +# def filter_by(self, tags: Tuple[str]): +# pass +# +# +# +# self.cache = Cache() +# self.cache.add("abc", result, batch_idx=, opt_idx=..) +# self.cache.add("abc", result, batch_idx=) +# +# self.cache.group_by("abc", ("batch_idx", "opt_idx")) From 7151442852a9cbfd4ebe85c0b9fdab24b114fc68 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 May 2021 11:53:55 +0000 Subject: [PATCH 108/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/cache.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/loops/cache.py b/pytorch_lightning/loops/cache.py index c8c35f2f0b940..da0080de8744f 100644 --- a/pytorch_lightning/loops/cache.py +++ b/pytorch_lightning/loops/cache.py @@ -1,24 +1,24 @@ # from typing import Tuple -# -# +# +# # class Cache: -# +# # def __init__(self): # self._store = ... -# +# # def add(self, obj: object, **tags): # pass -# +# # def merge(self, cache: "Cache"): # pass -# +# # def filter_by(self, tags: Tuple[str]): # pass -# -# -# +# +# +# # self.cache = Cache() # self.cache.add("abc", result, batch_idx=, opt_idx=..) # self.cache.add("abc", result, batch_idx=) -# +# # self.cache.group_by("abc", ("batch_idx", "opt_idx")) From cfdd17435c656dad10e182c71f2cdede9bbe16f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 13:54:52 +0200 Subject: [PATCH 109/455] todos on progress tracking --- pytorch_lightning/loops/batch_loop.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 0211f408c2af2..15ecc34bbc253 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -338,9 +338,11 @@ def _track_gradient_norm(self): return grad_norm_dict def _accumulated_batches_reached(self): + # TODO: use progress tracking of batches instead of iteration count, because iteration count may reset return (self.iteration_count + 1) % self.trainer.accumulate_grad_batches == 0 def _num_training_batches_reached(self, is_last_batch=False): + # TODO: use progress tracking of batches instead of iteration count, because iteration count may reset return (self.iteration_count + 1) == self.trainer.num_training_batches or is_last_batch def should_accumulate(self): From 288a911b7963e77d85577a2982c7385833f4b7cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 14:07:09 +0200 Subject: [PATCH 110/455] incorporate #7682 --- pytorch_lightning/loops/epoch_loop.py | 2 +- pytorch_lightning/loops/training_loop.py | 42 ++++++++++++------------ 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 6ea3468b9fde9..c2f3e9feb117d 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -159,7 +159,7 @@ def on_advance_end(self): # update epoch level lr_schedulers if no val loop outside train loop is triggered if not should_check_val or should_train_only: - self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + self.training_loop.update_lr_schedulers("epoch") if should_train_only: self.check_checkpoint_callback(True) diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 8d8b197949bba..916df3b861ccf 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -1,4 +1,3 @@ -from copy import deepcopy from typing import Dict, List, Union import pytorch_lightning as pl @@ -144,8 +143,7 @@ def on_advance_end(self): self.save_loggers_on_train_batch_end() # update LR schedulers - monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics) - self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) + self.update_lr_schedulers('step') self.trainer.checkpoint_connector.has_trained = True self.total_batch_idx += 1 @@ -338,17 +336,16 @@ def _prepare_outputs( processed_outputs = processed_outputs[0] return processed_outputs - def update_train_loop_lr_schedulers(self, monitor_metrics=None): - num_accumulated_batches_reached = self.batch_loop._accumulated_batches_reached() - num_training_batches_reached = self._num_training_batches_reached() - - if num_accumulated_batches_reached or num_training_batches_reached: - # update lr - self.trainer.optimizer_connector.update_learning_rates( - interval="step", - monitor_metrics=monitor_metrics, - opt_indices=[opt_idx for opt_idx, _ in self.batch_loop.get_active_optimizers(self.total_batch_idx)], - ) + def update_lr_schedulers(self, interval: str) -> None: + if interval == "step": + finished_accumulation = self.batch_loop._accumulated_batches_reached() + finished_epoch = self._num_training_batches_reached() + if not finished_accumulation and not finished_epoch: + return + self.trainer.optimizer_connector.update_learning_rates( + interval=interval, + opt_indices=[opt_idx for opt_idx, _ in self.batch_loop.get_active_optimizers(self.total_batch_idx)], + ) def increment_accumulated_grad_global_step(self): num_accumulated_batches_reached = self.batch_loop._accumulated_batches_reached() @@ -362,15 +359,21 @@ def increment_accumulated_grad_global_step(self): def should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool: """ Decide if we should run validation. """ - if not self.trainer.enable_validation: return False - # check if this epoch is eligible to run validation - if (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0: + is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 + if not is_val_check_epoch: return False # val_check_batch is inf for iterable datasets with no length defined + is_infinite_dataset = self.trainer.val_check_batch == float('inf') + if on_epoch and is_last_batch and is_infinite_dataset: + return True + + if on_epoch and self.trainer.should_stop: + return True + # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch is_val_check_batch = False if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'): @@ -380,12 +383,9 @@ def should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: boo # Note: num_training_batches is also inf for iterable datasets with no length defined epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 - is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") if on_epoch: - return ( - is_val_check_batch and epoch_end_val_check - ) or self.trainer.should_stop or is_last_batch_for_infinite_dataset + return is_val_check_batch and epoch_end_val_check else: return is_val_check_batch and not epoch_end_val_check From c3e951ee502c6bb174d0a702c583535df1fdcfea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 14:16:56 +0200 Subject: [PATCH 111/455] add missing reference to running_loss --- pytorch_lightning/loops/epoch_loop.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index c2f3e9feb117d..79f0dea990533 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -24,7 +24,6 @@ class EpochLoop(Loop): def __init__(self, min_epochs, max_epochs, min_steps, max_steps): super().__init__() - self.running_loss = torch.tensor(0.0) # dummy TODO: self._teardown_already_run = False # If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000 @@ -75,6 +74,10 @@ def max_steps(self, value): # TODO: This setter is required by debugging connector (fast dev run) self.training_loop.max_steps = value + @property + def running_loss(self): + return self.training_loop.batch_loop.running_loss + def connect(self, trainer: 'pl.Trainer', *args, **kwargs): self.trainer = trainer self.training_loop.connect(trainer) From 44b3f6175caf5e354c30117c8a153942891bac2a Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 25 May 2021 08:25:19 -0400 Subject: [PATCH 112/455] update --- pytorch_lightning/core/step_result.py | 33 ++++++-- tests/core/test_metric_result_integration.py | 77 +++++++++++++++++++ .../connectors/test_logger_connectors.py | 9 +-- 3 files changed, 104 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index ea0d094cac45b..5e205da209d59 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -14,6 +14,7 @@ from collections.abc import Generator, Mapping, Sequence from copy import deepcopy from dataclasses import dataclass +from weakref import proxy from typing import Any, Callable, Dict, Iterable, NamedTuple, Optional, Tuple, Union import torch @@ -98,7 +99,7 @@ def update(self, value: _METRIC, batch_size: Optional[int] = None) -> None: self.value = min(self.value, value.float().mean()) else: - self.value = value + self.value = proxy(value) self._forward_cache = value._forward_cache def compute(self) -> torch.Tensor: @@ -119,7 +120,7 @@ def __repr__(self) -> str: if self.meta.is_tensor_and_mean_reduction: attr = f"value={self.value}, cumulated_batch_size={self.cumulated_batch_size}" else: - attr = f"value={self.value}" + attr = f"value={getattr(self, 'value', None)}" return f"{self.__class__.__name__}({attr})" def reset(self): @@ -155,11 +156,9 @@ def forward(self, *args, **kwargs): return self._forward_cache - def state_dict(self): - return { - "meta": self.meta, - } +class ResultMeta(Dict): + pass class ResultCollection(dict): @@ -460,10 +459,30 @@ def __repr__(self) -> str: def state_dict(self): def get_state_dict(item: ResultMetric) -> Dict[str, Any]: - return item.state_dict() + state = item.__getstate__() + # delete reference to TorchMetrics Metric + state = deepcopy(state) + if 'value' in state['_modules'] and isinstance(state['_modules']["value"], Metric): + del state['_modules']["value"] + return ResultMeta(**state) return { k: apply_to_collection(v, ResultMetric, get_state_dict) for k, v in self.items() } + + def load_from_state_dict(self, state_dict: Dict[str, Any]): + def to_result_metric(item: ResultMeta) -> Dict[str, Any]: + result_metric = ResultMetric(item["meta"]) + result_metric.__dict__.update(item) + return result_metric + + state_dict = { + k: apply_to_collection(v, ResultMeta, to_result_metric) + for k, v in state_dict.items() + } + + for k, v in state_dict.items(): + self[k] = v + \ No newline at end of file diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 1372855e2ae34..1663f820202e9 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy import torch import torch.distributed as dist import torch.multiprocessing as mp @@ -79,6 +80,11 @@ def _ddp_test_fn(rank, worldsize): for k in batch_expected.keys(): assert batch_expected[k] == batch_log[k] + state_dict = result.state_dict() + result = ResultCollection(True) + result.load_from_state_dict(state_dict) + + epoch_log = result.get_epoch_metrics()[DefaultMetricsKeys.LOG] result.reset() @@ -143,3 +149,74 @@ def test_result_metric_integration(): assert set(epoch_log.keys()) == set(epoch_expected.keys()) for k in epoch_expected.keys(): assert epoch_expected[k] == epoch_log[k] + + +def test_result_collection_restoration(): + + _result = None + metric_a = DummyMetric() + metric_b = DummyMetric() + metric_c = DummyMetric() + + result = ResultCollection(True) + + for _ in range(2): + + cumulative_sum = 0 + + for i in range(3): + a = metric_a(i) + b = metric_b(i) + c = metric_c(i) + + cumulative_sum += i + + result.log('h', 'a', metric_a, on_step=True, on_epoch=True) + result.log('h', 'b', metric_b, on_step=False, on_epoch=True) + result.log('h', 'c', metric_c, on_step=True, on_epoch=False) + + result.log('m', 'a_1', a, on_step=True, on_epoch=True) + result.log('m', 'b_1', b, on_step=False, on_epoch=True) + result.log('m', 'c_1', [c, c], on_step=True, on_epoch=False) + + batch_log = result.get_batch_metrics()[DefaultMetricsKeys.LOG] + batch_expected = {"a_step": i, "c": i, "a_1_step": i, "c_1": [i, i]} + assert set(batch_log.keys()) == set(batch_expected.keys()) + for k in batch_expected.keys(): + assert batch_expected[k] == batch_log[k] + + _result = deepcopy(result) + state_dict = result.state_dict() + + result = ResultCollection(True) + result.load_from_state_dict(state_dict) + + # the metric reference are lost during serialization. + # they will be restored with the LightningModule state on the next step. + result.log('h', 'a', metric_a, on_step=True, on_epoch=True) + result.log('h', 'b', metric_b, on_step=False, on_epoch=True) + result.log('h', 'c', metric_c, on_step=True, on_epoch=False) + + assert _result.items() == result.items() + + epoch_log = result.get_epoch_metrics()[DefaultMetricsKeys.LOG] + _epoch_log = _result.get_epoch_metrics()[DefaultMetricsKeys.LOG] + + assert epoch_log == _epoch_log + + epoch_expected = {'a_epoch', 'b', 'b_1', 'a_1_epoch'} + + assert set(epoch_log.keys()) == epoch_expected + for k in list(epoch_expected): + if k in {'a_epoch', 'b'}: + assert epoch_log[k] == cumulative_sum + else: + assert epoch_log[k] == 1 + + _result.reset() + result.reset() + + # assert metric state reset to default values + assert metric_a.x == metric_a._defaults['x'], (metric_a.x, metric_a._defaults['x']) + assert metric_b.x == metric_b._defaults['x'] + assert metric_c.x == metric_c._defaults['x'] \ No newline at end of file diff --git a/tests/trainer/connectors/test_logger_connectors.py b/tests/trainer/connectors/test_logger_connectors.py index b6f487b6d6a0c..a246dc37fae52 100644 --- a/tests/trainer/connectors/test_logger_connectors.py +++ b/tests/trainer/connectors/test_logger_connectors.py @@ -130,11 +130,4 @@ def test_result_collection_on_tensor_with_mean_reduction(): 'loss_1_1_1': mean, 'loss_2_1_1': mean } - assert epoch_metrics[DefaultMetricsKeys.CALLBACK] == excepted - - -def test_result_collection_restoration(): - - result_collection = ResultCollection(True) - - result_collection \ No newline at end of file + assert epoch_metrics[DefaultMetricsKeys.CALLBACK] == excepted \ No newline at end of file From db4c054dbbc160bebd45a907969ff456a2a1221a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 May 2021 12:26:49 +0000 Subject: [PATCH 113/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/core/step_result.py | 15 ++++----------- tests/core/test_metric_result_integration.py | 6 +++--- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index cecc6cbcc3dc9..b5e870ce51215 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -14,8 +14,8 @@ from collections.abc import Generator, Mapping, Sequence from copy import deepcopy from dataclasses import dataclass -from weakref import proxy from typing import Any, Callable, Dict, Iterable, NamedTuple, Optional, Tuple, Union +from weakref import proxy import torch from torch import Tensor @@ -471,23 +471,16 @@ def get_state_dict(item: ResultMetric) -> Dict[str, Any]: del state['_modules']["value"] return ResultMeta(**state) - return { - k: apply_to_collection(v, ResultMetric, get_state_dict) - for k, v in self.items() - } + return {k: apply_to_collection(v, ResultMetric, get_state_dict) for k, v in self.items()} def load_from_state_dict(self, state_dict: Dict[str, Any]): + def to_result_metric(item: ResultMeta) -> Dict[str, Any]: result_metric = ResultMetric(item["meta"]) result_metric.__dict__.update(item) return result_metric - state_dict = { - k: apply_to_collection(v, ResultMeta, to_result_metric) - for k, v in state_dict.items() - } + state_dict = {k: apply_to_collection(v, ResultMeta, to_result_metric) for k, v in state_dict.items()} for k, v in state_dict.items(): self[k] = v - - diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 1663f820202e9..ba8147f74c90c 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -13,6 +13,7 @@ # limitations under the License. from copy import deepcopy + import torch import torch.distributed as dist import torch.multiprocessing as mp @@ -83,7 +84,6 @@ def _ddp_test_fn(rank, worldsize): state_dict = result.state_dict() result = ResultCollection(True) result.load_from_state_dict(state_dict) - epoch_log = result.get_epoch_metrics()[DefaultMetricsKeys.LOG] result.reset() @@ -161,7 +161,7 @@ def test_result_collection_restoration(): result = ResultCollection(True) for _ in range(2): - + cumulative_sum = 0 for i in range(3): @@ -219,4 +219,4 @@ def test_result_collection_restoration(): # assert metric state reset to default values assert metric_a.x == metric_a._defaults['x'], (metric_a.x, metric_a._defaults['x']) assert metric_b.x == metric_b._defaults['x'] - assert metric_c.x == metric_c._defaults['x'] \ No newline at end of file + assert metric_c.x == metric_c._defaults['x'] From 1a64a7280e9c8858555d925bd6324d6da6352725 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 16:14:39 +0200 Subject: [PATCH 114/455] revert bug report model --- pl_examples/bug_report_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index 1e6bd099af1f9..abb65ba86fd93 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -53,10 +53,10 @@ def run(): model = BoringModel() trainer = Trainer( default_root_dir=os.getcwd(), - limit_train_batches=2, - limit_val_batches=2, - num_sanity_val_steps=2, - max_epochs=2, + limit_train_batches=1, + limit_val_batches=1, + num_sanity_val_steps=0, + max_epochs=1, weights_summary=None, ) trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data) From f128303fdb4121544dbc25faf86490bf82681b29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 16:14:47 +0200 Subject: [PATCH 115/455] clean up todo --- pytorch_lightning/loops/training_loop.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 916df3b861ccf..d29c0e71b7679 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -59,9 +59,6 @@ def done(self): if self.trainer.should_stop: return True - # TODO: moved to on_advance_end, check if correct? - # self.total_batch_idx += 1 - # stop epoch if we limited the number of training batches if self._num_training_batches_reached(self.is_last_batch): return True From 714c99d5c3f05e977c2ea433a9ebf6425b656f6a Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 25 May 2021 11:22:12 -0400 Subject: [PATCH 116/455] update --- pytorch_lightning/core/step_result.py | 56 ++++++----- .../logger_connector/logger_connector.py | 6 +- pytorch_lightning/trainer/evaluation_loop.py | 2 +- pytorch_lightning/trainer/training_loop.py | 2 +- tests/core/test_metric_result_integration.py | 95 +++++++++++++++---- tests/models/test_hooks.py | 3 + 6 files changed, 117 insertions(+), 47 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index cecc6cbcc3dc9..44ddb2fc716d1 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -14,7 +14,6 @@ from collections.abc import Generator, Mapping, Sequence from copy import deepcopy from dataclasses import dataclass -from weakref import proxy from typing import Any, Callable, Dict, Iterable, NamedTuple, Optional, Tuple, Union import torch @@ -44,7 +43,6 @@ class Metadata: reduce_fx: Callable = torch.mean dataloader_idx: Optional[int] = None is_tensor: bool = True - should_reset: bool = True @property def forked(self) -> bool: @@ -99,7 +97,7 @@ def update(self, value: _METRIC, batch_size: Optional[int] = None) -> None: self.value = min(self.value, value.float().mean()) else: - self.value = proxy(value) + self.value = value self._forward_cache = value._forward_cache def compute(self) -> torch.Tensor: @@ -111,10 +109,7 @@ def compute(self) -> torch.Tensor: else: raise MisconfigurationException("Only mean, max are supported.") else: - try: - return self.value.compute() - except RuntimeError: - return torch.tensor(0.) + return self.value.compute() def __repr__(self) -> str: if self.meta.is_tensor_and_mean_reduction: @@ -170,7 +165,10 @@ class ResultCollection(dict): def __init__(self, is_train: bool) -> None: super().__init__() self.is_train = is_train - self.reset() + self._on_epoch_end_reached = False + self._minimize = None + self._current_hook_name: Optional[str] = None + self._batch_idx: Optional[int] = None @property def batch_size(self) -> int: @@ -180,6 +178,14 @@ def batch_size(self) -> int: def batch_size(self, batch_size: int) -> None: self._batch_size = batch_size + @property + def batch_idx(self) -> int: + return self._batch_idx + + @batch_idx.setter + def batch_idx(self, batch_idx: int) -> None: + self._batch_idx = batch_idx + @property def on_epoch_end_reached(self) -> bool: return self._on_epoch_end_reached @@ -187,7 +193,8 @@ def on_epoch_end_reached(self) -> bool: @on_epoch_end_reached.setter def on_epoch_end_reached(self, on_epoch_end_reached): self._on_epoch_end_reached = on_epoch_end_reached - self._batch_size = None + self._minimize = None + self._batch_idx = None @property def metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: @@ -219,9 +226,6 @@ def detach_fn(v): extra = apply_to_collection(extra, torch.Tensor, detach_fn) self['extra'] = extra - def should_reset(self, hook_name: str) -> bool: - return hook_name not in ("on_train_start") - def log( self, hook_name: str, @@ -260,11 +264,12 @@ def log( on_epoch=on_epoch, reduce_fx=reduce_fx, dataloader_idx=dataloader_idx, - should_reset=self.should_reset(hook_name) ) self.instance_result_metric(key, meta, value) - self.update_metrics(key, value, batch_size or torch.tensor(1.)) + self.update_metrics(hook_name, key, value, batch_size or torch.tensor(1.)) + + self._current_hook_name = hook_name def instance_result_metric(self, key: str, meta: Metadata, value: Union[Dict, torch.Tensor]) -> None: @@ -285,7 +290,11 @@ def fn(v): self[key + '.on_epoch'] = meta.on_epoch self[key + '.dataloader_idx'] = meta.dataloader_idx - def update_metrics(self, key: str, value: Union[Dict, torch.Tensor], batch_size) -> None: + def update_metrics(self, hook_name: str, key: str, value: Union[Dict, torch.Tensor], batch_size) -> None: + + if isinstance(self._current_hook_name, str) and self._current_hook_name != hook_name and self.batch_idx in (None, 0): + # when restarting an new epoch, reset the tensor hooks dynamically. + self.reset_metrics(hook_name, is_tensor=True) def fn(result_metric, v): assert isinstance(v, (torch.Tensor, Metric)) @@ -415,17 +424,20 @@ def cpu(self) -> 'ResultCollection': """Move all data to CPU.""" return self.to(device="cpu") - def reset(self) -> None: - """Call at the end of epoch to reset all metric objects""" - + def reset_metrics(self, hook_name: str = None, is_tensor: bool = False) -> None: + """Call at the end of epoch to reset all results provided as `Metric` or `tensor`""" def reset_fn(item: ResultMetric) -> None: - if item.meta.should_reset: + nonlocal hook_name + nonlocal is_tensor + if item.meta.is_tensor == is_tensor: + if isinstance(hook_name, str) and hook_name != item.meta.fx: + return item.reset() apply_to_collection(dict(self.items()), ResultMetric, reset_fn) - self._batch_size: Optional[int] = None - self._on_epoch_end_reached: bool = False - self._minimize: Optional[Tensor] = None + + def reset(self): + self.reset_metrics() def extract_batch_size(self, batch: Any) -> None: try: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 0b83aae77e4df..8b762e5efe2d6 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -64,7 +64,7 @@ def configure_logger(self, logger): else: self.trainer.logger = logger - def on_evaluation_batch_start(self, batch, dataloader_idx, num_dataloaders): + def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: if self.trainer.sanity_checking: return @@ -74,6 +74,7 @@ def on_evaluation_batch_start(self, batch, dataloader_idx, num_dataloaders): # track batch_size self.trainer.result_collections.extract_batch_size(batch) + self.trainer.result_collections.batch_idx = batch_idx @property def should_flush_logs(self): @@ -229,8 +230,9 @@ def update_evaluation_step_metrics(self) -> None: ############## TRAIN METRICS UPDATES START ############## # noqa E266 - def on_train_split_start(self, split_batch: Any) -> None: + def on_train_split_start(self, batch_idx: int, split_batch: Any) -> None: self.trainer.result_collections.extract_batch_size(split_batch) + self.trainer.result_collections.batch_idx = batch_idx def on_train_batch_end(self) -> None: self.trainer.result_collections.batch_size = 1 diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index fb80cc927893b..cde8fc9bb714a 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -211,7 +211,7 @@ def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: # set dataloader_idx to model and track batch_size - self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self.num_dataloaders) + self.trainer.logger_connector.on_evaluation_batch_start(batch, batch_idx, dataloader_idx, self.num_dataloaders) if self.trainer.testing: self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b17f80e6b6874..593f6138fca25 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -962,7 +962,7 @@ def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): model.toggle_optimizer(optimizer, opt_idx) # use to track metrics internally - self.trainer.logger_connector.on_train_split_start(split_batch) + self.trainer.logger_connector.on_train_split_start(split_idx, split_batch) def update_running_loss(self, current_loss: torch.Tensor) -> None: if self.trainer.lightning_module.automatic_optimization: diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 1663f820202e9..48ae804564ada 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import pytest from copy import deepcopy import torch import torch.distributed as dist @@ -80,11 +80,6 @@ def _ddp_test_fn(rank, worldsize): for k in batch_expected.keys(): assert batch_expected[k] == batch_log[k] - state_dict = result.state_dict() - result = ResultCollection(True) - result.load_from_state_dict(state_dict) - - epoch_log = result.get_epoch_metrics()[DefaultMetricsKeys.LOG] result.reset() @@ -164,23 +159,27 @@ def test_result_collection_restoration(): cumulative_sum = 0 + result.on_epoch_end_reached = False + for i in range(3): + a = metric_a(i) b = metric_b(i) c = metric_c(i) cumulative_sum += i - result.log('h', 'a', metric_a, on_step=True, on_epoch=True) - result.log('h', 'b', metric_b, on_step=False, on_epoch=True) - result.log('h', 'c', metric_c, on_step=True, on_epoch=False) + result.log('training_step', 'a', metric_a, on_step=True, on_epoch=True) + result.log('training_step', 'b', metric_b, on_step=False, on_epoch=True) + result.log('training_step', 'c', metric_c, on_step=True, on_epoch=False) - result.log('m', 'a_1', a, on_step=True, on_epoch=True) - result.log('m', 'b_1', b, on_step=False, on_epoch=True) - result.log('m', 'c_1', [c, c], on_step=True, on_epoch=False) + result.log('training_step', 'a_1', a, on_step=True, on_epoch=True) + result.log('training_step', 'b_1', b, on_step=False, on_epoch=True) + result.log('training_step', 'c_1', [c, c], on_step=True, on_epoch=False) - batch_log = result.get_batch_metrics()[DefaultMetricsKeys.LOG] + batch_log = result.metrics[DefaultMetricsKeys.LOG] batch_expected = {"a_step": i, "c": i, "a_1_step": i, "c_1": [i, i]} + assert set(batch_log.keys()) == set(batch_expected.keys()) for k in batch_expected.keys(): assert batch_expected[k] == batch_log[k] @@ -193,14 +192,17 @@ def test_result_collection_restoration(): # the metric reference are lost during serialization. # they will be restored with the LightningModule state on the next step. - result.log('h', 'a', metric_a, on_step=True, on_epoch=True) - result.log('h', 'b', metric_b, on_step=False, on_epoch=True) - result.log('h', 'c', metric_c, on_step=True, on_epoch=False) + result.log('training_step', 'a', metric_a, on_step=True, on_epoch=True) + result.log('training_step', 'b', metric_b, on_step=False, on_epoch=True) + result.log('training_step', 'c', metric_c, on_step=True, on_epoch=False) assert _result.items() == result.items() - epoch_log = result.get_epoch_metrics()[DefaultMetricsKeys.LOG] - _epoch_log = _result.get_epoch_metrics()[DefaultMetricsKeys.LOG] + result.on_epoch_end_reached = True + _result.on_epoch_end_reached = True + + epoch_log = result.metrics[DefaultMetricsKeys.LOG] + _epoch_log = _result.metrics[DefaultMetricsKeys.LOG] assert epoch_log == _epoch_log @@ -213,10 +215,61 @@ def test_result_collection_restoration(): else: assert epoch_log[k] == 1 - _result.reset() - result.reset() + _result.log('train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True) + result.log('train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True) + + _result.reset_metrics() + result.reset_metrics() # assert metric state reset to default values assert metric_a.x == metric_a._defaults['x'], (metric_a.x, metric_a._defaults['x']) assert metric_b.x == metric_b._defaults['x'] - assert metric_c.x == metric_c._defaults['x'] \ No newline at end of file + assert metric_c.x == metric_c._defaults['x'] + + +def test_simple_loop(): + + result = ResultCollection(True) + + result.log('a0', 'a', torch.tensor(0.), on_step=True, on_epoch=True) + result.log('a1', 'a', torch.tensor(0.), on_step=True, on_epoch=True) + + for epoch in range(2): + + result.log('b0', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True) + result.log('b1', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True) + + for batch_idx, batch_size in enumerate(range(2)): + + result.batch_idx = batch_idx + + result.log('c0', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) + result.log('c1', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) + result.log('c2', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) + + result.on_epoch_end_reached = True + + result.log('d0', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) + result.log('d1', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) + + assert result['a0.a'].value == torch.tensor(0.) + assert result['a0.a'].cumulated_batch_size == torch.tensor(1.) + assert result['a1.a'].value == torch.tensor(0.) + assert result['a1.a'].cumulated_batch_size == torch.tensor(1.) + + assert result['b0.a'].value == torch.tensor(1.) + epoch + assert result['b0.a'].cumulated_batch_size == torch.tensor(1.) + assert result['b1.a'].value == torch.tensor(1.) + epoch + assert result['b1.a'].cumulated_batch_size == torch.tensor(1.) + + assert result['c0.a'].value == torch.tensor(4.) + epoch * (batch_size + 1) + assert result['c0.a'].cumulated_batch_size == torch.tensor(2.) + assert result['c1.a'].value == torch.tensor(4.) + epoch * (batch_size + 1) + assert result['c1.a'].cumulated_batch_size == torch.tensor(2.) + assert result['c2.a'].value == torch.tensor(4.) + epoch * (batch_size + 1) + assert result['c2.a'].cumulated_batch_size == torch.tensor(2.) + + assert result['d0.a'].value == torch.tensor(3.) + epoch + assert result['d0.a'].cumulated_batch_size == torch.tensor(1.) + assert result['d1.a'].value == torch.tensor(3.) + epoch + assert result['d1.a'].cumulated_batch_size == torch.tensor(1.) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 78f8d2c0a94e9..63981f13a9752 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -70,6 +70,7 @@ def training_epoch_end(self, outputs): logger=False, prog_bar=True, ) + import pdb; pdb.set_trace() model = CurrentModel() trainer = Trainer( @@ -81,6 +82,8 @@ def training_epoch_end(self, outputs): assert trainer.state.finished, f"Training failed with {trainer.state}" metrics = trainer.progress_bar_dict + import pdb; pdb.set_trace() + # metrics added in training step should be unchanged by epoch end method assert metrics['step_metric'] == -1 # a metric shared in both methods gets overwritten by epoch_end From 162a2b907fb5492cf533c7208692e6977a551126 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 May 2021 15:23:36 +0000 Subject: [PATCH 117/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/core/step_result.py | 4 +++- tests/core/test_metric_result_integration.py | 8 ++++---- tests/models/test_hooks.py | 7 ++++--- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 0cc57515df87d..f7d8c7f942a12 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -293,7 +293,8 @@ def fn(v): def update_metrics(self, hook_name: str, key: str, value: Union[Dict, torch.Tensor], batch_size) -> None: - if isinstance(self._current_hook_name, str) and self._current_hook_name != hook_name and self.batch_idx in (None, 0): + if isinstance(self._current_hook_name, + str) and self._current_hook_name != hook_name and self.batch_idx in (None, 0): # when restarting an new epoch, reset the tensor hooks dynamically. self.reset_metrics(hook_name, is_tensor=True) @@ -427,6 +428,7 @@ def cpu(self) -> 'ResultCollection': def reset_metrics(self, hook_name: str = None, is_tensor: bool = False) -> None: """Call at the end of epoch to reset all results provided as `Metric` or `tensor`""" + def reset_fn(item: ResultMetric) -> None: nonlocal hook_name nonlocal is_tensor diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 55d344eee5855..70756ec5f413b 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pytest from copy import deepcopy +import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp @@ -163,7 +163,7 @@ def test_result_collection_restoration(): result.on_epoch_end_reached = False for i in range(3): - + a = metric_a(i) b = metric_b(i) c = metric_c(i) @@ -180,7 +180,7 @@ def test_result_collection_restoration(): batch_log = result.metrics[DefaultMetricsKeys.LOG] batch_expected = {"a_step": i, "c": i, "a_1_step": i, "c_1": [i, i]} - + assert set(batch_log.keys()) == set(batch_expected.keys()) for k in batch_expected.keys(): assert batch_expected[k] == batch_log[k] @@ -226,7 +226,7 @@ def test_result_collection_restoration(): assert metric_a.x == metric_a._defaults['x'], (metric_a.x, metric_a._defaults['x']) assert metric_b.x == metric_b._defaults['x'] assert metric_c.x == metric_c._defaults['x'] - + def test_simple_loop(): diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 63981f13a9752..2c8e0d5a25554 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -70,7 +70,8 @@ def training_epoch_end(self, outputs): logger=False, prog_bar=True, ) - import pdb; pdb.set_trace() + import pdb + pdb.set_trace() model = CurrentModel() trainer = Trainer( @@ -82,8 +83,8 @@ def training_epoch_end(self, outputs): assert trainer.state.finished, f"Training failed with {trainer.state}" metrics = trainer.progress_bar_dict - import pdb; pdb.set_trace() - + import pdb + pdb.set_trace() # metrics added in training step should be unchanged by epoch end method assert metrics['step_metric'] == -1 # a metric shared in both methods gets overwritten by epoch_end From f59241118a57af13d11480d11b33b3df71253edb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 17:23:42 +0200 Subject: [PATCH 118/455] fix current epoch when dumping chkpt --- pytorch_lightning/loops/epoch_loop.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 79f0dea990533..466ba965d9a60 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -185,8 +185,11 @@ def on_run_end(self): # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates # when a checkpoint was saved at the last step self.training_loop.global_step -= 1 + # the iteration_count/current_epoch is already incremented + self.current_epoch -= 1 # TODO: see discussion/rework https://github.com/PyTorchLightning/pytorch-lightning/issues/7406 self.check_checkpoint_callback(should_update=True, is_last=True) + self.current_epoch += 1 self.training_loop.global_step += 1 # hook From cca5426fe536030068a5693e3eab181242f68128 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 25 May 2021 11:37:59 -0400 Subject: [PATCH 119/455] update --- pytorch_lightning/core/step_result.py | 8 +++----- tests/core/test_metric_result_integration.py | 19 ++++++++++++------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 0cc57515df87d..314ecf382107e 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -364,11 +364,8 @@ def is_empty_fn(v): is_empty = False # apply detection. - apply_to_collection(value, object, is_empty_fn, wrong_dtype=( - Mapping, - Sequence, - NamedTuple, - )) + wrong_dtype = (Mapping, Sequence, NamedTuple,) + apply_to_collection(value, object, is_empty_fn, wrong_dtype=wrong_dtype) # skip is the value was actually empty. if is_empty: @@ -439,6 +436,7 @@ def reset_fn(item: ResultMetric) -> None: def reset(self): self.reset_metrics() + self.on_epoch_end_reached = False def extract_batch_size(self, batch: Any) -> None: try: diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 55d344eee5855..71a930f3369de 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -96,7 +96,7 @@ def _ddp_test_fn(rank, worldsize): assert epoch_expected[k] == epoch_log[k] -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, min_gpus=2) def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.set_random_master_port() @@ -159,9 +159,7 @@ def test_result_collection_restoration(): for _ in range(2): cumulative_sum = 0 - - result.on_epoch_end_reached = False - + for i in range(3): a = metric_a(i) @@ -170,7 +168,10 @@ def test_result_collection_restoration(): cumulative_sum += i + import pdb; pdb.set_trace() result.log('training_step', 'a', metric_a, on_step=True, on_epoch=True) + import pdb; pdb.set_trace() + result.log('training_step', 'b', metric_b, on_step=False, on_epoch=True) result.log('training_step', 'c', metric_c, on_step=True, on_epoch=False) @@ -219,8 +220,10 @@ def test_result_collection_restoration(): _result.log('train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True) result.log('train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True) - _result.reset_metrics() - result.reset_metrics() + _result.reset() + result.reset() + + print(result) # assert metric state reset to default values assert metric_a.x == metric_a._defaults['x'], (metric_a.x, metric_a._defaults['x']) @@ -228,7 +231,7 @@ def test_result_collection_restoration(): assert metric_c.x == metric_c._defaults['x'] -def test_simple_loop(): +def test_result_collection_simple_loop(): result = ResultCollection(True) @@ -274,3 +277,5 @@ def test_simple_loop(): assert result['d0.a'].cumulated_batch_size == torch.tensor(1.) assert result['d1.a'].value == torch.tensor(3.) + epoch assert result['d1.a'].cumulated_batch_size == torch.tensor(1.) + + result.reset() From cca2a84382635ccd579301beb615dcf0156d0a89 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 May 2021 15:38:53 +0000 Subject: [PATCH 120/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/core/step_result.py | 6 +++++- tests/core/test_metric_result_integration.py | 12 +++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 031a90cd318ae..3e5011778bd40 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -365,7 +365,11 @@ def is_empty_fn(v): is_empty = False # apply detection. - wrong_dtype = (Mapping, Sequence, NamedTuple,) + wrong_dtype = ( + Mapping, + Sequence, + NamedTuple, + ) apply_to_collection(value, object, is_empty_fn, wrong_dtype=wrong_dtype) # skip is the value was actually empty. diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index bd428b3d5256a..37ca104b6a209 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -159,7 +159,7 @@ def test_result_collection_restoration(): for _ in range(2): cumulative_sum = 0 - + for i in range(3): a = metric_a(i) @@ -168,10 +168,12 @@ def test_result_collection_restoration(): cumulative_sum += i - import pdb; pdb.set_trace() + import pdb + pdb.set_trace() result.log('training_step', 'a', metric_a, on_step=True, on_epoch=True) - import pdb; pdb.set_trace() - + import pdb + pdb.set_trace() + result.log('training_step', 'b', metric_b, on_step=False, on_epoch=True) result.log('training_step', 'c', metric_c, on_step=True, on_epoch=False) @@ -277,5 +279,5 @@ def test_result_collection_simple_loop(): assert result['d0.a'].cumulated_batch_size == torch.tensor(1.) assert result['d1.a'].value == torch.tensor(3.) + epoch assert result['d1.a'].cumulated_batch_size == torch.tensor(1.) - + result.reset() From 41930011ecbe21ddfb30e0645b63b40457dc1297 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 25 May 2021 11:39:16 -0400 Subject: [PATCH 121/455] update --- pytorch_lightning/core/step_result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 031a90cd318ae..0ac15d7d7bae2 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -194,7 +194,6 @@ def on_epoch_end_reached(self) -> bool: @on_epoch_end_reached.setter def on_epoch_end_reached(self, on_epoch_end_reached): self._on_epoch_end_reached = on_epoch_end_reached - self._minimize = None self._batch_idx = None @property @@ -439,6 +438,7 @@ def reset_fn(item: ResultMetric) -> None: def reset(self): self.reset_metrics() self.on_epoch_end_reached = False + self._minimize = None def extract_batch_size(self, batch: Any) -> None: try: From bef6a2dc982fdfe459a92a6dd6cee7557d0b5245 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 17:50:59 +0200 Subject: [PATCH 122/455] integrate #7701 --- pytorch_lightning/loops/training_loop.py | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index d29c0e71b7679..03293ed7cf164 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -46,22 +46,11 @@ def connect(self, trainer: 'pl.Trainer', *args, **kwargs): @property def done(self): - # max steps reached, end training - if ( - self.max_steps is not None and self.max_steps <= self.global_step + 1 - and self.batch_loop._accumulated_batches_reached() - ): - return True - - # end epoch early - # stop when the flag is changed or we've gone past the amount - # requested in the batches - if self.trainer.should_stop: - return True - - # stop epoch if we limited the number of training batches - if self._num_training_batches_reached(self.is_last_batch): - return True + max_steps_reached = ( + self.max_steps is not None and self.max_steps <= self.global_step + 1 + and self.batch_loop._accumulated_batches_reached() + ) + return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch) def run(self, *args, **kwargs): self.on_run_start() @@ -368,7 +357,7 @@ def should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: boo if on_epoch and is_last_batch and is_infinite_dataset: return True - if on_epoch and self.trainer.should_stop: + if self.trainer.should_stop: return True # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch From c90faa8f5547e7cfe3b52c263ac8787761329a3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 17:51:23 +0200 Subject: [PATCH 123/455] fix current epoch at end of training --- pytorch_lightning/loops/epoch_loop.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 466ba965d9a60..04f353c280e59 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -182,14 +182,16 @@ def on_run_end(self): return self._teardown_already_run = True + # NOTE: the iteration_count/current_epoch is already incremented + # Lightning today does not increment the current epoch at the last epoch run in Trainer.fit + # To simulate that current behavior, we decrement here. + self.current_epoch -= 1 + # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates # when a checkpoint was saved at the last step self.training_loop.global_step -= 1 - # the iteration_count/current_epoch is already incremented - self.current_epoch -= 1 # TODO: see discussion/rework https://github.com/PyTorchLightning/pytorch-lightning/issues/7406 self.check_checkpoint_callback(should_update=True, is_last=True) - self.current_epoch += 1 self.training_loop.global_step += 1 # hook From 3e9fcd54e1d905599a717b2c3fda82c927a2b9be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 17:57:23 +0200 Subject: [PATCH 124/455] fix access to get_active_optimizers --- pytorch_lightning/loops/epoch_loop.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 04f353c280e59..0efec706cb092 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -1,9 +1,10 @@ import logging from contextlib import suppress from copy import deepcopy -from typing import Any, List, Optional +from typing import Any, List, Optional, Tuple import torch +from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.callbacks.early_stopping import EarlyStopping @@ -215,7 +216,7 @@ def on_run_end(self): def should_accumulate(self): return self.training_loop.batch_loop.should_accumulate() - def get_active_optimizers(self, batch_idx): + def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, Optimizer]]: return self.training_loop.batch_loop.get_active_optimizers(batch_idx) def check_checkpoint_callback(self, should_update, is_last=False): From 0b91f63f5f8925ff90268c491756570fd70dd8e4 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 25 May 2021 12:03:22 -0400 Subject: [PATCH 125/455] update --- pytorch_lightning/core/step_result.py | 18 +++++++- tests/core/test_metric_result_integration.py | 48 ++++++++++---------- 2 files changed, 42 insertions(+), 24 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index f85e5c0f9742d..e58cef56cf49b 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -44,6 +44,7 @@ class Metadata: reduce_fx: Callable = torch.mean dataloader_idx: Optional[int] = None is_tensor: bool = True + lightning_attribute_name: Optional[str] = None @property def forked(self) -> bool: @@ -239,6 +240,7 @@ def log( enable_graph: bool = False, dataloader_idx: Optional[int] = None, batch_size: Optional[int] = None, + lightning_attribute_name: Optional[str] = None, ): """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" # no metrics should be logged with graphs @@ -254,6 +256,9 @@ def log( if dataloader_idx: key += f'.{dataloader_idx}' + if on_step and self.on_epoch_end_reached: + raise MisconfigurationException("Logging `on_step` after `on_epoch_end_reached` isn't authorized.") + if key not in self: meta = Metadata( fx=hook_name, @@ -264,6 +269,7 @@ def log( on_epoch=on_epoch, reduce_fx=reduce_fx, dataloader_idx=dataloader_idx, + lightning_attribute_name=lightning_attribute_name, ) self.instance_result_metric(key, meta, value) @@ -490,7 +496,7 @@ def get_state_dict(item: ResultMetric) -> Dict[str, Any]: return {k: apply_to_collection(v, ResultMetric, get_state_dict) for k, v in self.items()} - def load_from_state_dict(self, state_dict: Dict[str, Any]): + def load_from_state_dict(self, state_dict: Dict[str, Any], metrics: Dict[str, Metric] = None): def to_result_metric(item: ResultMeta) -> Dict[str, Any]: result_metric = ResultMetric(item["meta"]) @@ -501,3 +507,13 @@ def to_result_metric(item: ResultMeta) -> Dict[str, Any]: for k, v in state_dict.items(): self[k] = v + + if metrics is not None: + + def re_assign_metric(item): + nonlocal metrics + lightning_attribute_name = item.meta.lightning_attribute_name + if isinstance(lightning_attribute_name, str) and lightning_attribute_name in metrics: + item.value = metrics[lightning_attribute_name] + + apply_to_collection(dict(self.items()), ResultMetric, re_assign_metric) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 37ca104b6a209..ef70a6fd2dd7c 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -156,7 +156,7 @@ def test_result_collection_restoration(): result = ResultCollection(True) - for _ in range(2): + for epoch in range(2): cumulative_sum = 0 @@ -168,15 +168,13 @@ def test_result_collection_restoration(): cumulative_sum += i - import pdb - pdb.set_trace() - result.log('training_step', 'a', metric_a, on_step=True, on_epoch=True) - import pdb - pdb.set_trace() - - result.log('training_step', 'b', metric_b, on_step=False, on_epoch=True) - result.log('training_step', 'c', metric_c, on_step=True, on_epoch=False) - + result.log('training_step', 'a', metric_a, on_step=True, on_epoch=True, lightning_attribute_name="metric_a") + result.log( + 'training_step', 'b', metric_b, on_step=False, on_epoch=True, lightning_attribute_name="metric_b" + ) + result.log( + 'training_step', 'c', metric_c, on_step=True, on_epoch=False, lightning_attribute_name="metric_c" + ) result.log('training_step', 'a_1', a, on_step=True, on_epoch=True) result.log('training_step', 'b_1', b, on_step=False, on_epoch=True) result.log('training_step', 'c_1', [c, c], on_step=True, on_epoch=False) @@ -192,13 +190,15 @@ def test_result_collection_restoration(): state_dict = result.state_dict() result = ResultCollection(True) - result.load_from_state_dict(state_dict) - - # the metric reference are lost during serialization. - # they will be restored with the LightningModule state on the next step. - result.log('training_step', 'a', metric_a, on_step=True, on_epoch=True) - result.log('training_step', 'b', metric_b, on_step=False, on_epoch=True) - result.log('training_step', 'c', metric_c, on_step=True, on_epoch=False) + result.load_from_state_dict( + state_dict, + metrics={ + "metric_a": metric_a, + "metric_b": metric_b, + "metric_c": metric_c, + "metric_a_end": metric_a + } + ) assert _result.items() == result.items() @@ -210,23 +210,25 @@ def test_result_collection_restoration(): assert epoch_log == _epoch_log - epoch_expected = {'a_epoch', 'b', 'b_1', 'a_1_epoch'} + if epoch == 0: + epoch_expected = {'a_epoch', 'b', 'b_1', 'a_1_epoch'} + else: + epoch_expected = {'a_epoch', 'b', 'b_1', 'a_1_epoch', 'a'} assert set(epoch_log.keys()) == epoch_expected for k in list(epoch_expected): - if k in {'a_epoch', 'b'}: + if k in {'a_epoch', 'b', 'a'}: assert epoch_log[k] == cumulative_sum else: assert epoch_log[k] == 1 - _result.log('train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True) - result.log('train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True) + result.log( + 'train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True, lightning_attribute_name="metric_a_end" + ) _result.reset() result.reset() - print(result) - # assert metric state reset to default values assert metric_a.x == metric_a._defaults['x'], (metric_a.x, metric_a._defaults['x']) assert metric_b.x == metric_b._defaults['x'] From 68c72a64430545da65ebe5b7c1b20c4bba07bcd3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 May 2021 16:09:26 +0000 Subject: [PATCH 126/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/training_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 03293ed7cf164..c2223a8a3c391 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -47,8 +47,8 @@ def connect(self, trainer: 'pl.Trainer', *args, **kwargs): @property def done(self): max_steps_reached = ( - self.max_steps is not None and self.max_steps <= self.global_step + 1 - and self.batch_loop._accumulated_batches_reached() + self.max_steps is not None and self.max_steps <= self.global_step + 1 + and self.batch_loop._accumulated_batches_reached() ) return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch) From bbe763330c51241ef61fa35612aea3eb41e62751 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 18:17:45 +0200 Subject: [PATCH 127/455] update chlog with refactors --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 14dbee08920be..f2d97525b1de1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -61,7 +61,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506)) * Moved attributes `hiddens` and `split_idx` to TrainLoop ([#7507](https://github.com/PyTorchLightning/pytorch-lightning/pull/7507)) * Refactored the logic around manual and automatic optimization inside the optimizer loop ([#7526](https://github.com/PyTorchLightning/pytorch-lightning/pull/7526)) - + * Refactored logic deciding when to run validation ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682)) + * Simplified logic for updating the learning rate for schedulers ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682)) + * Removed the `on_epoch` guard from the "should stop" validation check ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701)) - Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/)) From 20082759704bc6ac794c844882a7d0c3e9c3ce7c Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 25 May 2021 12:59:33 -0400 Subject: [PATCH 128/455] update --- pytorch_lightning/core/step_result.py | 7 ++++--- pytorch_lightning/trainer/training_loop.py | 5 ++++- tests/core/test_metric_result_integration.py | 3 +-- .../loops/test_training_loop_flow_scalar.py | 17 +++++++++++++---- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index e58cef56cf49b..033b8453f72e9 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -448,7 +448,6 @@ def reset_fn(item: ResultMetric) -> None: def reset(self): self.reset_metrics() self.on_epoch_end_reached = False - self._minimize = None def extract_batch_size(self, batch: Any) -> None: try: @@ -496,7 +495,7 @@ def get_state_dict(item: ResultMetric) -> Dict[str, Any]: return {k: apply_to_collection(v, ResultMetric, get_state_dict) for k, v in self.items()} - def load_from_state_dict(self, state_dict: Dict[str, Any], metrics: Dict[str, Metric] = None): + def load_from_state_dict(self, state_dict: Dict[str, Any], metrics: Dict[str, Metric]): def to_result_metric(item: ResultMeta) -> Dict[str, Any]: result_metric = ResultMetric(item["meta"]) @@ -508,7 +507,9 @@ def to_result_metric(item: ResultMeta) -> Dict[str, Any]: for k, v in state_dict.items(): self[k] = v - if metrics is not None: + if metrics: + # the metric reference are lost during serialization and + # they need to be set back during loading def re_assign_metric(item): nonlocal metrics diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 593f6138fca25..0939ce4debbcd 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -382,7 +382,10 @@ def _prepare_outputs( for tbptt_output in batch_outputs: out = tbptt_output.extra - out['loss'] = tbptt_output.minimize.detach() + loss = tbptt_output.minimize + if isinstance(loss, torch.Tensor): + loss = loss.detach() + out['loss'] = loss processed_tbptt_outputs.append(out) # if there was only one tbptt step then we can collapse that dimension diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index ef70a6fd2dd7c..4c54323760e68 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -191,8 +191,7 @@ def test_result_collection_restoration(): result = ResultCollection(True) result.load_from_state_dict( - state_dict, - metrics={ + state_dict, { "metric_a": metric_a, "metric_b": metric_b, "metric_c": metric_c, diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index 2f503b62f56ee..2c93b8205c59f 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -22,6 +22,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.trainer.states import RunningStage from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.deterministic_model import DeterministicModel from tests.helpers.utils import no_warning_call @@ -149,6 +150,8 @@ def backward(self, loss, optimizer, optimizer_idx): assert len(trainer.logger_connector.callback_metrics) == 0 assert len(trainer.logger_connector.progress_bar_metrics) == 0 + trainer.state.stage = RunningStage.TRAINING + # make sure training outputs what is expected for batch_idx, batch in enumerate(model.train_dataloader()): break @@ -160,8 +163,8 @@ def backward(self, loss, optimizer, optimizer_idx): train_step_out = out.training_step_output_for_epoch_end assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] - assert isinstance(train_step_out['minimize'], torch.Tensor) - assert train_step_out['minimize'].item() == 171 + assert isinstance(train_step_out.minimize, torch.Tensor) + assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things opt_closure_result = trainer.train_loop.training_step_and_backward( @@ -229,6 +232,8 @@ def backward(self, loss, optimizer, optimizer_idx): assert len(trainer.logger_connector.callback_metrics) == 0 assert len(trainer.logger_connector.progress_bar_metrics) == 0 + trainer.state.stage = RunningStage.TRAINING + # make sure training outputs what is expected for batch_idx, batch in enumerate(model.train_dataloader()): break @@ -240,8 +245,8 @@ def backward(self, loss, optimizer, optimizer_idx): train_step_out = out.training_step_output_for_epoch_end assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] - assert isinstance(train_step_out['minimize'], torch.Tensor) - assert train_step_out['minimize'].item() == 171 + assert isinstance(train_step_out.minimize, torch.Tensor) + assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things opt_closure_result = trainer.train_loop.training_step_and_backward( @@ -316,6 +321,8 @@ def training_step(self, batch, batch_idx): with pytest.warns(UserWarning, match=r'.*training_step returned None.*'): trainer.fit(model) + trainer.state.stage = RunningStage.TRAINING + # manually check a few batches for batch_idx, batch in enumerate(model.train_dataloader()): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) @@ -359,6 +366,8 @@ def train_dataloader(self): with pytest.warns(UserWarning, match=r'.*train_dataloader yielded None.*'): trainer.fit(model) + trainer.state.stage = RunningStage.TRAINING + # manually check a few batches for batch_idx, batch in enumerate(model.train_dataloader()): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) From 8583aa1fd80988809f7a4a92975758be9ca7dec9 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 25 May 2021 13:55:35 -0400 Subject: [PATCH 129/455] update --- pytorch_lightning/core/step_result.py | 15 ++++++++--- .../logger_connector/logger_connector.py | 27 +++++++++++-------- pytorch_lightning/trainer/evaluation_loop.py | 3 +++ pytorch_lightning/trainer/properties.py | 2 +- pytorch_lightning/trainer/training_loop.py | 1 + tests/models/test_hooks.py | 4 --- .../loops/test_evaluation_loop_flow.py | 13 ++++++--- 7 files changed, 42 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 033b8453f72e9..50fbb8fc1452a 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -171,6 +171,7 @@ def __init__(self, is_train: bool) -> None: self._minimize = None self._current_hook_name: Optional[str] = None self._batch_idx: Optional[int] = None + self._root_device: Optional[torch.device] = None @property def batch_size(self) -> int: @@ -180,6 +181,14 @@ def batch_size(self) -> int: def batch_size(self, batch_size: int) -> None: self._batch_size = batch_size + @property + def root_device(self) -> Optional[torch.device]: + return self._root_device + + @root_device.setter + def root_device(self, root_device: torch.device) -> None: + self._root_device = root_device + @property def batch_idx(self) -> int: return self._batch_idx @@ -280,12 +289,12 @@ def log( def instance_result_metric(self, key: str, meta: Metadata, value: Union[Dict, torch.Tensor]) -> None: def fn(v): + assert self.root_device is not None nonlocal meta meta = deepcopy(meta) meta.is_tensor = torch.is_tensor(v) metric = ResultMetric(meta) - device = getattr(v, "device", torch.device("cpu")) - return metric.to(device) + return metric.to(self.root_device) self[key] = apply_to_collection(value, (torch.Tensor, Metric), fn) # cache the meta for reduction @@ -305,7 +314,7 @@ def update_metrics(self, hook_name: str, key: str, value: Union[Dict, torch.Tens def fn(result_metric, v): assert isinstance(v, (torch.Tensor, Metric)) - result_metric(v, batch_size) + result_metric(v.to(self.root_device), batch_size) apply_to_collections(self[key], value, ResultMetric, fn) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 8b762e5efe2d6..883ebe714e01c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -125,9 +125,6 @@ def log_metrics(self, metrics, grad_norm_dict, step=None): self.add_logged_metrics(scalar_metrics) def evaluation_epoch_end(self): - if self.trainer.sanity_checking: - return - # reset dataloader idx model_ref = self.trainer.lightning_module model_ref._current_dataloader_idx = None @@ -158,13 +155,13 @@ def prepare_eval_loop_results(self): self.add_to_eval_loop_results(dl_idx, has_been_initialized) def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: - if not self.trainer.sanity_checking: + metrics = self.trainer.result_collections.metrics - metrics = self.trainer.result_collections.metrics + # update metrics + self.add_progress_bar_metrics(metrics[DefaultMetricsKeys.PBAR]) + self.add_callback_metrics(metrics[DefaultMetricsKeys.CALLBACK]) - # update metrics - self.add_progress_bar_metrics(metrics[DefaultMetricsKeys.PBAR]) - self.add_callback_metrics(metrics[DefaultMetricsKeys.CALLBACK]) + if not self.trainer.sanity_checking: # log all the metrics as a single dict metrics_to_log = metrics[DefaultMetricsKeys.LOG] @@ -209,15 +206,15 @@ def increment_evaluation_log_step(self) -> None: self._test_log_step += 1 def update_evaluation_step_metrics(self) -> None: - if self.trainer.sanity_checking: - return - metrics = self.trainer.result_collections.metrics # update metrics self.add_progress_bar_metrics(metrics[DefaultMetricsKeys.PBAR]) self.add_callback_metrics(metrics[DefaultMetricsKeys.CALLBACK]) + if self.trainer.sanity_checking: + return + batch_log_metrics = metrics[DefaultMetricsKeys.LOG] # logs user requested information to logger @@ -228,8 +225,16 @@ def update_evaluation_step_metrics(self) -> None: # increment the step even if nothing was logged self.increment_evaluation_log_step() + def on_evaluation_start(self): + root_device = self.trainer.lightning_module.device + self.trainer.result_collections.root_device = root_device + ############## TRAIN METRICS UPDATES START ############## # noqa E266 + def on_train_start(self): + root_device = self.trainer.lightning_module.device + self.trainer.result_collections.root_device = root_device + def on_train_split_start(self, batch_idx: int, split_batch: Any) -> None: self.trainer.result_collections.extract_batch_size(split_batch) self.trainer.result_collections.batch_idx = batch_idx diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index cde8fc9bb714a..e7c13425e9870 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -79,6 +79,9 @@ def should_skip_evaluation(self, max_batches: List[Union[int, float]]) -> bool: def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end() + + self.trainer.logger_connector.on_evaluation_start() + if self.trainer.testing: self.trainer.call_hook('on_test_start', *args, **kwargs) else: diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 57d06d99dd773..fa31078b16577 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -516,7 +516,7 @@ def min_steps(self) -> Optional[int]: def result_collections(self) -> Optional[ResultCollection]: if self.training: return self.train_loop.train_results - elif self.validating: + elif self.validating or self.sanity_checking: return self.evaluation_loop.validation_results elif self.testing: return self.evaluation_loop.test_results diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 0939ce4debbcd..4ef36f44cabde 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -102,6 +102,7 @@ def should_skip_training(self) -> bool: def on_train_start(self): # hook + self.trainer.logger_connector.on_train_start() self.trainer.call_hook("on_train_start") def on_train_end(self): diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 2c8e0d5a25554..78f8d2c0a94e9 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -70,8 +70,6 @@ def training_epoch_end(self, outputs): logger=False, prog_bar=True, ) - import pdb - pdb.set_trace() model = CurrentModel() trainer = Trainer( @@ -83,8 +81,6 @@ def training_epoch_end(self, outputs): assert trainer.state.finished, f"Training failed with {trainer.state}" metrics = trainer.progress_bar_dict - import pdb - pdb.set_trace() # metrics added in training step should be unchanged by epoch end method assert metrics['step_metric'] == -1 # a metric shared in both methods gets overwritten by epoch_end diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py index 67ed756630734..075b9f7438124 100644 --- a/tests/trainer/loops/test_evaluation_loop_flow.py +++ b/tests/trainer/loops/test_evaluation_loop_flow.py @@ -19,6 +19,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.trainer.states import RunningStage from tests.helpers.deterministic_model import DeterministicModel @@ -65,6 +66,8 @@ def backward(self, loss, optimizer, optimizer_idx): assert not model.validation_step_end_called assert not model.validation_epoch_end_called + trainer.state.stage = RunningStage.TRAINING + # make sure training outputs what is expected for batch_idx, batch in enumerate(model.train_dataloader()): break @@ -76,8 +79,8 @@ def backward(self, loss, optimizer, optimizer_idx): train_step_out = out.training_step_output_for_epoch_end assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] - assert isinstance(train_step_out['minimize'], torch.Tensor) - assert train_step_out['minimize'].item() == 171 + assert isinstance(train_step_out.minimize, torch.Tensor) + assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things opt_closure_result = trainer.train_loop.training_step_and_backward( @@ -138,6 +141,8 @@ def backward(self, loss, optimizer, optimizer_idx): assert model.validation_step_end_called assert not model.validation_epoch_end_called + trainer.state.stage = RunningStage.TRAINING + # make sure training outputs what is expected for batch_idx, batch in enumerate(model.train_dataloader()): break @@ -149,8 +154,8 @@ def backward(self, loss, optimizer, optimizer_idx): train_step_out = out.training_step_output_for_epoch_end assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] - assert isinstance(train_step_out['minimize'], torch.Tensor) - assert train_step_out['minimize'].item() == 171 + assert isinstance(train_step_out.minimize, torch.Tensor) + assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things opt_closure_result = trainer.train_loop.training_step_and_backward( From 417da98aef287f39d8ed461ee3c5ebf7cf2a0123 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 25 May 2021 14:45:01 -0400 Subject: [PATCH 130/455] update --- pytorch_lightning/core/step_result.py | 17 +++--- .../logger_connector/logger_connector.py | 29 +++++------ pytorch_lightning/trainer/training_loop.py | 6 +-- .../logging_/test_train_loop_logging.py | 52 ++++++++----------- 4 files changed, 47 insertions(+), 57 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 50fbb8fc1452a..d4f02f8401595 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -107,9 +107,9 @@ def compute(self) -> torch.Tensor: if self.meta.is_tensor_and_mean_reduction: return torch.sum(self.value) / torch.sum(self.cumulated_batch_size) elif self.meta.is_tensor_and_max_reduction or self.meta.is_tensor_and_min_reduction: - return self.meta.fx(self.value) + return self.value else: - raise MisconfigurationException("Only mean, max are supported.") + raise MisconfigurationException("Only min, mean, max reduction are supported.") else: return self.value.compute() @@ -170,12 +170,13 @@ def __init__(self, is_train: bool) -> None: self._on_epoch_end_reached = False self._minimize = None self._current_hook_name: Optional[str] = None + self._batch_size: Optional[int] = None self._batch_idx: Optional[int] = None self._root_device: Optional[torch.device] = None @property - def batch_size(self) -> int: - return self._batch_size + def batch_size(self) -> 1: + return self._batch_size or 1 @batch_size.setter def batch_size(self, batch_size: int) -> None: @@ -190,7 +191,7 @@ def root_device(self, root_device: torch.device) -> None: self._root_device = root_device @property - def batch_idx(self) -> int: + def batch_idx(self) -> Optional[int]: return self._batch_idx @batch_idx.setter @@ -282,7 +283,9 @@ def log( ) self.instance_result_metric(key, meta, value) - self.update_metrics(hook_name, key, value, batch_size or torch.tensor(1.)) + batch_size = torch.tensor(batch_size or self.batch_size, device=self.root_device) + + self.update_metrics(hook_name, key, value, batch_size) self._current_hook_name = hook_name @@ -314,7 +317,7 @@ def update_metrics(self, hook_name: str, key: str, value: Union[Dict, torch.Tens def fn(result_metric, v): assert isinstance(v, (torch.Tensor, Metric)) - result_metric(v.to(self.root_device), batch_size) + result_metric(v.to(self.root_device), batch_size.to(self.root_device)) apply_to_collections(self[key], value, ResultMetric, fn) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 883ebe714e01c..c26d9fc8e14e2 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -64,18 +64,6 @@ def configure_logger(self, logger): else: self.trainer.logger = logger - def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: - if self.trainer.sanity_checking: - return - - model = self.trainer.lightning_module - # set dataloader_idx only if multiple ones - model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None - - # track batch_size - self.trainer.result_collections.extract_batch_size(batch) - self.trainer.result_collections.batch_idx = batch_idx - @property def should_flush_logs(self): should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0 @@ -205,6 +193,19 @@ def increment_evaluation_log_step(self) -> None: elif self.trainer.state.stage is RunningStage.TESTING: self._test_log_step += 1 + def on_evaluation_start(self): + root_device = self.trainer.lightning_module.device + self.trainer.result_collections.root_device = root_device + + def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: + model = self.trainer.lightning_module + # set dataloader_idx only if multiple ones + model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None + + # track batch_size + self.trainer.result_collections.extract_batch_size(batch) + self.trainer.result_collections.batch_idx = batch_idx + def update_evaluation_step_metrics(self) -> None: metrics = self.trainer.result_collections.metrics @@ -225,10 +226,6 @@ def update_evaluation_step_metrics(self) -> None: # increment the step even if nothing was logged self.increment_evaluation_log_step() - def on_evaluation_start(self): - root_device = self.trainer.lightning_module.device - self.trainer.result_collections.root_device = root_device - ############## TRAIN METRICS UPDATES START ############## # noqa E266 def on_train_start(self): diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 4ef36f44cabde..8455431be60a8 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -703,7 +703,7 @@ def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimi # opt_idx=0 to opt_idx=None in the signature here # toggle model params + set info to logger_connector - self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) + self.run_train_split_start(batch_idx, split_idx, split_batch, opt_idx, optimizer) result = AttributeDict() closure = self.make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens, result) @@ -958,7 +958,7 @@ def save_loggers_on_train_batch_end(self): if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() - def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): + def run_train_split_start(self, batch_idx: int, split_idx, split_batch, opt_idx, optimizer): # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1: @@ -966,7 +966,7 @@ def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): model.toggle_optimizer(optimizer, opt_idx) # use to track metrics internally - self.trainer.logger_connector.on_train_split_start(split_idx, split_batch) + self.trainer.logger_connector.on_train_split_start(batch_idx, split_batch) def update_running_loss(self, current_loss: torch.Tensor) -> None: if self.trainer.lightning_module.automatic_optimization: diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 7f9acb4cac8a1..fe2e74f8ccf32 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -508,7 +508,7 @@ class TestCallback(callbacks.Callback): count = 1 choices = [False, True] # used to compute expected values - callback_funcs_called = collections.defaultdict(list) + callback_funcs_called = collections.defaultdict(dict) funcs_called_count = collections.defaultdict(int) funcs_attr = {} @@ -517,22 +517,18 @@ def make_logging( ): self.funcs_called_count[func_name] += 1 iterate = list(itertools.product(*[on_steps, on_epochs, prob_bars])) + value = self.count * func_idx + + current_epoch = pl_module.trainer.current_epoch + for idx, (on_step, on_epoch, prog_bar) in enumerate(iterate): # run logging custom_func_name = f"{func_idx}_{idx}_{func_name}" - pl_module.log( - custom_func_name, self.count * func_idx, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar - ) - - # catch information for verification - - # on on_train_start is outside the main loop. Won't be called - if func_name == "on_train_start": - self.callback_funcs_called[func_name].append([self.count * func_idx]) + pl_module.log(custom_func_name, value, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar) - # Saved only values from second epoch, so we can compute its mean or latest. - if pl_module.trainer.current_epoch == 1: - self.callback_funcs_called[func_name].append([self.count * func_idx]) + if current_epoch not in self.callback_funcs_called[custom_func_name]: + self.callback_funcs_called[custom_func_name][current_epoch] = [] + self.callback_funcs_called[custom_func_name][current_epoch].append(value) forked = on_step and on_epoch @@ -561,6 +557,8 @@ def make_logging( "func_name": func_name } + self.count += 1 + def on_train_start(self, trainer, pl_module): self.make_logging( pl_module, 'on_train_start', 1, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices @@ -595,10 +593,6 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data on_epochs=self.choices, prob_bars=self.choices ) - # used to make sure aggregation works fine. - # we should obtain func[value * c for c in range(1, max_epochs * limit_train_batches)]) - # with func = np.mean if on_epoch else func = np.max - self.count += 1 def on_train_epoch_end(self, trainer, pl_module): self.make_logging( @@ -647,15 +641,6 @@ def training_step(self, batch, batch_idx): assert test_callback.funcs_called_count["on_epoch_end"] == 2 assert test_callback.funcs_called_count["on_train_epoch_end"] == 2 - # Make sure the func_name exists within callback_metrics. If not, we missed some - callback_metrics_keys = [*trainer.callback_metrics.keys()] - for func_name in test_callback.callback_funcs_called.keys(): - is_in = False - for callback_metrics_key in callback_metrics_keys: - if func_name in callback_metrics_key: - is_in = True - assert is_in, (func_name, callback_metrics_keys) - # function used to describe expected return logic def get_expected_output(func_attr, original_values): if func_attr["on_epoch"] and not func_attr["on_step"]: @@ -680,11 +665,16 @@ def get_expected_output(func_attr, original_values): func_attr = test_callback.funcs_attr[func_name] # retrived orginal logged values - original_values = test_callback.callback_funcs_called[func_attr["func_name"]] - - # compute expected output and compare to actual one - expected_output = get_expected_output(func_attr, original_values) - assert float(output_value) == float(expected_output) + values = test_callback.callback_funcs_called[func_name] + if len(values) > 0: + original_values = values[len(values) - 1] + # compute expected output and compare to actual one + expected_output = get_expected_output(func_attr, original_values) + try: + assert float(output_value) == float(expected_output) + except: + import pdb + pdb.set_trace() for func_name, func_attr in test_callback.funcs_attr.items(): if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]) and not func_attr["forked"]: From cd8603ce8b605db0d484552ae105bf964d669c94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 20:54:12 +0200 Subject: [PATCH 131/455] avoid pickle errors --- pytorch_lightning/loops/training_loop.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index c2223a8a3c391..cfbc5cfd314e8 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -177,6 +177,11 @@ def on_run_end(self): self.trainer.call_hook('on_epoch_end') return self.epoch_output + def __getstate__(self): + # avoid pickling errors "cannot pickle generator object" + self._train_dataloader = None + return self.__dict__ + # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP From 05567c8dcf8459a3316c030f2e504f395daf1c52 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 25 May 2021 14:56:28 -0400 Subject: [PATCH 132/455] update --- pytorch_lightning/core/step_result.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 7 +------ tests/trainer/logging_/test_eval_loop_logging.py | 3 +++ 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index d4f02f8401595..c859d1c069c41 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -444,13 +444,13 @@ def cpu(self) -> 'ResultCollection': """Move all data to CPU.""" return self.to(device="cpu") - def reset_metrics(self, hook_name: str = None, is_tensor: bool = False) -> None: + def reset_metrics(self, hook_name: str = None, is_tensor: Optional[bool] = None) -> None: """Call at the end of epoch to reset all results provided as `Metric` or `tensor`""" def reset_fn(item: ResultMetric) -> None: nonlocal hook_name nonlocal is_tensor - if item.meta.is_tensor == is_tensor: + if is_tensor is None or item.meta.is_tensor == is_tensor: if isinstance(hook_name, str) and hook_name != item.meta.fx: return item.reset() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 973dc6444263e..6033bbe7c8d37 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1128,17 +1128,12 @@ def _run_sanity_check(self, ref_model): self.state.stage = stage # reset metrics - self._reset_metrics(ref_model) + self.result_collections.reset() # reset the seed to what it was before sanity check # prevents sanity check to affect random sampling in training reset_seed() - def _reset_metrics(self, ref_model): - for module in ref_model.modules(): - if isinstance(module, Metric): - module.reset() - def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]: if ckpt_path is None: return diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 83523addff99a..a6ab93ca5345e 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -282,6 +282,7 @@ def validation_epoch_end(self, outputs) -> None: log_every_n_steps=1, weights_summary=None, callbacks=[ModelCheckpoint(dirpath=tmpdir)], + num_sanity_val_steps=2, ) trainer.fit(model) @@ -291,6 +292,8 @@ def validation_epoch_end(self, outputs) -> None: assert callback_metrics == {'debug_epoch', 'val_loss', 'val_loss_epoch'} # make sure values are correct + import pdb + pdb.set_trace() assert trainer.logged_metrics['val_loss_epoch'] == manual_mean assert trainer.callback_metrics['val_loss_epoch'] == manual_mean assert trainer.callback_metrics['val_loss'] == manual_mean From d193c38f1ea1425086c60cba32b3c3d7d3bfe002 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 21:07:12 +0200 Subject: [PATCH 133/455] update changelog --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f2d97525b1de1..0acd4e3f93f17 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -64,7 +64,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Refactored logic deciding when to run validation ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682)) * Simplified logic for updating the learning rate for schedulers ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682)) * Removed the `on_epoch` guard from the "should stop" validation check ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701)) - + ... + * Refactored internal loop interface; added new classes `EpochLoop`, `TrainingLoop`, `BatchLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700)) + - Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/)) From 87052b52f4188e680a033a31080b20a857d996c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 21:27:44 +0200 Subject: [PATCH 134/455] re-enable gradient norm tracking --- pytorch_lightning/loops/batch_loop.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 15ecc34bbc253..b3d889e4e995a 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -62,9 +62,8 @@ def run(self, batch, batch_idx, dataloader_idx): output = AttributeDict( signal=0, - # todo: Properly aggregate grad_norm accros opt_idx and split_idx - # grad_norm_dict=grad_norm_dict, - grad_norm_dict={}, + # TODO: Properly aggregate grad_norm accross opt_idx and split_idx + grad_norm_dict=self.grad_norm_dicts[-1], training_step_output_for_epoch_end=self.batch_outputs, ) return output @@ -75,6 +74,7 @@ def on_run_start(self, batch, batch_idx, dataloader_idx): # TODO: let loops track individual outputs self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] + self.grad_norm_dicts = [] def advance(self, batch, batch_idx, dataloader_idx): split_idx, split_batch = self._remaining_splits.pop(0) @@ -82,6 +82,7 @@ def advance(self, batch, batch_idx, dataloader_idx): # TODO: this list needs to go outside this loop # batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] + grad_norm_dict = {} if self.trainer.lightning_module.automatic_optimization: for opt_idx, optimizer in self.get_active_optimizers(batch_idx): @@ -95,6 +96,9 @@ def advance(self, batch, batch_idx, dataloader_idx): if result: self.batch_outputs[0].append(result.training_step_output_for_epoch_end) + # TODO: Properly aggregate grad_norm accross opt_idx and split_idx + self.grad_norm_dicts.append(grad_norm_dict) + # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP From 01be317c0ecbee4bb2c9232446be203071b67d68 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 25 May 2021 15:32:58 -0400 Subject: [PATCH 135/455] resolve logging tests --- pytorch_lightning/core/step_result.py | 26 ++++++++++++++----- .../logger_connector/logger_connector.py | 2 +- pytorch_lightning/trainer/evaluation_loop.py | 3 +-- pytorch_lightning/trainer/trainer.py | 6 ++--- .../logging_/test_eval_loop_logging.py | 3 +-- .../logging_/test_train_loop_logging.py | 6 +---- 6 files changed, 26 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index c859d1c069c41..1df26df3db339 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -154,6 +154,7 @@ def forward(self, *args, **kwargs): return self._forward_cache +# placeholder for apply_to_collection class ResultMeta(Dict): pass @@ -263,8 +264,9 @@ def log( key = f"{hook_name}.{name}" - if dataloader_idx: + if dataloader_idx is not None: key += f'.{dataloader_idx}' + hook_name += f'.{dataloader_idx}' if on_step and self.on_epoch_end_reached: raise MisconfigurationException("Logging `on_step` after `on_epoch_end_reached` isn't authorized.") @@ -308,12 +310,16 @@ def fn(v): self[key + '.on_epoch'] = meta.on_epoch self[key + '.dataloader_idx'] = meta.dataloader_idx - def update_metrics(self, hook_name: str, key: str, value: Union[Dict, torch.Tensor], batch_size) -> None: + def should_reset_tensors(self, hook_name: str) -> bool: + return (self._current_hook_name != hook_name and self._batch_idx in (None, 0)) + + def update_metrics( + self, hook_name: str, key: str, value: Union[Dict, torch.Tensor], batch_size: torch.Tensor + ) -> None: - if isinstance(self._current_hook_name, - str) and self._current_hook_name != hook_name and self.batch_idx in (None, 0): + if self.should_reset_tensors(hook_name): # when restarting an new epoch, reset the tensor hooks dynamically. - self.reset_metrics(hook_name, is_tensor=True) + self._reset_metrics(hook_name, is_tensor=True) def fn(result_metric, v): assert isinstance(v, (torch.Tensor, Metric)) @@ -444,7 +450,7 @@ def cpu(self) -> 'ResultCollection': """Move all data to CPU.""" return self.to(device="cpu") - def reset_metrics(self, hook_name: str = None, is_tensor: Optional[bool] = None) -> None: + def _reset_metrics(self, hook_name: str = None, is_tensor: Optional[bool] = None) -> None: """Call at the end of epoch to reset all results provided as `Metric` or `tensor`""" def reset_fn(item: ResultMetric) -> None: @@ -457,9 +463,15 @@ def reset_fn(item: ResultMetric) -> None: apply_to_collection(dict(self.items()), ResultMetric, reset_fn) + def reset_metrics(self): + self._reset_metrics(is_tensor=False) + self.on_epoch_end_reached = False + self._current_hook_name = None + def reset(self): - self.reset_metrics() + self._reset_metrics() self.on_epoch_end_reached = False + self._current_hook_name = None def extract_batch_size(self, batch: Any) -> None: try: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index c26d9fc8e14e2..95e51c3bd9286 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -284,7 +284,7 @@ def update_train_epoch_metrics(self) -> None: self.log_metrics(epoch_log_metrics, {}) # reset result collection for next epoch - self.trainer.result_collections.reset() + self.trainer.result_collections.reset_metrics() ############## TRAIN METRICS UPDATES END ############## # noqa E266 diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index e7c13425e9870..f4735304c97f0 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -102,8 +102,7 @@ def on_evaluation_model_train(self) -> None: model_ref.on_validation_model_train() def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: - if self.trainer.result_collections: - self.trainer.result_collections.reset() + self.trainer.result_collections.reset_metrics() if self.trainer.testing: self.trainer.call_hook('on_test_end', *args, **kwargs) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6033bbe7c8d37..556f4b4942497 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1125,15 +1125,15 @@ def _run_sanity_check(self, ref_model): self.on_sanity_check_end() - self.state.stage = stage - - # reset metrics + # reset validation metrics self.result_collections.reset() # reset the seed to what it was before sanity check # prevents sanity check to affect random sampling in training reset_seed() + self.state.stage = stage + def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]: if ckpt_path is None: return diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index a6ab93ca5345e..824bc742d0dc5 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -292,8 +292,6 @@ def validation_epoch_end(self, outputs) -> None: assert callback_metrics == {'debug_epoch', 'val_loss', 'val_loss_epoch'} # make sure values are correct - import pdb - pdb.set_trace() assert trainer.logged_metrics['val_loss_epoch'] == manual_mean assert trainer.callback_metrics['val_loss_epoch'] == manual_mean assert trainer.callback_metrics['val_loss'] == manual_mean @@ -853,6 +851,7 @@ def test_step(self, batch, batch_idx): limit_test_batches=2, max_epochs=2, progress_bar_refresh_rate=1, + num_sanity_val_steps=2, ) # Train the model ⚡ diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index fe2e74f8ccf32..77f9033b4fe7b 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -670,11 +670,7 @@ def get_expected_output(func_attr, original_values): original_values = values[len(values) - 1] # compute expected output and compare to actual one expected_output = get_expected_output(func_attr, original_values) - try: - assert float(output_value) == float(expected_output) - except: - import pdb - pdb.set_trace() + assert float(output_value) == float(expected_output) for func_name, func_attr in test_callback.funcs_attr.items(): if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]) and not func_attr["forked"]: From c847e66178bb56f19856b243b4fa61f70a774cb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 21:36:07 +0200 Subject: [PATCH 136/455] fix attribute errors in gradient clipping tests --- tests/trainer/test_trainer.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a8567db70d0a6..7e16430d94977 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -894,21 +894,21 @@ def test_gradient_clipping(tmpdir): default_root_dir=tmpdir, ) - trainer.train_loop.old_training_step_and_backward = trainer.train_loop.training_step_and_backward + old_training_step_and_backward = trainer.train_loop.training_loop.batch_loop.training_step_and_backward def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): """ wrap the forward step in a closure so second order methods work """ # test that gradient is clipped correctly - ret_val = trainer.train_loop.old_training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) + ret_val = old_training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) parameters = model.parameters() grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) assert (grad_norm - 1.0).abs() < 0.01, "Gradient norm != 1.0: {grad_norm}".format(grad_norm=grad_norm) return ret_val - trainer.train_loop.training_step_and_backward = training_step_and_backward + trainer.train_loop.training_loop.batch_loop.training_step_and_backward = training_step_and_backward # for the test model.prev_called_batch_idx = 0 @@ -932,14 +932,14 @@ def test_gradient_clipping_by_value(tmpdir): default_root_dir=tmpdir ) - trainer.train_loop.old_training_step_and_backward = trainer.train_loop.training_step_and_backward + old_training_step_and_backward = trainer.train_loop.training_loop.batch_loop.training_step_and_backward def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): """ wrap the forward step in a closure so second order methods work """ # test that gradient is clipped correctly - ret_val = trainer.train_loop.old_training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) + ret_val = old_training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) parameters = model.parameters() grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters] grad_max = torch.max(torch.stack(grad_max_list)) @@ -948,7 +948,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde return ret_val - trainer.train_loop.training_step_and_backward = training_step_and_backward + trainer.train_loop.training_loop.batch_loop.training_step_and_backward = training_step_and_backward # for the test model.prev_called_batch_idx = 0 @@ -1012,14 +1012,14 @@ def test_gradient_clipping_by_value_fp16(tmpdir): default_root_dir=tmpdir, ) - trainer.train_loop.old_training_step_and_backward = trainer.train_loop.training_step_and_backward + old_training_step_and_backward = trainer.train_loop.training_loop.batch_loop.training_step_and_backward def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): """ wrap the forward step in a closure so second order methods work """ # test that gradient is clipped correctly - ret_val = trainer.train_loop.old_training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) + ret_val = old_training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) parameters = model.parameters() grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters] grad_max = torch.max(torch.stack(grad_max_list)) @@ -1028,7 +1028,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde return ret_val - trainer.train_loop.training_step_and_backward = training_step_and_backward + trainer.train_loop.training_loop.batch_loop.training_step_and_backward = training_step_and_backward model.prev_called_batch_idx = 0 trainer.fit(model) From fe3929831c13b05d79b8af7ca59502214b38dfdc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 May 2021 21:56:41 +0200 Subject: [PATCH 137/455] fix skip_backward attribute error for swa callback --- .../callbacks/stochastic_weight_avg.py | 4 ++-- pytorch_lightning/loops/batch_loop.py | 24 ++++++++++++------- pytorch_lightning/loops/epoch_loop.py | 10 ++++++++ tests/callbacks/test_stochastic_weight_avg.py | 4 ++-- 4 files changed, 30 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 3ec7774d5f8b6..236145d00f4a8 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -220,12 +220,12 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo # performing only one pass over the train data-loader to compute activation statistics # Therefore, we will virtually increase `num_training_batches` by 1 and skip backward. trainer.num_training_batches += 1 - trainer.train_loop._skip_backward = True + trainer.train_loop.skip_backward = True self._accumulate_grad_batches = trainer.accumulate_grad_batches trainer.accumulate_grad_batches = len(trainer.train_dataloader) def on_train_epoch_end(self, trainer: 'pl.Trainer', *args): - trainer.train_loop._skip_backward = False + trainer.train_loop.skip_backward = False def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): if self._model_contains_batch_norm and trainer.current_epoch == self.swa_end + 1: diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index b3d889e4e995a..f60267f7ea379 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -12,7 +12,7 @@ from pytorch_lightning.core.step_result import Result from pytorch_lightning.loops.base import Loop from pytorch_lightning.plugins import ParallelPlugin -from pytorch_lightning.trainer.supporters import prefetch_iterator, TensorRunningAccum +from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.finite_checks import detect_nan_parameters @@ -26,16 +26,24 @@ class BatchLoop(Loop): def __init__(self): super().__init__() - # self.accumulated_loss = None # TODO: needs to be done over epoch - self.warning_cache = WarningCache() - # self._teardown_already_run = False - self.running_loss = TensorRunningAccum(window_length=20) self.accumulated_loss = None - self._skip_backward = False + self.running_loss = TensorRunningAccum(window_length=20) + self.split_idx = None + self.warning_cache = WarningCache() + self._hiddens = None self._optimizer_freq_cumsum = None + self._skip_backward = False - self.split_idx = None + @property + def skip_backward(self) -> bool: + """ Determines whether the loop will skip backward during automatic optimization. """ + return self._skip_backward + + @skip_backward.setter + def skip_backward(self, value: bool): + """ Determines whether the loop will skip backward during automatic optimization. """ + self._skip_backward = value def connect(self, trainer, *args, **kwargs): self.trainer = trainer @@ -433,7 +441,7 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, # lightning module hook result = self.training_step(split_batch, batch_idx, opt_idx, hiddens) - if not self._skip_backward and self.trainer.lightning_module.automatic_optimization: + if not self.skip_backward and self.trainer.lightning_module.automatic_optimization: is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0 if is_first_batch_to_accumulate: diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 0efec706cb092..152618b0e7d6d 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -79,6 +79,16 @@ def max_steps(self, value): def running_loss(self): return self.training_loop.batch_loop.running_loss + @property + def skip_backward(self) -> bool: + """ Determines whether the loop will skip backward during automatic optimization. """ + return self.training_loop.batch_loop.skip_backward + + @skip_backward.setter + def skip_backward(self, value: bool): + """ Determines whether the loop will skip backward during automatic optimization. """ + self.training_loop.batch_loop.skip_backward = value + def connect(self, trainer: 'pl.Trainer', *args, **kwargs): self.trainer = trainer self.training_loop.connect(trainer) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 81efc12b34662..b8bb5e220eda9 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -74,7 +74,7 @@ def transfer_weights(self, *args, **kwargs): def on_train_epoch_start(self, trainer, *args): super().on_train_epoch_start(trainer, *args) - assert trainer.train_loop._skip_backward == (trainer.current_epoch > self.swa_end) + assert trainer.train_loop.skip_backward == (trainer.current_epoch > self.swa_end) if self.swa_start <= trainer.current_epoch: assert isinstance(trainer.lr_schedulers[0]["scheduler"], SWALR) assert trainer.lr_schedulers[0]["interval"] == "epoch" @@ -92,7 +92,7 @@ def on_train_end(self, trainer, pl_module): super().on_train_end(trainer, pl_module) # make sure these are correctly set again - assert not trainer.train_loop._skip_backward + assert not trainer.train_loop.skip_backward assert trainer.accumulate_grad_batches == 2 assert trainer.num_training_batches == 5 From 8f0d9fee94bea79a3e6a95016e9a430a000b7f76 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 May 2021 19:57:35 +0000 Subject: [PATCH 138/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0acd4e3f93f17..4c5ac07987eaf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,7 +66,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Removed the `on_epoch` guard from the "should stop" validation check ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701)) ... * Refactored internal loop interface; added new classes `EpochLoop`, `TrainingLoop`, `BatchLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700)) - + - Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/)) From f32b4045b83bee99b80840f5b866654f88ffe13a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 May 2021 03:13:21 +0200 Subject: [PATCH 139/455] fix tests for train flow --- .../trainer/loops/test_training_loop_flow_scalar.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index 2f503b62f56ee..b025d9c00bf35 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -153,7 +153,7 @@ def backward(self, loss, optimizer, optimizer_idx): for batch_idx, batch in enumerate(model.train_dataloader()): break - out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) + out = trainer.train_loop.training_loop.batch_loop.run(batch, batch_idx, 0) assert out.signal == 0 assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) @@ -164,7 +164,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out['minimize'].item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.train_loop.training_step_and_backward( + opt_closure_result = trainer.train_loop.training_loop.batch_loop.training_step_and_backward( batch, batch_idx, 0, @@ -233,7 +233,7 @@ def backward(self, loss, optimizer, optimizer_idx): for batch_idx, batch in enumerate(model.train_dataloader()): break - out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) + out = trainer.train_loop.training_loop.batch_loop.run(batch, batch_idx, 0) assert out.signal == 0 assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) @@ -244,7 +244,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out['minimize'].item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.train_loop.training_step_and_backward( + opt_closure_result = trainer.train_loop.training_loop.batch_loop.training_step_and_backward( batch, batch_idx, 0, trainer.optimizers[0], hiddens=None ) assert opt_closure_result['loss'].item() == 171 @@ -318,7 +318,7 @@ def training_step(self, batch, batch_idx): # manually check a few batches for batch_idx, batch in enumerate(model.train_dataloader()): - out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) + out = trainer.train_loop.training_loop.batch_loop.run(batch, batch_idx, 0) if not batch_idx % 2: assert out.training_step_output_for_epoch_end == [[]] assert out.signal == 0 @@ -361,7 +361,7 @@ def train_dataloader(self): # manually check a few batches for batch_idx, batch in enumerate(model.train_dataloader()): - out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) + out = trainer.train_loop.training_loop.batch_loop.run(batch, batch_idx, 0) if not batch_idx % 2: assert out.training_step_output_for_epoch_end == [[]] assert out.signal == 0 From d728c4e2aa705eaa6fe0208085bb9c0df82887a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 May 2021 03:17:13 +0200 Subject: [PATCH 140/455] fix attribute errors in eval loop flow tests --- tests/trainer/loops/test_evaluation_loop_flow.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py index 67ed756630734..1a971f6a12df1 100644 --- a/tests/trainer/loops/test_evaluation_loop_flow.py +++ b/tests/trainer/loops/test_evaluation_loop_flow.py @@ -69,7 +69,7 @@ def backward(self, loss, optimizer, optimizer_idx): for batch_idx, batch in enumerate(model.train_dataloader()): break - out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) + out = trainer.train_loop.training_loop.batch_loop.run(batch, batch_idx, 0) assert out.signal == 0 assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) @@ -80,7 +80,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out['minimize'].item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.train_loop.training_step_and_backward( + opt_closure_result = trainer.train_loop.training_loop.batch_loop.training_step_and_backward( batch, batch_idx, 0, @@ -142,7 +142,7 @@ def backward(self, loss, optimizer, optimizer_idx): for batch_idx, batch in enumerate(model.train_dataloader()): break - out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) + out = trainer.train_loop.training_loop.batch_loop.run(batch, batch_idx, 0) assert out.signal == 0 assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) @@ -153,7 +153,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out['minimize'].item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.train_loop.training_step_and_backward( + opt_closure_result = trainer.train_loop.training_loop.batch_loop.training_step_and_backward( batch, batch_idx, 0, trainer.optimizers[0], hiddens=None ) assert opt_closure_result['loss'].item() == 171 From 0a242cb094b40a65239f26a589b8eda7602e61e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 May 2021 03:29:18 +0200 Subject: [PATCH 141/455] fix attribute error to warning cache --- tests/deprecated_api/test_remove_1-5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index d6c9b6d8f8f31..21a2f2d8e0f0f 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -242,7 +242,7 @@ def on_train_epoch_end(self, outputs): # noqa with pytest.deprecated_call(match="old signature will be removed in v1.5"): trainer.fit(model) - trainer.train_loop.warning_cache.clear() + trainer.train_loop.training_loop.warning_cache.clear() class NewSignature(Callback): From adc1322f7755c09fdbf03a6919187b70ca21ffd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 May 2021 03:30:59 +0200 Subject: [PATCH 142/455] fix attribute error --- tests/trainer/test_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 7e16430d94977..3426277a89b10 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -973,21 +973,21 @@ def test_gradient_clipping_fp16(tmpdir): default_root_dir=tmpdir, ) - trainer.train_loop.old_training_step_and_backward = trainer.train_loop.training_step_and_backward + old_training_step_and_backward = trainer.train_loop.training_loop.batch_loop.training_step_and_backward def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): """ wrap the forward step in a closure so second order methods work """ # test that gradient is clipped correctly - ret_val = trainer.train_loop.old_training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) + ret_val = old_training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) parameters = model.parameters() grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) assert (grad_norm - 1.0).abs() < 0.01, "Gradient norm != 1.0: {grad_norm}".format(grad_norm=grad_norm) return ret_val - trainer.train_loop.training_step_and_backward = training_step_and_backward + trainer.train_loop.training_loop.batch_loop.training_step_and_backward = training_step_and_backward model.prev_called_batch_idx = 0 trainer.fit(model) From 15244568ef46de283cf0522474a62c57afacd5bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 May 2021 03:39:09 +0200 Subject: [PATCH 143/455] undo comments --- pytorch_lightning/trainer/training_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index dfa5d4e8d0919..62138790138ee 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -311,8 +311,8 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): result = AttributeDict( closure_loss=closure_loss, loss=untouched_loss, - training_step_output=training_step_output, # Result object - training_step_output_for_epoch_end=training_step_output_for_epoch_end, # Result object + training_step_output=training_step_output, + training_step_output_for_epoch_end=training_step_output_for_epoch_end, ) return result From 8d813f2d082dd523cb236afc510bb26a12d4170c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 May 2021 03:48:03 +0200 Subject: [PATCH 144/455] refactor references to training_loop / batch loop objects --- pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/trainer/properties.py | 16 ++++++++++++++++ tests/deprecated_api/test_remove_1-5.py | 2 +- tests/trainer/loops/test_evaluation_loop_flow.py | 8 ++++---- .../loops/test_training_loop_flow_scalar.py | 12 ++++++------ tests/trainer/test_trainer.py | 16 ++++++++-------- 6 files changed, 36 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index cb368854144c1..90986b7eb81bd 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1324,7 +1324,7 @@ def training_step(...): # backward self._running_manual_backward = True - self.trainer.train_loop.training_loop.batch_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs) + self.trainer.batch_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs) self._running_manual_backward = False def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None: diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index e469d1bc12394..87f1a1b97cfb3 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -28,6 +28,9 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger +from pytorch_lightning.loops.batch_loop import BatchLoop +from pytorch_lightning.loops.epoch_loop import EpochLoop +from pytorch_lightning.loops.training_loop import TrainingLoop from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector @@ -491,6 +494,19 @@ def sanity_checking(self, val: bool) -> None: elif self.sanity_checking: self.state.stage = None + @property + def epoch_loop(self) -> EpochLoop: + # TODO: the current train_loop should be renamed to epoch_loop + return self.train_loop + + @property + def training_loop(self) -> TrainingLoop: + return self.epoch_loop.training_loop + + @property + def batch_loop(self) -> BatchLoop: + return self.epoch_loop.training_loop.batch_loop + @property def global_step(self) -> int: return self.train_loop.global_step diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 21a2f2d8e0f0f..4edf42854fabb 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -242,7 +242,7 @@ def on_train_epoch_end(self, outputs): # noqa with pytest.deprecated_call(match="old signature will be removed in v1.5"): trainer.fit(model) - trainer.train_loop.training_loop.warning_cache.clear() + trainer.training_loop.warning_cache.clear() class NewSignature(Callback): diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py index 1a971f6a12df1..f55fa1ca94512 100644 --- a/tests/trainer/loops/test_evaluation_loop_flow.py +++ b/tests/trainer/loops/test_evaluation_loop_flow.py @@ -69,7 +69,7 @@ def backward(self, loss, optimizer, optimizer_idx): for batch_idx, batch in enumerate(model.train_dataloader()): break - out = trainer.train_loop.training_loop.batch_loop.run(batch, batch_idx, 0) + out = trainer.batch_loop.run(batch, batch_idx, 0) assert out.signal == 0 assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) @@ -80,7 +80,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out['minimize'].item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.train_loop.training_loop.batch_loop.training_step_and_backward( + opt_closure_result = trainer.batch_loop.training_step_and_backward( batch, batch_idx, 0, @@ -142,7 +142,7 @@ def backward(self, loss, optimizer, optimizer_idx): for batch_idx, batch in enumerate(model.train_dataloader()): break - out = trainer.train_loop.training_loop.batch_loop.run(batch, batch_idx, 0) + out = trainer.batch_loop.run(batch, batch_idx, 0) assert out.signal == 0 assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) @@ -153,7 +153,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out['minimize'].item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.train_loop.training_loop.batch_loop.training_step_and_backward( + opt_closure_result = trainer.batch_loop.training_step_and_backward( batch, batch_idx, 0, trainer.optimizers[0], hiddens=None ) assert opt_closure_result['loss'].item() == 171 diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index b025d9c00bf35..6c8c02933a269 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -153,7 +153,7 @@ def backward(self, loss, optimizer, optimizer_idx): for batch_idx, batch in enumerate(model.train_dataloader()): break - out = trainer.train_loop.training_loop.batch_loop.run(batch, batch_idx, 0) + out = trainer.batch_loop.run(batch, batch_idx, 0) assert out.signal == 0 assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) @@ -164,7 +164,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out['minimize'].item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.train_loop.training_loop.batch_loop.training_step_and_backward( + opt_closure_result = trainer.batch_loop.training_step_and_backward( batch, batch_idx, 0, @@ -233,7 +233,7 @@ def backward(self, loss, optimizer, optimizer_idx): for batch_idx, batch in enumerate(model.train_dataloader()): break - out = trainer.train_loop.training_loop.batch_loop.run(batch, batch_idx, 0) + out = trainer.batch_loop.run(batch, batch_idx, 0) assert out.signal == 0 assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) @@ -244,7 +244,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out['minimize'].item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.train_loop.training_loop.batch_loop.training_step_and_backward( + opt_closure_result = trainer.batch_loop.training_step_and_backward( batch, batch_idx, 0, trainer.optimizers[0], hiddens=None ) assert opt_closure_result['loss'].item() == 171 @@ -318,7 +318,7 @@ def training_step(self, batch, batch_idx): # manually check a few batches for batch_idx, batch in enumerate(model.train_dataloader()): - out = trainer.train_loop.training_loop.batch_loop.run(batch, batch_idx, 0) + out = trainer.batch_loop.run(batch, batch_idx, 0) if not batch_idx % 2: assert out.training_step_output_for_epoch_end == [[]] assert out.signal == 0 @@ -361,7 +361,7 @@ def train_dataloader(self): # manually check a few batches for batch_idx, batch in enumerate(model.train_dataloader()): - out = trainer.train_loop.training_loop.batch_loop.run(batch, batch_idx, 0) + out = trainer.batch_loop.run(batch, batch_idx, 0) if not batch_idx % 2: assert out.training_step_output_for_epoch_end == [[]] assert out.signal == 0 diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 3426277a89b10..865bd07bb43d3 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -894,7 +894,7 @@ def test_gradient_clipping(tmpdir): default_root_dir=tmpdir, ) - old_training_step_and_backward = trainer.train_loop.training_loop.batch_loop.training_step_and_backward + old_training_step_and_backward = trainer.batch_loop.training_step_and_backward def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): """ @@ -908,7 +908,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde return ret_val - trainer.train_loop.training_loop.batch_loop.training_step_and_backward = training_step_and_backward + trainer.batch_loop.training_step_and_backward = training_step_and_backward # for the test model.prev_called_batch_idx = 0 @@ -932,7 +932,7 @@ def test_gradient_clipping_by_value(tmpdir): default_root_dir=tmpdir ) - old_training_step_and_backward = trainer.train_loop.training_loop.batch_loop.training_step_and_backward + old_training_step_and_backward = trainer.batch_loop.training_step_and_backward def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): """ @@ -948,7 +948,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde return ret_val - trainer.train_loop.training_loop.batch_loop.training_step_and_backward = training_step_and_backward + trainer.batch_loop.training_step_and_backward = training_step_and_backward # for the test model.prev_called_batch_idx = 0 @@ -973,7 +973,7 @@ def test_gradient_clipping_fp16(tmpdir): default_root_dir=tmpdir, ) - old_training_step_and_backward = trainer.train_loop.training_loop.batch_loop.training_step_and_backward + old_training_step_and_backward = trainer.batch_loop.training_step_and_backward def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): """ @@ -987,7 +987,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde return ret_val - trainer.train_loop.training_loop.batch_loop.training_step_and_backward = training_step_and_backward + trainer.batch_loop.training_step_and_backward = training_step_and_backward model.prev_called_batch_idx = 0 trainer.fit(model) @@ -1012,7 +1012,7 @@ def test_gradient_clipping_by_value_fp16(tmpdir): default_root_dir=tmpdir, ) - old_training_step_and_backward = trainer.train_loop.training_loop.batch_loop.training_step_and_backward + old_training_step_and_backward = trainer.batch_loop.training_step_and_backward def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): """ @@ -1028,7 +1028,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde return ret_val - trainer.train_loop.training_loop.batch_loop.training_step_and_backward = training_step_and_backward + trainer.batch_loop.training_step_and_backward = training_step_and_backward model.prev_called_batch_idx = 0 trainer.fit(model) From 51737ce276969a2b03f38fa0f81ac62370b0cebd Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 26 May 2021 03:49:06 -0400 Subject: [PATCH 145/455] update --- pytorch_lightning/core/step_result.py | 13 +++++-- pytorch_lightning/trainer/training_loop.py | 8 +++-- tests/core/test_metric_result_integration.py | 23 ++++++++---- tests/models/test_hooks.py | 2 ++ .../connectors/test_logger_connectors.py | 35 +++++++++++++------ tests/trainer/test_trainer.py | 2 +- 6 files changed, 60 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 1df26df3db339..c0bec1f283886 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -45,6 +45,7 @@ class Metadata: dataloader_idx: Optional[int] = None is_tensor: bool = True lightning_attribute_name: Optional[str] = None + has_reset: bool = False @property def forked(self) -> bool: @@ -126,6 +127,8 @@ def reset(self): else: self.value.reset() + self.meta.has_reset = True + def forward(self, *args, **kwargs): """ Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True. @@ -165,7 +168,7 @@ class ResultCollection(dict): EPOCH_SUFFIX = "_epoch" DATALOADER_SUFFIX = "/dataloader_idx_{}" - def __init__(self, is_train: bool) -> None: + def __init__(self, is_train: bool, root_device: Optional[torch.device] = None) -> None: super().__init__() self.is_train = is_train self._on_epoch_end_reached = False @@ -173,10 +176,10 @@ def __init__(self, is_train: bool) -> None: self._current_hook_name: Optional[str] = None self._batch_size: Optional[int] = None self._batch_idx: Optional[int] = None - self._root_device: Optional[torch.device] = None + self._root_device: Optional[torch.device] = root_device @property - def batch_size(self) -> 1: + def batch_size(self) -> int: return self._batch_size or 1 @batch_size.setter @@ -324,6 +327,7 @@ def update_metrics( def fn(result_metric, v): assert isinstance(v, (torch.Tensor, Metric)) result_metric(v.to(self.root_device), batch_size.to(self.root_device)) + result_metric.meta.has_reset = False apply_to_collections(self[key], value, ResultMetric, fn) @@ -342,6 +346,9 @@ def valid_metrics(self) -> Generator: for key, item in self.items(): if item is None or isinstance(item, bool) or key == "extra": continue + elif isinstance(item, ResultMetric) and item.meta.has_reset: + # skip the metric which have been reset. + continue yield (key, item) def _extract_metadata(self, key: str, result_metric, on_step: bool, suffix: str) -> Tuple: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 8455431be60a8..9145a7c8f8ed0 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -299,6 +299,9 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): # accumulate loss. if accumulate_grad_batches==1, no effect closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches + # detach the loss + training_step_output._minimize = training_step_output.minimize.detach() + # the loss will get scaled for amp. avoid any modifications to it untouched_loss = closure_loss.detach().clone() @@ -340,11 +343,10 @@ def _process_training_step_output(self, training_step_output, split_batch): result.minimize = loss self._hiddens = hiddens - training_step_output_for_epoch_end = copy(result) if self.trainer.move_metrics_to_cpu: - training_step_output_for_epoch_end = training_step_output_for_epoch_end.cpu() + result = result.cpu() - return training_step_output_for_epoch_end, result + return result, result @staticmethod def _prepare_outputs( diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 4c54323760e68..8162f9bd4f59f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -110,12 +110,18 @@ def test_result_metric_integration(): metric_b = DummyMetric() metric_c = DummyMetric() - result = ResultCollection(True) + result = ResultCollection(True, torch.device("cpu")) for _ in range(3): cumulative_sum = 0 + result.on_epoch_end_reached = False + for i in range(5): + + # need to set batch_idx + result.batch_idx = i + metric_a(i) metric_b(i) metric_c(i) @@ -132,6 +138,8 @@ def test_result_metric_integration(): for k in batch_expected.keys(): assert batch_expected[k] == batch_log[k] + result.on_epoch_end_reached = True + epoch_log = result.get_epoch_metrics()[DefaultMetricsKeys.LOG] result.reset() @@ -154,14 +162,17 @@ def test_result_collection_restoration(): metric_b = DummyMetric() metric_c = DummyMetric() - result = ResultCollection(True) + result = ResultCollection(True, torch.device("cpu")) for epoch in range(2): + result.on_epoch_end_reached = False cumulative_sum = 0 for i in range(3): + result.batch_idx = i + a = metric_a(i) b = metric_b(i) c = metric_c(i) @@ -189,7 +200,7 @@ def test_result_collection_restoration(): _result = deepcopy(result) state_dict = result.state_dict() - result = ResultCollection(True) + result = ResultCollection(True, torch.device("cpu")) result.load_from_state_dict( state_dict, { "metric_a": metric_a, @@ -236,13 +247,15 @@ def test_result_collection_restoration(): def test_result_collection_simple_loop(): - result = ResultCollection(True) + result = ResultCollection(True, torch.device("cpu")) result.log('a0', 'a', torch.tensor(0.), on_step=True, on_epoch=True) result.log('a1', 'a', torch.tensor(0.), on_step=True, on_epoch=True) for epoch in range(2): + result.on_epoch_end_reached = False + result.log('b0', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True) result.log('b1', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True) @@ -280,5 +293,3 @@ def test_result_collection_simple_loop(): assert result['d0.a'].cumulated_batch_size == torch.tensor(1.) assert result['d1.a'].value == torch.tensor(3.) + epoch assert result['d1.a'].cumulated_batch_size == torch.tensor(1.) - - result.reset() diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 78f8d2c0a94e9..a0d5f06a7fa48 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -81,6 +81,8 @@ def training_epoch_end(self, outputs): assert trainer.state.finished, f"Training failed with {trainer.state}" metrics = trainer.progress_bar_dict + import pdb + pdb.set_trace() # metrics added in training step should be unchanged by epoch end method assert metrics['step_metric'] == -1 # a metric shared in both methods gets overwritten by epoch_end diff --git a/tests/trainer/connectors/test_logger_connectors.py b/tests/trainer/connectors/test_logger_connectors.py index a0a01df0ff858..d1e41c76e420f 100644 --- a/tests/trainer/connectors/test_logger_connectors.py +++ b/tests/trainer/connectors/test_logger_connectors.py @@ -22,11 +22,16 @@ def test_result_collection_on_tensor_with_mean_reduction(): seed_everything(42) - result_collection = ResultCollection() + result_collection = ResultCollection(True, torch.device("cpu")) for i in range(1, 10): + + result_collection.batch_idx = i + for prob_bar in [False, True]: + for logger in [False, True]: + result_collection.log( "training_step", f"loss_1_{int(prob_bar)}_{int(logger)}", @@ -99,17 +104,23 @@ def test_result_collection_on_tensor_with_mean_reduction(): assert batch_metrics[DefaultMetricsKeys.LOG] == excepted excepted = { - 'loss_1_0_0': tensor([9.]), - 'loss_3_0_0': tensor([9.]), - 'loss_1_0_1': tensor([9.]), - 'loss_3_0_1': tensor([9.]), - 'loss_1_1_0': tensor([9.]), - 'loss_3_1_0': tensor([9.]), - 'loss_1_1_1': tensor([9.]), - 'loss_3_1_1': tensor([9.]) + 'loss_1_0_0': tensor(9.), + 'loss_1_0_0_step': tensor(9.), + 'loss_3_0_0': tensor(9.), + 'loss_1_0_1': tensor(9.), + 'loss_1_0_1_step': tensor(9.), + 'loss_3_0_1': tensor(9.), + 'loss_1_1_0': tensor(9.), + 'loss_1_1_0_step': tensor(9.), + 'loss_3_1_0': tensor(9.), + 'loss_1_1_1': tensor(9.), + 'loss_1_1_1_step': tensor(9.), + 'loss_3_1_1': tensor(9.) } assert batch_metrics[DefaultMetricsKeys.CALLBACK] == excepted + result_collection.on_epoch_end_reached = True + epoch_metrics = result_collection.get_epoch_metrics() mean = (tensor(excepted_values) * tensor(excepted_batches)).sum() / sum(excepted_batches) @@ -122,12 +133,16 @@ def test_result_collection_on_tensor_with_mean_reduction(): excepted = { 'loss_1_0_0': mean, + 'loss_1_0_0_epoch': mean, 'loss_2_0_0': mean, 'loss_1_0_1': mean, + 'loss_1_0_1_epoch': mean, 'loss_2_0_1': mean, 'loss_1_1_0': mean, + 'loss_1_1_0_epoch': mean, 'loss_2_1_0': mean, 'loss_1_1_1': mean, - 'loss_2_1_1': mean + 'loss_1_1_1_epoch': mean, + 'loss_2_1_1': mean, } assert epoch_metrics[DefaultMetricsKeys.CALLBACK] == excepted diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a8567db70d0a6..ac66ebbec3587 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -341,7 +341,7 @@ def mock_save_function(filepath, *args): for i, loss in enumerate(losses): trainer.train_loop.current_epoch = i trainer.train_loop.global_step = i - trainer.logger_connector.callback_metrics = {"checkpoint_on": torch.tensor(loss)} + trainer.logger_connector._callback_metrics = {"checkpoint_on": torch.tensor(loss)} checkpoint_callback.on_validation_end(trainer, trainer.lightning_module) file_lists = set(os.listdir(tmpdir)) From 1461c8c7def8c14a304e278e3d7eb2a01ae507f9 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 26 May 2021 04:09:57 -0400 Subject: [PATCH 146/455] resolve tests --- pytorch_lightning/core/step_result.py | 9 ++++----- tests/core/test_metric_result_integration.py | 9 +++------ tests/metrics/test_metric_lightning.py | 7 +++++-- tests/models/test_hooks.py | 2 -- 4 files changed, 12 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index c0bec1f283886..793ce1137475a 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -22,6 +22,7 @@ from torchmetrics import Metric from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections +from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import _METRIC @@ -76,17 +77,15 @@ def is_tensor_and_min_reduction(self) -> bool: return self.is_tensor and (self.reduce_fx in (torch.min, min)) -class ResultMetric(Metric): - - DTYPE = torch.float32 +class ResultMetric(Metric, DeviceDtypeModuleMixin): def __init__(self, metadata: Metadata) -> None: super().__init__(compute_on_step=metadata.is_tensor) self.meta = metadata if self.meta.is_tensor: - self.add_state("value", torch.tensor(.0, dtype=self.DTYPE)) + self.add_state("value", torch.tensor(.0)) if self.meta.is_tensor_and_mean_reduction: - self.add_state("cumulated_batch_size", torch.tensor(.0, dtype=self.DTYPE)) + self.add_state("cumulated_batch_size", torch.tensor(.0)) def update(self, value: _METRIC, batch_size: Optional[int] = None) -> None: if self.meta.is_tensor_and_mean_reduction: diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 8162f9bd4f59f..d4e6db0e89c5d 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -220,14 +220,11 @@ def test_result_collection_restoration(): assert epoch_log == _epoch_log - if epoch == 0: - epoch_expected = {'a_epoch', 'b', 'b_1', 'a_1_epoch'} - else: - epoch_expected = {'a_epoch', 'b', 'b_1', 'a_1_epoch', 'a'} + epoch_expected = {'a_1_epoch', 'a_epoch', 'b', 'b_1'} - assert set(epoch_log.keys()) == epoch_expected + assert set(epoch_log.keys()) == epoch_expected, epoch_log.keys() for k in list(epoch_expected): - if k in {'a_epoch', 'b', 'a'}: + if k in {'a_epoch', 'b'}: assert epoch_log[k] == cumulative_sum else: assert epoch_log[k] == 1 diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index e52e39cb16488..0ae9b18f7a88c 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -78,6 +78,7 @@ def __init__(self): self.metric_step = SumMetric() self.metric_epoch = SumMetric() self.sum = 0.0 + self.total_sum = 0.0 def on_epoch_start(self): self.sum = 0.0 @@ -90,7 +91,9 @@ def training_step(self, batch, batch_idx): return {'loss': self.step(x), 'data': x} def training_epoch_end(self, outs): - self.log("sum_epoch", self.metric_epoch(torch.stack([o['data'] for o in outs]).sum())) + total = torch.stack([o['data'] for o in outs]).sum() + self.log("sum_epoch", self.metric_epoch(total)) + self.total_sum = total model = TestModel() model.val_dataloader = None @@ -107,7 +110,7 @@ def training_epoch_end(self, outs): logged = trainer.logged_metrics assert torch.allclose(torch.tensor(logged["sum_step"]), model.sum) - assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum) + assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.total_sum) def test_scriptable(tmpdir): diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index a0d5f06a7fa48..78f8d2c0a94e9 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -81,8 +81,6 @@ def training_epoch_end(self, outputs): assert trainer.state.finished, f"Training failed with {trainer.state}" metrics = trainer.progress_bar_dict - import pdb - pdb.set_trace() # metrics added in training step should be unchanged by epoch end method assert metrics['step_metric'] == -1 # a metric shared in both methods gets overwritten by epoch_end From 92e8818496753eb6bf2f002f73b6d777cb496d20 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 26 May 2021 05:18:51 -0400 Subject: [PATCH 147/455] update --- pytorch_lightning/core/step_result.py | 97 ++++++++++++++++++++++++--- 1 file changed, 86 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 793ce1137475a..d85091360cddd 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -14,8 +14,7 @@ from collections.abc import Generator, Mapping, Sequence from copy import deepcopy from dataclasses import dataclass -from typing import Any, Callable, Dict, Iterable, NamedTuple, Optional, Tuple, Union -from weakref import proxy +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union import torch from torch import Tensor @@ -78,6 +77,9 @@ def is_tensor_and_min_reduction(self) -> bool: class ResultMetric(Metric, DeviceDtypeModuleMixin): + """ + This class is responsible to hold each single metric provided by ``LightningModule.log`` function. + """ def __init__(self, metadata: Metadata) -> None: super().__init__(compute_on_step=metadata.is_tensor) @@ -132,6 +134,7 @@ def forward(self, *args, **kwargs): """ Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True. """ + # todo (tchaton) Remove this override when merged to TorchMetrics. # add current step with torch.no_grad(): self.update(*args, **kwargs) @@ -162,6 +165,9 @@ class ResultMeta(Dict): class ResultCollection(dict): + """ + This class is used to capture all the logged values using LightningModule.log function. + """ STEP_SUFFIX = "_step" EPOCH_SUFFIX = "_epoch" @@ -212,6 +218,17 @@ def on_epoch_end_reached(self, on_epoch_end_reached): @property def metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: + """ + This function returns either batch or epoch metrics depending on `on_epoch_end_reached` attribute. + The metrics are returned as: + + + { + DefaultMetricsKeys.PBAR: {...}, + DefaultMetricsKeys.LOG: {...}, + DefaultMetricsKeys.CALLBACK: {...} + } + """ return self.get_epoch_metrics() if self.on_epoch_end_reached else self.get_batch_metrics() @property @@ -220,6 +237,9 @@ def minimize(self) -> Optional[Tensor]: @minimize.setter def minimize(self, loss: Optional[torch.Tensor]) -> None: + """ + The `LightningModule.training_step` loss will be saved as the ResultCollection minimize attribute. + """ if loss is not None: if not isinstance(loss, Tensor): raise ValueError(f"`Result.minimize` must be a `torch.Tensor`, found: {loss}") @@ -233,6 +253,9 @@ def extra(self) -> Dict: @extra.setter def extra(self, extra: Dict) -> None: + """ + The `LightningModule.training_step` extras will be saved as the ResultCollection extra key. + """ def detach_fn(v): return v.detach() @@ -255,25 +278,33 @@ def log( batch_size: Optional[int] = None, lightning_attribute_name: Optional[str] = None, ): - """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" + """ + This function is used to log metrics from with + :meth:`~pytorch_lightning.core.lightning.LightningModule.log` + """ # no metrics should be logged with graphs - if not enable_graph and isinstance(value, torch.Tensor): value = value.detach() + # move metrics to cpu on TPU. if isinstance(value, torch.Tensor) and value.device.type == "xla": value = value.cpu() + # storage key key = f"{hook_name}.{name}" + # add dataloader_suffix to both key and hook_name if dataloader_idx is not None: + # use as ResultCollection key key += f'.{dataloader_idx}' + # used to decide when to reset hook_name += f'.{dataloader_idx}' if on_step and self.on_epoch_end_reached: raise MisconfigurationException("Logging `on_step` after `on_epoch_end_reached` isn't authorized.") if key not in self: + # create metadata object if storage key doesn't exist in self meta = Metadata( fx=hook_name, name=name, @@ -285,17 +316,24 @@ def log( dataloader_idx=dataloader_idx, lightning_attribute_name=lightning_attribute_name, ) + # create one ResultMetric object per value. + # value can be provided as a nested collection. self.instance_result_metric(key, meta, value) + # compute batch_size batch_size = torch.tensor(batch_size or self.batch_size, device=self.root_device) + # update the ResultMetric self.update_metrics(hook_name, key, value, batch_size) + # save current_hook to know when to reset. self._current_hook_name = hook_name def instance_result_metric(self, key: str, meta: Metadata, value: Union[Dict, torch.Tensor]) -> None: - def fn(v): + def fn(v: Union[torch.Tensor, Metric]) -> ResultMetric: + # This local function is used to `ResultMetric`. + # The `Metadata` is_tensor is modified on the fly assert self.root_device is not None nonlocal meta meta = deepcopy(meta) @@ -303,8 +341,11 @@ def fn(v): metric = ResultMetric(meta) return metric.to(self.root_device) + # store a mapping between storage key and collection of `ResultMetric` self[key] = apply_to_collection(value, (torch.Tensor, Metric), fn) - # cache the meta for reduction + + # when the value was a nested collection, store some metadata + # to facilate access for later metrics gathering if not isinstance(self[key], ResultMetric): self[key + '.forked'] = meta.forked self[key + '.logger'] = meta.logger @@ -313,6 +354,7 @@ def fn(v): self[key + '.dataloader_idx'] = meta.dataloader_idx def should_reset_tensors(self, hook_name: str) -> bool: + # reset tensor metrics only when hook_name changed and starting a new iteration over dataloader. return (self._current_hook_name != hook_name and self._batch_idx in (None, 0)) def update_metrics( @@ -324,6 +366,7 @@ def update_metrics( self._reset_metrics(hook_name, is_tensor=True) def fn(result_metric, v): + # this function is used to call forward function of ResultMetric object. assert isinstance(v, (torch.Tensor, Metric)) result_metric(v.to(self.root_device), batch_size.to(self.root_device)) result_metric.meta.has_reset = False @@ -332,25 +375,37 @@ def fn(result_metric, v): @staticmethod def _get_forward_cache(result_metric: ResultMetric) -> Optional[torch.Tensor]: + # skip if meta `on_step` is False if not result_metric.meta.on_step: return + # extract `ResultMetric` forward cache return result_metric._forward_cache.detach() @staticmethod - def _to_item(forward_cache: torch.Tensor) -> float: - return forward_cache.item() + def _to_item(t: torch.Tensor) -> float: + return t.item() def valid_metrics(self) -> Generator: + """ + This function is used to iterate over current valid metrics. + """ for key, item in self.items(): + # skip when item is None, bool or extra arguments from training_step. if item is None or isinstance(item, bool) or key == "extra": continue + + # skip when the metrics hasn't been updated. elif isinstance(item, ResultMetric) and item.meta.has_reset: - # skip the metric which have been reset. continue + yield (key, item) def _extract_metadata(self, key: str, result_metric, on_step: bool, suffix: str) -> Tuple: + """ + This function is used to extract the metadata for `ResultMetric` and `nested ResultMetrics`. + """ + if isinstance(result_metric, ResultMetric): name = result_metric.meta.name name_forked = result_metric.meta.forked_step_name if on_step else result_metric.meta.forked_epoch_name @@ -366,6 +421,7 @@ def _extract_metadata(self, key: str, result_metric, on_step: bool, suffix: str) metric_on_epoch = self[key + '.on_epoch'] dataloader_idx = self[key + '.dataloader_idx'] + # add dataloader_suffix is provided. if dataloader_idx is not None: dataloader_suffix = self.DATALOADER_SUFFIX.format(dataloader_idx) name += dataloader_suffix @@ -375,7 +431,11 @@ def _extract_metadata(self, key: str, result_metric, on_step: bool, suffix: str) def get_metrics(self, on_step: bool) -> Dict[str, Dict[str, torch.Tensor]]: metrics = {k: {} for k in DefaultMetricsKeys} + + # either extract `forward_cache` or `computed` from `ResultMetric` objects fn = self._get_forward_cache if on_step else self._get_computed_cache + + # select suffix suffix = self.STEP_SUFFIX if on_step else self.EPOCH_SUFFIX # iterate over all stored metrics. @@ -390,14 +450,15 @@ def get_metrics(self, on_step: bool) -> Dict[str, Dict[str, torch.Tensor]]: def is_empty_fn(v): nonlocal is_empty + # update is_empty if any value is not None. if v is not None: is_empty = False # apply detection. + # todo: (tchaton) need to find a way to support NamedTuple wrong_dtype = ( Mapping, Sequence, - NamedTuple, ) apply_to_collection(value, object, is_empty_fn, wrong_dtype=wrong_dtype) @@ -434,12 +495,15 @@ def get_batch_metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: @staticmethod def _get_computed_cache(result_metric: ResultMetric) -> Optional[torch.Tensor]: + # skip if meta.on_epoch is False if not result_metric.meta.on_epoch: return + # perform reduction is not done alrady if not result_metric._computed: result_metric.compute() + # extract computed from ResultMetric. return result_metric._computed.detach() def get_epoch_metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: @@ -475,6 +539,9 @@ def reset_metrics(self): self._current_hook_name = None def reset(self): + """ + This function is used to reset entirely the ResultCollection + """ self._reset_metrics() self.on_epoch_end_reached = False self._current_hook_name = None @@ -521,6 +588,8 @@ def get_state_dict(item: ResultMetric) -> Dict[str, Any]: state = deepcopy(state) if 'value' in state['_modules'] and isinstance(state['_modules']["value"], Metric): del state['_modules']["value"] + + # ResultMeta is used as a placeholder for making re-loading simpler return ResultMeta(**state) return {k: apply_to_collection(v, ResultMetric, get_state_dict) for k, v in self.items()} @@ -528,16 +597,22 @@ def get_state_dict(item: ResultMetric) -> Dict[str, Any]: def load_from_state_dict(self, state_dict: Dict[str, Any], metrics: Dict[str, Metric]): def to_result_metric(item: ResultMeta) -> Dict[str, Any]: + # create a new ResultMetric result_metric = ResultMetric(item["meta"]) + # update its state result_metric.__dict__.update(item) - return result_metric + # move result_metric to root_device + return result_metric.to(self.root_device) + # transform ResultMeta into ResultMetric state_dict = {k: apply_to_collection(v, ResultMeta, to_result_metric) for k, v in state_dict.items()} + # add the state_dict as new key-value into self for k, v in state_dict.items(): self[k] = v if metrics: + # the metric reference are lost during serialization and # they need to be set back during loading From 6977cf946be70810ebdea14b4ea5a64119e730a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 May 2021 11:57:57 +0200 Subject: [PATCH 148/455] add reset --- pytorch_lightning/loops/base.py | 4 ++++ pytorch_lightning/loops/training_loop.py | 1 + 2 files changed, 5 insertions(+) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index f06620ada8c79..493a24bc572bb 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -23,6 +23,7 @@ def done(self) -> bool: """Property indicating when loop is finished""" def run(self, *args: Any, **kwargs: Any) -> Any: + self.reset() self.on_run_start(*args, **kwargs) while not self.done: @@ -50,6 +51,9 @@ def on_advance_end(self) -> None: def on_run_end(self) -> Any: pass + def reset(self) -> None: + self.iteration_count = 0 + def increment_iteration(self, iteration: int) -> int: return iteration + 1 diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index cfbc5cfd314e8..c651c0092bd57 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -53,6 +53,7 @@ def done(self): return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch) def run(self, *args, **kwargs): + self.reset() self.on_run_start() while True: From f710c3074846426e27dcaacab5fae7bfd7835e33 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 26 May 2021 06:14:17 -0400 Subject: [PATCH 149/455] resolve more tests --- pytorch_lightning/core/lightning.py | 12 ++-- pytorch_lightning/core/step_result.py | 70 ++++++++++++++++++-- pytorch_lightning/trainer/trainer.py | 1 - pytorch_lightning/trainer/training_loop.py | 4 +- tests/core/test_metric_result_integration.py | 9 ++- 5 files changed, 79 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index d96f69a2e78fb..04b89667e4567 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -80,7 +80,6 @@ class LightningModule( "model_size", "automatic_optimization", "truncated_bptt_steps", - "_results", ] + DeviceDtypeModuleMixin.__jit_unused_properties__ def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -216,11 +215,6 @@ def logger(self): """ Reference to the logger object in the Trainer. """ return self.trainer.logger if self.trainer else None - @property - def _results(self) -> 'Optional[ResultCollection]': - if hasattr(self, "trainer"): - return self.trainer.result_collections - def _apply_batch_transfer_handler( self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: Optional[int] = None ) -> Any: @@ -330,7 +324,9 @@ def log( ' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`' ) - if self._results is not None: + result_collections: Optional[ResultCollection] = self.trainer.result_collections + + if result_collections is not None: # TODO: if logged twice fail with crash # set the default depending on the fx_name @@ -360,7 +356,7 @@ def log( int, ), sync_fn) - self._results.log( + result_collections.log( self._current_fx_name, name, value, diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index d85091360cddd..3215943cc59c5 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -167,6 +167,42 @@ class ResultMeta(Dict): class ResultCollection(dict): """ This class is used to capture all the logged values using LightningModule.log function. + + Here is how to use the ResultCollection object. + + Example: + + # the root_device need to be provided before calling the ``log`` function + result = ResultCollection(True, torch.device("cpu")) + + # arguments: hook_name, key, value, metadata + result.log('a0', 'a', torch.tensor(0.), on_step=True, on_epoch=True) + result.log('a1', 'a', torch.tensor(0.), on_step=True, on_epoch=True) + + for epoch in range(2): + + # reset on ``on_epoch_end_reached`` + result.on_epoch_end_reached = False + + result.log('b0', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True) + result.log('b1', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True) + + for batch_idx, batch_size in enumerate(range(2)): + + # the batch_idx is used to reset the tensor metrics + result.batch_idx = batch_idx + + result.log('c0', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) + result.log('c1', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) + + # used to indicate epoch end has been reached + result.on_epoch_end_reached = True + + result.log('d0', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) + result.log('d1', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) + + # result.reset_metrics() [Optional]: Reset only torchmetric.Metric object. + # result.reset() [Optional]: Reset the entire ResultCollection. """ STEP_SUFFIX = "_step" @@ -281,6 +317,25 @@ def log( """ This function is used to log metrics from with :meth:`~pytorch_lightning.core.lightning.LightningModule.log` + + Args: + + hook_name: Current hook name + name: Key provided by the user on logging + value: Either a number, tensor or a collection of the previous. + prog_bar: Whether to add this value to the progress bar. + logger: Whether to log this value to the loggers + on_step: Whether to use this value during batch iteration. + on_epoch: Whether to use this value at the end of the batch iteration. + Automatic reduction will be performed. + reduce_fx: Which function to use for reduction. Currently support min, max and mean. + enable_graph: Whether to keep autograd graph when storing the value. + dataloader_idx: The current dataloader idx. This will be used to automatically + add `/dataloader_idx_{}` on the metrics. + batch_size: Current batch size. + lightning_attribute_name: When providing `nn.Metric` as a value, the ``lightning_attribute_name`` + need to be provided to enable automatic saving / re-loading. + """ # no metrics should be logged with graphs if not enable_graph and isinstance(value, torch.Tensor): @@ -290,6 +345,11 @@ def log( if isinstance(value, torch.Tensor) and value.device.type == "xla": value = value.cpu() + if isinstance(value, Metric) and lightning_attribute_name is None: + raise MisconfigurationException( + "The LightningModule attribute name should be provided when using torchmetrics.Metric" + ) + # storage key key = f"{hook_name}.{name}" @@ -365,8 +425,8 @@ def update_metrics( # when restarting an new epoch, reset the tensor hooks dynamically. self._reset_metrics(hook_name, is_tensor=True) + # this function is used to call the forward function of ResultMetric object. def fn(result_metric, v): - # this function is used to call forward function of ResultMetric object. assert isinstance(v, (torch.Tensor, Metric)) result_metric(v.to(self.root_device), batch_size.to(self.root_device)) result_metric.meta.has_reset = False @@ -391,11 +451,11 @@ def valid_metrics(self) -> Generator: This function is used to iterate over current valid metrics. """ for key, item in self.items(): - # skip when item is None, bool or extra arguments from training_step. + # skip when item is None, bool or extra arguments from training_step. if item is None or isinstance(item, bool) or key == "extra": continue - # skip when the metrics hasn't been updated. + # skip when the metrics hasn't been updated. elif isinstance(item, ResultMetric) and item.meta.has_reset: continue @@ -589,7 +649,7 @@ def get_state_dict(item: ResultMetric) -> Dict[str, Any]: if 'value' in state['_modules'] and isinstance(state['_modules']["value"], Metric): del state['_modules']["value"] - # ResultMeta is used as a placeholder for making re-loading simpler + # ResultMeta is used as a placeholder for making re-loading simpler return ResultMeta(**state) return {k: apply_to_collection(v, ResultMetric, get_state_dict) for k, v in self.items()} @@ -613,7 +673,7 @@ def to_result_metric(item: ResultMeta) -> Dict[str, Any]: if metrics: - # the metric reference are lost during serialization and + # the metric reference are lost during serialization and # they need to be set back during loading def re_assign_metric(item): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 556f4b4942497..e1871283d09db 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -22,7 +22,6 @@ import torch from torch.utils.data import DataLoader -from torchmetrics.metric import Metric from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.callbacks import Callback diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 9145a7c8f8ed0..0efcc210cc1c6 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -14,7 +14,7 @@ from collections import OrderedDict from contextlib import contextmanager, suppress -from copy import copy, deepcopy +from copy import deepcopy from functools import partial, update_wrapper from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -321,7 +321,7 @@ def _process_training_step_output(self, training_step_output, split_batch): if training_step_output_for_epoch_end is None: return None, None - result = self.trainer.lightning_module._results + result = self.trainer.result_collections loss = None hiddens = None diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index d4e6db0e89c5d..a8dbe62734f55 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -59,12 +59,17 @@ def _ddp_test_fn(rank, worldsize): metric_c = metric_c.to(f"cuda:{rank}") # dist_sync_on_step is False by default - result = ResultCollection(True) + result = ResultCollection(True, torch.device(f"cuda:{rank}")) for _ in range(3): cumulative_sum = 0 + result.on_epoch_end_reached = False + for i in range(5): + + result.batch_idx = i + metric_a(i) metric_b(i) metric_c(i) @@ -81,6 +86,8 @@ def _ddp_test_fn(rank, worldsize): for k in batch_expected.keys(): assert batch_expected[k] == batch_log[k] + result.on_epoch_end_reached = True + epoch_log = result.get_epoch_metrics()[DefaultMetricsKeys.LOG] result.reset() From 0bca3a38adb5b84a8e26f774e9d017d27c5eb91c Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 26 May 2021 06:16:41 -0400 Subject: [PATCH 150/455] flake8 --- pytorch_lightning/trainer/training_loop.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 2e1b04360c6ad..dd29fe770a8b8 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -14,7 +14,6 @@ from collections import OrderedDict from contextlib import contextmanager, suppress -from copy import deepcopy from functools import partial, update_wrapper from typing import Any, Callable, Dict, List, Optional, Tuple, Union From aec8c53bb9c3eb93394bc64b88ed89bf8a0c0919 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 26 May 2021 06:20:58 -0400 Subject: [PATCH 151/455] resolve flake8 --- tests/core/test_metric_result_integration.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index a8dbe62734f55..ac7cb9f56fe02 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -13,7 +13,6 @@ # limitations under the License. from copy import deepcopy -import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp From ebaa9fdc3d649bef67eb105a12f2c90e7ca91d0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 May 2021 12:48:25 +0200 Subject: [PATCH 152/455] replace batch_idx with iteration_count --- pytorch_lightning/loops/epoch_loop.py | 4 ++-- pytorch_lightning/loops/training_loop.py | 18 ++++++++++-------- pytorch_lightning/trainer/trainer.py | 1 + 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 152618b0e7d6d..787716a4d4f47 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -56,7 +56,7 @@ def total_batch_idx(self): @property def batch_idx(self): - return self.training_loop.batch_idx + return self.training_loop.iteration_count @property def split_idx(self): @@ -162,7 +162,7 @@ def on_advance_end(self): # # handle epoch_output on epoch end # self.on_train_epoch_end(outputs) # Handled in on_run_end of training_loop now - if self.training_loop.batch_idx is None: + if self.training_loop.batches_seen == 0: return should_check_val = self.training_loop.should_check_val_fx( diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index c651c0092bd57..57cf32f49b2a8 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -25,7 +25,7 @@ def __init__(self, min_steps, max_steps): # the total batch index across all epochs self.total_batch_idx = 0 # the current batch index in the loop that runs over the dataloader(s) - self.batch_idx = 0 + self.iteration_count = 0 # the current split index when the batch gets split into chunks in truncated backprop through time self.split_idx = None @@ -39,6 +39,10 @@ def __init__(self, min_steps, max_steps): self.batch_loop = None + @property + def batch_idx(self) -> int: + return self.iteration_count + def connect(self, trainer: 'pl.Trainer', *args, **kwargs): self.trainer = trainer self.batch_loop = BatchLoop() @@ -76,7 +80,6 @@ def on_run_start(self): self._train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) self._dataloader_idx = 0 self._should_stop = False - self.batch_idx = None self.batches_seen = 0 self.is_last_batch = False @@ -85,15 +88,14 @@ def on_run_start(self): def advance(self): # TODO: profiling is gone - batch_idx, (batch, is_last) = next(self._train_dataloader) - self.batch_idx = batch_idx + _, (batch, is_last) = next(self._train_dataloader) self.is_last_batch = is_last # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ with self.trainer.profiler.profile("run_training_batch"): - batch_output = self.batch_loop.run(batch, batch_idx, self._dataloader_idx) + batch_output = self.batch_loop.run(batch, self.iteration_count, self._dataloader_idx) self.batches_seen += 1 # when returning -1 from train_step, we end epoch early @@ -105,7 +107,7 @@ def advance(self): self.epoch_output, batch_output.training_step_output_for_epoch_end, batch, - batch_idx, + self.iteration_count, self._dataloader_idx, ) @@ -118,7 +120,7 @@ def on_advance_end(self): # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- - should_check_val = self.should_check_val_fx(self.batch_idx, self.is_last_batch) + should_check_val = self.should_check_val_fx(self.iteration_count, self.is_last_batch) if should_check_val: self.trainer.validating = True self.trainer._run_evaluation() @@ -143,7 +145,7 @@ def on_advance_end(self): # this is the old on train_epoch_end? def on_run_end(self): - if self.batch_idx is None: + if self.batches_seen == 0: # dataloader/iterator did not produce a batch return diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9ca1b7626e4a7..f5b1b79ac178e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -937,6 +937,7 @@ def _run_train_new_loop(self) -> None: self.reset_train_val_dataloaders(model) try: + # TODO: move skip condition into EpochLoop.done() if self._should_skip_training(): return self.train_loop.run() From 2536afa8451d67e4a72a15eb1bce3e093fc3b595 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 26 May 2021 07:04:09 -0400 Subject: [PATCH 153/455] update --- pytorch_lightning/core/lightning.py | 13 ++++++++++++- pytorch_lightning/core/step_result.py | 6 ++---- tests/core/test_metric_result_integration.py | 14 +++++++------- tests/metrics/test_metric_lightning.py | 3 ++- 4 files changed, 23 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 04b89667e4567..11f6cb1211865 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -32,6 +32,7 @@ from torch import ScriptModule, Tensor from torch.nn import Module from torch.optim.optimizer import Optimizer +from torchmetrics.metric import Metric from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks @@ -272,6 +273,7 @@ def log( sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, batch_size: Optional[int] = None, + lightning_attribute_name: Optional[str] = None, ) -> None: """ Log a key, value @@ -310,6 +312,7 @@ def log( each dataloader to not mix values batch_size: Current batch_size. This will be directly inferred from the loaded batch, but some esoteric data type such as graph might need to explicitly provide the batch_size. + lightning_attribute_name: The name of the Metric attribute name. This is used for fault tolerant logging. """ if tbptt_reduce_fx is not None: rank_zero_deprecation( @@ -342,6 +345,13 @@ def log( f"Logged key: {name} should not contain information about dataloader_idx." ) + if lightning_attribute_name is None and isinstance(value, Metric): + # todo (tchaton): find a more optimized way to find associated metrics. + for module_name, module in self.named_children(): + if isinstance(module, Metric) and hash(module) == hash(value): + lightning_attribute_name = module_name + break + sync_fn = partial( self.__sync, sync_fn=self.trainer.training_type_plugin.reduce, @@ -367,7 +377,8 @@ def log( reduce_fx=reduce_fx, enable_graph=enable_graph, dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), - batch_size=batch_size + batch_size=batch_size, + lightning_attribute_name=lightning_attribute_name, ) def log_dict( diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 3215943cc59c5..b8333003118c7 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -181,9 +181,6 @@ class ResultCollection(dict): for epoch in range(2): - # reset on ``on_epoch_end_reached`` - result.on_epoch_end_reached = False - result.log('b0', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True) result.log('b1', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True) @@ -201,7 +198,8 @@ class ResultCollection(dict): result.log('d0', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) result.log('d1', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) - # result.reset_metrics() [Optional]: Reset only torchmetric.Metric object. + # used to reset torchmetrics.Metric and set `on_epoch_end_reached` to False + result.reset_metrics() [Optional]: Reset only torchmetric.Metric object. # result.reset() [Optional]: Reset the entire ResultCollection. """ diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index ac7cb9f56fe02..3ca9b3b62b9a1 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -75,9 +75,9 @@ def _ddp_test_fn(rank, worldsize): cumulative_sum += i - result.log('h', 'a', metric_a, on_step=True, on_epoch=True) - result.log('h', 'b', metric_b, on_step=False, on_epoch=True) - result.log('h', 'c', metric_c, on_step=True, on_epoch=False) + result.log('h', 'a', metric_a, on_step=True, on_epoch=True, lightning_attribute_name="metric_a") + result.log('h', 'b', metric_b, on_step=False, on_epoch=True, lightning_attribute_name="metric_b") + result.log('h', 'c', metric_c, on_step=True, on_epoch=False, lightning_attribute_name="metric_c") batch_log = result.get_batch_metrics()[DefaultMetricsKeys.LOG] batch_expected = {"a_step": i, "c": i} @@ -134,9 +134,9 @@ def test_result_metric_integration(): cumulative_sum += i - result.log('h', 'a', metric_a, on_step=True, on_epoch=True) - result.log('h', 'b', metric_b, on_step=False, on_epoch=True) - result.log('h', 'c', metric_c, on_step=True, on_epoch=False) + result.log('h', 'a', metric_a, on_step=True, on_epoch=True, lightning_attribute_name="metric_a") + result.log('h', 'b', metric_b, on_step=False, on_epoch=True, lightning_attribute_name="metric_b") + result.log('h', 'c', metric_c, on_step=True, on_epoch=False, lightning_attribute_name="metric_c") batch_log = result.get_batch_metrics()[DefaultMetricsKeys.LOG] batch_expected = {"a_step": i, "c": i} @@ -170,7 +170,7 @@ def test_result_collection_restoration(): result = ResultCollection(True, torch.device("cpu")) - for epoch in range(2): + for _ in range(2): result.on_epoch_end_reached = False cumulative_sum = 0 diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 0ae9b18f7a88c..8ce7f1050c6fe 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -92,7 +92,8 @@ def training_step(self, batch, batch_idx): def training_epoch_end(self, outs): total = torch.stack([o['data'] for o in outs]).sum() - self.log("sum_epoch", self.metric_epoch(total)) + self.metric_epoch(total) + self.log("sum_epoch", self.metric_epoch) self.total_sum = total model = TestModel() From 02b43f34b84cdc93ca74633b5878794867623d70 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 26 May 2021 07:08:50 -0400 Subject: [PATCH 154/455] udpate --- pytorch_lightning/core/lightning.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 11f6cb1211865..a51f95ef6fcac 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -346,11 +346,13 @@ def log( ) if lightning_attribute_name is None and isinstance(value, Metric): - # todo (tchaton): find a more optimized way to find associated metrics. - for module_name, module in self.named_children(): - if isinstance(module, Metric) and hash(module) == hash(value): - lightning_attribute_name = module_name - break + # used to find this Metric associated LightningModule attribute name. + if not hasattr(self, "_map_metric_id_name"): + self._map_metric_id_name = { + id(module): module_name + for module_name, module in self.named_children() if isinstance(module, Metric) + } + lightning_attribute_name = self._map_metric_id_name[id(value)] sync_fn = partial( self.__sync, From 65fa28d254af5854b073999c5dddf229130a8096 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 26 May 2021 07:10:15 -0400 Subject: [PATCH 155/455] update --- pytorch_lightning/core/lightning.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a51f95ef6fcac..98f1157fb3c3d 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -113,6 +113,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._automatic_optimization: bool = True self._truncated_bptt_steps: int = 0 self._param_requires_grad_state = dict() + self._map_metric_id_name: Optional[Dict[int, str]] = None def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: if use_pl_optimizer: @@ -347,7 +348,7 @@ def log( if lightning_attribute_name is None and isinstance(value, Metric): # used to find this Metric associated LightningModule attribute name. - if not hasattr(self, "_map_metric_id_name"): + if self._map_metric_id_name is None: self._map_metric_id_name = { id(module): module_name for module_name, module in self.named_children() if isinstance(module, Metric) From 31708c828125359f3de5fa72f6236e2c3b70590d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 May 2021 13:12:35 +0200 Subject: [PATCH 156/455] implement reset method --- pytorch_lightning/loops/base.py | 7 ++++--- pytorch_lightning/loops/batch_loop.py | 10 +++++++--- pytorch_lightning/loops/epoch_loop.py | 4 ++++ pytorch_lightning/loops/training_loop.py | 5 ++++- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 493a24bc572bb..709438ba8c0ac 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -22,6 +22,10 @@ def connect(self, trainer, *args, **kwargs): def done(self) -> bool: """Property indicating when loop is finished""" + @abstractmethod + def reset(self) -> None: + pass + def run(self, *args: Any, **kwargs: Any) -> Any: self.reset() self.on_run_start(*args, **kwargs) @@ -51,9 +55,6 @@ def on_advance_end(self) -> None: def on_run_end(self) -> Any: pass - def reset(self) -> None: - self.iteration_count = 0 - def increment_iteration(self, iteration: int) -> int: return iteration + 1 diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index f60267f7ea379..723e3f01df999 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -21,6 +21,7 @@ from pytorch_lightning.utilities.warnings import WarningCache +# TODO: typing class BatchLoop(Loop): """ Runs over a single batch of data. """ @@ -76,14 +77,17 @@ def run(self, batch, batch_idx, dataloader_idx): ) return output - def on_run_start(self, batch, batch_idx, dataloader_idx): - self._hiddens = None - self._remaining_splits = list(enumerate(self.tbptt_split_batch(batch))) + def reset(self) -> None: + self.iteration_count = 0 + self._hiddens = None # TODO: let loops track individual outputs self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] self.grad_norm_dicts = [] + def on_run_start(self, batch, batch_idx, dataloader_idx): + self._remaining_splits = list(enumerate(self.tbptt_split_batch(batch))) + def advance(self, batch, batch_idx, dataloader_idx): split_idx, split_batch = self._remaining_splits.pop(0) self.split_idx = split_idx diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 787716a4d4f47..fb5756c0ad7df 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -21,6 +21,7 @@ log = logging.getLogger(__name__) +# TODO: typing class EpochLoop(Loop): def __init__(self, min_epochs, max_epochs, min_steps, max_steps): @@ -118,6 +119,9 @@ def done(self) -> bool: stop_epochs = self.current_epoch >= self.max_epochs if self.max_epochs is not None else False return stop_steps or should_stop or stop_epochs + def reset(self) -> None: + self.iteration_count = 0 + def on_run_start(self): # hook self.trainer.call_hook("on_train_start") diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 57cf32f49b2a8..bc229ec8df833 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -10,6 +10,7 @@ from pytorch_lightning.utilities.warnings import WarningCache +# TODO: typing class TrainingLoop(Loop): """ Runs over all batches in a dataloader (one epoch). """ @@ -72,7 +73,9 @@ def run(self, *args, **kwargs): return self.on_run_end() - def on_run_start(self): + def reset(self) -> None: + self.iteration_count = 0 + # modify dataloader if needed (ddp, etc...) train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) From 435bbb3de3073043a1dc56a7cfc19bbb2206793f Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 26 May 2021 07:19:52 -0400 Subject: [PATCH 157/455] resolve failing test --- tests/trainer/logging_/test_logger_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 4c2185d38295e..cc0abaae48a4b 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -356,8 +356,8 @@ def _step(self, stage, batch): acc.reset.reset_mock() ap.reset.reset_mock() - self.log(f"{stage}/accuracy", acc) - self.log(f"{stage}/ap", ap) + self.log(f"{stage}/accuracy", acc, lightning_attribute_name=f"acc_{stage}") + self.log(f"{stage}/ap", ap, lightning_attribute_name=f"ap_{stage}") return loss From 7b01b731004cb465cf6945ae819c984686b923dc Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 26 May 2021 12:52:23 +0100 Subject: [PATCH 158/455] resolve test --- .../trainer/connectors/test_logger_connectors.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/trainer/connectors/test_logger_connectors.py b/tests/trainer/connectors/test_logger_connectors.py index d1e41c76e420f..2c4627a92f45f 100644 --- a/tests/trainer/connectors/test_logger_connectors.py +++ b/tests/trainer/connectors/test_logger_connectors.py @@ -32,6 +32,8 @@ def test_result_collection_on_tensor_with_mean_reduction(): for logger in [False, True]: + i = float(i) + result_collection.log( "training_step", f"loss_1_{int(prob_bar)}_{int(logger)}", @@ -74,11 +76,15 @@ def test_result_collection_on_tensor_with_mean_reduction(): ) excepted_values = [ - tensor(1), tensor(2), - tensor(3), tensor(4), - tensor(5), tensor(6), - tensor(7), tensor(8), - tensor(9) + tensor(1.), + tensor(2.), + tensor(3.), + tensor(4.), + tensor(5.), + tensor(6.), + tensor(7.), + tensor(8.), + tensor(9.) ] excepted_batches = [1, 4, 9, 16, 25, 36, 49, 64, 81] total_value = tensor(excepted_values) * tensor(excepted_batches) From 7520a87b2d7b5097b4f47f020d00c4dda5503acd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 May 2021 15:37:25 +0200 Subject: [PATCH 159/455] wip evaluation loop --- pytorch_lightning/loops/base.py | 5 +- pytorch_lightning/loops/evaluation_loop.py | 258 +++++++++++++++++++ pytorch_lightning/trainer/evaluation_loop.py | 13 +- pytorch_lightning/trainer/trainer.py | 74 ++++-- 4 files changed, 314 insertions(+), 36 deletions(-) create mode 100644 pytorch_lightning/loops/evaluation_loop.py diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 709438ba8c0ac..cb141aa649aa4 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -1,7 +1,7 @@ +from weakref import proxy from abc import ABC, abstractmethod from typing import Any, Optional -from _weakref import proxy import pytorch_lightning as pl @@ -57,6 +57,3 @@ def on_run_end(self) -> Any: def increment_iteration(self, iteration: int) -> int: return iteration + 1 - - def state_dict(self) -> dict: - return dict() diff --git a/pytorch_lightning/loops/evaluation_loop.py b/pytorch_lightning/loops/evaluation_loop.py new file mode 100644 index 0000000000000..b5586610da107 --- /dev/null +++ b/pytorch_lightning/loops/evaluation_loop.py @@ -0,0 +1,258 @@ +from pytorch_lightning.loops.base import Loop +from typing import Any + +class EvaluationLoop(Loop): + + def reset(self): + # TODO + pass + + def done(self): + # TODO + pass + + def advance(self, *args: Any, **kwargs: Any) -> None: + # TODO + pass + + + +# HELPERS + + def on_trainer_init(self) -> None: + self.trainer.num_sanity_val_batches = [] + self.trainer.num_test_batches = [] + self.trainer.num_val_batches = [] + self.trainer.test_dataloaders = None + self.trainer.val_dataloaders = None + + # .validate() and .test() set this when they load a checkpoint + self.trainer.validated_ckpt_path = None + self.trainer.tested_ckpt_path = None + + # when true, print evaluation results in .validate() and .test() + self.trainer.verbose_evaluate = True + + + def get_evaluation_dataloaders(self) -> Tuple[Optional[List[DataLoader]], List[Union[int, float]]]: + model = self.trainer.lightning_module + + # select dataloaders + if self.trainer.testing: + self.trainer.reset_test_dataloader(model) + + dataloaders = self.trainer.test_dataloaders + max_batches = self.trainer.num_test_batches + else: + # val + if self.trainer.val_dataloaders is None or self.trainer.reload_dataloaders_every_epoch: + self.trainer.reset_val_dataloader(model) + if self.trainer.sanity_checking: + self.trainer.num_sanity_val_batches = [ + min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches + ] + max_batches = self.trainer.num_sanity_val_batches + else: + max_batches = self.trainer.num_val_batches + dataloaders = self.trainer.val_dataloaders + return dataloaders, max_batches + + + def should_skip_evaluation(self, max_batches: List[Union[int, float]]) -> bool: + return sum(max_batches) == 0 + + + def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: + self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end() + if self.trainer.testing: + self.trainer.call_hook('on_test_start', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_start', *args, **kwargs) + + + def on_evaluation_model_eval(self) -> None: + model_ref = self.trainer.lightning_module + if self.trainer.testing: + model_ref.on_test_model_eval() + else: + model_ref.on_validation_model_eval() + + + def on_evaluation_model_train(self) -> None: + model_ref = self.trainer.lightning_module + if self.trainer.testing: + model_ref.on_test_model_train() + else: + model_ref.on_validation_model_train() + + + def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: + if self.trainer.testing: + self.trainer.call_hook('on_test_end', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_end', *args, **kwargs) + + if self.trainer.state.fn != TrainerFn.FITTING: + # summarize profile results + self.trainer.profiler.describe() + + + def reload_evaluation_dataloaders(self) -> None: + model = self.trainer.lightning_module + if self.trainer.testing: + self.trainer.reset_test_dataloader(model) + else: + self.trainer.reset_val_dataloader(model) + + + def setup(self, max_batches: List[Union[int, float]], dataloaders: List[DataLoader]) -> None: + # bookkeeping + self.outputs = [] + self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) + + # convert max_batches to list + if isinstance(max_batches, int): + max_batches = [max_batches] * len(dataloaders) + + self.max_batches = max_batches + self.num_dataloaders = self._get_num_dataloaders(dataloaders) + + + def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: + self.trainer.call_hook('on_epoch_start', *args, **kwargs) + + if self.trainer.testing: + self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) + + + def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Union[Any, int]]: + # make dataloader_idx arg in validation_step optional + step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)]) + + multiple_val_loaders = ( + not self.trainer.testing and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1 + ) + multiple_test_loaders = (self.trainer.testing and self._get_num_dataloaders(self.trainer.test_dataloaders) > 1) + + if multiple_test_loaders or multiple_val_loaders: + step_kwargs['dataloader_idx'] = dataloader_idx + + return step_kwargs + + + def _get_num_dataloaders(self, dataloaders: Optional[List[DataLoader]]) -> int: + # case where user does: + # return dl1, dl2 + if dataloaders is not None: + length = len(dataloaders) + if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): + length = len(dataloaders[0]) + return length + else: + return 0 + + + def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]: + # configure step_kwargs + step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) + + model_ref = self.trainer.lightning_module + model_ref._results = Result() + + if self.trainer.testing: + model_ref._current_fx_name = "test_step" + with self.trainer.profiler.profile("test_step"): + output = self.trainer.accelerator.test_step(step_kwargs) + else: + model_ref._current_fx_name = "validation_step" + with self.trainer.profiler.profile("validation_step"): + output = self.trainer.accelerator.validation_step(step_kwargs) + + # capture any logged information + self.trainer.logger_connector.cache_logged_metrics() + # track batch size for weighted average + if isinstance(output, Result): + output.track_batch_size(batch) + + return output + + + def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: + if self.trainer.testing: + output = self.trainer.call_hook('test_step_end', *args, **kwargs) + else: + output = self.trainer.call_hook('validation_step_end', *args, **kwargs) + return output + + + def _should_track_batch_outputs_for_epoch_end(self) -> bool: + model = self.trainer.lightning_module + if self.trainer.testing: + return is_overridden('test_epoch_end', model=model) + else: + return is_overridden('validation_epoch_end', model=model) + + + def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: + # unset dataloder_idx in model + self.trainer.logger_connector.evaluation_epoch_end() + + # call the model epoch end + model = self.trainer.lightning_module + + if self.trainer.testing: + if is_overridden('test_epoch_end', model=model): + model._current_fx_name = 'test_epoch_end' + model.test_epoch_end(outputs) + + else: + if is_overridden('validation_epoch_end', model=model): + model._current_fx_name = 'validation_epoch_end' + model.validation_epoch_end(outputs) + + # capture logging + self.trainer.logger_connector.cache_logged_metrics() + + + def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + # set dataloader_idx to model and track batch_size + self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self.num_dataloaders) + + if self.trainer.testing: + self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) + else: + self.trainer.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx) + + + def on_evaluation_batch_end( + self, + output: Optional[STEP_OUTPUT], + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + if self.trainer.testing: + self.trainer.call_hook('on_test_batch_end', output, batch, batch_idx, dataloader_idx) + else: + self.trainer.call_hook('on_validation_batch_end', output, batch, batch_idx, dataloader_idx) + + # store predicitons if do_write_predictions and track eval loss history + self.store_predictions(output, batch_idx, dataloader_idx) + + + def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, dataloader_idx: int) -> None: + # Add step predictions to prediction collection to write later + if output is not None and self.predictions is not None: + if isinstance(output, Result) and self.trainer.testing: + self.predictions.add(output.pop('predictions', None)) + + # track debug metrics + self.trainer.dev_debugger.track_eval_loss_history(batch_idx, dataloader_idx, output) + + + def on_evaluation_epoch_end(self) -> None: + hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" + self.trainer.call_hook(hook_name) + self.trainer.call_hook('on_epoch_end') diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index f048297892533..726f1f66c8812 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -36,18 +36,7 @@ def __init__(self, trainer: 'pl.Trainer'): self.num_dataloaders: Optional[int] = None def on_trainer_init(self) -> None: - self.trainer.num_sanity_val_batches = [] - self.trainer.num_test_batches = [] - self.trainer.num_val_batches = [] - self.trainer.test_dataloaders = None - self.trainer.val_dataloaders = None - - # .validate() and .test() set this when they load a checkpoint - self.trainer.validated_ckpt_path = None - self.trainer.tested_ckpt_path = None - - # when true, print evaluation results in .validate() and .test() - self.trainer.verbose_evaluate = True + pass def get_evaluation_dataloaders(self) -> Tuple[Optional[List[DataLoader]], List[Union[int, float]]]: model = self.trainer.lightning_module diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f5b1b79ac178e..7816943a9c6dd 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -31,7 +31,6 @@ from pytorch_lightning.core.step_result import Result from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loops.epoch_loop import EpochLoop -from pytorch_lightning.loops.training_loop import TrainingLoop from pytorch_lightning.plugins import Plugin from pytorch_lightning.plugins.environments import ClusterEnvironment from pytorch_lightning.profiler import ( @@ -56,7 +55,6 @@ from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin from pytorch_lightning.trainer.deprecated_api import DeprecatedTrainerAttributes -from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin @@ -84,6 +82,11 @@ NEW_LOOP = True +if NEW_LOOP: + from pytorch_lightning.loops.evaluation_loop import EvaluationLoop +else: + from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop + class Trainer( TrainerProperties, @@ -341,12 +344,14 @@ def __init__( if NEW_LOOP: self.train_loop = EpochLoop(min_epochs, max_epochs, min_steps, max_steps) + self.evaluation_loop = EvaluationLoop() self.train_loop.connect(self) + self.evaluation_loop.connect(self) else: # old loops: self.train_loop = TrainLoop(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps) + self.evaluation_loop = EvaluationLoop(self) - self.evaluation_loop = EvaluationLoop(self) self.predict_loop = PredictLoop(self) # training state @@ -391,8 +396,7 @@ def __init__( truncated_bptt_steps, terminate_on_nan, ) - self._setup_on_init(num_sanity_val_steps, ) - self.evaluation_loop.on_trainer_init() + self._setup_on_init(num_sanity_val_steps) self.predict_loop.on_trainer_init() # configure tuner @@ -437,6 +441,19 @@ def _setup_on_init( else: self.num_sanity_val_steps = num_sanity_val_steps + self.num_sanity_val_batches = [] + self.num_test_batches = [] + self.num_val_batches = [] + self.test_dataloaders = None + self.val_dataloaders = None + + # .validate() and .test() set this when they load a checkpoint + self.validated_ckpt_path = None + self.tested_ckpt_path = None + + # when true, print evaluation results in .validate() and .test() + self.verbose_evaluate = True + def _setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): # clean hparams if hasattr(model, "hparams"): @@ -1034,14 +1051,7 @@ def _run_train_old_loop(self) -> None: self.state.stage = None raise - def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: - if not (self.evaluating or self.sanity_checking): - rank_zero_warn( - f"`trainer._run_evaluation()` was called but the running stage is set to {self.state.stage}." - " This should not happen normally. Setting it to `RunningStage.VALIDATING`", RuntimeWarning - ) - self.validating = True - + def _run_evaluatin_old_loop(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: # prepare dataloaders dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders() @@ -1049,13 +1059,6 @@ def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: if self.evaluation_loop.should_skip_evaluation(max_batches): return [], [] - # enable eval mode + no grads - self.evaluation_loop.on_evaluation_model_eval() - # ref model - model = self.lightning_module - model.zero_grad() - torch.set_grad_enabled(False) - # hook self.evaluation_loop.on_evaluation_start() @@ -1133,6 +1136,37 @@ def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: # hook self.evaluation_loop.on_evaluation_end() + return eval_loop_results + + def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: + if not (self.evaluating or self.sanity_checking): + rank_zero_warn( + f"`trainer._run_evaluation()` was called but the running stage is set to {self.state.stage}." + " This should not happen normally. Setting it to `RunningStage.VALIDATING`", RuntimeWarning + ) + self.validating = True + + # TODO: move this check inside new loop + # prepare dataloaders + dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders() + + # TODO: move this check inside new loop + # check if we want to skip this evaluation + if self.evaluation_loop.should_skip_evaluation(max_batches): + return [], [] + + # enable eval mode + no grads + self.evaluation_loop.on_evaluation_model_eval() + # ref model + model = self.lightning_module + model.zero_grad() + torch.set_grad_enabled(False) + + if NEW_LOOP: + eval_loop_results = self.evaluation_loop.run() + else: + eval_loop_results = self._run_evaluatin_old_loop(on_epoch) + # save predictions to disk self.evaluation_loop.predictions.to_disk() From 8378f1c24658ed05eddaad08c97e878be516e519 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 May 2021 16:11:25 +0200 Subject: [PATCH 160/455] integrate #7357 --- pytorch_lightning/loops/epoch_loop.py | 33 ++++++++++-------------- pytorch_lightning/loops/training_loop.py | 28 +++++++------------- pytorch_lightning/trainer/trainer.py | 6 ++--- 3 files changed, 26 insertions(+), 41 deletions(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index fb5756c0ad7df..28efd9f0eb516 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -160,7 +160,14 @@ def advance(self): if epoch_output is None: return + # the global step is manually decreased here due to backwards compatibility with existing loggers + # as they expect that the same step is used when logging epoch end metrics even when the batch loop has + # finished. this means the attribute does not exactly track the number of optimizer steps applied. + # TODO(@carmocca): deprecate and rename so users don't get confused + self.global_step -= 1 + # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output) + self.global_step += 1 def on_advance_end(self): # # handle epoch_output on epoch end @@ -169,27 +176,15 @@ def on_advance_end(self): if self.training_loop.batches_seen == 0: return - should_check_val = self.training_loop.should_check_val_fx( - self.batch_idx, self.training_loop.is_last_batch, on_epoch=True - ) - should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) - should_train_only = self.trainer.disable_validation or should_skip_eval - - # update epoch level lr_schedulers if no val loop outside train loop is triggered - if not should_check_val or should_train_only: - self.training_loop.update_lr_schedulers("epoch") + self.training_loop.update_lr_schedulers('epoch') - if should_train_only: + did_train_only = self.trainer.disable_validation or self.trainer.evaluation_loop.should_skip_evaluation( + self.trainer.num_val_batches + ) + if did_train_only: + self.global_step -= 1 self.check_checkpoint_callback(True) - - if should_check_val: - self.trainer.validating = True - self.trainer._run_evaluation(on_epoch=True) - self.trainer.training = True - - # TODO: move inside training_loop.on_run_end? equivalent? order? - # Needs to check batch_output signal -1, see #7677 - self.training_loop.increment_accumulated_grad_global_step() + self.global_step += 1 # why is this not the same as the old on_train_epoch_end? def on_run_end(self): diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index bc229ec8df833..d65ff7b346288 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -51,10 +51,7 @@ def connect(self, trainer: 'pl.Trainer', *args, **kwargs): @property def done(self): - max_steps_reached = ( - self.max_steps is not None and self.max_steps <= self.global_step + 1 - and self.batch_loop._accumulated_batches_reached() - ) + max_steps_reached = (self.max_steps is not None and self.max_steps <= self.global_step) return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch) def run(self, *args, **kwargs): @@ -140,12 +137,12 @@ def on_advance_end(self): self.total_batch_idx += 1 - if self.done: - raise StopIteration - # progress global step according to grads progress self.increment_accumulated_grad_global_step() + if self.done: + raise StopIteration + # this is the old on train_epoch_end? def on_run_end(self): if self.batches_seen == 0: @@ -354,7 +351,7 @@ def increment_accumulated_grad_global_step(self): self.total_batch_idx, self.trainer.global_step ) - def should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool: + def should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: """ Decide if we should run validation. """ if not self.trainer.enable_validation: return False @@ -365,26 +362,19 @@ def should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: boo # val_check_batch is inf for iterable datasets with no length defined is_infinite_dataset = self.trainer.val_check_batch == float('inf') - if on_epoch and is_last_batch and is_infinite_dataset: + if is_last_batch and is_infinite_dataset: return True if self.trainer.should_stop: return True # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch - is_val_check_batch = False - if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'): + is_val_check_batch = is_last_batch + if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset: is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 elif self.trainer.val_check_batch != float('inf'): is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 - - # Note: num_training_batches is also inf for iterable datasets with no length defined - epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 - - if on_epoch: - return is_val_check_batch and epoch_end_val_check - else: - return is_val_check_batch and not epoch_end_val_check + return is_val_check_batch def save_loggers_on_train_batch_end(self): # when loggers should save to disk diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6fb5a0cbbef1b..d954d0440b705 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1051,7 +1051,7 @@ def _run_train_old_loop(self) -> None: self.state.stage = None raise - def _run_evaluatin_old_loop(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: + def _run_evaluatin_old_loop(self) -> _EVALUATE_OUTPUT: # prepare dataloaders dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders() @@ -1127,7 +1127,7 @@ def _run_evaluatin_old_loop(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: return eval_loop_results - def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: + def _run_evaluation(self) -> _EVALUATE_OUTPUT: if not (self.evaluating or self.sanity_checking): rank_zero_warn( f"`trainer._run_evaluation()` was called but the running stage is set to {self.state.stage}." @@ -1154,7 +1154,7 @@ def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: if NEW_LOOP: eval_loop_results = self.evaluation_loop.run() else: - eval_loop_results = self._run_evaluatin_old_loop(on_epoch) + eval_loop_results = self._run_evaluatin_old_loop() # save predictions to disk self.evaluation_loop.predictions.to_disk() From a0a46a938be509236433bf62347291380b0fcc64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 May 2021 16:48:48 +0200 Subject: [PATCH 161/455] wip eval loop refactor --- .../loops/evaluation_dataloader_loop.py | 355 ++++++++++++++++++ pytorch_lightning/loops/evaluation_loop.py | 258 ------------- pytorch_lightning/trainer/trainer.py | 4 +- 3 files changed, 357 insertions(+), 260 deletions(-) create mode 100644 pytorch_lightning/loops/evaluation_dataloader_loop.py diff --git a/pytorch_lightning/loops/evaluation_dataloader_loop.py b/pytorch_lightning/loops/evaluation_dataloader_loop.py new file mode 100644 index 0000000000000..6e94c4c0e9136 --- /dev/null +++ b/pytorch_lightning/loops/evaluation_dataloader_loop.py @@ -0,0 +1,355 @@ +from weakref import proxy + +from pytorch_lightning.loops.base import Loop +from typing import Any, Optional, Sequence, Union +from torch.utils.data.dataloader import _BaseDataLoaderIter, DataLoader + + +class EvaluationDataLoaderLoop(Loop): + + def __init__(self): + super().__init__() + self._dataloaders: Optional[Union[DataLoader, Sequence[DataLoader]]] = None + self._max_batches: Optional[Union[int, Sequence[int]]] = None + + def reset(self): + self.iteration_count = 0 + + # prepare dataloaders + self._dataloaders, self._max_batches = self.get_evaluation_dataloaders() + self._dataloaders = iter(enumerate(self._dataloaders)) + self._current_loader, self._current_loader_idx = None, None + + def done(self): + try: + self._current_loader_idx, self._current_loader = next(self._dataloaders) + except StopIteration: + return True + return False + + def advance(self, *args: Any, **kwargs: Any) -> None: + dataloader = self.trainer.accelerator.process_dataloader(self._current_loader) + dl_max_batches = self.max_batches[self._current_loader_idx] + + self.evaluation_loop(dataloader, self._current_loader_idx, dl_max_batches) + + def on_run_start(self, *args: Any, **kwargs: Any) -> None: + # hook + self.on_evaluation_start() + + def temp_run(self): + # check if we want to skip this evaluation + if self.should_skip_evaluation(self._max_batches): + return [], [] + + + + # set up the eval loop + self.setup(self._max_batches, self._dataloaders) + + # hook + self.on_evaluation_epoch_start() + + # run validation/testing + for dataloader_idx, dataloader in enumerate(self._dataloaders): + dataloader = self.trainer.accelerator.process_dataloader(dataloader) + dl_max_batches = self.max_batches[dataloader_idx] + + self.evaluation_loop(dataloader, dataloader_idx, dl_max_batches) + + outputs = self.outputs + + # reset outputs + self.outputs = [] + + # with a single dataloader don't pass a 2D list + if len(outputs) > 0 and self.num_dataloaders == 1: + outputs = outputs[0] + + # lightning module method + self.evaluation_epoch_end(outputs) + + # hook + self.on_evaluation_epoch_end() + + # log epoch metrics + eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results() + + # hook + self.on_evaluation_end() + + return eval_loop_results + + # TODO: Move this to separate loop + def evaluation_loop(self, dataloader, dataloader_idx, dl_max_batches): + dl_outputs = [] + for batch_idx, batch in enumerate(dataloader): + if batch is None: + continue + + # stop short when running on limited batches + if batch_idx >= dl_max_batches: + break + + # hook + self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) + + # lightning module methods + with self.trainer.profiler.profile("evaluation_step_and_end"): + output = self.evaluation_step(batch, batch_idx, dataloader_idx) + output = self.evaluation_step_end(output) + + # hook + store predictions + self.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx) + + # log batch metrics + self.trainer.logger_connector.log_evaluation_step_metrics() + + # track epoch level outputs + dl_outputs = self.trainer._track_output_for_epoch_end(dl_outputs, output) + + # store batch level output per dataloader + if self.should_track_batch_outputs_for_epoch_end: + self.outputs.append(dl_outputs) + + + +# HELPERS + + def on_trainer_init(self) -> None: + self.trainer.num_sanity_val_batches = [] + self.trainer.num_test_batches = [] + self.trainer.num_val_batches = [] + self.trainer.test_dataloaders = None + self.trainer.val_dataloaders = None + + # .validate() and .test() set this when they load a checkpoint + self.trainer.validated_ckpt_path = None + self.trainer.tested_ckpt_path = None + + # when true, print evaluation results in .validate() and .test() + self.trainer.verbose_evaluate = True + + + def get_evaluation_dataloaders(self) -> Tuple[Optional[List[DataLoader]], List[Union[int, float]]]: + model = self.trainer.lightning_module + + # select dataloaders + if self.trainer.testing: + self.trainer.reset_test_dataloader(model) + + dataloaders = self.trainer.test_dataloaders + max_batches = self.trainer.num_test_batches + else: + # val + if self.trainer.val_dataloaders is None or self.trainer.reload_dataloaders_every_epoch: + self.trainer.reset_val_dataloader(model) + if self.trainer.sanity_checking: + self.trainer.num_sanity_val_batches = [ + min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches + ] + max_batches = self.trainer.num_sanity_val_batches + else: + max_batches = self.trainer.num_val_batches + dataloaders = self.trainer.val_dataloaders + return dataloaders, max_batches + + + def should_skip_evaluation(self, max_batches: List[Union[int, float]]) -> bool: + return sum(max_batches) == 0 + + + def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: + self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end() + if self.trainer.testing: + self.trainer.call_hook('on_test_start', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_start', *args, **kwargs) + + + def on_evaluation_model_eval(self) -> None: + model_ref = self.trainer.lightning_module + if self.trainer.testing: + model_ref.on_test_model_eval() + else: + model_ref.on_validation_model_eval() + + + def on_evaluation_model_train(self) -> None: + model_ref = self.trainer.lightning_module + if self.trainer.testing: + model_ref.on_test_model_train() + else: + model_ref.on_validation_model_train() + + + def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: + if self.trainer.testing: + self.trainer.call_hook('on_test_end', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_end', *args, **kwargs) + + if self.trainer.state.fn != TrainerFn.FITTING: + # summarize profile results + self.trainer.profiler.describe() + + + def reload_evaluation_dataloaders(self) -> None: + model = self.trainer.lightning_module + if self.trainer.testing: + self.trainer.reset_test_dataloader(model) + else: + self.trainer.reset_val_dataloader(model) + + + def setup(self, max_batches: List[Union[int, float]], dataloaders: List[DataLoader]) -> None: + # bookkeeping + self.outputs = [] + self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) + + # convert max_batches to list + if isinstance(max_batches, int): + max_batches = [max_batches] * len(dataloaders) + + self.max_batches = max_batches + self.num_dataloaders = self._get_num_dataloaders(dataloaders) + + + def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: + self.trainer.call_hook('on_epoch_start', *args, **kwargs) + + if self.trainer.testing: + self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) + + + def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Union[Any, int]]: + # make dataloader_idx arg in validation_step optional + step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)]) + + multiple_val_loaders = ( + not self.trainer.testing and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1 + ) + multiple_test_loaders = (self.trainer.testing and self._get_num_dataloaders(self.trainer.test_dataloaders) > 1) + + if multiple_test_loaders or multiple_val_loaders: + step_kwargs['dataloader_idx'] = dataloader_idx + + return step_kwargs + + + def _get_num_dataloaders(self, dataloaders: Optional[List[DataLoader]]) -> int: + # case where user does: + # return dl1, dl2 + if dataloaders is not None: + length = len(dataloaders) + if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): + length = len(dataloaders[0]) + return length + else: + return 0 + + + def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]: + # configure step_kwargs + step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) + + model_ref = self.trainer.lightning_module + model_ref._results = Result() + + if self.trainer.testing: + model_ref._current_fx_name = "test_step" + with self.trainer.profiler.profile("test_step"): + output = self.trainer.accelerator.test_step(step_kwargs) + else: + model_ref._current_fx_name = "validation_step" + with self.trainer.profiler.profile("validation_step"): + output = self.trainer.accelerator.validation_step(step_kwargs) + + # capture any logged information + self.trainer.logger_connector.cache_logged_metrics() + # track batch size for weighted average + if isinstance(output, Result): + output.track_batch_size(batch) + + return output + + + def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: + if self.trainer.testing: + output = self.trainer.call_hook('test_step_end', *args, **kwargs) + else: + output = self.trainer.call_hook('validation_step_end', *args, **kwargs) + return output + + + def _should_track_batch_outputs_for_epoch_end(self) -> bool: + model = self.trainer.lightning_module + if self.trainer.testing: + return is_overridden('test_epoch_end', model=model) + else: + return is_overridden('validation_epoch_end', model=model) + + + def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: + # unset dataloder_idx in model + self.trainer.logger_connector.evaluation_epoch_end() + + # call the model epoch end + model = self.trainer.lightning_module + + if self.trainer.testing: + if is_overridden('test_epoch_end', model=model): + model._current_fx_name = 'test_epoch_end' + model.test_epoch_end(outputs) + + else: + if is_overridden('validation_epoch_end', model=model): + model._current_fx_name = 'validation_epoch_end' + model.validation_epoch_end(outputs) + + # capture logging + self.trainer.logger_connector.cache_logged_metrics() + + + def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + # set dataloader_idx to model and track batch_size + self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self.num_dataloaders) + + if self.trainer.testing: + self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) + else: + self.trainer.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx) + + + def on_evaluation_batch_end( + self, + output: Optional[STEP_OUTPUT], + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + if self.trainer.testing: + self.trainer.call_hook('on_test_batch_end', output, batch, batch_idx, dataloader_idx) + else: + self.trainer.call_hook('on_validation_batch_end', output, batch, batch_idx, dataloader_idx) + + # store predicitons if do_write_predictions and track eval loss history + self.store_predictions(output, batch_idx, dataloader_idx) + + + def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, dataloader_idx: int) -> None: + # Add step predictions to prediction collection to write later + if output is not None and self.predictions is not None: + if isinstance(output, Result) and self.trainer.testing: + self.predictions.add(output.pop('predictions', None)) + + # track debug metrics + self.trainer.dev_debugger.track_eval_loss_history(batch_idx, dataloader_idx, output) + + + def on_evaluation_epoch_end(self) -> None: + hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" + self.trainer.call_hook(hook_name) + self.trainer.call_hook('on_epoch_end') diff --git a/pytorch_lightning/loops/evaluation_loop.py b/pytorch_lightning/loops/evaluation_loop.py index b5586610da107..e69de29bb2d1d 100644 --- a/pytorch_lightning/loops/evaluation_loop.py +++ b/pytorch_lightning/loops/evaluation_loop.py @@ -1,258 +0,0 @@ -from pytorch_lightning.loops.base import Loop -from typing import Any - -class EvaluationLoop(Loop): - - def reset(self): - # TODO - pass - - def done(self): - # TODO - pass - - def advance(self, *args: Any, **kwargs: Any) -> None: - # TODO - pass - - - -# HELPERS - - def on_trainer_init(self) -> None: - self.trainer.num_sanity_val_batches = [] - self.trainer.num_test_batches = [] - self.trainer.num_val_batches = [] - self.trainer.test_dataloaders = None - self.trainer.val_dataloaders = None - - # .validate() and .test() set this when they load a checkpoint - self.trainer.validated_ckpt_path = None - self.trainer.tested_ckpt_path = None - - # when true, print evaluation results in .validate() and .test() - self.trainer.verbose_evaluate = True - - - def get_evaluation_dataloaders(self) -> Tuple[Optional[List[DataLoader]], List[Union[int, float]]]: - model = self.trainer.lightning_module - - # select dataloaders - if self.trainer.testing: - self.trainer.reset_test_dataloader(model) - - dataloaders = self.trainer.test_dataloaders - max_batches = self.trainer.num_test_batches - else: - # val - if self.trainer.val_dataloaders is None or self.trainer.reload_dataloaders_every_epoch: - self.trainer.reset_val_dataloader(model) - if self.trainer.sanity_checking: - self.trainer.num_sanity_val_batches = [ - min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches - ] - max_batches = self.trainer.num_sanity_val_batches - else: - max_batches = self.trainer.num_val_batches - dataloaders = self.trainer.val_dataloaders - return dataloaders, max_batches - - - def should_skip_evaluation(self, max_batches: List[Union[int, float]]) -> bool: - return sum(max_batches) == 0 - - - def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: - self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end() - if self.trainer.testing: - self.trainer.call_hook('on_test_start', *args, **kwargs) - else: - self.trainer.call_hook('on_validation_start', *args, **kwargs) - - - def on_evaluation_model_eval(self) -> None: - model_ref = self.trainer.lightning_module - if self.trainer.testing: - model_ref.on_test_model_eval() - else: - model_ref.on_validation_model_eval() - - - def on_evaluation_model_train(self) -> None: - model_ref = self.trainer.lightning_module - if self.trainer.testing: - model_ref.on_test_model_train() - else: - model_ref.on_validation_model_train() - - - def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: - if self.trainer.testing: - self.trainer.call_hook('on_test_end', *args, **kwargs) - else: - self.trainer.call_hook('on_validation_end', *args, **kwargs) - - if self.trainer.state.fn != TrainerFn.FITTING: - # summarize profile results - self.trainer.profiler.describe() - - - def reload_evaluation_dataloaders(self) -> None: - model = self.trainer.lightning_module - if self.trainer.testing: - self.trainer.reset_test_dataloader(model) - else: - self.trainer.reset_val_dataloader(model) - - - def setup(self, max_batches: List[Union[int, float]], dataloaders: List[DataLoader]) -> None: - # bookkeeping - self.outputs = [] - self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) - - # convert max_batches to list - if isinstance(max_batches, int): - max_batches = [max_batches] * len(dataloaders) - - self.max_batches = max_batches - self.num_dataloaders = self._get_num_dataloaders(dataloaders) - - - def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: - self.trainer.call_hook('on_epoch_start', *args, **kwargs) - - if self.trainer.testing: - self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) - else: - self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) - - - def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Union[Any, int]]: - # make dataloader_idx arg in validation_step optional - step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)]) - - multiple_val_loaders = ( - not self.trainer.testing and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1 - ) - multiple_test_loaders = (self.trainer.testing and self._get_num_dataloaders(self.trainer.test_dataloaders) > 1) - - if multiple_test_loaders or multiple_val_loaders: - step_kwargs['dataloader_idx'] = dataloader_idx - - return step_kwargs - - - def _get_num_dataloaders(self, dataloaders: Optional[List[DataLoader]]) -> int: - # case where user does: - # return dl1, dl2 - if dataloaders is not None: - length = len(dataloaders) - if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): - length = len(dataloaders[0]) - return length - else: - return 0 - - - def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]: - # configure step_kwargs - step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) - - model_ref = self.trainer.lightning_module - model_ref._results = Result() - - if self.trainer.testing: - model_ref._current_fx_name = "test_step" - with self.trainer.profiler.profile("test_step"): - output = self.trainer.accelerator.test_step(step_kwargs) - else: - model_ref._current_fx_name = "validation_step" - with self.trainer.profiler.profile("validation_step"): - output = self.trainer.accelerator.validation_step(step_kwargs) - - # capture any logged information - self.trainer.logger_connector.cache_logged_metrics() - # track batch size for weighted average - if isinstance(output, Result): - output.track_batch_size(batch) - - return output - - - def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: - if self.trainer.testing: - output = self.trainer.call_hook('test_step_end', *args, **kwargs) - else: - output = self.trainer.call_hook('validation_step_end', *args, **kwargs) - return output - - - def _should_track_batch_outputs_for_epoch_end(self) -> bool: - model = self.trainer.lightning_module - if self.trainer.testing: - return is_overridden('test_epoch_end', model=model) - else: - return is_overridden('validation_epoch_end', model=model) - - - def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: - # unset dataloder_idx in model - self.trainer.logger_connector.evaluation_epoch_end() - - # call the model epoch end - model = self.trainer.lightning_module - - if self.trainer.testing: - if is_overridden('test_epoch_end', model=model): - model._current_fx_name = 'test_epoch_end' - model.test_epoch_end(outputs) - - else: - if is_overridden('validation_epoch_end', model=model): - model._current_fx_name = 'validation_epoch_end' - model.validation_epoch_end(outputs) - - # capture logging - self.trainer.logger_connector.cache_logged_metrics() - - - def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - # set dataloader_idx to model and track batch_size - self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self.num_dataloaders) - - if self.trainer.testing: - self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) - else: - self.trainer.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx) - - - def on_evaluation_batch_end( - self, - output: Optional[STEP_OUTPUT], - batch: Any, - batch_idx: int, - dataloader_idx: int, - ) -> None: - if self.trainer.testing: - self.trainer.call_hook('on_test_batch_end', output, batch, batch_idx, dataloader_idx) - else: - self.trainer.call_hook('on_validation_batch_end', output, batch, batch_idx, dataloader_idx) - - # store predicitons if do_write_predictions and track eval loss history - self.store_predictions(output, batch_idx, dataloader_idx) - - - def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, dataloader_idx: int) -> None: - # Add step predictions to prediction collection to write later - if output is not None and self.predictions is not None: - if isinstance(output, Result) and self.trainer.testing: - self.predictions.add(output.pop('predictions', None)) - - # track debug metrics - self.trainer.dev_debugger.track_eval_loss_history(batch_idx, dataloader_idx, output) - - - def on_evaluation_epoch_end(self) -> None: - hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" - self.trainer.call_hook(hook_name) - self.trainer.call_hook('on_epoch_end') diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d954d0440b705..1b84edb82b07d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -83,7 +83,7 @@ NEW_LOOP = True if NEW_LOOP: - from pytorch_lightning.loops.evaluation_loop import EvaluationLoop + from pytorch_lightning.loops.evaluation_dataloader_loop import EvaluationDataLoaderLoop else: from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop @@ -344,7 +344,7 @@ def __init__( if NEW_LOOP: self.train_loop = EpochLoop(min_epochs, max_epochs, min_steps, max_steps) - self.evaluation_loop = EvaluationLoop() + self.evaluation_loop = EvaluationDataLoaderLoop() self.train_loop.connect(self) self.evaluation_loop.connect(self) else: From cea36e7f9d966438f60f88677afecff343574b4f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 26 May 2021 14:49:47 +0000 Subject: [PATCH 162/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/base.py | 3 +- .../loops/evaluation_dataloader_loop.py | 39 +++++-------------- 2 files changed, 10 insertions(+), 32 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index cb141aa649aa4..7c44230384622 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -1,7 +1,6 @@ -from weakref import proxy from abc import ABC, abstractmethod from typing import Any, Optional - +from weakref import proxy import pytorch_lightning as pl diff --git a/pytorch_lightning/loops/evaluation_dataloader_loop.py b/pytorch_lightning/loops/evaluation_dataloader_loop.py index 6e94c4c0e9136..190e8498c9521 100644 --- a/pytorch_lightning/loops/evaluation_dataloader_loop.py +++ b/pytorch_lightning/loops/evaluation_dataloader_loop.py @@ -1,9 +1,10 @@ +from typing import Any, Optional, Sequence, Union from weakref import proxy -from pytorch_lightning.loops.base import Loop -from typing import Any, Optional, Sequence, Union from torch.utils.data.dataloader import _BaseDataLoaderIter, DataLoader +from pytorch_lightning.loops.base import Loop + class EvaluationDataLoaderLoop(Loop): @@ -42,8 +43,6 @@ def temp_run(self): if self.should_skip_evaluation(self._max_batches): return [], [] - - # set up the eval loop self.setup(self._max_batches, self._dataloaders) @@ -113,7 +112,6 @@ def evaluation_loop(self, dataloader, dataloader_idx, dl_max_batches): self.outputs.append(dl_outputs) - # HELPERS def on_trainer_init(self) -> None: @@ -130,7 +128,6 @@ def on_trainer_init(self) -> None: # when true, print evaluation results in .validate() and .test() self.trainer.verbose_evaluate = True - def get_evaluation_dataloaders(self) -> Tuple[Optional[List[DataLoader]], List[Union[int, float]]]: model = self.trainer.lightning_module @@ -154,11 +151,9 @@ def get_evaluation_dataloaders(self) -> Tuple[Optional[List[DataLoader]], List[U dataloaders = self.trainer.val_dataloaders return dataloaders, max_batches - def should_skip_evaluation(self, max_batches: List[Union[int, float]]) -> bool: return sum(max_batches) == 0 - def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end() if self.trainer.testing: @@ -166,7 +161,6 @@ def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: else: self.trainer.call_hook('on_validation_start', *args, **kwargs) - def on_evaluation_model_eval(self) -> None: model_ref = self.trainer.lightning_module if self.trainer.testing: @@ -174,7 +168,6 @@ def on_evaluation_model_eval(self) -> None: else: model_ref.on_validation_model_eval() - def on_evaluation_model_train(self) -> None: model_ref = self.trainer.lightning_module if self.trainer.testing: @@ -182,7 +175,6 @@ def on_evaluation_model_train(self) -> None: else: model_ref.on_validation_model_train() - def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: if self.trainer.testing: self.trainer.call_hook('on_test_end', *args, **kwargs) @@ -193,7 +185,6 @@ def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: # summarize profile results self.trainer.profiler.describe() - def reload_evaluation_dataloaders(self) -> None: model = self.trainer.lightning_module if self.trainer.testing: @@ -201,7 +192,6 @@ def reload_evaluation_dataloaders(self) -> None: else: self.trainer.reset_val_dataloader(model) - def setup(self, max_batches: List[Union[int, float]], dataloaders: List[DataLoader]) -> None: # bookkeeping self.outputs = [] @@ -214,7 +204,6 @@ def setup(self, max_batches: List[Union[int, float]], dataloaders: List[DataLoad self.max_batches = max_batches self.num_dataloaders = self._get_num_dataloaders(dataloaders) - def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: self.trainer.call_hook('on_epoch_start', *args, **kwargs) @@ -223,13 +212,12 @@ def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: else: self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) - def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Union[Any, int]]: # make dataloader_idx arg in validation_step optional step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)]) multiple_val_loaders = ( - not self.trainer.testing and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1 + not self.trainer.testing and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1 ) multiple_test_loaders = (self.trainer.testing and self._get_num_dataloaders(self.trainer.test_dataloaders) > 1) @@ -238,7 +226,6 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict return step_kwargs - def _get_num_dataloaders(self, dataloaders: Optional[List[DataLoader]]) -> int: # case where user does: # return dl1, dl2 @@ -250,7 +237,6 @@ def _get_num_dataloaders(self, dataloaders: Optional[List[DataLoader]]) -> int: else: return 0 - def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]: # configure step_kwargs step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) @@ -275,7 +261,6 @@ def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Op return output - def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: if self.trainer.testing: output = self.trainer.call_hook('test_step_end', *args, **kwargs) @@ -283,7 +268,6 @@ def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT output = self.trainer.call_hook('validation_step_end', *args, **kwargs) return output - def _should_track_batch_outputs_for_epoch_end(self) -> bool: model = self.trainer.lightning_module if self.trainer.testing: @@ -291,7 +275,6 @@ def _should_track_batch_outputs_for_epoch_end(self) -> bool: else: return is_overridden('validation_epoch_end', model=model) - def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: # unset dataloder_idx in model self.trainer.logger_connector.evaluation_epoch_end() @@ -312,7 +295,6 @@ def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: # capture logging self.trainer.logger_connector.cache_logged_metrics() - def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: # set dataloader_idx to model and track batch_size self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self.num_dataloaders) @@ -322,13 +304,12 @@ def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: else: self.trainer.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx) - def on_evaluation_batch_end( - self, - output: Optional[STEP_OUTPUT], - batch: Any, - batch_idx: int, - dataloader_idx: int, + self, + output: Optional[STEP_OUTPUT], + batch: Any, + batch_idx: int, + dataloader_idx: int, ) -> None: if self.trainer.testing: self.trainer.call_hook('on_test_batch_end', output, batch, batch_idx, dataloader_idx) @@ -338,7 +319,6 @@ def on_evaluation_batch_end( # store predicitons if do_write_predictions and track eval loss history self.store_predictions(output, batch_idx, dataloader_idx) - def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, dataloader_idx: int) -> None: # Add step predictions to prediction collection to write later if output is not None and self.predictions is not None: @@ -348,7 +328,6 @@ def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, datal # track debug metrics self.trainer.dev_debugger.track_eval_loss_history(batch_idx, dataloader_idx, output) - def on_evaluation_epoch_end(self) -> None: hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" self.trainer.call_hook(hook_name) From 9e347c4ac490c6d968ea12899acaa48fe39a9edf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 27 May 2021 12:44:49 +0200 Subject: [PATCH 163/455] wip evaluation loop --- pytorch_lightning/loops/epoch_loop.py | 1 + .../loops/evaluation_dataloader_loop.py | 195 ++++-------------- pytorch_lightning/loops/evaluation_loop.py | 155 ++++++++++++++ 3 files changed, 192 insertions(+), 159 deletions(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 28efd9f0eb516..ab2f694c2d330 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -28,6 +28,7 @@ def __init__(self, min_epochs, max_epochs, min_steps, max_steps): super().__init__() self._teardown_already_run = False + # TODO: Move this to trainer (it's a trainer default, loops shouldn't have to care about this # If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000 self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs # If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1 diff --git a/pytorch_lightning/loops/evaluation_dataloader_loop.py b/pytorch_lightning/loops/evaluation_dataloader_loop.py index 190e8498c9521..93c06583e1952 100644 --- a/pytorch_lightning/loops/evaluation_dataloader_loop.py +++ b/pytorch_lightning/loops/evaluation_dataloader_loop.py @@ -1,9 +1,15 @@ -from typing import Any, Optional, Sequence, Union -from weakref import proxy +from typing import Any, Optional, Sequence, Union, Tuple, List, Dict +from collections import OrderedDict -from torch.utils.data.dataloader import _BaseDataLoaderIter, DataLoader +from torch.utils.data.dataloader import DataLoader +from pytorch_lightning.core.step_result import Result from pytorch_lightning.loops.base import Loop +from pytorch_lightning.loops.evaluation_loop import EvaluationLoop +from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.trainer.supporters import PredictionCollection +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.types import STEP_OUTPUT, EPOCH_OUTPUT class EvaluationDataLoaderLoop(Loop): @@ -12,55 +18,52 @@ def __init__(self): super().__init__() self._dataloaders: Optional[Union[DataLoader, Sequence[DataLoader]]] = None self._max_batches: Optional[Union[int, Sequence[int]]] = None + self.outputs = [] + self.evaluation_loop = EvaluationLoop() + + @property + def current_dataloader_idx(self) -> int: + return self.iteration_count + + @property + def num_dataloaders(self): + return self._get_num_dataloaders(self._dataloaders) def reset(self): self.iteration_count = 0 # prepare dataloaders self._dataloaders, self._max_batches = self.get_evaluation_dataloaders() - self._dataloaders = iter(enumerate(self._dataloaders)) - self._current_loader, self._current_loader_idx = None, None + # bookkeeping + self.outputs = [] + + if isinstance(self._max_batches, int): + self._max_batches = [self._max_batches] * len(self._dataloaders) + + self._max_batches = self._max_batches def done(self): - try: - self._current_loader_idx, self._current_loader = next(self._dataloaders) - except StopIteration: - return True - return False + return (self.current_dataloader_idx >= len(self._dataloaders)) or self.should_skip_evaluation(self._max_batches) def advance(self, *args: Any, **kwargs: Any) -> None: - dataloader = self.trainer.accelerator.process_dataloader(self._current_loader) - dl_max_batches = self.max_batches[self._current_loader_idx] + dataloader = self._dataloaders[self.current_dataloader_idx] + dataloader = self.trainer.accelerator.process_dataloader(dataloader) + dl_max_batches = self._max_batches[self.current_dataloader_idx] - self.evaluation_loop(dataloader, self._current_loader_idx, dl_max_batches) + dl_outputs = self.evaluation_loop.run(dataloader, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders) + + # store batch level output per dataloader + if self.should_track_batch_outputs_for_epoch_end: + self.outputs.append(dl_outputs) def on_run_start(self, *args: Any, **kwargs: Any) -> None: # hook self.on_evaluation_start() - - def temp_run(self): - # check if we want to skip this evaluation - if self.should_skip_evaluation(self._max_batches): - return [], [] - - # set up the eval loop - self.setup(self._max_batches, self._dataloaders) - - # hook self.on_evaluation_epoch_start() - # run validation/testing - for dataloader_idx, dataloader in enumerate(self._dataloaders): - dataloader = self.trainer.accelerator.process_dataloader(dataloader) - dl_max_batches = self.max_batches[dataloader_idx] - - self.evaluation_loop(dataloader, dataloader_idx, dl_max_batches) - + def on_run_end(self) -> Any: outputs = self.outputs - # reset outputs - self.outputs = [] - # with a single dataloader don't pass a 2D list if len(outputs) > 0 and self.num_dataloaders == 1: outputs = outputs[0] @@ -79,55 +82,10 @@ def temp_run(self): return eval_loop_results - # TODO: Move this to separate loop - def evaluation_loop(self, dataloader, dataloader_idx, dl_max_batches): - dl_outputs = [] - for batch_idx, batch in enumerate(dataloader): - if batch is None: - continue - - # stop short when running on limited batches - if batch_idx >= dl_max_batches: - break - - # hook - self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) - - # lightning module methods - with self.trainer.profiler.profile("evaluation_step_and_end"): - output = self.evaluation_step(batch, batch_idx, dataloader_idx) - output = self.evaluation_step_end(output) - - # hook + store predictions - self.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx) - - # log batch metrics - self.trainer.logger_connector.log_evaluation_step_metrics() - - # track epoch level outputs - dl_outputs = self.trainer._track_output_for_epoch_end(dl_outputs, output) - - # store batch level output per dataloader - if self.should_track_batch_outputs_for_epoch_end: - self.outputs.append(dl_outputs) # HELPERS - def on_trainer_init(self) -> None: - self.trainer.num_sanity_val_batches = [] - self.trainer.num_test_batches = [] - self.trainer.num_val_batches = [] - self.trainer.test_dataloaders = None - self.trainer.val_dataloaders = None - - # .validate() and .test() set this when they load a checkpoint - self.trainer.validated_ckpt_path = None - self.trainer.tested_ckpt_path = None - - # when true, print evaluation results in .validate() and .test() - self.trainer.verbose_evaluate = True - def get_evaluation_dataloaders(self) -> Tuple[Optional[List[DataLoader]], List[Union[int, float]]]: model = self.trainer.lightning_module @@ -192,18 +150,6 @@ def reload_evaluation_dataloaders(self) -> None: else: self.trainer.reset_val_dataloader(model) - def setup(self, max_batches: List[Union[int, float]], dataloaders: List[DataLoader]) -> None: - # bookkeeping - self.outputs = [] - self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) - - # convert max_batches to list - if isinstance(max_batches, int): - max_batches = [max_batches] * len(dataloaders) - - self.max_batches = max_batches - self.num_dataloaders = self._get_num_dataloaders(dataloaders) - def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: self.trainer.call_hook('on_epoch_start', *args, **kwargs) @@ -212,20 +158,6 @@ def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: else: self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) - def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Union[Any, int]]: - # make dataloader_idx arg in validation_step optional - step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)]) - - multiple_val_loaders = ( - not self.trainer.testing and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1 - ) - multiple_test_loaders = (self.trainer.testing and self._get_num_dataloaders(self.trainer.test_dataloaders) > 1) - - if multiple_test_loaders or multiple_val_loaders: - step_kwargs['dataloader_idx'] = dataloader_idx - - return step_kwargs - def _get_num_dataloaders(self, dataloaders: Optional[List[DataLoader]]) -> int: # case where user does: # return dl1, dl2 @@ -237,37 +169,6 @@ def _get_num_dataloaders(self, dataloaders: Optional[List[DataLoader]]) -> int: else: return 0 - def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]: - # configure step_kwargs - step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) - - model_ref = self.trainer.lightning_module - model_ref._results = Result() - - if self.trainer.testing: - model_ref._current_fx_name = "test_step" - with self.trainer.profiler.profile("test_step"): - output = self.trainer.accelerator.test_step(step_kwargs) - else: - model_ref._current_fx_name = "validation_step" - with self.trainer.profiler.profile("validation_step"): - output = self.trainer.accelerator.validation_step(step_kwargs) - - # capture any logged information - self.trainer.logger_connector.cache_logged_metrics() - # track batch size for weighted average - if isinstance(output, Result): - output.track_batch_size(batch) - - return output - - def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: - if self.trainer.testing: - output = self.trainer.call_hook('test_step_end', *args, **kwargs) - else: - output = self.trainer.call_hook('validation_step_end', *args, **kwargs) - return output - def _should_track_batch_outputs_for_epoch_end(self) -> bool: model = self.trainer.lightning_module if self.trainer.testing: @@ -295,30 +196,6 @@ def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: # capture logging self.trainer.logger_connector.cache_logged_metrics() - def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - # set dataloader_idx to model and track batch_size - self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self.num_dataloaders) - - if self.trainer.testing: - self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) - else: - self.trainer.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx) - - def on_evaluation_batch_end( - self, - output: Optional[STEP_OUTPUT], - batch: Any, - batch_idx: int, - dataloader_idx: int, - ) -> None: - if self.trainer.testing: - self.trainer.call_hook('on_test_batch_end', output, batch, batch_idx, dataloader_idx) - else: - self.trainer.call_hook('on_validation_batch_end', output, batch, batch_idx, dataloader_idx) - - # store predicitons if do_write_predictions and track eval loss history - self.store_predictions(output, batch_idx, dataloader_idx) - def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, dataloader_idx: int) -> None: # Add step predictions to prediction collection to write later if output is not None and self.predictions is not None: diff --git a/pytorch_lightning/loops/evaluation_loop.py b/pytorch_lightning/loops/evaluation_loop.py index e69de29bb2d1d..c51fc8289330d 100644 --- a/pytorch_lightning/loops/evaluation_loop.py +++ b/pytorch_lightning/loops/evaluation_loop.py @@ -0,0 +1,155 @@ +from collections import OrderedDict +from typing import Any, Optional, Dict, Union + +from torch.utils.data import DataLoader + +from pytorch_lightning.core.step_result import Result +from pytorch_lightning.loops.base import Loop +from pytorch_lightning.trainer.supporters import PredictionCollection +from pytorch_lightning.utilities.types import STEP_OUTPUT + + +class EvaluationLoop(Loop): + + def __init__(self): + super().__init__() + self.predictions: Optional[PredictionCollection] = None + self.dataloader: Optional[DataLoader] = None + self.dl_max_batches: Optional[int] = None + self.dataloader_idx: Optional[int] = None + self.num_dataloaders: Optional[int] = None + self.outputs = [] + + + @property + def done(self) -> bool: + return self.batch_idx >= self.dl_max_batches + + def reset(self) -> None: + self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) + self.dataloader = None + self.dl_max_batches = None + self.dataloader_idx = None + self.num_dataloaders = None + self.outputs = [] + + def on_run_start(self, dataloader, dl_max_batches, dataloader_idx, num_dataloaders) -> None: + self.dataloader = dataloader + self.dl_max_batches = dl_max_batches + self.dataloader_idx = dataloader_idx + self.num_dataloaders = num_dataloaders + self.dataloader_iter = enumerate(self.dataloader) + + # fetch first batch + self.batch_idx, self.batch = next(self.dataloader_iter) + + def on_advance_start(self, *args: Any, **kwargs: Any) -> None: + self.batch_idx, self.batch = next(self.dataloader_iter) + + + def advance(self, dataloader, dl_max_batches, dataloader_idx, num_dataloaders) -> None: + + if self.batch is None: + return + + # hook + self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) + + # lightning module methods + with self.trainer.profiler.profile("evaluation_step_and_end"): + output = self.evaluation_step(batch, batch_idx, dataloader_idx) + output = self.evaluation_step_end(output) + + # hook + store predictions + self.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx) + + # log batch metrics + self.trainer.logger_connector.log_evaluation_step_metrics() + + # track epoch level outputs + self.outputs = self.trainer._track_output_for_epoch_end(self.outputs, output) + + def on_advance_end(self) -> None: + # fetch next batch + self.batch_idx, self.batch = next(self.dataloader_iter) + + def on_run_end(self) -> Any: + return self.outputs + + def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]: + # configure step_kwargs + step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) + + model_ref = self.trainer.lightning_module + model_ref._results = Result() + + if self.trainer.testing: + model_ref._current_fx_name = "test_step" + with self.trainer.profiler.profile("test_step"): + output = self.trainer.accelerator.test_step(step_kwargs) + else: + model_ref._current_fx_name = "validation_step" + with self.trainer.profiler.profile("validation_step"): + output = self.trainer.accelerator.validation_step(step_kwargs) + + # capture any logged information + self.trainer.logger_connector.cache_logged_metrics() + # track batch size for weighted average + if isinstance(output, Result): + output.track_batch_size(batch) + + return output + + def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: + if self.trainer.testing: + output = self.trainer.call_hook('test_step_end', *args, **kwargs) + else: + output = self.trainer.call_hook('validation_step_end', *args, **kwargs) + return output + + def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + # set dataloader_idx to model and track batch_size + self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self.num_dataloaders) + + if self.trainer.testing: + self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) + else: + self.trainer.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx) + + def on_evaluation_batch_end( + self, + output: Optional[STEP_OUTPUT], + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + if self.trainer.testing: + self.trainer.call_hook('on_test_batch_end', output, batch, batch_idx, dataloader_idx) + else: + self.trainer.call_hook('on_validation_batch_end', output, batch, batch_idx, dataloader_idx) + + # store predicitons if do_write_predictions and track eval loss history + self.store_predictions(output, batch_idx, dataloader_idx) + + def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, dataloader_idx: int) -> None: + # Add step predictions to prediction collection to write later + if output is not None and self.predictions is not None: + if isinstance(output, Result) and self.trainer.testing: + self.predictions.add(output.pop('predictions', None)) + + # track debug metrics + self.trainer.dev_debugger.track_eval_loss_history(batch_idx, dataloader_idx, output) + + def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Union[Any, int]]: + # make dataloader_idx arg in validation_step optional + step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)]) + + multiple_val_loaders = ( + not self.trainer.testing and self.num_dataloaders > 1 + ) + multiple_test_loaders = (self.trainer.testing and self.num_dataloaders > 1) + + if multiple_test_loaders or multiple_val_loaders: + step_kwargs['dataloader_idx'] = dataloader_idx + + return step_kwargs \ No newline at end of file From 86daef211265ce7c7ea309f9592e5cf127eea2fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 27 May 2021 13:46:52 +0200 Subject: [PATCH 164/455] stop iteration in base class --- pytorch_lightning/loops/base.py | 12 +++++++----- pytorch_lightning/loops/training_loop.py | 5 +++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 7c44230384622..60ebf5b31caa2 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -30,11 +30,13 @@ def run(self, *args: Any, **kwargs: Any) -> Any: self.on_run_start(*args, **kwargs) while not self.done: - - self.on_advance_start(*args, **kwargs) - self.advance(*args, **kwargs) - self.on_advance_end() - self.iteration_count = self.increment_iteration(self.iteration_count) + try: + self.on_advance_start(*args, **kwargs) + self.advance(*args, **kwargs) + self.on_advance_end() + self.iteration_count = self.increment_iteration(self.iteration_count) + except StopIteration: + break return self.on_run_end() diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index d65ff7b346288..f1a3d51913832 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -58,16 +58,17 @@ def run(self, *args, **kwargs): self.reset() self.on_run_start() + # TODO: while condition is different from super.run(), + # redesign the done conditions and use the base class run() implementation while True: try: self.on_advance_start() self.advance() self.on_advance_end() + self.iteration_count = self.increment_iteration(self.iteration_count) except StopIteration: break - self.iteration_count = self.increment_iteration(self.iteration_count) - return self.on_run_end() def reset(self) -> None: From 60802f5a50f3b95a7a2a3fa33132c8cf204e0f2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 27 May 2021 14:02:40 +0200 Subject: [PATCH 165/455] wip evaluation loop --- .../loops/evaluation_dataloader_loop.py | 10 ++++- pytorch_lightning/loops/evaluation_loop.py | 37 +++++++------------ pytorch_lightning/trainer/trainer.py | 1 + 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/loops/evaluation_dataloader_loop.py b/pytorch_lightning/loops/evaluation_dataloader_loop.py index 93c06583e1952..4296a1621c98b 100644 --- a/pytorch_lightning/loops/evaluation_dataloader_loop.py +++ b/pytorch_lightning/loops/evaluation_dataloader_loop.py @@ -29,6 +29,14 @@ def current_dataloader_idx(self) -> int: def num_dataloaders(self): return self._get_num_dataloaders(self._dataloaders) + def connect(self, trainer, *args, **kwargs): + super().connect(trainer, *args, **kwargs) + self.evaluation_loop.connect(trainer, *args, **kwargs) + + @property + def done(self): + return (self.current_dataloader_idx >= len(self._dataloaders)) or self.should_skip_evaluation(self._max_batches) + def reset(self): self.iteration_count = 0 @@ -42,8 +50,6 @@ def reset(self): self._max_batches = self._max_batches - def done(self): - return (self.current_dataloader_idx >= len(self._dataloaders)) or self.should_skip_evaluation(self._max_batches) def advance(self, *args: Any, **kwargs: Any) -> None: dataloader = self._dataloaders[self.current_dataloader_idx] diff --git a/pytorch_lightning/loops/evaluation_loop.py b/pytorch_lightning/loops/evaluation_loop.py index c51fc8289330d..248c79b800c3d 100644 --- a/pytorch_lightning/loops/evaluation_loop.py +++ b/pytorch_lightning/loops/evaluation_loop.py @@ -1,7 +1,5 @@ from collections import OrderedDict -from typing import Any, Optional, Dict, Union - -from torch.utils.data import DataLoader +from typing import Any, Optional, Dict, Union, Iterator from pytorch_lightning.core.step_result import Result from pytorch_lightning.loops.base import Loop @@ -14,43 +12,40 @@ class EvaluationLoop(Loop): def __init__(self): super().__init__() self.predictions: Optional[PredictionCollection] = None - self.dataloader: Optional[DataLoader] = None - self.dl_max_batches: Optional[int] = None + self.dataloader: Optional[Iterator] = None + self.dl_max_batches: Optional[int] = None self.dataloader_idx: Optional[int] = None self.num_dataloaders: Optional[int] = None + self.batch_idx: Optional[int] = None self.outputs = [] + def connect(self, trainer, *args, **kwargs): + super().connect(trainer, *args, **kwargs) @property def done(self) -> bool: - return self.batch_idx >= self.dl_max_batches + return self.batch_idx is not None and self.batch_idx >= self.dl_max_batches def reset(self) -> None: + self.iteration_count = 0 self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) - self.dataloader = None self.dl_max_batches = None self.dataloader_idx = None self.num_dataloaders = None self.outputs = [] def on_run_start(self, dataloader, dl_max_batches, dataloader_idx, num_dataloaders) -> None: - self.dataloader = dataloader self.dl_max_batches = dl_max_batches self.dataloader_idx = dataloader_idx self.num_dataloaders = num_dataloaders - self.dataloader_iter = enumerate(self.dataloader) - - # fetch first batch - self.batch_idx, self.batch = next(self.dataloader_iter) - - def on_advance_start(self, *args: Any, **kwargs: Any) -> None: - self.batch_idx, self.batch = next(self.dataloader_iter) - + self.dataloader = enumerate(dataloader) def advance(self, dataloader, dl_max_batches, dataloader_idx, num_dataloaders) -> None: + batch_idx, batch = next(self.dataloader) + self.batch_idx = batch_idx - if self.batch is None: - return + if batch is None: + raise StopIteration # hook self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) @@ -69,10 +64,6 @@ def advance(self, dataloader, dl_max_batches, dataloader_idx, num_dataloaders) - # track epoch level outputs self.outputs = self.trainer._track_output_for_epoch_end(self.outputs, output) - def on_advance_end(self) -> None: - # fetch next batch - self.batch_idx, self.batch = next(self.dataloader_iter) - def on_run_end(self) -> Any: return self.outputs @@ -152,4 +143,4 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict if multiple_test_loaders or multiple_val_loaders: step_kwargs['dataloader_idx'] = dataloader_idx - return step_kwargs \ No newline at end of file + return step_kwargs diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1b84edb82b07d..33ea0658ee172 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1169,6 +1169,7 @@ def _run_evaluation(self) -> _EVALUATE_OUTPUT: return eval_loop_results + # TODO: move inside evaluation loop def _track_output_for_epoch_end(self, outputs, output): if output is not None: if isinstance(output, Result): From faf1a76f7dbd9e4912fe57ef33618c985967efe4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 27 May 2021 16:20:29 +0200 Subject: [PATCH 166/455] WIP --- .../logger_connector/fx_validator.py | 13 +++++++-- .../logger_connector/logger_connector.py | 28 +++++++++++-------- pytorch_lightning/trainer/training_loop.py | 2 +- tests/trainer/test_trainer.py | 2 +- 4 files changed, 28 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index 3db8aace451dd..7ab288e6041fd 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -75,16 +75,23 @@ class FxValidator: training_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)), validation_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)), test_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)), + on_before_batch_transfer=None, + transfer_batch_to_device=None, + on_after_batch_transfer=None, + backward=None, + optimizer_step=None, # TODO(@carmocca): some {step,epoch}_{start,end} are missing ) - def check_logging(self, fx_name: str, on_step: bool, on_epoch: bool) -> None: - if fx_name not in self.functions: + @classmethod + def check_logging(cls, fx_name: str, on_step: bool, on_epoch: bool) -> None: + """Check if the given function name is allowed to log""" + if fx_name not in cls.functions: raise RuntimeError( f'You are trying to `self.log()` inside `{fx_name}` but it is not implemented.' ' Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`' ) - allowed = self.functions[fx_name] + allowed = cls.functions[fx_name] if allowed is None: raise MisconfigurationException(f"{fx_name} function doesn't support logging using `self.log()`") diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 95e51c3bd9286..221083be7a63e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -112,6 +112,10 @@ def log_metrics(self, metrics, grad_norm_dict, step=None): self.add_logged_metrics(scalar_metrics) + """ + Evaluation metric updates + """ + def evaluation_epoch_end(self): # reset dataloader idx model_ref = self.trainer.lightning_module @@ -226,7 +230,9 @@ def update_evaluation_step_metrics(self) -> None: # increment the step even if nothing was logged self.increment_evaluation_log_step() - ############## TRAIN METRICS UPDATES START ############## # noqa E266 + """ + Train metric updates + """ def on_train_start(self): root_device = self.trainer.lightning_module.device @@ -286,12 +292,12 @@ def update_train_epoch_metrics(self) -> None: # reset result collection for next epoch self.trainer.result_collections.reset_metrics() - ############## TRAIN METRICS UPDATES END ############## # noqa E266 - - ############## UTILS START ############## # noqa E266 + """ + Utilities and properties + """ @property - def callback_metrics(self) -> Dict: + def callback_metrics(self) -> Dict[str, float]: if self.trainer.result_collections: metrics = self.trainer.result_collections.metrics[DefaultMetricsKeys.CALLBACK] self._callback_metrics.update(metrics) @@ -300,31 +306,29 @@ def callback_metrics(self) -> Dict: return self._callback_metrics @property - def logged_metrics(self) -> Dict: + def logged_metrics(self) -> Dict[str, float]: if self.trainer.result_collections: metrics = self.trainer.result_collections.metrics[DefaultMetricsKeys.LOG] self._logged_metrics.update(metrics) return self._logged_metrics @property - def progress_bar_metrics(self) -> Dict: + def progress_bar_metrics(self) -> Dict[str, float]: if self.trainer.result_collections: metrics = self.trainer.result_collections.metrics[DefaultMetricsKeys.PBAR] self._progress_bar_metrics.update(metrics) return self._progress_bar_metrics - def add_progress_bar_metrics(self, metrics): + def add_progress_bar_metrics(self, metrics: Dict[str, float]) -> None: self._progress_bar_metrics.update(metrics) self.trainer.dev_debugger.track_pbar_metrics_history(metrics) - def add_logged_metrics(self, metrics): + def add_logged_metrics(self, metrics: Dict[str, float]) -> None: self._logged_metrics.update(metrics) self.trainer.dev_debugger.track_logged_metrics_history(metrics) - def add_callback_metrics(self, metrics): + def add_callback_metrics(self, metrics: Dict[str, float]) -> None: self._callback_metrics.update(metrics) def check_logging(self, fx_name: str, on_step: bool, on_epoch: bool) -> None: self._fx_validator.check_logging(fx_name=fx_name, on_step=on_step, on_epoch=on_epoch) - - ############## UTILS END ############## # noqa E266 diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f3bfd827deed0..55468fe7b609e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -545,7 +545,7 @@ def run_training_epoch(self): # TODO(@carmocca): deprecate and rename so users don't get confused self.global_step -= 1 # log epoch metrics - self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output) + self.trainer.logger_connector.update_train_epoch_metrics() self.global_step += 1 self.update_lr_schedulers('epoch') diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index ac66ebbec3587..73d31b2a2e54d 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -341,7 +341,7 @@ def mock_save_function(filepath, *args): for i, loss in enumerate(losses): trainer.train_loop.current_epoch = i trainer.train_loop.global_step = i - trainer.logger_connector._callback_metrics = {"checkpoint_on": torch.tensor(loss)} + trainer.logger_connector.add_callback_metrics({"checkpoint_on": loss}) checkpoint_callback.on_validation_end(trainer, trainer.lightning_module) file_lists = set(os.listdir(tmpdir)) From 920d7928f03e261d2b0c3a50a49782e61a57b3a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 28 May 2021 01:01:29 +0200 Subject: [PATCH 167/455] fix predictions access --- pytorch_lightning/loops/evaluation_dataloader_loop.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/evaluation_dataloader_loop.py b/pytorch_lightning/loops/evaluation_dataloader_loop.py index 4296a1621c98b..5a83e1a9ef4d8 100644 --- a/pytorch_lightning/loops/evaluation_dataloader_loop.py +++ b/pytorch_lightning/loops/evaluation_dataloader_loop.py @@ -29,6 +29,11 @@ def current_dataloader_idx(self) -> int: def num_dataloaders(self): return self._get_num_dataloaders(self._dataloaders) + @property + def predictions(self): + # TODO: fixme + return self.evaluation_loop.predictions + def connect(self, trainer, *args, **kwargs): super().connect(trainer, *args, **kwargs) self.evaluation_loop.connect(trainer, *args, **kwargs) @@ -50,7 +55,6 @@ def reset(self): self._max_batches = self._max_batches - def advance(self, *args: Any, **kwargs: Any) -> None: dataloader = self._dataloaders[self.current_dataloader_idx] dataloader = self.trainer.accelerator.process_dataloader(dataloader) From 395050a105ac27a584b96c336b2a4bdb79a13e65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 28 May 2021 01:07:08 +0200 Subject: [PATCH 168/455] integrate #7736 --- pytorch_lightning/loops/batch_loop.py | 8 ++++---- pytorch_lightning/loops/epoch_loop.py | 1 - pytorch_lightning/loops/evaluation_dataloader_loop.py | 2 +- pytorch_lightning/loops/evaluation_loop.py | 2 +- pytorch_lightning/loops/training_loop.py | 2 +- 5 files changed, 7 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 723e3f01df999..9e414e8cd862c 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -9,9 +9,9 @@ from torch.optim import Optimizer from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.core.step_result import Result from pytorch_lightning.loops.base import Loop from pytorch_lightning.plugins import ParallelPlugin +from pytorch_lightning.trainer.connectors.logger_connector.result import Result from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -179,9 +179,9 @@ def training_step_and_backward_closure( return_result: AttributeDict, ) -> Optional[torch.Tensor]: - step_result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) - if step_result is not None: - return_result.update(step_result) + result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) + if result is not None: + return_result.update(result) return return_result.loss def make_closure(self, *closure_args, **closure_kwargs: Any) -> Callable: diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index ab2f694c2d330..01f789f36cebb 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -8,7 +8,6 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.early_stopping import EarlyStopping -from pytorch_lightning.core.step_result import Result from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.training_loop import TrainingLoop from pytorch_lightning.trainer.supporters import TensorRunningAccum diff --git a/pytorch_lightning/loops/evaluation_dataloader_loop.py b/pytorch_lightning/loops/evaluation_dataloader_loop.py index 5a83e1a9ef4d8..bdebb704a5a75 100644 --- a/pytorch_lightning/loops/evaluation_dataloader_loop.py +++ b/pytorch_lightning/loops/evaluation_dataloader_loop.py @@ -3,9 +3,9 @@ from torch.utils.data.dataloader import DataLoader -from pytorch_lightning.core.step_result import Result from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.evaluation_loop import EvaluationLoop +from pytorch_lightning.trainer.connectors.logger_connector.result import Result from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.model_helpers import is_overridden diff --git a/pytorch_lightning/loops/evaluation_loop.py b/pytorch_lightning/loops/evaluation_loop.py index 248c79b800c3d..5039cb2b9c7f6 100644 --- a/pytorch_lightning/loops/evaluation_loop.py +++ b/pytorch_lightning/loops/evaluation_loop.py @@ -1,8 +1,8 @@ from collections import OrderedDict from typing import Any, Optional, Dict, Union, Iterator -from pytorch_lightning.core.step_result import Result from pytorch_lightning.loops.base import Loop +from pytorch_lightning.trainer.connectors.logger_connector.result import Result from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.types import STEP_OUTPUT diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index f1a3d51913832..67444f3bef999 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -1,9 +1,9 @@ from typing import Dict, List, Union import pytorch_lightning as pl -from pytorch_lightning.core.step_result import Result from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.batch_loop import BatchLoop +from pytorch_lightning.trainer.connectors.logger_connector.result import Result from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature From 1fbcb3a96a8274b473a41725386bfe6294482ec8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 28 May 2021 10:17:20 +0200 Subject: [PATCH 169/455] fix eval loop --- .../loops/evaluation_dataloader_loop.py | 6 +++--- pytorch_lightning/loops/evaluation_loop.py | 16 +++++++++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/loops/evaluation_dataloader_loop.py b/pytorch_lightning/loops/evaluation_dataloader_loop.py index bdebb704a5a75..84d76c1348985 100644 --- a/pytorch_lightning/loops/evaluation_dataloader_loop.py +++ b/pytorch_lightning/loops/evaluation_dataloader_loop.py @@ -92,9 +92,9 @@ def on_run_end(self) -> Any: return eval_loop_results - - -# HELPERS +# ------------------------------------------------------------------------------------------------------------ +# HELPER --- TO BE CLEANED UP +# ------------------------------------------------------------------------------------------------------------ def get_evaluation_dataloaders(self) -> Tuple[Optional[List[DataLoader]], List[Union[int, float]]]: model = self.trainer.lightning_module diff --git a/pytorch_lightning/loops/evaluation_loop.py b/pytorch_lightning/loops/evaluation_loop.py index 5039cb2b9c7f6..455d632a915b2 100644 --- a/pytorch_lightning/loops/evaluation_loop.py +++ b/pytorch_lightning/loops/evaluation_loop.py @@ -11,12 +11,12 @@ class EvaluationLoop(Loop): def __init__(self): super().__init__() + self.batch_idx: Optional[int] = None self.predictions: Optional[PredictionCollection] = None self.dataloader: Optional[Iterator] = None self.dl_max_batches: Optional[int] = None self.dataloader_idx: Optional[int] = None self.num_dataloaders: Optional[int] = None - self.batch_idx: Optional[int] = None self.outputs = [] def connect(self, trainer, *args, **kwargs): @@ -28,25 +28,31 @@ def done(self) -> bool: def reset(self) -> None: self.iteration_count = 0 + self.batch_idx = None self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) self.dl_max_batches = None self.dataloader_idx = None self.num_dataloaders = None self.outputs = [] - def on_run_start(self, dataloader, dl_max_batches, dataloader_idx, num_dataloaders) -> None: + def on_run_start(self, dataloader, dataloader_idx, dl_max_batches, num_dataloaders) -> None: self.dl_max_batches = dl_max_batches self.dataloader_idx = dataloader_idx self.num_dataloaders = num_dataloaders self.dataloader = enumerate(dataloader) - def advance(self, dataloader, dl_max_batches, dataloader_idx, num_dataloaders) -> None: + def advance(self, dataloader, dataloader_idx, dl_max_batches, num_dataloaders) -> None: batch_idx, batch = next(self.dataloader) + + # TODO: is self.batch_idx needed or can it be set to iteration_count? self.batch_idx = batch_idx if batch is None: raise StopIteration + if batch_idx >= dl_max_batches: + raise StopIteration + # hook self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) @@ -67,6 +73,10 @@ def advance(self, dataloader, dl_max_batches, dataloader_idx, num_dataloaders) - def on_run_end(self) -> Any: return self.outputs +# ------------------------------------------------------------------------------------------------------------ +# HELPER --- TO BE CLEANED UP +# ------------------------------------------------------------------------------------------------------------ + def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]: # configure step_kwargs step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) From 89408708d9786988b2c5c873a40436d4e33e5d17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 28 May 2021 10:22:21 +0200 Subject: [PATCH 170/455] adjust test --- tests/trainer/test_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 865bd07bb43d3..3272193f4aa35 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1069,7 +1069,7 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches): assert trainer.num_sanity_val_steps == num_sanity_val_steps with patch.object( - trainer.evaluation_loop, "evaluation_step", wraps=trainer.evaluation_loop.evaluation_step + trainer.evaluation_loop.evaluation_loop, "evaluation_step", wraps=trainer.evaluation_loop.evaluation_loop.evaluation_step ) as mocked: val_dataloaders = model.val_dataloader__multiple_mixed_length() trainer.fit(model, val_dataloaders=val_dataloaders) From 218c7842ba935dff8b7c04a432acb7da325b9699 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 May 2021 08:23:12 +0000 Subject: [PATCH 171/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/evaluation_dataloader_loop.py | 9 ++++++--- pytorch_lightning/loops/evaluation_loop.py | 7 +++---- tests/trainer/test_trainer.py | 4 +++- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/loops/evaluation_dataloader_loop.py b/pytorch_lightning/loops/evaluation_dataloader_loop.py index 84d76c1348985..bd42a655f44ec 100644 --- a/pytorch_lightning/loops/evaluation_dataloader_loop.py +++ b/pytorch_lightning/loops/evaluation_dataloader_loop.py @@ -1,5 +1,5 @@ -from typing import Any, Optional, Sequence, Union, Tuple, List, Dict from collections import OrderedDict +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from torch.utils.data.dataloader import DataLoader @@ -9,7 +9,7 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.types import STEP_OUTPUT, EPOCH_OUTPUT +from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT class EvaluationDataLoaderLoop(Loop): @@ -60,7 +60,9 @@ def advance(self, *args: Any, **kwargs: Any) -> None: dataloader = self.trainer.accelerator.process_dataloader(dataloader) dl_max_batches = self._max_batches[self.current_dataloader_idx] - dl_outputs = self.evaluation_loop.run(dataloader, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders) + dl_outputs = self.evaluation_loop.run( + dataloader, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders + ) # store batch level output per dataloader if self.should_track_batch_outputs_for_epoch_end: @@ -92,6 +94,7 @@ def on_run_end(self) -> Any: return eval_loop_results + # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ diff --git a/pytorch_lightning/loops/evaluation_loop.py b/pytorch_lightning/loops/evaluation_loop.py index 455d632a915b2..2b8038de971b5 100644 --- a/pytorch_lightning/loops/evaluation_loop.py +++ b/pytorch_lightning/loops/evaluation_loop.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from typing import Any, Optional, Dict, Union, Iterator +from typing import Any, Dict, Iterator, Optional, Union from pytorch_lightning.loops.base import Loop from pytorch_lightning.trainer.connectors.logger_connector.result import Result @@ -73,6 +73,7 @@ def advance(self, dataloader, dataloader_idx, dl_max_batches, num_dataloaders) - def on_run_end(self) -> Any: return self.outputs + # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ @@ -145,9 +146,7 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict # make dataloader_idx arg in validation_step optional step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)]) - multiple_val_loaders = ( - not self.trainer.testing and self.num_dataloaders > 1 - ) + multiple_val_loaders = (not self.trainer.testing and self.num_dataloaders > 1) multiple_test_loaders = (self.trainer.testing and self.num_dataloaders > 1) if multiple_test_loaders or multiple_val_loaders: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 3272193f4aa35..18c632e63b55c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1069,7 +1069,9 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches): assert trainer.num_sanity_val_steps == num_sanity_val_steps with patch.object( - trainer.evaluation_loop.evaluation_loop, "evaluation_step", wraps=trainer.evaluation_loop.evaluation_loop.evaluation_step + trainer.evaluation_loop.evaluation_loop, + "evaluation_step", + wraps=trainer.evaluation_loop.evaluation_loop.evaluation_step ) as mocked: val_dataloaders = model.val_dataloader__multiple_mixed_length() trainer.fit(model, val_dataloaders=val_dataloaders) From 63da8913ea912d35fc2b5d35de4982bce7a07854 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 28 May 2021 10:28:02 +0200 Subject: [PATCH 172/455] fixed mock path --- tests/trainer/loops/test_evaluation_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py index 278ed8619d0de..2b92094de8320 100644 --- a/tests/trainer/loops/test_evaluation_loop.py +++ b/tests/trainer/loops/test_evaluation_loop.py @@ -17,7 +17,7 @@ from tests.helpers.boring_model import BoringModel -@mock.patch("pytorch_lightning.trainer.evaluation_loop.EvaluationLoop.on_evaluation_epoch_end") +@mock.patch("pytorch_lightning.loops.evaluation_dataloader_loop.EvaluationDataLoaderLoop.on_evaluation_epoch_end") def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir): """ Tests that `on_evaluation_epoch_end` is called From a2ff4975c60246ccfa5f54f749a6114eb1ac5391 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 28 May 2021 10:31:39 +0200 Subject: [PATCH 173/455] fix new evaluation loop references --- tests/trainer/test_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 18c632e63b55c..a7ea0875b1bcc 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1099,7 +1099,7 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): assert trainer.num_sanity_val_steps == float("inf") with patch.object( - trainer.evaluation_loop, "evaluation_step", wraps=trainer.evaluation_loop.evaluation_step + trainer.evaluation_loop.evaluation_loop, "evaluation_step", wraps=trainer.evaluation_loop.evaluation_loop.evaluation_step ) as mocked: val_dataloaders = model.val_dataloader__multiple() trainer.fit(model, val_dataloaders=val_dataloaders) From d787e9e54c985118885509bb38367950bdf94388 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 28 May 2021 10:38:09 +0200 Subject: [PATCH 174/455] fix pickling error --- pytorch_lightning/loops/evaluation_loop.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/loops/evaluation_loop.py b/pytorch_lightning/loops/evaluation_loop.py index 2b8038de971b5..75e5823ff2c87 100644 --- a/pytorch_lightning/loops/evaluation_loop.py +++ b/pytorch_lightning/loops/evaluation_loop.py @@ -73,6 +73,11 @@ def advance(self, dataloader, dataloader_idx, dl_max_batches, num_dataloaders) - def on_run_end(self) -> Any: return self.outputs + def __getstate__(self): + # avoid pickling errors "cannot pickle generator object" + self.dataloader = None + return self.__dict__ + # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP From b95d1a9043189c595031da2d92c60a133c30b732 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 28 May 2021 10:38:46 +0200 Subject: [PATCH 175/455] free memory before passing outputs to epoch end --- pytorch_lightning/loops/evaluation_dataloader_loop.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_lightning/loops/evaluation_dataloader_loop.py b/pytorch_lightning/loops/evaluation_dataloader_loop.py index bd42a655f44ec..2994e54a44290 100644 --- a/pytorch_lightning/loops/evaluation_dataloader_loop.py +++ b/pytorch_lightning/loops/evaluation_dataloader_loop.py @@ -76,6 +76,9 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None: def on_run_end(self) -> Any: outputs = self.outputs + # free memory + self.outputs = [] + # with a single dataloader don't pass a 2D list if len(outputs) > 0 and self.num_dataloaders == 1: outputs = outputs[0] From 87bd86501275d19659d3df3b2325566e2f14cc96 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 May 2021 08:57:33 +0000 Subject: [PATCH 176/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/trainer/test_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a7ea0875b1bcc..a1aae6cabdae6 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1099,7 +1099,9 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): assert trainer.num_sanity_val_steps == float("inf") with patch.object( - trainer.evaluation_loop.evaluation_loop, "evaluation_step", wraps=trainer.evaluation_loop.evaluation_loop.evaluation_step + trainer.evaluation_loop.evaluation_loop, + "evaluation_step", + wraps=trainer.evaluation_loop.evaluation_loop.evaluation_step ) as mocked: val_dataloaders = model.val_dataloader__multiple() trainer.fit(model, val_dataloaders=val_dataloaders) From 0fe002b3eadd94766b79534f1bcc051d2d808629 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 28 May 2021 11:09:12 +0200 Subject: [PATCH 177/455] remove unnecessary batch_idx --- pytorch_lightning/loops/evaluation_loop.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/pytorch_lightning/loops/evaluation_loop.py b/pytorch_lightning/loops/evaluation_loop.py index 75e5823ff2c87..a9993694b8f12 100644 --- a/pytorch_lightning/loops/evaluation_loop.py +++ b/pytorch_lightning/loops/evaluation_loop.py @@ -11,7 +11,6 @@ class EvaluationLoop(Loop): def __init__(self): super().__init__() - self.batch_idx: Optional[int] = None self.predictions: Optional[PredictionCollection] = None self.dataloader: Optional[Iterator] = None self.dl_max_batches: Optional[int] = None @@ -24,11 +23,10 @@ def connect(self, trainer, *args, **kwargs): @property def done(self) -> bool: - return self.batch_idx is not None and self.batch_idx >= self.dl_max_batches + return self.iteration_count >= self.dl_max_batches def reset(self) -> None: self.iteration_count = 0 - self.batch_idx = None self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) self.dl_max_batches = None self.dataloader_idx = None @@ -44,15 +42,9 @@ def on_run_start(self, dataloader, dataloader_idx, dl_max_batches, num_dataloade def advance(self, dataloader, dataloader_idx, dl_max_batches, num_dataloaders) -> None: batch_idx, batch = next(self.dataloader) - # TODO: is self.batch_idx needed or can it be set to iteration_count? - self.batch_idx = batch_idx - if batch is None: raise StopIteration - if batch_idx >= dl_max_batches: - raise StopIteration - # hook self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) From b12e12c7fc0ce01c5cae8577126773570990a831 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 28 May 2021 12:21:00 +0200 Subject: [PATCH 178/455] base class for dataloader loop --- pytorch_lightning/loops/base.py | 2 +- .../loops/dataloader/__init__.py | 0 .../loops/dataloader/dataloader_loop.py | 37 +++++++++++++++++++ .../loops/evaluation_dataloader_loop.py | 28 +++++++------- 4 files changed, 51 insertions(+), 16 deletions(-) create mode 100644 pytorch_lightning/loops/dataloader/__init__.py create mode 100644 pytorch_lightning/loops/dataloader/dataloader_loop.py diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 60ebf5b31caa2..d4b00ea4047c9 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -12,7 +12,7 @@ def __init__(self): self.trainer: Optional['pl.Trainer'] = None @abstractmethod - def connect(self, trainer, *args, **kwargs): + def connect(self, trainer, *args, **kwargs) -> None: """Connects Loop with all the necessary things like connectors and accelerators""" self.trainer = proxy(trainer) diff --git a/pytorch_lightning/loops/dataloader/__init__.py b/pytorch_lightning/loops/dataloader/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pytorch_lightning/loops/dataloader/dataloader_loop.py b/pytorch_lightning/loops/dataloader/dataloader_loop.py new file mode 100644 index 0000000000000..c28f1f6a9de49 --- /dev/null +++ b/pytorch_lightning/loops/dataloader/dataloader_loop.py @@ -0,0 +1,37 @@ +from abc import abstractmethod +from typing import Sequence + +from torch.utils.data import DataLoader + +from pytorch_lightning.loops.base import Loop + + +# TODO: Handle max_batches also in base class here +class DataLoaderLoop(Loop): + + def __init__(self): + super().__init__() + + @property + @abstractmethod + def dataloaders(self) -> Sequence[DataLoader]: + pass + + @property + def current_dataloader_idx(self) -> int: + return self.iteration_count + + @property + def current_dataloader(self): + return self.dataloaders[self.current_dataloader_idx] + + @property + def num_dataloaders(self) -> int: + return len(self.dataloaders) + + @property + def done(self) -> bool: + return self.current_dataloader_idx >= self.num_dataloaders + + def reset(self) -> None: + self.iteration_count = 0 diff --git a/pytorch_lightning/loops/evaluation_dataloader_loop.py b/pytorch_lightning/loops/evaluation_dataloader_loop.py index 2994e54a44290..afddfecfcf74e 100644 --- a/pytorch_lightning/loops/evaluation_dataloader_loop.py +++ b/pytorch_lightning/loops/evaluation_dataloader_loop.py @@ -1,18 +1,16 @@ -from collections import OrderedDict -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union from torch.utils.data.dataloader import DataLoader -from pytorch_lightning.loops.base import Loop +from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop from pytorch_lightning.loops.evaluation_loop import EvaluationLoop from pytorch_lightning.trainer.connectors.logger_connector.result import Result from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT -class EvaluationDataLoaderLoop(Loop): +class EvaluationDataLoaderLoop(DataLoaderLoop): def __init__(self): super().__init__() @@ -22,27 +20,27 @@ def __init__(self): self.evaluation_loop = EvaluationLoop() @property - def current_dataloader_idx(self) -> int: - return self.iteration_count + def num_dataloaders(self) -> int: + return self._get_num_dataloaders(self.dataloaders) @property - def num_dataloaders(self): - return self._get_num_dataloaders(self._dataloaders) + def dataloaders(self) -> Sequence[DataLoader]: + return self._dataloaders @property def predictions(self): # TODO: fixme return self.evaluation_loop.predictions - def connect(self, trainer, *args, **kwargs): + def connect(self, trainer, *args, **kwargs) -> None: super().connect(trainer, *args, **kwargs) self.evaluation_loop.connect(trainer, *args, **kwargs) @property - def done(self): - return (self.current_dataloader_idx >= len(self._dataloaders)) or self.should_skip_evaluation(self._max_batches) + def done(self) -> bool: + return (self.current_dataloader_idx >= len(self.dataloaders)) or self.should_skip_evaluation(self._max_batches) - def reset(self): + def reset(self) -> None: self.iteration_count = 0 # prepare dataloaders @@ -56,8 +54,7 @@ def reset(self): self._max_batches = self._max_batches def advance(self, *args: Any, **kwargs: Any) -> None: - dataloader = self._dataloaders[self.current_dataloader_idx] - dataloader = self.trainer.accelerator.process_dataloader(dataloader) + dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) dl_max_batches = self._max_batches[self.current_dataloader_idx] dl_outputs = self.evaluation_loop.run( @@ -125,6 +122,7 @@ def get_evaluation_dataloaders(self) -> Tuple[Optional[List[DataLoader]], List[U dataloaders = self.trainer.val_dataloaders return dataloaders, max_batches + # TODO: this is currently also used in the new and old TrainingLoop def should_skip_evaluation(self, max_batches: List[Union[int, float]]) -> bool: return sum(max_batches) == 0 From e0605b7cf41f6cb929e702adaa3e8c043518b2a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 28 May 2021 12:29:59 +0200 Subject: [PATCH 179/455] move dataloader loop --- .../loops/{ => dataloader}/evaluation_dataloader_loop.py | 0 pytorch_lightning/trainer/trainer.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename pytorch_lightning/loops/{ => dataloader}/evaluation_dataloader_loop.py (100%) diff --git a/pytorch_lightning/loops/evaluation_dataloader_loop.py b/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py similarity index 100% rename from pytorch_lightning/loops/evaluation_dataloader_loop.py rename to pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7f7106be5d1ed..efc3bce13227f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -83,7 +83,7 @@ NEW_LOOP = True if NEW_LOOP: - from pytorch_lightning.loops.evaluation_dataloader_loop import EvaluationDataLoaderLoop + from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationDataLoaderLoop else: from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop From a592ccd3d69441dbd8516aa2a09da9f9c0f34845 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 28 May 2021 14:04:40 +0200 Subject: [PATCH 180/455] remove abstract requirement for connect() method --- pytorch_lightning/loops/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index d4b00ea4047c9..9ef2903dd5761 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -11,7 +11,6 @@ def __init__(self): self.iteration_count: int = 0 self.trainer: Optional['pl.Trainer'] = None - @abstractmethod def connect(self, trainer, *args, **kwargs) -> None: """Connects Loop with all the necessary things like connectors and accelerators""" self.trainer = proxy(trainer) From cc207ee56454932da473e15ffc995a10ac548a78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 28 May 2021 14:04:52 +0200 Subject: [PATCH 181/455] remove redundant assignment --- .../loops/dataloader/evaluation_dataloader_loop.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py b/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py index afddfecfcf74e..d4b89e4188f55 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py @@ -51,8 +51,6 @@ def reset(self) -> None: if isinstance(self._max_batches, int): self._max_batches = [self._max_batches] * len(self._dataloaders) - self._max_batches = self._max_batches - def advance(self, *args: Any, **kwargs: Any) -> None: dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) dl_max_batches = self._max_batches[self.current_dataloader_idx] From f7f8201916a8a65b2421c8dbbfa4890892efb6f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 28 May 2021 14:06:18 +0200 Subject: [PATCH 182/455] implemented prediction loop --- .../dataloader/prediction_dataloader_loop.py | 143 ++++++++++++++++++ pytorch_lightning/loops/prediction_loop.py | 93 ++++++++++++ pytorch_lightning/trainer/trainer.py | 19 ++- 3 files changed, 251 insertions(+), 4 deletions(-) create mode 100644 pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py create mode 100644 pytorch_lightning/loops/prediction_loop.py diff --git a/pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py b/pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py new file mode 100644 index 0000000000000..a30d2b93142f1 --- /dev/null +++ b/pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py @@ -0,0 +1,143 @@ +from typing import Any, Sequence, List, Optional + +import torch +from torch.utils.data import DataLoader + +from pytorch_lightning.plugins import DDPSpawnPlugin +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import _PREDICT_OUTPUT +from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop +from pytorch_lightning.loops.prediction_loop import PredictionLoop + + +class PredictionDataLoaderLoop(DataLoaderLoop): + + def __init__(self): + super().__init__() + self.prediction_loop = PredictionLoop() + self._return_predictions = False + self.predictions = None + self.epoch_batch_indices = None + self._dataloaders = None + self._max_batches = None + + @property + def return_predictions(self) -> bool: + return self._return_predictions + + @return_predictions.setter + def return_predictions(self, return_predictions: Optional[bool] = None) -> None: + # ``DDPSpawnPlugin`` plugins and derivate don't support return predictions. + is_ddp_spawn = isinstance(self.trainer.training_type_plugin, DDPSpawnPlugin) + if return_predictions and is_ddp_spawn: + raise MisconfigurationException( + "`return_predictions` should be set to `False` when using the `DDPSpawnPlugin` or children class. " + f"Found {return_predictions} with training_type_plugin {type(self.trainer.training_type_plugin)}." + ) + # For non ``DDPSpawnPlugin`` plugin, the `return_predictions` is True by default unless user decide otherwise. + self._return_predictions = not is_ddp_spawn if return_predictions is None else return_predictions + + @property + def num_dataloaders(self) -> int: + return self._get_num_dataloaders(self.dataloaders) + + @property + def dataloaders(self) -> Sequence[DataLoader]: + return self._dataloaders + + @property + def done(self) -> bool: + return (self.current_dataloader_idx >= len(self.dataloaders)) or self.should_skip_predict(self._max_batches) + + def connect(self, trainer, *args, **kwargs) -> None: + super().connect(trainer, *args, **kwargs) + self.prediction_loop.connect(trainer, *args, **kwargs) + + def reset(self) -> None: + super().reset() + self._dataloaders, self._max_batches = self.get_predict_dataloaders() + + # convert max_batches to list + if isinstance(self._max_batches, int): + self._max_batches = [self._max_batches] * len(self.dataloaders) + + self.predictions = [] + self.epoch_batch_indices = [] + + def on_run_start(self) -> None: + self.on_predict_start() + + def advance(self, *args, **kwargs) -> None: + dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) + dataloader_iter = enumerate(dataloader) + dl_max_batches = self._max_batches[self.current_dataloader_idx] + + dl_predictions, dl_batch_indices = self.prediction_loop.run( + dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders, self.return_predictions + ) + self.predictions.append(dl_predictions) + self.epoch_batch_indices.append(dl_batch_indices) + + def on_run_end(self): + results = self.on_predict_epoch_end() + self.on_predict_end() + return results + +# ------------------------------------------------------------------------------------------------------------ +# HELPER --- TO BE CLEANED UP +# ------------------------------------------------------------------------------------------------------------ + + def get_predict_dataloaders(self): + self.trainer.reset_predict_dataloader(self.trainer.lightning_module) + + dataloaders = self.trainer.predict_dataloaders + max_batches = self.trainer.num_predict_batches + + return dataloaders, max_batches + + def should_skip_predict(self, max_batches): + return sum(max_batches) == 0 + + def on_predict_start(self) -> None: + # enable eval mode + no grads + self.on_predict_model_eval() + self.trainer.lightning_module.zero_grad() + self._previous_grad_status = torch.is_grad_enabled() + torch.set_grad_enabled(False) + + # hook + self.trainer.call_hook("on_predict_start") + self.trainer.call_hook("on_predict_epoch_start") + + def on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: + self.trainer.profiler.describe() + + results = self.predictions + + self.trainer.call_hook("on_predict_epoch_end", results) + + if self.return_predictions: + return results[0] if self.num_dataloaders == 1 else results + + def on_predict_end(self): + # clear memory. the predictions are extracted in `on_predict_epoch_end`. + self.predictions = [] + self.epoch_batch_indices = [] + + # reset grad to its previous status. + torch.set_grad_enabled(self._previous_grad_status) + + # hook + self.trainer.call_hook("on_predict_end") + + def on_predict_model_eval(self): + model_ref = self.trainer.lightning_module + model_ref.on_predict_model_eval() + + def _get_num_dataloaders(self, dataloaders: List[DataLoader]) -> int: + # case where user does: + # return dl1, dl2 + length = len(dataloaders) + if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): + length = len(dataloaders[0]) + return length \ No newline at end of file diff --git a/pytorch_lightning/loops/prediction_loop.py b/pytorch_lightning/loops/prediction_loop.py new file mode 100644 index 0000000000000..ded02a2aa083c --- /dev/null +++ b/pytorch_lightning/loops/prediction_loop.py @@ -0,0 +1,93 @@ +from collections import OrderedDict +from typing import Any, List + +from pytorch_lightning.loops.base import Loop +from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper +from pytorch_lightning.utilities.warnings import WarningCache + + +class PredictionLoop(Loop): + + def __init__(self): + super().__init__() + self.warning_cache = WarningCache() + self.dl_max_batches = None + self.num_dataloaders = None + self.return_predictions = False + self.predictions: List[Any] = [] + self.batch_indices: [List[int]] = [] + + @property + def should_store_predictions(self) -> bool: + any_pred = any(cb.interval.on_epoch for cb in self.trainer.prediction_writer_callbacks) + return self.return_predictions or any_pred + + @property + def done(self) -> bool: + return self.iteration_count >= self.dl_max_batches + + def reset(self) -> None: + self.batch_indices: List[int] = [] + self.predictions: List[Any] = [] + + def on_run_start(self, dataloader, dataloader_idx, dl_max_batches, num_dataloaders, return_predictions=False) -> None: + self.dl_max_batches = dl_max_batches + self.num_dataloaders = num_dataloaders + self.return_predictions = return_predictions + + def advance(self, dataloader_iter, dataloader_idx, dl_max_batches, *args, **kwargs) -> None: + batch_idx, batch = next(dataloader_iter) + if batch is None: + raise StopIteration + + # TODO: needed? + # stop short when running on limited batches + if batch_idx >= dl_max_batches: + raise StopIteration + + # lightning module methods + with self.trainer.profiler.profile("predict_step"): + self.predict_step(batch, batch_idx, dataloader_idx) + + def on_run_end(self) -> Any: + return self.predictions, self.batch_indices + +# ------------------------------------------------------------------------------------------------------------ +# HELPER --- TO BE CLEANED UP +# ------------------------------------------------------------------------------------------------------------ + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + # configure step_kwargs + step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) + + # extract batch_indices and store them + self._store_batch_indices(dataloader_idx) + + model_ref = self.trainer.lightning_module + + self.trainer.call_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx) + + model_ref._current_fx_name = "predict_step" + predictions = self.trainer.accelerator.predict_step(step_kwargs) + + if predictions is None: + self.warning_cache.warn("predict returned None if it was on purpose, ignore this warning...") + + self.trainer.call_hook("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx) + + if self.should_store_predictions: + self.predictions.append(predictions) + + def _build_kwargs(self, batch, batch_idx, dataloader_idx): + step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)]) + if self.num_dataloaders: + step_kwargs['dataloader_idx'] = dataloader_idx + return step_kwargs + + def _store_batch_indices(self, dataloader_idx: int) -> None: + batch_sampler = self.trainer.predict_dataloaders[dataloader_idx].batch_sampler + if isinstance(batch_sampler, IndexBatchSamplerWrapper): + self.batch_indices = batch_sampler.batch_indices + if self.should_store_predictions: + self.batch_indices.append(batch_sampler.batch_indices) + diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index efc3bce13227f..db17520c3e239 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -29,6 +29,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.loops.dataloader.prediction_dataloader_loop import PredictionDataLoaderLoop from pytorch_lightning.loops.epoch_loop import EpochLoop from pytorch_lightning.plugins import Plugin from pytorch_lightning.plugins.environments import ClusterEnvironment @@ -84,6 +85,7 @@ if NEW_LOOP: from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationDataLoaderLoop + from pytorch_lightning.loops.dataloader.prediction_dataloader_loop import PredictionDataLoaderLoop else: from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop @@ -345,14 +347,15 @@ def __init__( if NEW_LOOP: self.train_loop = EpochLoop(min_epochs, max_epochs, min_steps, max_steps) self.evaluation_loop = EvaluationDataLoaderLoop() + self.predict_loop = PredictionDataLoaderLoop() self.train_loop.connect(self) self.evaluation_loop.connect(self) + self.predict_loop.connect(self) else: # old loops: self.train_loop = TrainLoop(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps) self.evaluation_loop = EvaluationLoop(self) - - self.predict_loop = PredictLoop(self) + self.predict_loop = PredictLoop(self) # training state if weights_summary is not None and weights_summary not in ModelSummary.MODES: @@ -397,7 +400,6 @@ def __init__( terminate_on_nan, ) self._setup_on_init(num_sanity_val_steps) - self.predict_loop.on_trainer_init() # configure tuner self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size) @@ -454,6 +456,9 @@ def _setup_on_init( # when true, print evaluation results in .validate() and .test() self.verbose_evaluate = True + self.num_predict_batches = [] + self.predicted_ckpt_path = None + def _setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): # clean hparams if hasattr(model, "hparams"): @@ -1201,7 +1206,7 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT: return eval_loop_results - def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: + def _run_predict_old_loop(self) -> Optional[_PREDICT_OUTPUT]: # prepare dataloaders dataloaders, max_batches = self.predict_loop.get_predict_dataloaders() @@ -1239,6 +1244,12 @@ def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: return results + def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: + if NEW_LOOP: + return self.predict_loop.run() + else: + return self._run_predict_old_loop() + def _run_sanity_check(self, ref_model): using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model) should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 From da14ecd4be332f3697ea85bed3cdc5780690a087 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 28 May 2021 14:42:14 +0200 Subject: [PATCH 183/455] fix predict multiple dataloaders --- pytorch_lightning/loops/prediction_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/loops/prediction_loop.py b/pytorch_lightning/loops/prediction_loop.py index ded02a2aa083c..9ba20fed23dec 100644 --- a/pytorch_lightning/loops/prediction_loop.py +++ b/pytorch_lightning/loops/prediction_loop.py @@ -27,6 +27,7 @@ def done(self) -> bool: return self.iteration_count >= self.dl_max_batches def reset(self) -> None: + self.iteration_count = 0 self.batch_indices: List[int] = [] self.predictions: List[Any] = [] From 7660aaec21f76881eec6ab974901c825f3167b05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 28 May 2021 14:57:51 +0200 Subject: [PATCH 184/455] rename the batch indices attributes --- pytorch_lightning/callbacks/prediction_writer.py | 2 +- pytorch_lightning/loops/prediction_loop.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/prediction_writer.py b/pytorch_lightning/callbacks/prediction_writer.py index cbcff74ff0278..85a1408104ccd 100644 --- a/pytorch_lightning/callbacks/prediction_writer.py +++ b/pytorch_lightning/callbacks/prediction_writer.py @@ -109,7 +109,7 @@ def on_predict_batch_end( if not self.interval.on_batch: return is_distributed = trainer.accelerator_connector.is_distributed - batch_indices = trainer.predict_loop.batch_indices if is_distributed else None + batch_indices = trainer.predict_loop.prediction_loop.current_batch_indices if is_distributed else None self.write_on_batch_end(trainer, pl_module, outputs, batch_indices, batch, batch_idx, dataloader_idx) def on_predict_epoch_end( diff --git a/pytorch_lightning/loops/prediction_loop.py b/pytorch_lightning/loops/prediction_loop.py index 9ba20fed23dec..658e8c03866c3 100644 --- a/pytorch_lightning/loops/prediction_loop.py +++ b/pytorch_lightning/loops/prediction_loop.py @@ -15,7 +15,8 @@ def __init__(self): self.num_dataloaders = None self.return_predictions = False self.predictions: List[Any] = [] - self.batch_indices: [List[int]] = [] + self.current_batch_indices: [List[int]] = [] + self.all_batch_indices: [List[int]] = [] @property def should_store_predictions(self) -> bool: @@ -28,7 +29,7 @@ def done(self) -> bool: def reset(self) -> None: self.iteration_count = 0 - self.batch_indices: List[int] = [] + self.all_batch_indices: List[int] = [] self.predictions: List[Any] = [] def on_run_start(self, dataloader, dataloader_idx, dl_max_batches, num_dataloaders, return_predictions=False) -> None: @@ -51,7 +52,7 @@ def advance(self, dataloader_iter, dataloader_idx, dl_max_batches, *args, **kwar self.predict_step(batch, batch_idx, dataloader_idx) def on_run_end(self) -> Any: - return self.predictions, self.batch_indices + return self.predictions, self.all_batch_indices # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP @@ -88,7 +89,7 @@ def _build_kwargs(self, batch, batch_idx, dataloader_idx): def _store_batch_indices(self, dataloader_idx: int) -> None: batch_sampler = self.trainer.predict_dataloaders[dataloader_idx].batch_sampler if isinstance(batch_sampler, IndexBatchSamplerWrapper): - self.batch_indices = batch_sampler.batch_indices + self.current_batch_indices = batch_sampler.batch_indices if self.should_store_predictions: - self.batch_indices.append(batch_sampler.batch_indices) + self.all_batch_indices.append(batch_sampler.batch_indices) From 105f420f097702b3b72cdbc538cc265bc6483f4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 28 May 2021 14:58:57 +0200 Subject: [PATCH 185/455] remove redundant stop iteration --- pytorch_lightning/loops/prediction_loop.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pytorch_lightning/loops/prediction_loop.py b/pytorch_lightning/loops/prediction_loop.py index 658e8c03866c3..3a0a410d34d76 100644 --- a/pytorch_lightning/loops/prediction_loop.py +++ b/pytorch_lightning/loops/prediction_loop.py @@ -42,12 +42,6 @@ def advance(self, dataloader_iter, dataloader_idx, dl_max_batches, *args, **kwar if batch is None: raise StopIteration - # TODO: needed? - # stop short when running on limited batches - if batch_idx >= dl_max_batches: - raise StopIteration - - # lightning module methods with self.trainer.profiler.profile("predict_step"): self.predict_step(batch, batch_idx, dataloader_idx) From 6655c9ef0793e03058bd2d9d19ab1731ea1ab6e3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 May 2021 12:59:45 +0000 Subject: [PATCH 186/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../loops/dataloader/prediction_dataloader_loop.py | 9 +++++---- pytorch_lightning/loops/prediction_loop.py | 6 ++++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py b/pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py index a30d2b93142f1..42b086a7205c5 100644 --- a/pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py @@ -1,13 +1,13 @@ -from typing import Any, Sequence, List, Optional +from typing import Any, List, Optional, Sequence import torch from torch.utils.data import DataLoader +from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop +from pytorch_lightning.loops.prediction_loop import PredictionLoop from pytorch_lightning.plugins import DDPSpawnPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import _PREDICT_OUTPUT -from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop -from pytorch_lightning.loops.prediction_loop import PredictionLoop class PredictionDataLoaderLoop(DataLoaderLoop): @@ -83,6 +83,7 @@ def on_run_end(self): self.on_predict_end() return results + # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ @@ -140,4 +141,4 @@ def _get_num_dataloaders(self, dataloaders: List[DataLoader]) -> int: length = len(dataloaders) if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): length = len(dataloaders[0]) - return length \ No newline at end of file + return length diff --git a/pytorch_lightning/loops/prediction_loop.py b/pytorch_lightning/loops/prediction_loop.py index 3a0a410d34d76..5f158ecbaa871 100644 --- a/pytorch_lightning/loops/prediction_loop.py +++ b/pytorch_lightning/loops/prediction_loop.py @@ -32,7 +32,9 @@ def reset(self) -> None: self.all_batch_indices: List[int] = [] self.predictions: List[Any] = [] - def on_run_start(self, dataloader, dataloader_idx, dl_max_batches, num_dataloaders, return_predictions=False) -> None: + def on_run_start( + self, dataloader, dataloader_idx, dl_max_batches, num_dataloaders, return_predictions=False + ) -> None: self.dl_max_batches = dl_max_batches self.num_dataloaders = num_dataloaders self.return_predictions = return_predictions @@ -48,6 +50,7 @@ def advance(self, dataloader_iter, dataloader_idx, dl_max_batches, *args, **kwar def on_run_end(self) -> Any: return self.predictions, self.all_batch_indices + # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ @@ -86,4 +89,3 @@ def _store_batch_indices(self, dataloader_idx: int) -> None: self.current_batch_indices = batch_sampler.batch_indices if self.should_store_predictions: self.all_batch_indices.append(batch_sampler.batch_indices) - From 5441a3c246f8a7d1713e5c216ac073f30dedb373 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 28 May 2021 15:37:30 +0200 Subject: [PATCH 187/455] typing --- pytorch_lightning/loops/dataloader/dataloader_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/dataloader/dataloader_loop.py b/pytorch_lightning/loops/dataloader/dataloader_loop.py index c28f1f6a9de49..c6449cb6aaeeb 100644 --- a/pytorch_lightning/loops/dataloader/dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/dataloader_loop.py @@ -22,7 +22,7 @@ def current_dataloader_idx(self) -> int: return self.iteration_count @property - def current_dataloader(self): + def current_dataloader(self) -> DataLoader: return self.dataloaders[self.current_dataloader_idx] @property From ad95fb22ba4e390b61c0452b64bb01230c6d608d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 28 May 2021 15:49:21 +0200 Subject: [PATCH 188/455] fix pickle problem with eval and predict dataloader --- .../loops/dataloader/evaluation_dataloader_loop.py | 3 ++- pytorch_lightning/loops/evaluation_loop.py | 13 +++---------- pytorch_lightning/loops/prediction_loop.py | 2 +- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py b/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py index d4b89e4188f55..f7e4dfaca9544 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py @@ -53,10 +53,11 @@ def reset(self) -> None: def advance(self, *args: Any, **kwargs: Any) -> None: dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) + dataloader_iter = enumerate(dataloader) dl_max_batches = self._max_batches[self.current_dataloader_idx] dl_outputs = self.evaluation_loop.run( - dataloader, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders + dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders ) # store batch level output per dataloader diff --git a/pytorch_lightning/loops/evaluation_loop.py b/pytorch_lightning/loops/evaluation_loop.py index a9993694b8f12..3221e718dcd1b 100644 --- a/pytorch_lightning/loops/evaluation_loop.py +++ b/pytorch_lightning/loops/evaluation_loop.py @@ -33,14 +33,13 @@ def reset(self) -> None: self.num_dataloaders = None self.outputs = [] - def on_run_start(self, dataloader, dataloader_idx, dl_max_batches, num_dataloaders) -> None: + def on_run_start(self, dataloader_iter, dataloader_idx, dl_max_batches, num_dataloaders) -> None: self.dl_max_batches = dl_max_batches self.dataloader_idx = dataloader_idx self.num_dataloaders = num_dataloaders - self.dataloader = enumerate(dataloader) - def advance(self, dataloader, dataloader_idx, dl_max_batches, num_dataloaders) -> None: - batch_idx, batch = next(self.dataloader) + def advance(self, dataloader_iter, dataloader_idx, dl_max_batches, num_dataloaders) -> None: + batch_idx, batch = next(dataloader_iter) if batch is None: raise StopIteration @@ -65,12 +64,6 @@ def advance(self, dataloader, dataloader_idx, dl_max_batches, num_dataloaders) - def on_run_end(self) -> Any: return self.outputs - def __getstate__(self): - # avoid pickling errors "cannot pickle generator object" - self.dataloader = None - return self.__dict__ - - # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ diff --git a/pytorch_lightning/loops/prediction_loop.py b/pytorch_lightning/loops/prediction_loop.py index 5f158ecbaa871..fe51cfd3e92d5 100644 --- a/pytorch_lightning/loops/prediction_loop.py +++ b/pytorch_lightning/loops/prediction_loop.py @@ -33,7 +33,7 @@ def reset(self) -> None: self.predictions: List[Any] = [] def on_run_start( - self, dataloader, dataloader_idx, dl_max_batches, num_dataloaders, return_predictions=False + self, dataloader_iter, dataloader_idx, dl_max_batches, num_dataloaders, return_predictions=False ) -> None: self.dl_max_batches = dl_max_batches self.num_dataloaders = num_dataloaders From 881614487f0693ea0740b9eb5579017cce8c8c10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 28 May 2021 15:58:12 +0200 Subject: [PATCH 189/455] fix problem with train dataloader pickle if it is an attribute of loop --- pytorch_lightning/loops/epoch_loop.py | 4 +++- pytorch_lightning/loops/training_loop.py | 21 +++++++-------------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 01f789f36cebb..884e1f6af8811 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -151,10 +151,12 @@ def on_advance_start(self): # equal to old on_train_epoch_start self.trainer.call_hook("on_train_epoch_start") def advance(self): + train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) + train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) with self.trainer.profiler.profile("run_training_epoch"): # run train epoch - epoch_output = self.training_loop.run() + epoch_output = self.training_loop.run(train_dataloader) # log epoch metrics if epoch_output is None: diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 67444f3bef999..0c4ed21d75424 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Union +from typing import Dict, List, Union, Iterator import pytorch_lightning as pl from pytorch_lightning.loops.base import Loop @@ -30,7 +30,6 @@ def __init__(self, min_steps, max_steps): # the current split index when the batch gets split into chunks in truncated backprop through time self.split_idx = None - self._train_dataloader = None self._dataloader_idx = None self._should_stop = False @@ -62,8 +61,8 @@ def run(self, *args, **kwargs): # redesign the done conditions and use the base class run() implementation while True: try: - self.on_advance_start() - self.advance() + self.on_advance_start(*args, **kwargs) + self.advance(*args, **kwargs) self.on_advance_end() self.iteration_count = self.increment_iteration(self.iteration_count) except StopIteration: @@ -73,23 +72,17 @@ def run(self, *args, **kwargs): def reset(self) -> None: self.iteration_count = 0 - - # modify dataloader if needed (ddp, etc...) - train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) - - # reset - self._train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) - self._dataloader_idx = 0 - self._should_stop = False self.batches_seen = 0 self.is_last_batch = False + self._dataloader_idx = 0 + self._should_stop = False # track epoch output self.epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] - def advance(self): + def advance(self, dataloader_iter: Iterator, **kwargs): # TODO: profiling is gone - _, (batch, is_last) = next(self._train_dataloader) + _, (batch, is_last) = next(dataloader_iter) self.is_last_batch = is_last # ------------------------------------ From 02b660037dab15d9b90e347eb47a749fa064e392 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 May 2021 14:19:52 +0000 Subject: [PATCH 190/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/evaluation_loop.py | 1 + pytorch_lightning/loops/training_loop.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/evaluation_loop.py b/pytorch_lightning/loops/evaluation_loop.py index 3221e718dcd1b..e3883e337ddc1 100644 --- a/pytorch_lightning/loops/evaluation_loop.py +++ b/pytorch_lightning/loops/evaluation_loop.py @@ -64,6 +64,7 @@ def advance(self, dataloader_iter, dataloader_idx, dl_max_batches, num_dataloade def on_run_end(self) -> Any: return self.outputs + # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 0c4ed21d75424..c2fb3bdf8b233 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Union, Iterator +from typing import Dict, Iterator, List, Union import pytorch_lightning as pl from pytorch_lightning.loops.base import Loop From afc75951fb7f61e26dba559547d495c3c720ba59 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 29 May 2021 19:20:02 +0200 Subject: [PATCH 191/455] Rename result --- pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/core/{step_result.py => result.py} | 0 .../trainer/connectors/logger_connector/logger_connector.py | 2 +- pytorch_lightning/trainer/evaluation_loop.py | 2 +- pytorch_lightning/trainer/properties.py | 2 +- pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/trainer/training_loop.py | 2 +- tests/core/test_metric_result_integration.py | 2 +- tests/models/test_tpu.py | 2 +- tests/trainer/connectors/test_logger_connectors.py | 2 +- 10 files changed, 9 insertions(+), 9 deletions(-) rename pytorch_lightning/core/{step_result.py => result.py} (100%) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 98f1157fb3c3d..50bb6dce706e2 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -38,8 +38,8 @@ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.core.result import ResultCollection from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES -from pytorch_lightning.core.step_result import ResultCollection from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/result.py similarity index 100% rename from pytorch_lightning/core/step_result.py rename to pytorch_lightning/core/result.py diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 221083be7a63e..80af04a1af1e4 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -19,7 +19,7 @@ import torch from pytorch_lightning.core import memory -from pytorch_lightning.core.step_result import DefaultMetricsKeys +from pytorch_lightning.core.result import DefaultMetricsKeys from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.trainer.states import RunningStage, TrainerFn diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index f4735304c97f0..32f1ee006bf0f 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -17,7 +17,7 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning.core.step_result import ResultCollection +from pytorch_lightning.core.result import ResultCollection from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.model_helpers import is_overridden diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 6406f3bdb3466..1d39595fe527b 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -26,7 +26,7 @@ from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.core.step_result import ResultCollection +from pytorch_lightning.core.result import ResultCollection from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b972ce9c88c8f..8b77b32f8b73b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -28,7 +28,7 @@ from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.memory import ModelSummary -from pytorch_lightning.core.step_result import ResultCollection +from pytorch_lightning.core.result import ResultCollection from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.plugins import Plugin from pytorch_lightning.plugins.environments import ClusterEnvironment diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 55468fe7b609e..5f5d9c9833735 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -22,7 +22,7 @@ from torch.optim import Optimizer from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.core.step_result import ResultCollection +from pytorch_lightning.core.result import ResultCollection from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 3ca9b3b62b9a1..929cae77fc382 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -19,7 +19,7 @@ from torchmetrics import Metric import tests.helpers.utils as tutils -from pytorch_lightning.core.step_result import DefaultMetricsKeys, ResultCollection +from pytorch_lightning.core.result import DefaultMetricsKeys, ResultCollection from tests.helpers.runif import RunIf diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 9b978e66512e7..7632d3d124e54 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -24,7 +24,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.core.step_result import ResultCollection +from pytorch_lightning.core.result import ResultCollection from pytorch_lightning.plugins import TPUSpawnPlugin from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.distributed import ReduceOp diff --git a/tests/trainer/connectors/test_logger_connectors.py b/tests/trainer/connectors/test_logger_connectors.py index 2c4627a92f45f..c7803c93992ff 100644 --- a/tests/trainer/connectors/test_logger_connectors.py +++ b/tests/trainer/connectors/test_logger_connectors.py @@ -15,7 +15,7 @@ from torch import tensor from pytorch_lightning import seed_everything -from pytorch_lightning.core.step_result import DefaultMetricsKeys, ResultCollection +from pytorch_lightning.core.result import DefaultMetricsKeys, ResultCollection def test_result_collection_on_tensor_with_mean_reduction(): From 3ac48e041edd134d68b9fdb176d11b9676112163 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 29 May 2021 19:20:23 +0200 Subject: [PATCH 192/455] Move result --- pytorch_lightning/core/lightning.py | 2 +- .../trainer/connectors/logger_connector/logger_connector.py | 2 +- .../{core => trainer/connectors/logger_connector}/result.py | 0 pytorch_lightning/trainer/evaluation_loop.py | 2 +- pytorch_lightning/trainer/properties.py | 2 +- pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/trainer/training_loop.py | 2 +- tests/core/test_metric_result_integration.py | 2 +- tests/models/test_tpu.py | 2 +- tests/trainer/connectors/test_logger_connectors.py | 2 +- 10 files changed, 9 insertions(+), 9 deletions(-) rename pytorch_lightning/{core => trainer/connectors/logger_connector}/result.py (100%) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 50bb6dce706e2..67a3aeeb127ed 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -38,8 +38,8 @@ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.core.result import ResultCollection from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 80af04a1af1e4..3744fadeaefcf 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -19,9 +19,9 @@ import torch from pytorch_lightning.core import memory -from pytorch_lightning.core.result import DefaultMetricsKeys from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator +from pytorch_lightning.trainer.connectors.logger_connector.result import DefaultMetricsKeys from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import DeviceType from pytorch_lightning.utilities.metrics import metrics_to_scalars diff --git a/pytorch_lightning/core/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py similarity index 100% rename from pytorch_lightning/core/result.py rename to pytorch_lightning/trainer/connectors/logger_connector/result.py diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 32f1ee006bf0f..7875f26f470f5 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -17,7 +17,7 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning.core.result import ResultCollection +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.model_helpers import is_overridden diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 1d39595fe527b..666620091aece 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -26,13 +26,13 @@ from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.core.result import ResultCollection from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.states import RunningStage, TrainerState, TrainerStatus from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8b77b32f8b73b..0d2ff45900c30 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -28,7 +28,6 @@ from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.memory import ModelSummary -from pytorch_lightning.core.result import ResultCollection from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.plugins import Plugin from pytorch_lightning.plugins.environments import ClusterEnvironment @@ -48,6 +47,7 @@ from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 5f5d9c9833735..fa926617134ae 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -22,8 +22,8 @@ from torch.optim import Optimizer from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.core.result import ResultCollection from pytorch_lightning.plugins import ParallelPlugin +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType from pytorch_lightning.utilities.distributed import rank_zero_info diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 929cae77fc382..c6c133a87b62c 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -19,7 +19,7 @@ from torchmetrics import Metric import tests.helpers.utils as tutils -from pytorch_lightning.core.result import DefaultMetricsKeys, ResultCollection +from pytorch_lightning.trainer.connectors.logger_connector.result import DefaultMetricsKeys, ResultCollection from tests.helpers.runif import RunIf diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 7632d3d124e54..11a9873cece30 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -24,8 +24,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.core.result import ResultCollection from pytorch_lightning.plugins import TPUSpawnPlugin +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/tests/trainer/connectors/test_logger_connectors.py b/tests/trainer/connectors/test_logger_connectors.py index c7803c93992ff..51a251246f65d 100644 --- a/tests/trainer/connectors/test_logger_connectors.py +++ b/tests/trainer/connectors/test_logger_connectors.py @@ -15,7 +15,7 @@ from torch import tensor from pytorch_lightning import seed_everything -from pytorch_lightning.core.result import DefaultMetricsKeys, ResultCollection +from pytorch_lightning.trainer.connectors.logger_connector.result import DefaultMetricsKeys, ResultCollection def test_result_collection_on_tensor_with_mean_reduction(): From bdf85a20aaf09f000826c311a6522fa9bd5f31da Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 29 May 2021 19:26:10 +0200 Subject: [PATCH 193/455] Imports --- pytorch_lightning/core/lightning.py | 8 +++++--- pytorch_lightning/overrides/base.py | 6 +++--- pytorch_lightning/overrides/data_parallel.py | 4 ++-- pytorch_lightning/overrides/distributed.py | 4 ++-- pytorch_lightning/plugins/precision/apex_amp.py | 2 +- pytorch_lightning/plugins/training_type/deepspeed.py | 6 +++--- pytorch_lightning/plugins/training_type/parallel.py | 4 ++-- 7 files changed, 18 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 67a3aeeb127ed..a67cec4c7230a 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -26,7 +26,7 @@ from argparse import Namespace from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union import torch from torch import ScriptModule, Tensor @@ -39,7 +39,6 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES -from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -51,6 +50,9 @@ from pytorch_lightning.utilities.types import _METRIC, EPOCH_OUTPUT, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache +if TYPE_CHECKING: + from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection + warning_cache = WarningCache() log = logging.getLogger(__name__) @@ -328,7 +330,7 @@ def log( ' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`' ) - result_collections: Optional[ResultCollection] = self.trainer.result_collections + result_collections: Optional['ResultCollection'] = self.trainer.result_collections if result_collections is not None: # TODO: if logged twice fail with crash diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index 88e8ed6375e1b..11bd9fa0842a7 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -15,13 +15,13 @@ from torch.nn import DataParallel from torch.nn.parallel import DistributedDataParallel -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): - def __init__(self, pl_module: LightningModule): + def __init__(self, pl_module: 'pl.LightningModule'): """ Wraps the user's LightningModule and redirects the forward call to the appropriate method, either ``training_step``, ``validation_step`` or ``test_step``. @@ -66,7 +66,7 @@ def on_post_move_to_device(self): pass -def unwrap_lightning_module(wrapped_model) -> LightningModule: +def unwrap_lightning_module(wrapped_model) -> 'pl.LightningModule': model = wrapped_model if isinstance(model, (DistributedDataParallel, DataParallel)): model = model.module diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 3d6e527ef95a9..a2be6d6f8e704 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -17,7 +17,7 @@ import torch -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -53,7 +53,7 @@ class LightningParallelModule(_LightningModuleWrapperBase): """ - def __init__(self, pl_module: LightningModule): + def __init__(self, pl_module: 'pl.LightningModule'): super().__init__(pl_module) _ignore_scalar_return_in_dp() diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py index 559e1161ce676..a02c1d6f694f6 100644 --- a/pytorch_lightning/overrides/distributed.py +++ b/pytorch_lightning/overrides/distributed.py @@ -18,13 +18,13 @@ from torch.nn.parallel import DistributedDataParallel from torch.utils.data import BatchSampler, DistributedSampler, Sampler -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase class LightningDistributedModule(_LightningModuleWrapperBase): - def __init__(self, pl_module: LightningModule): + def __init__(self, pl_module: 'pl.LightningModule'): """ Wraps the user's LightningModule and redirects the forward call to the appropriate method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``. diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 71c2119e734fd..aa3aad7689cf0 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -39,7 +39,7 @@ def __init__(self, amp_level: str = "O2") -> None: def master_params(self, optimizer: Optimizer) -> _PARAMETERS: return amp.master_params(optimizer) - def dispatch(self, trainer: "pl.Trainer") -> None: + def dispatch(self, trainer: 'pl.Trainer') -> None: if not self._connected: accelerator = trainer.accelerator _, accelerator.optimizers = amp.initialize( diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 8dd04aafa6b86..4ee6bf8408ad2 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -22,8 +22,8 @@ import torch +import pytorch_lightning as pl from pytorch_lightning.callbacks import GradientAccumulationScheduler -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.ddp import DDPPlugin @@ -51,7 +51,7 @@ def remove_module_hooks(model: torch.nn.Module) -> None: class LightningDeepSpeedModule(_LightningModuleWrapperBase): - def __init__(self, pl_module: LightningModule, precision: int): + def __init__(self, pl_module: 'pl.LightningModule', precision: int): super().__init__(pl_module) self.precision = precision @@ -378,7 +378,7 @@ def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) return distributed_sampler_kwargs - def init_optimizers(self, trainer, model: LightningModule) -> Tuple[List, List, List]: + def init_optimizers(self, trainer, model: 'pl.LightningModule') -> Tuple[List, List, List]: # Skip initializing optimizers here as DeepSpeed handles optimizers via config. # User may have specified config options instead in configure_optimizers, but this is handled # via `_initialize_deepspeed_train` diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index a8028e5be1a69..dfda76d3404a8 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -19,7 +19,7 @@ import torch from torch.nn.parallel import DistributedDataParallel -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin @@ -99,7 +99,7 @@ def torch_distributed_backend(self): return torch_backend @staticmethod - def configure_sync_batchnorm(model: LightningModule) -> LightningModule: + def configure_sync_batchnorm(model: 'pl.LightningModule') -> 'pl.LightningModule': """ Add global batchnorm for a model spread across multiple GPUs and nodes. From a65fd2eaff68254bf952fc260ab81c6d58967f95 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 29 May 2021 19:41:51 +0200 Subject: [PATCH 194/455] Comments --- pytorch_lightning/core/lightning.py | 2 +- .../plugins/training_type/ddp2.py | 2 +- pytorch_lightning/plugins/training_type/dp.py | 3 +- .../logger_connector/logger_connector.py | 42 +++--- pytorch_lightning/trainer/evaluation_loop.py | 2 +- pytorch_lightning/trainer/properties.py | 3 +- pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/trainer/training_loop.py | 2 +- tests/callbacks/test_progress_bar.py | 2 +- .../connectors/test_logger_connectors.py | 136 ----------------- .../logging_/test_eval_loop_logging.py | 2 +- .../trainer/logging_/test_logger_connector.py | 138 +++++++++++++++++- .../loops/test_evaluation_loop_flow.py | 7 +- 13 files changed, 169 insertions(+), 174 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a67cec4c7230a..fff1884947790 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -330,7 +330,7 @@ def log( ' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`' ) - result_collections: Optional['ResultCollection'] = self.trainer.result_collections + result_collections: Optional['ResultCollection'] = self.trainer.result_collection if result_collections is not None: # TODO: if logged twice fail with crash diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index 13e6530270d4e..146dfc6c97dae 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -48,7 +48,7 @@ def reduce(self, tensor, *args, **kwargs): reduced value, except when the input was not a tensor the output remains is unchanged """ - def _reduce(t: torch.Tensor): + def _reduce(t: torch.Tensor) -> torch.Tensor: dtype_tensor = t.dtype return t.float().mean().type(dtype_tensor) diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 20367ec6a8f35..08899d72db17f 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -18,7 +18,6 @@ from pytorch_lightning.overrides.data_parallel import LightningParallelModule from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin -from pytorch_lightning.trainer.connectors.logger_connector.result import Result from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -65,7 +64,7 @@ def reduce(self, tensor, *args, **kwargs): reduced value, except when the input was not a tensor the output remains is unchanged """ - def _reduce(t: torch.Tensor): + def _reduce(t: torch.Tensor) -> torch.Tensor: dtype_tensor = t.dtype return t.float().mean().type(dtype_tensor) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 3744fadeaefcf..bfa80ce9ae097 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -120,13 +120,13 @@ def evaluation_epoch_end(self): # reset dataloader idx model_ref = self.trainer.lightning_module model_ref._current_dataloader_idx = None - self.trainer.result_collections.on_epoch_end_reached = True + self.trainer.result_collection.on_epoch_end_reached = True def add_to_eval_loop_results(self, dl_idx, has_been_initialized): if self.trainer.sanity_checking: return - callback_metrics = self.trainer.result_collections.metrics[DefaultMetricsKeys.CALLBACK] + callback_metrics = self.trainer.result_collection.metrics[DefaultMetricsKeys.CALLBACK] if os.getenv("PL_DEV_DEBUG", '0') == '1': callback_metrics["debug_epoch"] = self.trainer.current_epoch callback_metrics = deepcopy(callback_metrics) @@ -147,7 +147,7 @@ def prepare_eval_loop_results(self): self.add_to_eval_loop_results(dl_idx, has_been_initialized) def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: - metrics = self.trainer.result_collections.metrics + metrics = self.trainer.result_collection.metrics # update metrics self.add_progress_bar_metrics(metrics[DefaultMetricsKeys.PBAR]) @@ -199,7 +199,7 @@ def increment_evaluation_log_step(self) -> None: def on_evaluation_start(self): root_device = self.trainer.lightning_module.device - self.trainer.result_collections.root_device = root_device + self.trainer.result_collection.root_device = root_device def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: model = self.trainer.lightning_module @@ -207,11 +207,11 @@ def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None # track batch_size - self.trainer.result_collections.extract_batch_size(batch) - self.trainer.result_collections.batch_idx = batch_idx + self.trainer.result_collection.extract_batch_size(batch) + self.trainer.result_collection.batch_idx = batch_idx def update_evaluation_step_metrics(self) -> None: - metrics = self.trainer.result_collections.metrics + metrics = self.trainer.result_collection.metrics # update metrics self.add_progress_bar_metrics(metrics[DefaultMetricsKeys.PBAR]) @@ -236,17 +236,17 @@ def update_evaluation_step_metrics(self) -> None: def on_train_start(self): root_device = self.trainer.lightning_module.device - self.trainer.result_collections.root_device = root_device + self.trainer.result_collection.root_device = root_device def on_train_split_start(self, batch_idx: int, split_batch: Any) -> None: - self.trainer.result_collections.extract_batch_size(split_batch) - self.trainer.result_collections.batch_idx = batch_idx + self.trainer.result_collection.extract_batch_size(split_batch) + self.trainer.result_collection.batch_idx = batch_idx def on_train_batch_end(self) -> None: - self.trainer.result_collections.batch_size = 1 + self.trainer.result_collection.batch_size = 1 def update_train_step_metrics(self, batch_output): - metrics = self.trainer.result_collections.metrics + metrics = self.trainer.result_collection.metrics # update metrics self.add_progress_bar_metrics(metrics[DefaultMetricsKeys.PBAR]) @@ -268,11 +268,11 @@ def update_train_step_metrics(self, batch_output): def on_train_epoch_end(self): # inform cached logger connector epoch finished - self.trainer.result_collections.on_epoch_end_reached = True + self.trainer.result_collection.on_epoch_end_reached = True def update_train_epoch_metrics(self) -> None: - metrics = self.trainer.result_collections.metrics + metrics = self.trainer.result_collection.metrics # update metrics self.add_progress_bar_metrics(metrics[DefaultMetricsKeys.PBAR]) @@ -290,7 +290,7 @@ def update_train_epoch_metrics(self) -> None: self.log_metrics(epoch_log_metrics, {}) # reset result collection for next epoch - self.trainer.result_collections.reset_metrics() + self.trainer.result_collection.reset_metrics() """ Utilities and properties @@ -298,8 +298,8 @@ def update_train_epoch_metrics(self) -> None: @property def callback_metrics(self) -> Dict[str, float]: - if self.trainer.result_collections: - metrics = self.trainer.result_collections.metrics[DefaultMetricsKeys.CALLBACK] + if self.trainer.result_collection: + metrics = self.trainer.result_collection.metrics[DefaultMetricsKeys.CALLBACK] self._callback_metrics.update(metrics) if os.getenv("PL_DEV_DEBUG", '0') == '1': self._callback_metrics["debug_epoch"] = self.trainer.current_epoch @@ -307,15 +307,15 @@ def callback_metrics(self) -> Dict[str, float]: @property def logged_metrics(self) -> Dict[str, float]: - if self.trainer.result_collections: - metrics = self.trainer.result_collections.metrics[DefaultMetricsKeys.LOG] + if self.trainer.result_collection: + metrics = self.trainer.result_collection.metrics[DefaultMetricsKeys.LOG] self._logged_metrics.update(metrics) return self._logged_metrics @property def progress_bar_metrics(self) -> Dict[str, float]: - if self.trainer.result_collections: - metrics = self.trainer.result_collections.metrics[DefaultMetricsKeys.PBAR] + if self.trainer.result_collection: + metrics = self.trainer.result_collection.metrics[DefaultMetricsKeys.PBAR] self._progress_bar_metrics.update(metrics) return self._progress_bar_metrics diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 7875f26f470f5..078776602d3a9 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -102,7 +102,7 @@ def on_evaluation_model_train(self) -> None: model_ref.on_validation_model_train() def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: - self.trainer.result_collections.reset_metrics() + self.trainer.result_collection.reset_metrics() if self.trainer.testing: self.trainer.call_hook('on_test_end', *args, **kwargs) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 666620091aece..24d9c208280f4 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -517,14 +517,13 @@ def min_steps(self) -> Optional[int]: return self.train_loop.min_steps @property - def result_collections(self) -> Optional[ResultCollection]: + def result_collection(self) -> Optional[ResultCollection]: if self.training: return self.train_loop.train_results elif self.validating or self.sanity_checking: return self.evaluation_loop.validation_results elif self.testing: return self.evaluation_loop.test_results - return None # Used to represent the concrete type TrainerProperties class methods are called on. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0d2ff45900c30..7861bb8005ade 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1114,7 +1114,7 @@ def _run_sanity_check(self, ref_model): self.on_sanity_check_end() # reset validation metrics - self.result_collections.reset() + self.result_collection.reset() # reset the seed to what it was before sanity check # prevents sanity check to affect random sampling in training diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 04c2fc140d484..fac4e2352a9e1 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -320,7 +320,7 @@ def _process_training_step_output(self, training_step_output, split_batch): if training_step_output_for_epoch_end is None: return None, None - result = self.trainer.result_collections + result = self.trainer.result_collection loss = None hiddens = None diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index f28178626b25e..00fefcccc3435 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -389,7 +389,7 @@ def training_step(self, batch, batch_idx): return super().training_step(batch, batch_idx) def on_train_end(self) -> None: - print(self.trainer.result_collections) + print(self.trainer.result_collection) trainer = Trainer( default_root_dir=tmpdir, diff --git a/tests/trainer/connectors/test_logger_connectors.py b/tests/trainer/connectors/test_logger_connectors.py index 51a251246f65d..0017e9165389a 100644 --- a/tests/trainer/connectors/test_logger_connectors.py +++ b/tests/trainer/connectors/test_logger_connectors.py @@ -16,139 +16,3 @@ from pytorch_lightning import seed_everything from pytorch_lightning.trainer.connectors.logger_connector.result import DefaultMetricsKeys, ResultCollection - - -def test_result_collection_on_tensor_with_mean_reduction(): - - seed_everything(42) - - result_collection = ResultCollection(True, torch.device("cpu")) - - for i in range(1, 10): - - result_collection.batch_idx = i - - for prob_bar in [False, True]: - - for logger in [False, True]: - - i = float(i) - - result_collection.log( - "training_step", - f"loss_1_{int(prob_bar)}_{int(logger)}", - torch.tensor(i), - on_step=True, - on_epoch=True, - batch_size=i**2, - prog_bar=prob_bar, - logger=logger - ) - result_collection.log( - "training_step", - f"loss_2_{int(prob_bar)}_{int(logger)}", - torch.tensor(i), - on_step=False, - on_epoch=True, - batch_size=i**2, - prog_bar=prob_bar, - logger=logger - ) - result_collection.log( - "training_step", - f"loss_3_{int(prob_bar)}_{int(logger)}", - torch.tensor(i), - on_step=True, - on_epoch=False, - batch_size=i**2, - prog_bar=prob_bar, - logger=logger - ) - result_collection.log( - "training_step", - f"loss_4_{int(prob_bar)}_{int(logger)}", - torch.tensor(i), - on_step=False, - on_epoch=False, - batch_size=i**2, - prog_bar=prob_bar, - logger=logger - ) - - excepted_values = [ - tensor(1.), - tensor(2.), - tensor(3.), - tensor(4.), - tensor(5.), - tensor(6.), - tensor(7.), - tensor(8.), - tensor(9.) - ] - excepted_batches = [1, 4, 9, 16, 25, 36, 49, 64, 81] - total_value = tensor(excepted_values) * tensor(excepted_batches) - assert result_collection["training_step.loss_1_0_0"].value == sum(total_value) - assert result_collection["training_step.loss_1_0_0"].cumulated_batch_size == sum(excepted_batches) - - batch_metrics = result_collection.get_batch_metrics() - - expected = { - 'loss_1_1_0_step': tensor([9.]), - 'loss_3_1_0': tensor([9.]), - 'loss_1_1_1_step': tensor([9.]), - 'loss_3_1_1': tensor([9.]) - } - assert batch_metrics[DefaultMetricsKeys.PBAR] == expected - - excepted = { - 'loss_1_0_1_step': tensor([9.]), - 'loss_3_0_1': tensor([9.]), - 'loss_1_1_1_step': tensor([9.]), - 'loss_3_1_1': tensor([9.]) - } - assert batch_metrics[DefaultMetricsKeys.LOG] == excepted - - excepted = { - 'loss_1_0_0': tensor(9.), - 'loss_1_0_0_step': tensor(9.), - 'loss_3_0_0': tensor(9.), - 'loss_1_0_1': tensor(9.), - 'loss_1_0_1_step': tensor(9.), - 'loss_3_0_1': tensor(9.), - 'loss_1_1_0': tensor(9.), - 'loss_1_1_0_step': tensor(9.), - 'loss_3_1_0': tensor(9.), - 'loss_1_1_1': tensor(9.), - 'loss_1_1_1_step': tensor(9.), - 'loss_3_1_1': tensor(9.) - } - assert batch_metrics[DefaultMetricsKeys.CALLBACK] == excepted - - result_collection.on_epoch_end_reached = True - - epoch_metrics = result_collection.get_epoch_metrics() - - mean = (tensor(excepted_values) * tensor(excepted_batches)).sum() / sum(excepted_batches) - - expected = {'loss_1_1_0_epoch': mean, 'loss_2_1_0': mean, 'loss_1_1_1_epoch': mean, 'loss_2_1_1': mean} - assert epoch_metrics[DefaultMetricsKeys.PBAR] == expected - - excepted = {'loss_1_0_1_epoch': mean, 'loss_2_0_1': mean, 'loss_1_1_1_epoch': mean, 'loss_2_1_1': mean} - assert epoch_metrics[DefaultMetricsKeys.LOG] == excepted - - excepted = { - 'loss_1_0_0': mean, - 'loss_1_0_0_epoch': mean, - 'loss_2_0_0': mean, - 'loss_1_0_1': mean, - 'loss_1_0_1_epoch': mean, - 'loss_2_0_1': mean, - 'loss_1_1_0': mean, - 'loss_1_1_0_epoch': mean, - 'loss_2_1_0': mean, - 'loss_1_1_1': mean, - 'loss_1_1_1_epoch': mean, - 'loss_2_1_1': mean, - } - assert epoch_metrics[DefaultMetricsKeys.CALLBACK] == excepted diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 824bc742d0dc5..5f32682f35663 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -542,7 +542,7 @@ def validation_step(self, batch, batch_idx): self.log('val_loss', loss) def on_validation_end(self) -> None: - print(self.trainer.result_collections) + print(self.trainer.result_collection) max_epochs = 1 model = TestModel() diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index cc0abaae48a4b..3b0717fedf81a 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -22,10 +22,11 @@ from torch.utils.data import DataLoader from torchmetrics import Accuracy, AveragePrecision -from pytorch_lightning import LightningModule +from pytorch_lightning import LightningModule, seed_everything from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator +from pytorch_lightning.trainer.connectors.logger_connector.result import DefaultMetricsKeys, ResultCollection from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -432,3 +433,138 @@ def _assert_called(model, stage): trainer.test(model) _assert_called(model, 'test') + + +def test_result_collection_on_tensor_with_mean_reduction(): + seed_everything(42) + + result_collection = ResultCollection(True, torch.device("cpu")) + + for i in range(1, 10): + + result_collection.batch_idx = i + + for prob_bar in [False, True]: + + for logger in [False, True]: + + i = float(i) + + result_collection.log( + "training_step", + f"loss_1_{int(prob_bar)}_{int(logger)}", + torch.tensor(i), + on_step=True, + on_epoch=True, + batch_size=i**2, + prog_bar=prob_bar, + logger=logger + ) + result_collection.log( + "training_step", + f"loss_2_{int(prob_bar)}_{int(logger)}", + torch.tensor(i), + on_step=False, + on_epoch=True, + batch_size=i**2, + prog_bar=prob_bar, + logger=logger + ) + result_collection.log( + "training_step", + f"loss_3_{int(prob_bar)}_{int(logger)}", + torch.tensor(i), + on_step=True, + on_epoch=False, + batch_size=i**2, + prog_bar=prob_bar, + logger=logger + ) + result_collection.log( + "training_step", + f"loss_4_{int(prob_bar)}_{int(logger)}", + torch.tensor(i), + on_step=False, + on_epoch=False, + batch_size=i**2, + prog_bar=prob_bar, + logger=logger + ) + + excepted_values = [ + torch.tensor(1.), + torch.tensor(2.), + torch.tensor(3.), + torch.tensor(4.), + torch.tensor(5.), + torch.tensor(6.), + torch.tensor(7.), + torch.tensor(8.), + torch.tensor(9.) + ] + excepted_batches = [1, 4, 9, 16, 25, 36, 49, 64, 81] + total_value = torch.tensor(excepted_values) * torch.tensor(excepted_batches) + assert result_collection["training_step.loss_1_0_0"].value == sum(total_value) + assert result_collection["training_step.loss_1_0_0"].cumulated_batch_size == sum(excepted_batches) + + batch_metrics = result_collection.get_batch_metrics() + + expected = { + 'loss_1_1_0_step': torch.tensor([9.]), + 'loss_3_1_0': torch.tensor([9.]), + 'loss_1_1_1_step': torch.tensor([9.]), + 'loss_3_1_1': torch.tensor([9.]) + } + assert batch_metrics[DefaultMetricsKeys.PBAR] == expected + + excepted = { + 'loss_1_0_1_step': torch.tensor([9.]), + 'loss_3_0_1': torch.tensor([9.]), + 'loss_1_1_1_step': torch.tensor([9.]), + 'loss_3_1_1': torch.tensor([9.]) + } + assert batch_metrics[DefaultMetricsKeys.LOG] == excepted + + excepted = { + 'loss_1_0_0': torch.tensor(9.), + 'loss_1_0_0_step': torch.tensor(9.), + 'loss_3_0_0': torch.tensor(9.), + 'loss_1_0_1': torch.tensor(9.), + 'loss_1_0_1_step': torch.tensor(9.), + 'loss_3_0_1': torch.tensor(9.), + 'loss_1_1_0': torch.tensor(9.), + 'loss_1_1_0_step': torch.tensor(9.), + 'loss_3_1_0': torch.tensor(9.), + 'loss_1_1_1': torch.tensor(9.), + 'loss_1_1_1_step': torch.tensor(9.), + 'loss_3_1_1': torch.tensor(9.) + } + assert batch_metrics[DefaultMetricsKeys.CALLBACK] == excepted + + result_collection.on_epoch_end_reached = True + + epoch_metrics = result_collection.get_epoch_metrics() + + mean = (torch.tensor(excepted_values) * torch.tensor(excepted_batches)).sum() / sum(excepted_batches) + + expected = {'loss_1_1_0_epoch': mean, 'loss_2_1_0': mean, 'loss_1_1_1_epoch': mean, 'loss_2_1_1': mean} + assert epoch_metrics[DefaultMetricsKeys.PBAR] == expected + + excepted = {'loss_1_0_1_epoch': mean, 'loss_2_0_1': mean, 'loss_1_1_1_epoch': mean, 'loss_2_1_1': mean} + assert epoch_metrics[DefaultMetricsKeys.LOG] == excepted + + excepted = { + 'loss_1_0_0': mean, + 'loss_1_0_0_epoch': mean, + 'loss_2_0_0': mean, + 'loss_1_0_1': mean, + 'loss_1_0_1_epoch': mean, + 'loss_2_0_1': mean, + 'loss_1_1_0': mean, + 'loss_1_1_0_epoch': mean, + 'loss_2_1_0': mean, + 'loss_1_1_1': mean, + 'loss_1_1_1_epoch': mean, + 'loss_2_1_1': mean, + } + assert epoch_metrics[DefaultMetricsKeys.CALLBACK] == excepted diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py index 075b9f7438124..4ef44c1d5ac90 100644 --- a/tests/trainer/loops/test_evaluation_loop_flow.py +++ b/tests/trainer/loops/test_evaluation_loop_flow.py @@ -66,12 +66,9 @@ def backward(self, loss, optimizer, optimizer_idx): assert not model.validation_step_end_called assert not model.validation_epoch_end_called + # simulate training manually trainer.state.stage = RunningStage.TRAINING - - # make sure training outputs what is expected - for batch_idx, batch in enumerate(model.train_dataloader()): - break - + batch_idx, batch = 0, next(iter(model.train_dataloader())) out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) From f92504c2bdf45c55c3cc1b56e805075e42fbb398 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 29 May 2021 19:42:21 +0200 Subject: [PATCH 195/455] Extra file --- .../connectors/test_logger_connectors.py | 18 ------------------ 1 file changed, 18 deletions(-) delete mode 100644 tests/trainer/connectors/test_logger_connectors.py diff --git a/tests/trainer/connectors/test_logger_connectors.py b/tests/trainer/connectors/test_logger_connectors.py deleted file mode 100644 index 0017e9165389a..0000000000000 --- a/tests/trainer/connectors/test_logger_connectors.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch -from torch import tensor - -from pytorch_lightning import seed_everything -from pytorch_lightning.trainer.connectors.logger_connector.result import DefaultMetricsKeys, ResultCollection From 3166240dbd08adada9c5fb52e049599cf9de668a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 29 May 2021 19:59:14 +0200 Subject: [PATCH 196/455] Reorder --- pytorch_lightning/trainer/properties.py | 283 +++++++++++++----------- 1 file changed, 150 insertions(+), 133 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 24d9c208280f4..9384c190227b2 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -33,6 +33,7 @@ from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection +from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop from pytorch_lightning.trainer.states import RunningStage, TrainerState, TrainerStatus from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn @@ -61,6 +62,10 @@ class TrainerProperties(ABC): logger_connector: LoggerConnector state: TrainerState train_loop: TrainLoop + evaluation_loop: EvaluationLoop + ''' + Accelerator properties + ''' @property def accelerator(self) -> Accelerator: @@ -135,46 +140,92 @@ def data_parallel_device_ids(self) -> Optional[List[int]]: return self.accelerator_connector.parallel_device_ids @property - def log_dir(self) -> Optional[str]: - if self.logger is None: - dirpath = self.default_root_dir - else: - dirpath = getattr(self.logger, 'log_dir' if isinstance(self.logger, TensorBoardLogger) else 'save_dir') + def lightning_module(self) -> LightningModule: + return self.accelerator.lightning_module - dirpath = self.accelerator.broadcast(dirpath) - return dirpath + @property + def optimizers(self) -> Optional[List[Optimizer]]: + return self.accelerator.optimizers + + @optimizers.setter + def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: + # Necessary to rewrap optimizers to lightning + # They will be re-created when accessing + # the `lightning_optimizers` trainer property + self._lightning_optimizers = None + + self.accelerator.optimizers = new_optims @property - def use_amp(self) -> bool: - return self.precision == 16 + def lr_schedulers(self) -> Optional[list]: + return self.accelerator.lr_schedulers + + @lr_schedulers.setter + def lr_schedulers(self, new_schedulers: Optional[list]) -> None: + self.accelerator.lr_schedulers = new_schedulers @property - def callback_metrics(self) -> dict: - return self.logger_connector.callback_metrics + def optimizer_frequencies(self) -> list: + return self.accelerator.optimizer_frequencies - @callback_metrics.setter - def callback_metrics(self, x: dict) -> None: - self.logger_connector.callback_metrics = x + @optimizer_frequencies.setter + def optimizer_frequencies(self, new_freqs: list) -> None: + self.accelerator.optimizer_frequencies = new_freqs @property - def logged_metrics(self) -> dict: - return self.logger_connector.logged_metrics + def amp_backend(self) -> Optional[str]: + return self.accelerator.amp_backend - @logged_metrics.setter - def logged_metrics(self, x: dict) -> None: - self.logger_connector.logged_metrics = x + @property + def precision(self) -> Union[str, int]: + return self.accelerator.precision @property - def progress_bar_metrics(self) -> dict: - return self.logger_connector.progress_bar_metrics + def scaler(self): + return self.accelerator.scaler - @progress_bar_metrics.setter - def progress_bar_metrics(self, x: dict) -> None: - self.logger_connector.progress_bar_metrics = x + @property + def gpus(self) -> Optional[Union[List[int], str, int]]: + return self.accelerator_connector.gpus @property - def interrupted(self) -> bool: - return self.state.status == TrainerStatus.INTERRUPTED + def model(self) -> torch.nn.Module: + """ + The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel. + To access the pure LightningModule, use + :meth:`~pytorch_lightning.trainer.trainer.Trainer.lightning_module` instead. + """ + return self.accelerator.model + + @model.setter + def model(self, model: torch.nn.Module) -> None: + """ + Setter for the model, pass-through to accelerator and plugin where the model reference is stored. + Used by the Tuner to reset the state of Trainer and Accelerator. + + Args: + model: The LightningModule, possibly wrapped into DataParallel or DistributedDataParallel, depending + on the backend. + """ + self.accelerator.model = model + + ''' + General properties + ''' + + @property + def log_dir(self) -> Optional[str]: + if self.logger is None: + dirpath = self.default_root_dir + else: + dirpath = getattr(self.logger, 'log_dir' if isinstance(self.logger, TensorBoardLogger) else 'save_dir') + + dirpath = self.accelerator.broadcast(dirpath) + return dirpath + + @property + def use_amp(self) -> bool: + return self.precision == 16 @property def is_global_zero(self) -> bool: @@ -195,39 +246,16 @@ def slurm_job_id(self) -> Optional[int]: job_id = None return job_id - @classmethod - def default_attributes(cls) -> dict: - init_signature = inspect.signature(cls) - return {k: v.default for k, v in init_signature.parameters.items()} - - @classmethod - def get_deprecated_arg_names(cls) -> List: - """Returns a list with deprecated Trainer arguments.""" - depr_arg_names = [] - for name, val in cls.__dict__.items(): - if name.startswith('DEPRECATED') and isinstance(val, (tuple, list)): - depr_arg_names.extend(val) - return depr_arg_names - - @classmethod - def from_argparse_args(cls: Type['_T'], args: Union[Namespace, ArgumentParser], **kwargs) -> '_T': - return from_argparse_args(cls, args, **kwargs) - - @classmethod - def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: - return parse_argparser(cls, arg_parser) - - @classmethod - def match_env_arguments(cls) -> Namespace: - return parse_env_variables(cls) - - @classmethod - def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: - return add_argparse_args(cls, parent_parser, **kwargs) + @property + def lightning_optimizers(self) -> List[LightningOptimizer]: + if self._lightning_optimizers is None: + self.convert_to_lightning_optimizers() + return self._lightning_optimizers @property - def gpus(self) -> Optional[Union[List[int], str, int]]: - return self.accelerator_connector.gpus + def distributed_sampler_kwargs(self) -> Optional[dict]: + if isinstance(self.training_type_plugin, ParallelPlugin): + return self.training_type_plugin.distributed_sampler_kwargs @property def data_parallel(self) -> bool: @@ -336,91 +364,47 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]: def save_checkpoint(self, filepath, weights_only: bool = False) -> None: self.checkpoint_connector.save_checkpoint(filepath, weights_only) - @property - def model(self) -> torch.nn.Module: - """ - The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel. - To access the pure LightningModule, use - :meth:`~pytorch_lightning.trainer.trainer.Trainer.lightning_module` instead. - """ - return self.accelerator.model - - @model.setter - def model(self, model: torch.nn.Module) -> None: - """ - Setter for the model, pass-through to accelerator and plugin where the model reference is stored. - Used by the Tuner to reset the state of Trainer and Accelerator. - - Args: - model: The LightningModule, possibly wrapped into DataParallel or DistributedDataParallel, depending - on the backend. - """ - self.accelerator.model = model - - @property - def lightning_optimizers(self) -> List[LightningOptimizer]: - if self._lightning_optimizers is None: - self.convert_to_lightning_optimizers() - return self._lightning_optimizers - - @property - def lightning_module(self) -> LightningModule: - return self.accelerator.lightning_module + ''' + Parsing properties + ''' - @property - def optimizers(self) -> Optional[List[Optimizer]]: - return self.accelerator.optimizers - - @optimizers.setter - def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: - # Necessary to rewrap optimizers to lightning - # They will be re-created when accessing - # the `lightning_optimizers` trainer property - self._lightning_optimizers = None - - self.accelerator.optimizers = new_optims - - @property - def lr_schedulers(self) -> Optional[list]: - return self.accelerator.lr_schedulers - - @lr_schedulers.setter - def lr_schedulers(self, new_schedulers: Optional[list]) -> None: - self.accelerator.lr_schedulers = new_schedulers - - @property - def optimizer_frequencies(self) -> list: - return self.accelerator.optimizer_frequencies + @classmethod + def default_attributes(cls) -> dict: + init_signature = inspect.signature(cls) + return {k: v.default for k, v in init_signature.parameters.items()} - @optimizer_frequencies.setter - def optimizer_frequencies(self, new_freqs: list) -> None: - self.accelerator.optimizer_frequencies = new_freqs + @classmethod + def get_deprecated_arg_names(cls) -> List: + """Returns a list with deprecated Trainer arguments.""" + depr_arg_names = [] + for name, val in cls.__dict__.items(): + if name.startswith('DEPRECATED') and isinstance(val, (tuple, list)): + depr_arg_names.extend(val) + return depr_arg_names - @property - def amp_backend(self) -> Optional[str]: - return self.accelerator.amp_backend + @classmethod + def from_argparse_args(cls: Type['_T'], args: Union[Namespace, ArgumentParser], **kwargs) -> '_T': + return from_argparse_args(cls, args, **kwargs) - @property - def precision(self) -> Union[str, int]: - return self.accelerator.precision + @classmethod + def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: + return parse_argparser(cls, arg_parser) - @property - def scaler(self): - return self.accelerator.scaler + @classmethod + def match_env_arguments(cls) -> Namespace: + return parse_env_variables(cls) - # TODO: refactor this so that it can be done in LightningOptimizer - def __getstate__(self): - # remove lightning_optimizers - self._lightning_optimizers = None - return self.__dict__ + @classmethod + def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: + return add_argparse_args(cls, parent_parser, **kwargs) - def __setstate__(self, state): - self.__dict__ = state + ''' + State properties + ''' @property - def distributed_sampler_kwargs(self) -> Optional[dict]: - if isinstance(self.training_type_plugin, ParallelPlugin): - return self.training_type_plugin.distributed_sampler_kwargs + def interrupted(self) -> bool: + return self.state.status == TrainerStatus.INTERRUPTED @property def training(self) -> bool: @@ -492,6 +476,10 @@ def sanity_checking(self, val: bool) -> None: elif self.sanity_checking: self.state.stage = None + ''' + Loop properties + ''' + @property def global_step(self) -> int: return self.train_loop.global_step @@ -516,6 +504,22 @@ def max_steps(self) -> Optional[int]: def min_steps(self) -> Optional[int]: return self.train_loop.min_steps + ''' + Logging properties + ''' + + @property + def callback_metrics(self) -> dict: + return self.logger_connector.callback_metrics + + @property + def logged_metrics(self) -> dict: + return self.logger_connector.logged_metrics + + @property + def progress_bar_metrics(self) -> dict: + return self.logger_connector.progress_bar_metrics + @property def result_collection(self) -> Optional[ResultCollection]: if self.training: @@ -525,6 +529,19 @@ def result_collection(self) -> Optional[ResultCollection]: elif self.testing: return self.evaluation_loop.test_results + ''' + Other + ''' + + # TODO: refactor this so that it can be done in LightningOptimizer + def __getstate__(self): + # remove lightning_optimizers + self._lightning_optimizers = None + return self.__dict__ + + def __setstate__(self, state): + self.__dict__ = state + # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties) From 48bcc2ed2373168b69ed52da4e90dc2fc1ef154a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 29 May 2021 22:09:34 +0200 Subject: [PATCH 197/455] Remove dev debugger metric tracking --- .../connectors/logger_connector/logger_connector.py | 3 --- pytorch_lightning/utilities/debugging.py | 12 ------------ 2 files changed, 15 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 7bd834d5925b4..844593016d7b2 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -230,7 +230,6 @@ def log_metrics(self, metrics, grad_norm_dict, step=None): # track the logged metrics self.logged_metrics.update(scalar_metrics) - self.trainer.dev_debugger.track_logged_metrics_history(scalar_metrics) def add_progress_bar_metrics(self, metrics): for k, v in metrics.items(): @@ -239,8 +238,6 @@ def add_progress_bar_metrics(self, metrics): self._progress_bar_metrics.metrics[k] = v - self.trainer.dev_debugger.track_pbar_metrics_history(metrics) - def evaluation_epoch_end(self): # reset dataloader idx model_ref = self.trainer.lightning_module diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index 56833fd03735a..3a3afd2b36329 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -39,8 +39,6 @@ class InternalDebugger(object): def __init__(self, trainer): self.enabled = os.environ.get('PL_DEV_DEBUG', '0') == '1' self.trainer = trainer - self.logged_metrics = [] - self.pbar_added_metrics = [] self.saved_train_losses = [] self.saved_val_losses = [] self.saved_test_losses = [] @@ -110,11 +108,6 @@ def track_load_dataloader_call(self, name, dataloaders): elif 'test' in name: self.test_dataloader_calls.append(values) - @enabled_only - def track_logged_metrics_history(self, scalar_metrics): - scalar_metrics['global_step'] = self.trainer.global_step - self.logged_metrics.append(scalar_metrics) - @enabled_only def track_train_loss_history(self, batch_idx, loss): loss_dict = {'batch_idx': batch_idx, 'epoch': self.trainer.current_epoch, 'loss': loss.detach()} @@ -151,11 +144,6 @@ def track_eval_loss_history(self, batch_idx, dataloader_idx, output): else: self.saved_val_losses.append(loss_dict) - @enabled_only - def track_pbar_metrics_history(self, metrics): - metrics['debug_epoch'] = self.trainer.current_epoch - self.pbar_added_metrics.append(metrics) - @enabled_only def track_early_stopping_history(self, callback, current): debug_dict = { From 25460426500987ccd82651df0d4165a09e51db90 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 29 May 2021 22:14:04 +0200 Subject: [PATCH 198/455] Fix tests --- tests/accelerators/test_multi_nodes_gpu.py | 3 -- .../logging_/test_eval_loop_logging.py | 42 +++---------------- .../logging_/test_train_loop_logging.py | 8 ---- 3 files changed, 6 insertions(+), 47 deletions(-) diff --git a/tests/accelerators/test_multi_nodes_gpu.py b/tests/accelerators/test_multi_nodes_gpu.py index 42a9b1c064199..ca6f816260f6e 100644 --- a/tests/accelerators/test_multi_nodes_gpu.py +++ b/tests/accelerators/test_multi_nodes_gpu.py @@ -13,7 +13,6 @@ # limitations under the License. import os import sys -from unittest import mock import pytest import torch @@ -73,7 +72,6 @@ def validation_step(self, batch, batch_idx): # use an environment variable `PL_RUNNING_MULTINODE_TESTS` and set `RunIf(multinode=True)` @pytest.mark.skip("Multi-node testing is currently disabled") @RunIf(special=True) -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test__validation_step__log(tmpdir): """ Tests that validation_step can log @@ -131,6 +129,5 @@ def backward(self, loss, optimizer, optimizer_idx): # we don't want to enable val metrics during steps because it is not something that users should do # on purpose DO NOT allow step_b... it's silly to monitor val step metrics callback_metrics = set(trainer.callback_metrics.keys()) - callback_metrics.remove('debug_epoch') expected_cb_metrics = {'a', 'a2', 'b', 'a_epoch', 'b_epoch', 'a_step'} assert expected_cb_metrics == callback_metrics diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 331734aa9b412..fc83fa8520b41 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -16,7 +16,6 @@ """ import collections import itertools -import os from unittest import mock from unittest.mock import call @@ -160,7 +159,6 @@ def backward(self, loss, optimizer, optimizer_idx): assert expected_cb_metrics == callback_metrics -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @pytest.mark.parametrize(['batches', 'log_interval', 'max_epochs'], [(1, 1, 1), (64, 32, 2)]) def test_eval_epoch_logging(tmpdir, batches, log_interval, max_epochs): """ @@ -199,16 +197,12 @@ def validation_epoch_end(self, outputs): assert pbar_metrics == expected_pbar_metrics callback_metrics = set(trainer.callback_metrics.keys()) - callback_metrics.remove('debug_epoch') expected_callback_metrics = set() expected_callback_metrics = expected_callback_metrics.union(logged_metrics) expected_callback_metrics = expected_callback_metrics.union(pbar_metrics) expected_callback_metrics.remove('epoch') assert callback_metrics == expected_callback_metrics - # assert the loggers received the expected number - assert len(trainer.dev_debugger.logged_metrics) == max_epochs - def test_eval_float_logging(tmpdir): """ @@ -244,7 +238,6 @@ def validation_step(self, batch, batch_idx): assert logged_metrics == expected_logged_metrics -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_eval_logging_auto_reduce(tmpdir): """ Tests that only training_step can be used @@ -288,26 +281,12 @@ def validation_epoch_end(self, outputs) -> None: # make sure all the metrics are available for callbacks manual_mean = model.manual_epoch_end_mean callback_metrics = set(trainer.callback_metrics.keys()) - assert callback_metrics == {'debug_epoch', 'val_loss', 'val_loss_epoch'} + assert callback_metrics == {'val_loss', 'val_loss_epoch'} # make sure values are correct assert trainer.logged_metrics['val_loss_epoch'] == manual_mean assert trainer.callback_metrics['val_loss'] == trainer.logged_metrics['val_loss_step'] - # make sure correct values were logged - logged_val = trainer.dev_debugger.logged_metrics - - # 3 val batches - assert logged_val[0]['val_loss_step'] == model.seen_vals[0] - assert logged_val[1]['val_loss_step'] == model.seen_vals[1] - assert logged_val[2]['val_loss_step'] == model.seen_vals[2] - - # epoch mean - assert logged_val[3]['val_loss_epoch'] == model.manual_epoch_end_mean - - # only those logged - assert len(logged_val) == 4 - @pytest.mark.parametrize(['batches', 'log_interval', 'max_epochs'], [(1, 1, 1), (64, 32, 2)]) def test_eval_epoch_only_logging(tmpdir, batches, log_interval, max_epochs): @@ -414,7 +393,6 @@ def test_step(self, batch, *args): assert {"test_loss", "test_loss_epoch"} == set(results[0]) -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_log_works_in_val_callback(tmpdir): """ Tests that log can be called within callback @@ -588,7 +566,6 @@ def get_expected_output(func_attr, original_values): # Make sure the func_name output equals the average from all logged values when on_epoch true # pop extra keys - trainer.callback_metrics.pop("debug_epoch") trainer.callback_metrics.pop("val_loss") for func_name, output_value in trainer.callback_metrics.items(): # not sure how to handle this now @@ -615,7 +592,6 @@ def get_expected_output(func_attr, original_values): assert func_name not in trainer.logger_connector.progress_bar_metrics -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_log_works_in_test_callback(tmpdir): """ Tests that log can be called within callback @@ -770,10 +746,6 @@ def get_expected_output(func_attr, original_values): return expected_output # Make sure the func_name output equals the average from all logged values when on_epoch true - # pop extra keys - assert "debug_epoch" in trainer.callback_metrics - trainer.callback_metrics.pop("debug_epoch") - for dl_idx in range(num_dataloaders): key = f"test_loss/dataloader_idx_{dl_idx}" assert key in trainer.callback_metrics @@ -807,7 +779,6 @@ def get_expected_output(func_attr, original_values): @mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics") -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_validation_step_log_with_tensorboard(mock_log_metrics, tmpdir): """ This tests make sure we properly log_metrics to loggers @@ -870,7 +841,7 @@ def get_metrics_at_idx(idx): else: return mock_calls[idx][2]["metrics"] - expected = ['valid_loss_0_step', 'valid_loss_2', 'global_step'] + expected = ['valid_loss_0_step', 'valid_loss_2'] assert sorted(get_metrics_at_idx(1)) == sorted(expected) assert sorted(get_metrics_at_idx(2)) == sorted(expected) @@ -879,12 +850,12 @@ def get_metrics_at_idx(idx): expected = model.val_losses[3] assert get_metrics_at_idx(2)["valid_loss_0_step"] == expected - expected = ['valid_loss_0_epoch', 'valid_loss_1', 'epoch', 'global_step'] + expected = ['valid_loss_0_epoch', 'valid_loss_1', 'epoch'] assert sorted(get_metrics_at_idx(3)) == sorted(expected) expected = torch.stack(model.val_losses[2:4]).mean() assert get_metrics_at_idx(3)["valid_loss_1"] == expected - expected = ['valid_loss_0_step', 'valid_loss_2', 'global_step'] + expected = ['valid_loss_0_step', 'valid_loss_2'] assert sorted(get_metrics_at_idx(4)) == sorted(expected) assert sorted(get_metrics_at_idx(5)) == sorted(expected) @@ -894,7 +865,7 @@ def get_metrics_at_idx(idx): expected = model.val_losses[5] assert get_metrics_at_idx(5)["valid_loss_0_step"] == expected - expected = ['valid_loss_0_epoch', 'valid_loss_1', 'epoch', 'global_step'] + expected = ['valid_loss_0_epoch', 'valid_loss_1', 'epoch'] assert sorted(get_metrics_at_idx(6)) == sorted(expected) expected = torch.stack(model.val_losses[4:]).mean() @@ -905,9 +876,8 @@ def get_metrics_at_idx(idx): 'train_loss', 'valid_loss_0_epoch', 'valid_loss_0', - 'debug_epoch', 'valid_loss_1', 'test_loss', } assert set(trainer.callback_metrics) == expected_callback_metrics - assert set(results[0]) == {'test_loss', 'debug_epoch'} + assert set(results[0]) == {'test_loss'} diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 546fb9ff8fdac..e5fc755fa9a9d 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -34,7 +34,6 @@ from tests.helpers.runif import RunIf -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test__training_step__log(tmpdir): """ Tests that only training_step can be used @@ -122,7 +121,6 @@ def backward(self, loss, optimizer, optimizer_idx): assert pbar_metrics == expected_pbar_metrics callback_metrics = set(trainer.callback_metrics.keys()) - callback_metrics.remove('debug_epoch') expected_callback_metrics = set() expected_callback_metrics = expected_callback_metrics.union(logged_metrics) expected_callback_metrics = expected_callback_metrics.union(pbar_metrics) @@ -131,7 +129,6 @@ def backward(self, loss, optimizer, optimizer_idx): assert callback_metrics == expected_callback_metrics -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test__training_step__epoch_end__log(tmpdir): """ Tests that only training_step can be used @@ -183,7 +180,6 @@ def backward(self, loss, optimizer, optimizer_idx): assert pbar_metrics == expected_pbar_metrics callback_metrics = set(trainer.callback_metrics.keys()) - callback_metrics.remove('debug_epoch') expected_callback_metrics = set() expected_callback_metrics = expected_callback_metrics.union(logged_metrics) expected_callback_metrics = expected_callback_metrics.union(pbar_metrics) @@ -192,7 +188,6 @@ def backward(self, loss, optimizer, optimizer_idx): assert callback_metrics == expected_callback_metrics -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @pytest.mark.parametrize(['batches', 'log_interval', 'max_epochs'], [(1, 1, 1), (64, 32, 2)]) def test__training_step__step_end__epoch_end__log(tmpdir, batches, log_interval, max_epochs): """ @@ -245,7 +240,6 @@ def training_epoch_end(self, outputs): assert pbar_metrics == expected_pbar_metrics callback_metrics = set(trainer.callback_metrics.keys()) - callback_metrics.remove('debug_epoch') expected_callback_metrics = set() expected_callback_metrics = expected_callback_metrics.union(logged_metrics) expected_callback_metrics = expected_callback_metrics.union(pbar_metrics) @@ -495,7 +489,6 @@ def validation_step(self, batch, batch_idx): trainer.fit(model, train_data, val_data) -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_log_works_in_train_callback(tmpdir): """ Tests that log can be called within callback @@ -667,7 +660,6 @@ def get_expected_output(func_attr, original_values): # Make sure the func_name output equals the average from all logged values when on_epoch true # pop extra keys - trainer.callback_metrics.pop("debug_epoch") assert trainer.logged_metrics["train_loss"] == model.manual_loss[-1] assert trainer.callback_metrics["train_loss"] == model.manual_loss[-1] trainer.callback_metrics.pop("train_loss") From c7171931db1bb5d6f98e496d6a1dd9af6ec66730 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 29 May 2021 22:18:39 +0200 Subject: [PATCH 199/455] Fix test --- tests/trainer/logging_/test_train_loop_logging.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index e5fc755fa9a9d..cecba77652a0c 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -247,9 +247,6 @@ def training_epoch_end(self, outputs): expected_callback_metrics.remove('epoch') assert callback_metrics == expected_callback_metrics - # assert the loggers received the expected number - assert len(trainer.dev_debugger.logged_metrics) == ((batches / log_interval) * max_epochs) + max_epochs - @pytest.mark.parametrize(['batches', 'fx', 'result'], [(1, min, 0), (2, max, 1), (11, max, 10)]) def test__training_step__log_max_reduce_fx(tmpdir, batches, fx, result): From 3a76df360d9b0231e92e0733b2e325f4a92a84e9 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 29 May 2021 22:18:48 +0200 Subject: [PATCH 200/455] Import --- tests/trainer/logging_/test_train_loop_logging.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index cecba77652a0c..de6236c4387e6 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -17,8 +17,6 @@ import collections import itertools -import os -from unittest import mock import numpy as np import pytest From 3209b0a6d1f4b975d2aa373cf2b9fe0f0fdb91a1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 29 May 2021 22:26:43 +0200 Subject: [PATCH 201/455] Clean logging tests --- tests/accelerators/test_multi_nodes_gpu.py | 11 +- tests/helpers/boring_model.py | 13 - .../logging_/test_eval_loop_logging.py | 276 +++++------------- .../logging_/test_train_loop_logging.py | 241 +++------------ 4 files changed, 114 insertions(+), 427 deletions(-) diff --git a/tests/accelerators/test_multi_nodes_gpu.py b/tests/accelerators/test_multi_nodes_gpu.py index 42a9b1c064199..cae257666e390 100644 --- a/tests/accelerators/test_multi_nodes_gpu.py +++ b/tests/accelerators/test_multi_nodes_gpu.py @@ -13,7 +13,6 @@ # limitations under the License. import os import sys -from unittest import mock import pytest import torch @@ -73,7 +72,6 @@ def validation_step(self, batch, batch_idx): # use an environment variable `PL_RUNNING_MULTINODE_TESTS` and set `RunIf(multinode=True)` @pytest.mark.skip("Multi-node testing is currently disabled") @RunIf(special=True) -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test__validation_step__log(tmpdir): """ Tests that validation_step can log @@ -117,7 +115,7 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.fit(model) # make sure all the metrics are available for callbacks - expected_logged_metrics = { + assert set(trainer.logged_metrics) == { 'a2', 'a_step', 'a_epoch', @@ -125,12 +123,7 @@ def backward(self, loss, optimizer, optimizer_idx): 'b_epoch', 'epoch', } - logged_metrics = set(trainer.logged_metrics.keys()) - assert expected_logged_metrics == logged_metrics # we don't want to enable val metrics during steps because it is not something that users should do # on purpose DO NOT allow step_b... it's silly to monitor val step metrics - callback_metrics = set(trainer.callback_metrics.keys()) - callback_metrics.remove('debug_epoch') - expected_cb_metrics = {'a', 'a2', 'b', 'a_epoch', 'b_epoch', 'a_step'} - assert expected_cb_metrics == callback_metrics + assert set(trainer.callback_metrics) == {'a', 'a2', 'b', 'a_epoch', 'b_epoch', 'a_step'} diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py index eb81baeb2c29d..fff9d2de79f77 100644 --- a/tests/helpers/boring_model.py +++ b/tests/helpers/boring_model.py @@ -34,19 +34,6 @@ def __len__(self): return self.len -class RandomDictStringDataset(Dataset): - - def __init__(self, size, length): - self.len = length - self.data = torch.randn(length, size) - - def __getitem__(self, index): - return {"id": str(index), "x": self.data[index]} - - def __len__(self): - return self.len - - class RandomDataset(Dataset): def __init__(self, size, length): diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 331734aa9b412..c6239a00dd3ee 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -16,7 +16,6 @@ """ import collections import itertools -import os from unittest import mock from unittest.mock import call @@ -24,12 +23,9 @@ import pytest import torch -from pytorch_lightning import callbacks, seed_everything, Trainer -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning import callbacks, Trainer from pytorch_lightning.loggers import TensorBoardLogger from tests.helpers import BoringModel, RandomDataset -from tests.helpers.deterministic_model import DeterministicModel def test__validation_step__log(tmpdir): @@ -37,25 +33,18 @@ def test__validation_step__log(tmpdir): Tests that validation_step can log """ - class TestModel(DeterministicModel): + class TestModel(BoringModel): def training_step(self, batch, batch_idx): - acc = self.step(batch, batch_idx) - acc = acc + batch_idx - self.log('a', acc, on_step=True, on_epoch=True) + out = super().training_step(batch, batch_idx) + self.log('a', out['loss'], on_step=True, on_epoch=True) self.log('a2', 2) - - self.training_step_called = True - return acc + return out def validation_step(self, batch, batch_idx): - acc = self.step(batch, batch_idx) - acc = acc + batch_idx - self.log('b', acc, on_step=True, on_epoch=True) - self.training_step_called = True - - def backward(self, loss, optimizer, optimizer_idx): - return LightningModule.backward(self, loss, optimizer, optimizer_idx) + out = super().validation_step(batch, batch_idx) + self.log('b', out['x'], on_step=True, on_epoch=True) + return out model = TestModel() model.validation_step_end = None @@ -90,39 +79,27 @@ def backward(self, loss, optimizer, optimizer_idx): assert expected_cb_metrics == callback_metrics -def test__validation_step__step_end__epoch_end__log(tmpdir): +def test__validation_step__epoch_end__log(tmpdir): """ - Tests that validation_step can log + Tests that validation_epoch_end can log """ - class TestModel(DeterministicModel): + class TestModel(BoringModel): def training_step(self, batch, batch_idx): - acc = self.step(batch, batch_idx) - acc = acc + batch_idx - self.log('a', acc) - self.log('b', acc, on_step=True, on_epoch=True) - self.training_step_called = True - return acc + out = super().training_step(batch, batch_idx) + self.log('a', out['loss']) + self.log('b', out['loss'], on_step=True, on_epoch=True) + return out def validation_step(self, batch, batch_idx): - acc = self.step(batch, batch_idx) - acc = acc + batch_idx - self.log('c', acc) - self.log('d', acc, on_step=True, on_epoch=True) - self.validation_step_called = True - return acc - - def validation_step_end(self, acc): - self.validation_step_end_called = True - return ['random_thing'] + out = super().validation_step(batch, batch_idx) + self.log('c', out['x']) + self.log('d', out['x'], on_step=True, on_epoch=True) + return out def validation_epoch_end(self, outputs): self.log('g', torch.tensor(2, device=self.device), on_epoch=True) - self.validation_epoch_end_called = True - - def backward(self, loss, optimizer, optimizer_idx): - return LightningModule.backward(self, loss, optimizer, optimizer_idx) model = TestModel() @@ -136,9 +113,8 @@ def backward(self, loss, optimizer, optimizer_idx): ) trainer.fit(model) - # make sure all the metrics are available for callbacks - logged_metrics = set(trainer.logged_metrics.keys()) - expected_logged_metrics = { + # make sure all the metrics are available for loggers + assert set(trainer.logged_metrics) == { 'epoch', 'a', 'b_step', @@ -148,24 +124,15 @@ def backward(self, loss, optimizer, optimizer_idx): 'd_epoch', 'g', } - assert expected_logged_metrics == logged_metrics - progress_bar_metrics = set(trainer.progress_bar_metrics.keys()) - expected_pbar_metrics = set() - assert expected_pbar_metrics == progress_bar_metrics + assert not trainer.progress_bar_metrics # we don't want to enable val metrics during steps because it is not something that users should do - callback_metrics = set(trainer.callback_metrics.keys()) - expected_cb_metrics = {'a', 'b', 'b_epoch', 'c', 'd', 'd_epoch', 'g', 'b_step'} - assert expected_cb_metrics == callback_metrics + assert set(trainer.callback_metrics) == {'a', 'b', 'b_epoch', 'c', 'd', 'd_epoch', 'g', 'b_step'} -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @pytest.mark.parametrize(['batches', 'log_interval', 'max_epochs'], [(1, 1, 1), (64, 32, 2)]) def test_eval_epoch_logging(tmpdir, batches, log_interval, max_epochs): - """ - Tests that only training_step can be used - """ class TestModel(BoringModel): @@ -185,35 +152,23 @@ def validation_epoch_end(self, outputs): ) trainer.fit(model) - # make sure all the metrics are available for callbacks - logged_metrics = set(trainer.logged_metrics.keys()) - expected_logged_metrics = { + # assert the loggers received the expected number + logged_metrics = set(trainer.logged_metrics) + assert logged_metrics == { 'c', 'd/e/f', 'epoch', } - assert logged_metrics == expected_logged_metrics - pbar_metrics = set(trainer.progress_bar_metrics.keys()) - expected_pbar_metrics = {'c'} - assert pbar_metrics == expected_pbar_metrics - - callback_metrics = set(trainer.callback_metrics.keys()) - callback_metrics.remove('debug_epoch') - expected_callback_metrics = set() - expected_callback_metrics = expected_callback_metrics.union(logged_metrics) - expected_callback_metrics = expected_callback_metrics.union(pbar_metrics) - expected_callback_metrics.remove('epoch') - assert callback_metrics == expected_callback_metrics + pbar_metrics = set(trainer.progress_bar_metrics) + assert pbar_metrics == {'c'} - # assert the loggers received the expected number - assert len(trainer.dev_debugger.logged_metrics) == max_epochs + # make sure all the metrics are available for callbacks + callback_metrics = set(trainer.callback_metrics) + assert callback_metrics == (logged_metrics | pbar_metrics) - {'epoch'} def test_eval_float_logging(tmpdir): - """ - Tests that only training_step can be used - """ class TestModel(BoringModel): @@ -235,45 +190,28 @@ def validation_step(self, batch, batch_idx): ) trainer.fit(model) - # make sure all the metrics are available for callbacks - logged_metrics = set(trainer.logged_metrics.keys()) - expected_logged_metrics = { - 'a', - 'epoch', - } - assert logged_metrics == expected_logged_metrics + assert set(trainer.logged_metrics) == {'a', 'epoch'} -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_eval_logging_auto_reduce(tmpdir): - """ - Tests that only training_step can be used - """ - seed_everything(1234) class TestModel(BoringModel): - - def on_pretrain_routine_end(self) -> None: - self.seen_vals = [] - self.manual_epoch_end_mean = None - - def on_validation_epoch_start(self) -> None: - self.seen_vals = [] + val_losses = [] + manual_epoch_end_mean = None def validation_step(self, batch, batch_idx): output = self.layer(batch) loss = self.loss(batch, output) - self.seen_vals.append(loss) + self.val_losses.append(loss) self.log('val_loss', loss, on_epoch=True, on_step=True, prog_bar=True) return {"x": loss} def validation_epoch_end(self, outputs) -> None: - for passed_in, manually_tracked in zip(outputs, self.seen_vals): + for passed_in, manually_tracked in zip(outputs, self.val_losses): assert passed_in['x'] == manually_tracked self.manual_epoch_end_mean = torch.stack([x['x'] for x in outputs]).mean() model = TestModel() - trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=3, @@ -281,93 +219,61 @@ def validation_epoch_end(self, outputs) -> None: max_epochs=1, log_every_n_steps=1, weights_summary=None, - callbacks=[ModelCheckpoint(dirpath=tmpdir)], + num_sanity_val_steps=0, ) trainer.fit(model) # make sure all the metrics are available for callbacks - manual_mean = model.manual_epoch_end_mean - callback_metrics = set(trainer.callback_metrics.keys()) - assert callback_metrics == {'debug_epoch', 'val_loss', 'val_loss_epoch'} + assert set(trainer.callback_metrics) == {'val_loss', 'val_loss_epoch'} # make sure values are correct - assert trainer.logged_metrics['val_loss_epoch'] == manual_mean + assert trainer.logged_metrics['val_loss_epoch'] == model.manual_epoch_end_mean assert trainer.callback_metrics['val_loss'] == trainer.logged_metrics['val_loss_step'] - # make sure correct values were logged - logged_val = trainer.dev_debugger.logged_metrics - - # 3 val batches - assert logged_val[0]['val_loss_step'] == model.seen_vals[0] - assert logged_val[1]['val_loss_step'] == model.seen_vals[1] - assert logged_val[2]['val_loss_step'] == model.seen_vals[2] - - # epoch mean - assert logged_val[3]['val_loss_epoch'] == model.manual_epoch_end_mean - - # only those logged - assert len(logged_val) == 4 - @pytest.mark.parametrize(['batches', 'log_interval', 'max_epochs'], [(1, 1, 1), (64, 32, 2)]) def test_eval_epoch_only_logging(tmpdir, batches, log_interval, max_epochs): """ - Tests that only test_epoch_end can be used to log, and we return them in the results. + Tests that test_epoch_end can be used to log, and we return them in the results. """ class TestModel(BoringModel): def test_epoch_end(self, outputs): - self.log('c', torch.tensor(2), on_epoch=True, prog_bar=True, logger=True) + self.log('c', torch.tensor(2)) self.log('d/e/f', 2) model = TestModel() - trainer = Trainer( default_root_dir=tmpdir, - limit_train_batches=batches, - limit_val_batches=batches, max_epochs=max_epochs, + limit_test_batches=batches, log_every_n_steps=log_interval, weights_summary=None, ) - trainer.fit(model) results = trainer.test(model) - expected_result_metrics = { - 'c': torch.tensor(2), - 'd/e/f': 2, - } - for result in results: - assert result == expected_result_metrics - - -def test_monitor_val_epoch_end(tmpdir): - epoch_min_loss_override = 0 - model = BoringModel() - checkpoint_callback = callbacks.ModelCheckpoint(dirpath=tmpdir, save_top_k=1, monitor="avg_val_loss") - trainer = Trainer( - max_epochs=epoch_min_loss_override + 2, - logger=False, - callbacks=[checkpoint_callback], - ) - trainer.fit(model) + assert len(results) == 1 + assert results[0] == {'c': torch.tensor(2), 'd/e/f': 2} -def test_multi_dataloaders_add_suffix_properly(tmpdir): +@pytest.mark.parametrize('suffix', (False, True)) +def test_multi_dataloaders_add_suffix_properly(tmpdir, suffix): class TestModel(BoringModel): - def test_step(self, batch, *args): - output = self.layer(batch) - loss = self.loss(batch, output) - self.log("test_loss", loss, on_step=True, on_epoch=True) + def test_step(self, batch, batch_idx, dataloader_idx=0): + out = super().test_step(batch, batch_idx) + self.log("test_loss", out['y'], on_step=True, on_epoch=True) + return out def test_dataloader(self): - return [ - torch.utils.data.DataLoader(RandomDataset(32, 64)), - torch.utils.data.DataLoader(RandomDataset(32, 64)) - ] + if suffix: + return [ + torch.utils.data.DataLoader(RandomDataset(32, 64)), + torch.utils.data.DataLoader(RandomDataset(32, 64)) + ] + return super().test_dataloader() model = TestModel() model.test_epoch_end = None @@ -383,38 +289,13 @@ def test_dataloader(self): ) results = trainer.test(model) - assert {"test_loss/dataloader_idx_0", "test_loss_epoch/dataloader_idx_0"} == set(results[0]) - assert {"test_loss/dataloader_idx_1", "test_loss_epoch/dataloader_idx_1"} == set(results[1]) + for i, r in enumerate(results): + expected = {'test_loss', 'test_loss_epoch'} + if suffix: + expected = {e + f'/dataloader_idx_{i}' for e in expected} + assert set(r) == expected -def test_single_dataloader_no_suffix_added(tmpdir): - - class TestModel(BoringModel): - - def test_step(self, batch, *args): - output = self.layer(batch) - loss = self.loss(batch, output) - self.log("test_loss", loss, on_step=True, on_epoch=True) - - model = TestModel() - model.test_epoch_end = None - - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=0, - limit_val_batches=0, - limit_test_batches=5, - max_epochs=1, - log_every_n_steps=1, - weights_summary=None, - ) - results = trainer.test(model) - - assert len(results) == 1 - assert {"test_loss", "test_loss_epoch"} == set(results[0]) - - -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_log_works_in_val_callback(tmpdir): """ Tests that log can be called within callback @@ -588,7 +469,6 @@ def get_expected_output(func_attr, original_values): # Make sure the func_name output equals the average from all logged values when on_epoch true # pop extra keys - trainer.callback_metrics.pop("debug_epoch") trainer.callback_metrics.pop("val_loss") for func_name, output_value in trainer.callback_metrics.items(): # not sure how to handle this now @@ -615,7 +495,6 @@ def get_expected_output(func_attr, original_values): assert func_name not in trainer.logger_connector.progress_bar_metrics -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_log_works_in_test_callback(tmpdir): """ Tests that log can be called within callback @@ -770,10 +649,6 @@ def get_expected_output(func_attr, original_values): return expected_output # Make sure the func_name output equals the average from all logged values when on_epoch true - # pop extra keys - assert "debug_epoch" in trainer.callback_metrics - trainer.callback_metrics.pop("debug_epoch") - for dl_idx in range(num_dataloaders): key = f"test_loss/dataloader_idx_{dl_idx}" assert key in trainer.callback_metrics @@ -807,7 +682,6 @@ def get_expected_output(func_attr, original_values): @mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics") -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_validation_step_log_with_tensorboard(mock_log_metrics, tmpdir): """ This tests make sure we properly log_metrics to loggers @@ -860,7 +734,6 @@ def test_step(self, batch, batch_idx): expected_num_calls = 1 + 2 + 1 + 2 + 1 assert len(mock_log_metrics.mock_calls) == expected_num_calls - assert mock_log_metrics.mock_calls[0] == call({'hp_metric': -1}, 0) def get_metrics_at_idx(idx): @@ -870,44 +743,43 @@ def get_metrics_at_idx(idx): else: return mock_calls[idx][2]["metrics"] - expected = ['valid_loss_0_step', 'valid_loss_2', 'global_step'] - assert sorted(get_metrics_at_idx(1)) == sorted(expected) - assert sorted(get_metrics_at_idx(2)) == sorted(expected) + expected = {'valid_loss_0_step', 'valid_loss_2'} + assert set(get_metrics_at_idx(1)) == expected + assert set(get_metrics_at_idx(2)) == expected expected = model.val_losses[2] assert get_metrics_at_idx(1)["valid_loss_0_step"] == expected expected = model.val_losses[3] assert get_metrics_at_idx(2)["valid_loss_0_step"] == expected - expected = ['valid_loss_0_epoch', 'valid_loss_1', 'epoch', 'global_step'] - assert sorted(get_metrics_at_idx(3)) == sorted(expected) + expected = {'valid_loss_0_epoch', 'valid_loss_1', 'epoch'} + assert set(get_metrics_at_idx(3)) == expected expected = torch.stack(model.val_losses[2:4]).mean() assert get_metrics_at_idx(3)["valid_loss_1"] == expected - expected = ['valid_loss_0_step', 'valid_loss_2', 'global_step'] - assert sorted(get_metrics_at_idx(4)) == sorted(expected) - assert sorted(get_metrics_at_idx(5)) == sorted(expected) + expected = {'valid_loss_0_step', 'valid_loss_2'} + assert set(get_metrics_at_idx(4)) == expected + assert set(get_metrics_at_idx(5)) == expected expected = model.val_losses[4] assert get_metrics_at_idx(4)["valid_loss_0_step"] == expected expected = model.val_losses[5] assert get_metrics_at_idx(5)["valid_loss_0_step"] == expected - expected = ['valid_loss_0_epoch', 'valid_loss_1', 'epoch', 'global_step'] - assert sorted(get_metrics_at_idx(6)) == sorted(expected) + expected = {'valid_loss_0_epoch', 'valid_loss_1', 'epoch'} + assert set(get_metrics_at_idx(6)) == expected expected = torch.stack(model.val_losses[4:]).mean() assert get_metrics_at_idx(6)["valid_loss_1"] == expected results = trainer.test(model) - expected_callback_metrics = { + expected = { 'train_loss', 'valid_loss_0_epoch', 'valid_loss_0', - 'debug_epoch', 'valid_loss_1', 'test_loss', } - assert set(trainer.callback_metrics) == expected_callback_metrics - assert set(results[0]) == {'test_loss', 'debug_epoch'} + assert set(trainer.callback_metrics) == expected + assert set(results[0]) == {'test_loss'} diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 546fb9ff8fdac..b259160cff256 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -17,8 +17,6 @@ import collections import itertools -import os -from unittest import mock import numpy as np import pytest @@ -28,58 +26,51 @@ import pytorch_lightning as pl from pytorch_lightning import callbacks, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.core.lightning import LightningModule -from tests.helpers.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset -from tests.helpers.deterministic_model import DeterministicModel +from tests.helpers.boring_model import BoringModel, RandomDictDataset from tests.helpers.runif import RunIf -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test__training_step__log(tmpdir): """ Tests that only training_step can be used """ - class TestModel(DeterministicModel): + class TestModel(BoringModel): def training_step(self, batch, batch_idx): - acc = self.step(batch, batch_idx) - acc = acc + batch_idx + out = super().training_step(batch, batch_idx) + loss = out['loss'] # ----------- # default # ----------- - self.log('default', acc) + self.log('default', loss) # ----------- # logger # ----------- # on_step T on_epoch F - self.log('l_s', acc, on_step=True, on_epoch=False, prog_bar=False, logger=True) + self.log('l_s', loss, on_step=True, on_epoch=False, prog_bar=False, logger=True) # on_step F on_epoch T - self.log('l_e', acc, on_step=False, on_epoch=True, prog_bar=False, logger=True) + self.log('l_e', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True) # on_step T on_epoch T - self.log('l_se', acc, on_step=True, on_epoch=True, prog_bar=False, logger=True) + self.log('l_se', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True) # ----------- # pbar # ----------- # on_step T on_epoch F - self.log('p_s', acc, on_step=True, on_epoch=False, prog_bar=True, logger=False) + self.log('p_s', loss, on_step=True, on_epoch=False, prog_bar=True, logger=False) # on_step F on_epoch T - self.log('p_e', acc, on_step=False, on_epoch=True, prog_bar=True, logger=False) + self.log('p_e', loss, on_step=False, on_epoch=True, prog_bar=True, logger=False) # on_step T on_epoch T - self.log('p_se', acc, on_step=True, on_epoch=True, prog_bar=True, logger=False) - - self.training_step_called = True - return acc + self.log('p_se', loss, on_step=True, on_epoch=True, prog_bar=True, logger=False) - def backward(self, loss, optimizer, optimizer_idx): - return LightningModule.backward(self, loss, optimizer, optimizer_idx) + return loss model = TestModel() model.val_dataloader = None @@ -95,14 +86,8 @@ def backward(self, loss, optimizer, optimizer_idx): ) trainer.fit(model) - # make sure correct steps were called - assert model.training_step_called - assert not model.training_step_end_called - assert not model.training_epoch_end_called - - # make sure all the metrics are available for callbacks - logged_metrics = set(trainer.logged_metrics.keys()) - expected_logged_metrics = { + logged_metrics = set(trainer.logged_metrics) + assert logged_metrics == { 'epoch', 'default', 'l_e', @@ -110,51 +95,36 @@ def backward(self, loss, optimizer, optimizer_idx): 'l_se_step', 'l_se_epoch', } - assert logged_metrics == expected_logged_metrics - pbar_metrics = set(trainer.progress_bar_metrics.keys()) - expected_pbar_metrics = { + pbar_metrics = set(trainer.progress_bar_metrics) + assert pbar_metrics == { 'p_e', 'p_s', 'p_se_step', 'p_se_epoch', } - assert pbar_metrics == expected_pbar_metrics - callback_metrics = set(trainer.callback_metrics.keys()) - callback_metrics.remove('debug_epoch') - expected_callback_metrics = set() - expected_callback_metrics = expected_callback_metrics.union(logged_metrics) - expected_callback_metrics = expected_callback_metrics.union(pbar_metrics) - expected_callback_metrics.update({'p_se', 'l_se'}) - expected_callback_metrics.remove('epoch') - assert callback_metrics == expected_callback_metrics + assert set(trainer.callback_metrics) == (logged_metrics | pbar_metrics | {'p_se', 'l_se'}) - {'epoch'} -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test__training_step__epoch_end__log(tmpdir): """ - Tests that only training_step can be used + Tests that training_epoch_end can log """ - class TestModel(DeterministicModel): + class TestModel(BoringModel): def training_step(self, batch, batch_idx): - self.training_step_called = True - acc = self.step(batch, batch_idx) - acc = acc + batch_idx - self.log('a', acc, on_step=True, on_epoch=True) - self.log_dict({'a1': acc, 'a2': acc}) - return acc + out = super().training_step(batch, batch_idx) + loss = out['loss'] + self.log('a', loss, on_step=True, on_epoch=True) + self.log_dict({'a1': loss, 'a2': loss}) + return out def training_epoch_end(self, outputs): - self.training_epoch_end_called = True self.log('b1', outputs[0]['loss']) self.log('b', outputs[0]['loss'], on_epoch=True, prog_bar=True, logger=True) - def backward(self, loss, optimizer, optimizer_idx): - return LightningModule.backward(self, loss, optimizer, optimizer_idx) - model = TestModel() model.val_dataloader = None @@ -168,52 +138,33 @@ def backward(self, loss, optimizer, optimizer_idx): ) trainer.fit(model) - # make sure correct steps were called - assert model.training_step_called - assert not model.training_step_end_called - assert model.training_epoch_end_called - - # make sure all the metrics are available for callbacks - logged_metrics = set(trainer.logged_metrics.keys()) - expected_logged_metrics = {'epoch', 'a_step', 'a_epoch', 'b', 'b1', 'a1', 'a2'} - assert logged_metrics == expected_logged_metrics + logged_metrics = set(trainer.logged_metrics) + assert logged_metrics == {'epoch', 'a_step', 'a_epoch', 'b', 'b1', 'a1', 'a2'} - pbar_metrics = set(trainer.progress_bar_metrics.keys()) - expected_pbar_metrics = {'b'} - assert pbar_metrics == expected_pbar_metrics + pbar_metrics = set(trainer.progress_bar_metrics) + assert pbar_metrics == {'b'} - callback_metrics = set(trainer.callback_metrics.keys()) - callback_metrics.remove('debug_epoch') - expected_callback_metrics = set() - expected_callback_metrics = expected_callback_metrics.union(logged_metrics) - expected_callback_metrics = expected_callback_metrics.union(pbar_metrics) - expected_callback_metrics.remove('epoch') - expected_callback_metrics.add('a') - assert callback_metrics == expected_callback_metrics + assert set(trainer.callback_metrics) == (logged_metrics | pbar_metrics | {'a'}) - {'epoch'} -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @pytest.mark.parametrize(['batches', 'log_interval', 'max_epochs'], [(1, 1, 1), (64, 32, 2)]) def test__training_step__step_end__epoch_end__log(tmpdir, batches, log_interval, max_epochs): """ - Tests that only training_step can be used + Tests that training_step_end and training_epoch_end can log """ class TestModel(BoringModel): def training_step(self, batch, batch_idx): - self.training_step_called = True loss = self.step(batch[0]) self.log('a', loss, on_step=True, on_epoch=True) return loss def training_step_end(self, out): - self.training_step_end_called = True self.log('b', out, on_step=True, on_epoch=True, prog_bar=True, logger=True) return out def training_epoch_end(self, outputs): - self.training_epoch_end_called = True self.log('c', outputs[0]['loss'], on_epoch=True, prog_bar=True, logger=True) self.log('d/e/f', 2) @@ -230,31 +181,14 @@ def training_epoch_end(self, outputs): ) trainer.fit(model) - # make sure correct steps were called - assert model.training_step_called - assert model.training_step_end_called - assert model.training_epoch_end_called - # make sure all the metrics are available for callbacks - logged_metrics = set(trainer.logged_metrics.keys()) - expected_logged_metrics = {'a_step', 'a_epoch', 'b_step', 'b_epoch', 'c', 'd/e/f', 'epoch'} - assert logged_metrics == expected_logged_metrics + logged_metrics = set(trainer.logged_metrics) + assert logged_metrics == {'a_step', 'a_epoch', 'b_step', 'b_epoch', 'c', 'd/e/f', 'epoch'} - pbar_metrics = set(trainer.progress_bar_metrics.keys()) - expected_pbar_metrics = {'c', 'b_epoch', 'b_step'} - assert pbar_metrics == expected_pbar_metrics + pbar_metrics = set(trainer.progress_bar_metrics) + assert pbar_metrics == {'c', 'b_epoch', 'b_step'} - callback_metrics = set(trainer.callback_metrics.keys()) - callback_metrics.remove('debug_epoch') - expected_callback_metrics = set() - expected_callback_metrics = expected_callback_metrics.union(logged_metrics) - expected_callback_metrics = expected_callback_metrics.union(pbar_metrics) - expected_callback_metrics.update({'a', 'b'}) - expected_callback_metrics.remove('epoch') - assert callback_metrics == expected_callback_metrics - - # assert the loggers received the expected number - assert len(trainer.dev_debugger.logged_metrics) == ((batches / log_interval) * max_epochs) + max_epochs + assert set(trainer.callback_metrics) == (logged_metrics | pbar_metrics | {'a', 'b'}) - {'epoch'} @pytest.mark.parametrize(['batches', 'fx', 'result'], [(1, min, 0), (2, max, 1), (11, max, 10)]) @@ -267,7 +201,7 @@ class TestModel(BoringModel): def training_step(self, batch, batch_idx): acc = self.step(batch[0]) - self.log('foo', torch.tensor(batch_idx).long(), on_step=False, on_epoch=True, reduce_fx=fx) + self.log('foo', torch.tensor(batch_idx, dtype=torch.long), on_step=False, on_epoch=True, reduce_fx=fx) return acc def validation_step(self, batch, batch_idx): @@ -347,7 +281,6 @@ def train_dataloader(self): model = TestModel() model.training_epoch_end = None - model.example_input_array = torch.randn(5, truncated_bptt_steps) trainer = Trainer( default_root_dir=tmpdir, @@ -360,9 +293,7 @@ def train_dataloader(self): ) trainer.fit(model) - generated = set(trainer.logged_metrics.keys()) - expected = {'a_step', 'a_epoch', 'epoch'} - assert generated == expected + assert set(trainer.logged_metrics) == {'a_step', 'a_epoch', 'epoch'} def test_different_batch_types_for_sizing(tmpdir): @@ -400,102 +331,9 @@ def val_dataloader(self): ) trainer.fit(model) - generated = set(trainer.logger_connector.logged_metrics) - expected = {'a_step', 'a_epoch', 'n_step', 'n_epoch', 'epoch'} - - assert generated == expected - - -def test_validation_step_with_string_data_logging(tmpdir): - - class TestModel(BoringModel): - - def on_train_epoch_start(self) -> None: - print("override any method to prove your bug") - - def training_step(self, batch, batch_idx): - output = self.layer(batch["x"]) - loss = self.loss(batch, output) - return {"loss": loss} - - def validation_step(self, batch, batch_idx): - output = self.layer(batch["x"]) - loss = self.loss(batch, output) - self.log("x", loss) - return {"x": loss} - - # fake data - train_data = torch.utils.data.DataLoader(RandomDictStringDataset(32, 64)) - val_data = torch.utils.data.DataLoader(RandomDictStringDataset(32, 64)) - - # model - model = TestModel() - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=1, - limit_val_batches=1, - max_epochs=1, - weights_summary=None, - ) - trainer.fit(model, train_data, val_data) - - -def test_nested_datasouce_batch(tmpdir): - - class NestedDictStringDataset(Dataset): - - def __init__(self, size, length): - self.len = length - self.data = torch.randn(length, size) - - def __getitem__(self, index): - x = { - 'post_text': ['bird is fast', 'big cat'], - 'dense_0': [ - torch.tensor([-0.1000, 0.2000], dtype=torch.float64), - torch.tensor([1, 1], dtype=torch.uint8), - ], - 'post_id': ['115', '116'], - 'label': [torch.tensor([0, 1]), torch.tensor([1, 1], dtype=torch.uint8)] - } - return x - - def __len__(self): - return self.len - - class TestModel(BoringModel): - - def on_train_epoch_start(self) -> None: - print("override any method to prove your bug") - - def training_step(self, batch, batch_idx): - output = self.layer(torch.rand(32)) - loss = self.loss(batch, output) - return {"loss": loss} - - def validation_step(self, batch, batch_idx): - output = self.layer(torch.rand(32)) - loss = self.loss(batch, output) - self.log("x", loss) - return {"x": loss} - - # fake data - train_data = torch.utils.data.DataLoader(NestedDictStringDataset(32, 64)) - val_data = torch.utils.data.DataLoader(NestedDictStringDataset(32, 64)) - - # model - model = TestModel() - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=1, - limit_val_batches=1, - max_epochs=1, - weights_summary=None, - ) - trainer.fit(model, train_data, val_data) + assert set(trainer.logged_metrics) == {'a_step', 'a_epoch', 'n_step', 'n_epoch', 'epoch'} -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_log_works_in_train_callback(tmpdir): """ Tests that log can be called within callback @@ -666,8 +504,6 @@ def get_expected_output(func_attr, original_values): return expected_output # Make sure the func_name output equals the average from all logged values when on_epoch true - # pop extra keys - trainer.callback_metrics.pop("debug_epoch") assert trainer.logged_metrics["train_loss"] == model.manual_loss[-1] assert trainer.callback_metrics["train_loss"] == model.manual_loss[-1] trainer.callback_metrics.pop("train_loss") @@ -742,7 +578,6 @@ def training_step(self, batch, batch_idx): return acc def validation_step(self, batch, batch_idx): - self.training_step_called = True output = self.layer(batch) loss = self.loss(batch, output) self.log('bar', 2, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='AVG') From 0a48efb1fe3706b626f687845d9393db189a92d4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 29 May 2021 23:22:39 +0200 Subject: [PATCH 202/455] Fix tests --- tests/trainer/flags/test_fast_dev_run.py | 1 - .../trainer/logging_/test_logger_connector.py | 27 ------------------- 2 files changed, 28 deletions(-) diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index 1dffefb092716..8320134058c4e 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -95,7 +95,6 @@ def _make_fast_dev_run_assertions(trainer, model): # there should be no logger with fast_dev_run assert isinstance(trainer.logger, DummyLogger) - assert len(trainer.dev_debugger.logged_metrics) == fast_dev_run # checkpoint callback should not have been called with fast_dev_run assert trainer.checkpoint_callback == checkpoint_callback diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 75b81392ff916..c225cd9c6501c 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -575,33 +575,6 @@ def validation_step(self, *args, **kwargs): assert 'val_loss_custom_naming_1' in logged -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -def test_logged_metrics_steps(tmpdir): - - class TestModel(BoringModel): - - def validation_step(self, batch, batch_idx): - loss_val = torch.randn(1) - self.log('val_loss', loss_val) - return loss_val - - model = TestModel() - model.validation_epoch_end = None - - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=2, - log_every_n_steps=1, - weights_summary=None, - ) - trainer.fit(model) - - assert trainer.dev_debugger.logged_metrics[0]['global_step'] == 1 - assert trainer.dev_debugger.logged_metrics[1]['global_step'] == 3 - - def test_metrics_reset(tmpdir): """Tests that metrics are reset correctly after the end of the train/val/test epoch.""" From 282f580f418bb0c428a577116fdaa93373df5fe6 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 29 May 2021 23:38:36 +0200 Subject: [PATCH 203/455] Fix test --- tests/models/test_grad_norm.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py index 0e380e085ce6a..6a7e2b06c46e0 100644 --- a/tests/models/test_grad_norm.py +++ b/tests/models/test_grad_norm.py @@ -59,15 +59,25 @@ def on_after_backward(self): self.stored_grad_norms.append(out) -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @pytest.mark.parametrize("norm_type", [1., 1.25, 2, 3, 5, 10, 'inf']) def test_grad_tracking(tmpdir, norm_type, rtol=5e-3): # rtol=5e-3 respects the 3 decimals rounding in `.grad_norms` and above - reset_seed() - # use a custom grad tracking module and a list logger - model = ModelWithManualGradTracker(norm_type) + class TestModel(ModelWithManualGradTracker): + logged_metrics = [] + + def on_train_batch_end(self, *_) -> None: + if self.trainer.logged_metrics: + # add batch level logged metrics + # copy so they don't get reduced + self.logged_metrics.append(self.trainer.logged_metrics.copy()) + + def on_train_end(self): + # add aggregated logged metrics + self.logged_metrics.append(self.trainer.logged_metrics) + + model = TestModel(norm_type) trainer = Trainer( default_root_dir=tmpdir, @@ -76,18 +86,13 @@ def test_grad_tracking(tmpdir, norm_type, rtol=5e-3): log_every_n_steps=1, # request grad_norms every batch ) trainer.fit(model) - assert trainer.state.finished, f"Training failed with {trainer.state}" - logged_metrics = trainer.dev_debugger.logged_metrics - assert len(logged_metrics) == len(model.stored_grad_norms) + assert len(model.logged_metrics) == len(model.stored_grad_norms) # compare the logged metrics against tracked norms on `.backward` - for mod, log in zip(model.stored_grad_norms, logged_metrics): - common = mod.keys() & log.keys() - - log, mod = [log[k] for k in common], [mod[k] for k in common] - - assert np.allclose(log, mod, rtol=rtol) + for mod, log in zip(model.stored_grad_norms, model.logged_metrics): + for k in (mod.keys() & log.keys()): + assert np.allclose(mod[k], log[k], rtol=rtol), k @pytest.mark.parametrize("log_every_n_steps", [1, 2, 3]) From 741ec5bf43022d8ce538714e3a78e02838a98b4b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 29 May 2021 23:40:09 +0200 Subject: [PATCH 204/455] flake8 --- tests/models/test_grad_norm.py | 2 -- tests/trainer/logging_/test_logger_connector.py | 1 - 2 files changed, 3 deletions(-) diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py index 6a7e2b06c46e0..e547e4a973e89 100644 --- a/tests/models/test_grad_norm.py +++ b/tests/models/test_grad_norm.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -from unittest import mock from unittest.mock import patch import numpy as np diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index c225cd9c6501c..dcfa4a2681388 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -14,7 +14,6 @@ """ Tests to ensure that the training loop works with a dict (1.0) """ -import os from copy import deepcopy from typing import Any, Callable from unittest import mock From 60f3440dac129549fbc20ea963911e2cee2c7de6 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 29 May 2021 23:48:20 +0200 Subject: [PATCH 205/455] flake8 --- tests/trainer/logging_/test_train_loop_logging.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index b259160cff256..338e567be46d3 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -21,7 +21,6 @@ import numpy as np import pytest import torch -from torch.utils.data import Dataset import pytorch_lightning as pl from pytorch_lightning import callbacks, Trainer From 62b8dce8f5a37dced61af4410f4118461cd01c63 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 29 May 2021 23:49:05 +0200 Subject: [PATCH 206/455] Docstring --- tests/trainer/logging_/test_eval_loop_logging.py | 2 +- tests/trainer/logging_/test_train_loop_logging.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index c6239a00dd3ee..1a32cc5699a30 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Tests to ensure that the training loop works with a dict (1.0) +Test logging in the evaluation loop """ import collections import itertools diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 338e567be46d3..20a0014eb7cba 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Tests to ensure that the training loop works with a dict (1.0) +Test logging in the training loop """ import collections From 8a0b1a3ad31a96ef1954007aada04e9f78fd8499 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 29 May 2021 23:56:47 +0200 Subject: [PATCH 207/455] Fix tests --- tests/checkpointing/test_model_checkpoint.py | 21 ++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 2f867d4e998b4..62b9d8364b01c 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -83,6 +83,7 @@ def __init__(self): super().__init__() self.train_log_epochs = torch.randn(max_epochs, limit_train_batches) self.val_logs = torch.randn(max_epochs, limit_val_batches) + self.scores = [] def training_step(self, batch, batch_idx): log_value = self.train_log_epochs[self.current_epoch, batch_idx] @@ -109,6 +110,14 @@ def configure_optimizers(self): return [optimizer], [lr_scheduler] + def on_train_epoch_end(self): + if 'train' in monitor: + self.scores.append(self.trainer.logged_metrics[monitor]) + + def on_validation_epoch_end(self): + if not self.trainer.sanity_checking and 'val' in monitor: + self.scores.append(self.trainer.logged_metrics[monitor]) + filename = '{' + f'{monitor}' + ':.4f}-{epoch}' checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1) @@ -131,13 +140,12 @@ def configure_optimizers(self): assert trainer.state.finished, f"Training failed with {trainer.state}" ckpt_files = list(Path(tmpdir).glob('*.ckpt')) - scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric] lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates - assert len(ckpt_files) == len(scores) == max_epochs + assert len(ckpt_files) == len(model.scores) == max_epochs assert len(lr_scheduler_debug) == max_epochs for epoch in range(max_epochs): - score = scores[epoch] + score = model.scores[epoch] expected_score = getattr(model, f'{monitor}s')[epoch].mean().item() expected_filename = f'{monitor}={score:.4f}-epoch={epoch}.ckpt' assert math.isclose(score, expected_score, rel_tol=1e-4) @@ -193,6 +201,7 @@ def __init__(self): super().__init__() self.val_logs = torch.randn(per_epoch_val_checks * max_epochs, limit_val_batches) self.val_loop_count = 0 + self.scores = [] def validation_step(self, batch, batch_idx): log_value = self.val_logs[self.val_loop_count, batch_idx] @@ -202,6 +211,7 @@ def validation_step(self, batch, batch_idx): def validation_epoch_end(self, outputs): self.val_loop_count += 1 super().validation_epoch_end(outputs) + self.scores.append(self.trainer.logged_metrics[monitor]) def configure_optimizers(self): optimizer = optim.SGD(self.parameters(), lr=lr) @@ -236,7 +246,6 @@ def configure_optimizers(self): assert trainer.state.finished, f"Training failed with {trainer.state}" ckpt_files = list(Path(tmpdir).glob('*.ckpt')) - scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric] lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates # on_train_end ckpt callback is called which creates an additional ckpt in case no ckpt is created at the @@ -246,14 +255,14 @@ def configure_optimizers(self): additional_ckpt_path = [f for f in ckpt_files if 'v1' in f.stem][0] additional_ckpt = True - assert len(ckpt_files) == len(scores) + additional_ckpt == per_epoch_val_checks * max_epochs + additional_ckpt + assert len(ckpt_files) == len(model.scores) + additional_ckpt == per_epoch_val_checks * max_epochs + additional_ckpt assert len(lr_scheduler_debug) == max_epochs def _make_assertions(epoch, ix, version=''): global_ix = ix + per_epoch_val_checks * epoch duplicated = bool(version) - score = scores[global_ix] + score = model.scores[global_ix] expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item() expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{version}.ckpt' assert math.isclose(score, expected_score, rel_tol=1e-4) From e17ed6b5f7c1711bd261032bfacdf948db576c23 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 30 May 2021 00:06:55 +0200 Subject: [PATCH 208/455] Progress --- .../logger_connector/logger_connector.py | 46 +++++-------------- pytorch_lightning/trainer/properties.py | 28 +++++------ pytorch_lightning/utilities/debugging.py | 12 ----- tests/trainer/test_trainer.py | 2 +- 4 files changed, 27 insertions(+), 61 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index bfa80ce9ae097..2cf9ec9e20339 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -110,7 +110,7 @@ def log_metrics(self, metrics, grad_norm_dict, step=None): self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step) self.trainer.logger.save() - self.add_logged_metrics(scalar_metrics) + self._logged_metrics.update(scalar_metrics) """ Evaluation metric updates @@ -127,8 +127,6 @@ def add_to_eval_loop_results(self, dl_idx, has_been_initialized): return callback_metrics = self.trainer.result_collection.metrics[DefaultMetricsKeys.CALLBACK] - if os.getenv("PL_DEV_DEBUG", '0') == '1': - callback_metrics["debug_epoch"] = self.trainer.current_epoch callback_metrics = deepcopy(callback_metrics) for key in list(callback_metrics.keys()): if "dataloader_idx" in key: @@ -150,8 +148,8 @@ def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: metrics = self.trainer.result_collection.metrics # update metrics - self.add_progress_bar_metrics(metrics[DefaultMetricsKeys.PBAR]) - self.add_callback_metrics(metrics[DefaultMetricsKeys.CALLBACK]) + self._progress_bar_metrics.update(metrics[DefaultMetricsKeys.PBAR]) + self._callback_metrics.update(metrics[DefaultMetricsKeys.CALLBACK]) if not self.trainer.sanity_checking: @@ -214,15 +212,14 @@ def update_evaluation_step_metrics(self) -> None: metrics = self.trainer.result_collection.metrics # update metrics - self.add_progress_bar_metrics(metrics[DefaultMetricsKeys.PBAR]) - self.add_callback_metrics(metrics[DefaultMetricsKeys.CALLBACK]) + self._progress_bar_metrics.update(metrics[DefaultMetricsKeys.PBAR]) + self._callback_metrics.update(metrics[DefaultMetricsKeys.CALLBACK]) if self.trainer.sanity_checking: return - batch_log_metrics = metrics[DefaultMetricsKeys.LOG] - # logs user requested information to logger + batch_log_metrics = metrics[DefaultMetricsKeys.LOG] if len(batch_log_metrics) > 0: kwargs = dict() if "step" in batch_log_metrics else dict(step=self.evaluation_log_step) self.log_metrics(batch_log_metrics, {}, **kwargs) @@ -249,15 +246,14 @@ def update_train_step_metrics(self, batch_output): metrics = self.trainer.result_collection.metrics # update metrics - self.add_progress_bar_metrics(metrics[DefaultMetricsKeys.PBAR]) - self.add_callback_metrics(metrics[DefaultMetricsKeys.CALLBACK]) + self._progress_bar_metrics.update(metrics[DefaultMetricsKeys.PBAR]) + self._callback_metrics.update(metrics[DefaultMetricsKeys.CALLBACK]) if self.trainer.train_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization: return - batch_log_metrics = metrics[DefaultMetricsKeys.LOG] - # when metrics should be logged + batch_log_metrics = metrics[DefaultMetricsKeys.LOG] if self.should_update_logs or self.trainer.fast_dev_run is True: # logs user requested information to logger grad_norm_dict = batch_output.grad_norm_dict @@ -274,16 +270,11 @@ def update_train_epoch_metrics(self) -> None: metrics = self.trainer.result_collection.metrics - # update metrics - self.add_progress_bar_metrics(metrics[DefaultMetricsKeys.PBAR]) - - callback_metrics = metrics[DefaultMetricsKeys.CALLBACK] - - self._callback_metrics.update(callback_metrics) - - epoch_log_metrics = metrics[DefaultMetricsKeys.LOG] + self._progress_bar_metrics.update(metrics[DefaultMetricsKeys.PBAR]) + self._callback_metrics.update(metrics[DefaultMetricsKeys.CALLBACK]) # add the metrics to the loggers + epoch_log_metrics = metrics[DefaultMetricsKeys.LOG] if epoch_log_metrics and len(epoch_log_metrics) > 0: epoch_log_metrics["epoch"] = self.trainer.current_epoch self._logged_metrics.update(epoch_log_metrics) @@ -301,8 +292,6 @@ def callback_metrics(self) -> Dict[str, float]: if self.trainer.result_collection: metrics = self.trainer.result_collection.metrics[DefaultMetricsKeys.CALLBACK] self._callback_metrics.update(metrics) - if os.getenv("PL_DEV_DEBUG", '0') == '1': - self._callback_metrics["debug_epoch"] = self.trainer.current_epoch return self._callback_metrics @property @@ -319,16 +308,5 @@ def progress_bar_metrics(self) -> Dict[str, float]: self._progress_bar_metrics.update(metrics) return self._progress_bar_metrics - def add_progress_bar_metrics(self, metrics: Dict[str, float]) -> None: - self._progress_bar_metrics.update(metrics) - self.trainer.dev_debugger.track_pbar_metrics_history(metrics) - - def add_logged_metrics(self, metrics: Dict[str, float]) -> None: - self._logged_metrics.update(metrics) - self.trainer.dev_debugger.track_logged_metrics_history(metrics) - - def add_callback_metrics(self, metrics: Dict[str, float]) -> None: - self._callback_metrics.update(metrics) - def check_logging(self, fx_name: str, on_step: bool, on_epoch: bool) -> None: self._fx_validator.check_logging(fx_name=fx_name, on_step=on_step, on_epoch=on_epoch) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 9384c190227b2..49029432f481e 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -63,9 +63,9 @@ class TrainerProperties(ABC): state: TrainerState train_loop: TrainLoop evaluation_loop: EvaluationLoop - ''' + """ Accelerator properties - ''' + """ @property def accelerator(self) -> Accelerator: @@ -209,9 +209,9 @@ def model(self, model: torch.nn.Module) -> None: """ self.accelerator.model = model - ''' + """ General properties - ''' + """ @property def log_dir(self) -> Optional[str]: @@ -364,9 +364,9 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]: def save_checkpoint(self, filepath, weights_only: bool = False) -> None: self.checkpoint_connector.save_checkpoint(filepath, weights_only) - ''' + """ Parsing properties - ''' + """ @classmethod def default_attributes(cls) -> dict: @@ -398,9 +398,9 @@ def match_env_arguments(cls) -> Namespace: def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: return add_argparse_args(cls, parent_parser, **kwargs) - ''' + """ State properties - ''' + """ @property def interrupted(self) -> bool: @@ -476,9 +476,9 @@ def sanity_checking(self, val: bool) -> None: elif self.sanity_checking: self.state.stage = None - ''' + """ Loop properties - ''' + """ @property def global_step(self) -> int: @@ -504,9 +504,9 @@ def max_steps(self) -> Optional[int]: def min_steps(self) -> Optional[int]: return self.train_loop.min_steps - ''' + """ Logging properties - ''' + """ @property def callback_metrics(self) -> dict: @@ -529,9 +529,9 @@ def result_collection(self) -> Optional[ResultCollection]: elif self.testing: return self.evaluation_loop.test_results - ''' + """ Other - ''' + """ # TODO: refactor this so that it can be done in LightningOptimizer def __getstate__(self): diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index 56833fd03735a..3a3afd2b36329 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -39,8 +39,6 @@ class InternalDebugger(object): def __init__(self, trainer): self.enabled = os.environ.get('PL_DEV_DEBUG', '0') == '1' self.trainer = trainer - self.logged_metrics = [] - self.pbar_added_metrics = [] self.saved_train_losses = [] self.saved_val_losses = [] self.saved_test_losses = [] @@ -110,11 +108,6 @@ def track_load_dataloader_call(self, name, dataloaders): elif 'test' in name: self.test_dataloader_calls.append(values) - @enabled_only - def track_logged_metrics_history(self, scalar_metrics): - scalar_metrics['global_step'] = self.trainer.global_step - self.logged_metrics.append(scalar_metrics) - @enabled_only def track_train_loss_history(self, batch_idx, loss): loss_dict = {'batch_idx': batch_idx, 'epoch': self.trainer.current_epoch, 'loss': loss.detach()} @@ -151,11 +144,6 @@ def track_eval_loss_history(self, batch_idx, dataloader_idx, output): else: self.saved_val_losses.append(loss_dict) - @enabled_only - def track_pbar_metrics_history(self, metrics): - metrics['debug_epoch'] = self.trainer.current_epoch - self.pbar_added_metrics.append(metrics) - @enabled_only def track_early_stopping_history(self, callback, current): debug_dict = { diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 73d31b2a2e54d..d5e3ea919c57e 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -341,7 +341,7 @@ def mock_save_function(filepath, *args): for i, loss in enumerate(losses): trainer.train_loop.current_epoch = i trainer.train_loop.global_step = i - trainer.logger_connector.add_callback_metrics({"checkpoint_on": loss}) + trainer.logger_connector.callback_metrics.update({"checkpoint_on": loss}) checkpoint_callback.on_validation_end(trainer, trainer.lightning_module) file_lists = set(os.listdir(tmpdir)) From 19b28368b7ded0b447100d7d6654c4b44e3cc516 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 30 May 2021 01:27:44 +0200 Subject: [PATCH 209/455] Fix test --- tests/models/test_grad_norm.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py index e547e4a973e89..de6f7c3c2f1db 100644 --- a/tests/models/test_grad_norm.py +++ b/tests/models/test_grad_norm.py @@ -66,14 +66,8 @@ class TestModel(ModelWithManualGradTracker): logged_metrics = [] def on_train_batch_end(self, *_) -> None: - if self.trainer.logged_metrics: - # add batch level logged metrics - # copy so they don't get reduced - self.logged_metrics.append(self.trainer.logged_metrics.copy()) - - def on_train_end(self): - # add aggregated logged metrics - self.logged_metrics.append(self.trainer.logged_metrics) + # copy so they don't get reduced + self.logged_metrics.append(self.trainer.logged_metrics.copy()) model = TestModel(norm_type) From 1529f4a5462a0bc5dcaebce906c03622841ecd86 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 30 May 2021 01:49:02 +0200 Subject: [PATCH 210/455] Minor changes --- tests/models/test_tpu.py | 18 ++++++------------ .../trainer/logging_/test_eval_loop_logging.py | 1 - .../logging_/test_train_loop_logging.py | 3 +-- .../trainer/loops/test_evaluation_loop_flow.py | 4 +--- .../loops/test_training_loop_flow_scalar.py | 8 ++------ .../optimization/test_multiple_optimizers.py | 5 +---- 6 files changed, 11 insertions(+), 28 deletions(-) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 11a9873cece30..cfe9a17852908 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -21,7 +21,7 @@ import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils -from pytorch_lightning import Trainer +from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.plugins import TPUSpawnPlugin @@ -424,20 +424,14 @@ def test_if_test_works_with_checkpoint_false(tmpdir): def test_tpu_sync_dist(): """Test tpu spawn sync dist operation """ - def test_sync_dist(rank): - tensor = torch.tensor([1.0]) - training_type_plugin = TPUSpawnPlugin() - - res = ResultCollection() - res.log( - "test_tensor", - tensor, - sync_fn=training_type_plugin.reduce, + def test_sync_dist(_): + value = LightningModule._LightningModule__sync( + torch.tensor([1.0]), + sync_fn=TPUSpawnPlugin().reduce, sync_dist=True, sync_dist_op=torch.distributed.ReduceOp.SUM ) - - assert res["test_tensor"].item() == 8, "Result-Log does not work properly with TPU Spawn and Tensors" + assert value.item() == 8 xmp.spawn(test_sync_dist, nprocs=8, start_method='fork') diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index e69c7c905381b..1f1f235dc0763 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -725,7 +725,6 @@ def test_step(self, batch, batch_idx): limit_test_batches=2, max_epochs=2, progress_bar_refresh_rate=1, - num_sanity_val_steps=2, ) # Train the model ⚡ diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index be6db48b910ef..d4ba81ce656e6 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -363,8 +363,7 @@ def make_logging( custom_func_name = f"{func_idx}_{idx}_{func_name}" pl_module.log(custom_func_name, value, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar) - if current_epoch not in self.callback_funcs_called[custom_func_name]: - self.callback_funcs_called[custom_func_name][current_epoch] = [] + self.callback_funcs_called[custom_func_name].setdefault(current_epoch, []) self.callback_funcs_called[custom_func_name][current_epoch].append(value) forked = on_step and on_epoch diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py index 4ef44c1d5ac90..1936916fd2f1e 100644 --- a/tests/trainer/loops/test_evaluation_loop_flow.py +++ b/tests/trainer/loops/test_evaluation_loop_flow.py @@ -141,9 +141,7 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected - for batch_idx, batch in enumerate(model.train_dataloader()): - break - + batch_idx, batch = 0, next(iter(model.train_dataloader())) out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index 2c93b8205c59f..cc878d46c1866 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -153,9 +153,7 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected - for batch_idx, batch in enumerate(model.train_dataloader()): - break - + batch_idx, batch = 0, next(iter(model.train_dataloader())) out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) @@ -235,9 +233,7 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected - for batch_idx, batch in enumerate(model.train_dataloader()): - break - + batch_idx, batch = 0, next(iter(model.train_dataloader())) out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) diff --git a/tests/trainer/optimization/test_multiple_optimizers.py b/tests/trainer/optimization/test_multiple_optimizers.py index c795107a36371..495f51ab8d394 100644 --- a/tests/trainer/optimization/test_multiple_optimizers.py +++ b/tests/trainer/optimization/test_multiple_optimizers.py @@ -30,10 +30,7 @@ def configure_optimizers(self): def test_unbalanced_logging_with_multiple_optimizers(tmpdir): - """ - This tests ensures reduction works in unbalanced logging settings, - even when a Callback also logs. - """ + """This tests ensures reduction works in unbalanced logging settings""" class TestModel(MultiOptModel): From caea6cfafbb60e7052c0baab9e20dd17ea6866ba Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 30 May 2021 01:49:08 +0200 Subject: [PATCH 211/455] Minor changes --- tests/models/test_tpu.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index cfe9a17852908..8e3bab7350f7f 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -25,7 +25,6 @@ from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.plugins import TPUSpawnPlugin -from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException From b595ef2f409405145522bae86599b4484eafd4f5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 30 May 2021 02:03:43 +0200 Subject: [PATCH 212/455] Remove print --- tests/callbacks/test_progress_bar.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 00fefcccc3435..965f74f802f05 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -388,9 +388,6 @@ def training_step(self, batch, batch_idx): self.log('bar', {"baz": torch.tensor([1])}, prog_bar=True) return super().training_step(batch, batch_idx) - def on_train_end(self) -> None: - print(self.trainer.result_collection) - trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, From 14cf512364184b3382afa01f92e6d4df3e228096 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 30 May 2021 02:11:22 +0200 Subject: [PATCH 213/455] flake8 --- tests/trainer/logging_/test_logger_connector.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 490a7f1065e96..3b00fc0b6f27d 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -14,9 +14,6 @@ """ Tests to ensure that the training loop works with a dict (1.0) """ -import os -from copy import deepcopy -from typing import Any, Callable from unittest import mock import pytest From d9c91bc81a1a37e0228778b2598d1c4b30d85632 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 30 May 2021 16:03:39 +0200 Subject: [PATCH 214/455] Progress --- pytorch_lightning/core/lightning.py | 50 ++++++++----------- .../plugins/training_type/ddp2.py | 19 ++++--- pytorch_lightning/plugins/training_type/dp.py | 19 ++++--- .../connectors/logger_connector/result.py | 2 +- pytorch_lightning/trainer/training_loop.py | 28 ++++------- pytorch_lightning/utilities/types.py | 1 + .../trainer/logging_/test_logger_connector.py | 4 +- 7 files changed, 54 insertions(+), 69 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index fff1884947790..d3395d3d1da58 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -26,7 +26,7 @@ from argparse import Namespace from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch from torch import ScriptModule, Tensor @@ -50,9 +50,6 @@ from pytorch_lightning.utilities.types import _METRIC, EPOCH_OUTPUT, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache -if TYPE_CHECKING: - from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection - warning_cache = WarningCache() log = logging.getLogger(__name__) @@ -115,7 +112,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._automatic_optimization: bool = True self._truncated_bptt_steps: int = 0 self._param_requires_grad_state = dict() - self._map_metric_id_name: Optional[Dict[int, str]] = None + self._metric_attributes: Optional[Dict[int, str]] = None def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: if use_pl_optimizer: @@ -276,7 +273,7 @@ def log( sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, batch_size: Optional[int] = None, - lightning_attribute_name: Optional[str] = None, + metric_attribute: Optional[str] = None, ) -> None: """ Log a key, value @@ -314,8 +311,9 @@ def log( the name (when using multiple). If False, user needs to give unique names for each dataloader to not mix values batch_size: Current batch_size. This will be directly inferred from the loaded batch, - but some esoteric data type such as graph might need to explicitly provide the batch_size. - lightning_attribute_name: The name of the Metric attribute name. This is used for fault tolerant logging. + but some some data structures might need to explicitly provide it. + metric_attribute: The attribute name for the metric in the LightningModule. + Necessary to save/restore its state. """ if tbptt_reduce_fx is not None: rank_zero_deprecation( @@ -330,11 +328,8 @@ def log( ' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`' ) - result_collections: Optional['ResultCollection'] = self.trainer.result_collection - - if result_collections is not None: - # TODO: if logged twice fail with crash - + result_collection = self.trainer.result_collection + if result_collection is not None: # set the default depending on the fx_name on_step = self.__auto_choose_log_on_step(on_step) on_epoch = self.__auto_choose_log_on_epoch(on_epoch) @@ -348,14 +343,15 @@ def log( f"Logged key: {name} should not contain information about dataloader_idx." ) - if lightning_attribute_name is None and isinstance(value, Metric): - # used to find this Metric associated LightningModule attribute name. - if self._map_metric_id_name is None: - self._map_metric_id_name = { - id(module): module_name - for module_name, module in self.named_children() if isinstance(module, Metric) + if metric_attribute is None and isinstance(value, Metric): + if self._metric_attributes is None: + # compute once + self._metric_attributes = { + module: name + for name, module in self.named_children() if isinstance(module, Metric) } - lightning_attribute_name = self._map_metric_id_name[id(value)] + # try to find the passed metric in the LightningModule + metric_attribute = self._metric_attributes.get(value, None) sync_fn = partial( self.__sync, @@ -365,13 +361,9 @@ def log( sync_dist_group=sync_dist_group, device=self.device, ) - value = apply_to_collection(value, ( - torch.Tensor, - float, - int, - ), sync_fn) + value = apply_to_collection(value, (torch.Tensor, float, int), sync_fn) - result_collections.log( + result_collection.log( self._current_fx_name, name, value, @@ -383,7 +375,7 @@ def log( enable_graph=enable_graph, dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), batch_size=batch_size, - lightning_attribute_name=lightning_attribute_name, + lightning_attribute_name=metric_attribute, ) def log_dict( @@ -403,7 +395,7 @@ def log_dict( add_dataloader_idx: bool = True, ) -> None: """ - Log a dictonary of values at once + Log a dictionary of values at once Example:: @@ -460,7 +452,7 @@ def __sync( if isinstance(value, torch.Tensor): value = value.clone() else: - return torch.tensor(value, device=device, dtype=torch.float) + value = torch.tensor(value, device=device, dtype=torch.float) sync_fn = sync_fn or sync_ddp_if_available dist_available = torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed() diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index 146dfc6c97dae..185e955135141 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -15,6 +15,7 @@ from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.types import _METRIC_COLLECTION class DDP2Plugin(DDPPlugin): @@ -34,27 +35,25 @@ def setup(self, model): self.task_idx = self.cluster_environment.local_rank() # the difference to DDP is that we don't call children processes here - def reduce(self, tensor, *args, **kwargs): + def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION: """ - Reduces a tensor from all processes to one aggregated tensor. + Reduces a collection of tensors from all processes. It can be applied to just a single tensor. In DDP2, the reduction here is only across local devices within the node. Args: - tensor: the tensor to sync and reduce + collection: The collection of tensors to sync and reduce. *args: ignored for DDP2 **kwargs: ignored for DDP2 Return: - reduced value, except when the input was not a tensor the output remains is unchanged + Reduced tensor values or the same value if it was not or did not contain a tensor. """ - def _reduce(t: torch.Tensor) -> torch.Tensor: - dtype_tensor = t.dtype - return t.float().mean().type(dtype_tensor) + def mean(t: torch.Tensor) -> torch.Tensor: + original_dtype = t.dtype + return t.float().mean().to(original_dtype) - tensor = apply_to_collection(tensor, torch.Tensor, _reduce) - - return tensor + return apply_to_collection(collection, torch.Tensor, mean) @property def root_device(self): diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 08899d72db17f..18aeb6a451d4a 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -19,6 +19,7 @@ from pytorch_lightning.overrides.data_parallel import LightningParallelModule from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.types import _METRIC_COLLECTION class DataParallelPlugin(ParallelPlugin): @@ -51,26 +52,24 @@ def setup(self, model): model.to(self.root_device) self._model = DataParallel(LightningParallelModule(model), self.parallel_devices) - def reduce(self, tensor, *args, **kwargs): + def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION: """ - Reduces a tensor from all parallel processes to one aggregated tensor. + Reduces a collection of tensors from all processes. It can be applied to just a single tensor. Args: - tensor: the tensor to sync and reduce + collection: The collection of tensors to sync and reduce. *args: ignored for DP **kwargs: ignored for DP Return: - reduced value, except when the input was not a tensor the output remains is unchanged + Reduced tensor values or the same value if it was not or did not contain a tensor. """ - def _reduce(t: torch.Tensor) -> torch.Tensor: - dtype_tensor = t.dtype - return t.float().mean().type(dtype_tensor) + def mean(t: torch.Tensor) -> torch.Tensor: + original_dtype = t.dtype + return t.float().mean().to(original_dtype) - tensor = apply_to_collection(tensor, torch.Tensor, _reduce) - - return tensor + return apply_to_collection(collection, torch.Tensor, mean) @property def root_device(self): diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index b8333003118c7..ce73e6780850f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -331,7 +331,7 @@ def log( dataloader_idx: The current dataloader idx. This will be used to automatically add `/dataloader_idx_{}` on the metrics. batch_size: Current batch size. - lightning_attribute_name: When providing `nn.Metric` as a value, the ``lightning_attribute_name`` + lightning_attribute_name: When providing `nn.Metric` as a value, the ``metric_attribute`` need to be provided to enable automatic saving / re-loading. """ diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index fac4e2352a9e1..ba9e0fdcfa199 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -201,19 +201,17 @@ def reset_train_val_dataloaders(self, model) -> None: self.trainer.reset_val_dataloader(model) def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): - hook_overridden = self._should_add_batch_output_to_epoch_output() + if not hook_overridden: + return # track the outputs to reduce at the end of the epoch for opt_idx, opt_outputs in enumerate(batch_end_outputs): - - # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end - if not hook_overridden: - continue - # with 1 step (no tbptt) don't use a sequence at epoch end - if isinstance(opt_outputs, - list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], ResultCollection): + if ( + isinstance(opt_outputs, list) and len(opt_outputs) == 1 + and not isinstance(opt_outputs[0], ResultCollection) + ): opt_outputs = opt_outputs[0] epoch_output[opt_idx].append(opt_outputs) @@ -298,9 +296,6 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): # accumulate loss. if accumulate_grad_batches==1, no effect closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches - # detach the loss - training_step_output._minimize = training_step_output.minimize.detach() - # the loss will get scaled for amp. avoid any modifications to it untouched_loss = closure_loss.detach().clone() @@ -661,30 +656,29 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): if self.trainer.lightning_module.automatic_optimization: for opt_idx, optimizer in self.get_active_optimizers(batch_idx): - result = self._run_optimization(batch_idx, split_idx, split_batch, opt_idx, optimizer) + result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer) if result: batch_outputs[opt_idx].append(result.training_step_output_for_epoch_end) grad_norm_dict = result.get("grad_norm_dict", {}) else: # in manual optimization, there is no looping over optimizers - result = self._run_optimization(batch_idx, split_idx, split_batch) + result = self._run_optimization(batch_idx, split_batch) if result: batch_outputs[0].append(result.training_step_output_for_epoch_end) output = AttributeDict( signal=0, - # todo: Properly aggregate grad_norm accros opt_idx and split_idx grad_norm_dict=grad_norm_dict, training_step_output_for_epoch_end=batch_outputs, ) return output - def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimizer=None): + def _run_optimization(self, batch_idx, split_batch, opt_idx=0, optimizer=None): # TODO: In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change # opt_idx=0 to opt_idx=None in the signature here # toggle model params + set info to logger_connector - self.run_train_split_start(batch_idx, split_idx, split_batch, opt_idx, optimizer) + self.run_train_split_start(batch_idx, split_batch, opt_idx, optimizer) result = AttributeDict() closure = self.make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens, result) @@ -934,7 +928,7 @@ def save_loggers_on_train_batch_end(self): if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() - def run_train_split_start(self, batch_idx: int, split_idx, split_batch, opt_idx, optimizer): + def run_train_split_start(self, batch_idx, split_batch, opt_idx, optimizer): # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1: diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 8a81040af07db..f209287358f84 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -23,6 +23,7 @@ from torchmetrics import Metric _METRIC = Union[Metric, torch.Tensor, Number] +_METRIC_COLLECTION = Union[_METRIC, Dict[str, '_METRIC_COLLECTION']] STEP_OUTPUT = Union[torch.Tensor, Dict[str, Any]] EPOCH_OUTPUT = List[STEP_OUTPUT] _EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 3b00fc0b6f27d..1e6f697f2717c 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -329,8 +329,8 @@ def _step(self, stage, batch): acc.reset.reset_mock() ap.reset.reset_mock() - self.log(f"{stage}/accuracy", acc, lightning_attribute_name=f"acc_{stage}") - self.log(f"{stage}/ap", ap, lightning_attribute_name=f"ap_{stage}") + self.log(f"{stage}/accuracy", acc, metric_attribute=f"acc_{stage}") + self.log(f"{stage}/ap", ap, metric_attribute=f"ap_{stage}") return loss From 2da03692b296e4a734ba20c87906f73c17312012 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 30 May 2021 16:05:22 +0200 Subject: [PATCH 215/455] Progress --- pytorch_lightning/core/lightning.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index d3395d3d1da58..416a018391a03 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -448,16 +448,17 @@ def __sync( if not isinstance(value, (torch.Tensor, numbers.Number)): return value + sync_fn = sync_fn or sync_ddp_if_available + dist_available = torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed() + if not sync_dist or not dist_available: + return value + # TODO: Find a way to make the reduction only once, so we don't need to clone. if isinstance(value, torch.Tensor): value = value.clone() else: value = torch.tensor(value, device=device, dtype=torch.float) - sync_fn = sync_fn or sync_ddp_if_available - dist_available = torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed() - if not sync_dist or not dist_available: - return value return sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op) def write_prediction( From f1900e257d92fd88752a4d0e519b0ac2e656ad2f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 30 May 2021 16:20:46 +0200 Subject: [PATCH 216/455] Dead code --- pytorch_lightning/trainer/training_loop.py | 42 ++++++++----------- .../loops/test_evaluation_loop_flow.py | 4 +- .../loops/test_training_loop_flow_scalar.py | 8 ++-- 3 files changed, 23 insertions(+), 31 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ba9e0fdcfa199..099e0971757d7 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -282,10 +282,8 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): training_step_output = self.trainer.call_hook("training_step_end", training_step_output) - training_step_output_for_epoch_end, training_step_output = self._process_training_step_output( - training_step_output, split_batch - ) - if training_step_output_for_epoch_end is None: + training_step_output = self._process_training_step_output(training_step_output) + if training_step_output is None: return # enable empty loss when using manual opt @@ -304,16 +302,12 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): closure_loss=closure_loss, loss=untouched_loss, training_step_output=training_step_output, - training_step_output_for_epoch_end=training_step_output_for_epoch_end, ) return result - def _process_training_step_output(self, training_step_output, split_batch): - training_step_output_for_epoch_end = training_step_output - - # enable validation_step return None - if training_step_output_for_epoch_end is None: - return None, None + def _process_training_step_output(self, training_step_output): + if training_step_output is None: + return None result = self.trainer.result_collection @@ -340,7 +334,7 @@ def _process_training_step_output(self, training_step_output, split_batch): if self.trainer.move_metrics_to_cpu: result = result.cpu() - return result, result + return result @staticmethod def _prepare_outputs( @@ -485,21 +479,20 @@ def run_training_epoch(self): if batch_output.signal == -1: break - # ----------------------------------------- - # SAVE METRICS TO LOGGERS AND PROGRESS_BAR - # ----------------------------------------- - self.trainer.logger_connector.update_train_step_metrics(batch_output) - # hook - # TODO: add outputs to batches self.on_train_batch_end( epoch_output, - batch_output.training_step_output_for_epoch_end, + batch_output.training_step_output, batch, batch_idx, dataloader_idx, ) + # ----------------------------------------- + # SAVE METRICS TO LOGGERS AND PROGRESS_BAR + # ----------------------------------------- + self.trainer.logger_connector.update_train_step_metrics(batch_output) + # ----------------------------------------- # VALIDATE IF NEEDED # ----------------------------------------- @@ -635,7 +628,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): return AttributeDict( signal=0, grad_norm_dict={}, - training_step_output_for_epoch_end=batch_outputs, + training_step_output=batch_outputs, ) # hook @@ -658,20 +651,19 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): for opt_idx, optimizer in self.get_active_optimizers(batch_idx): result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer) if result: - batch_outputs[opt_idx].append(result.training_step_output_for_epoch_end) + batch_outputs[opt_idx].append(result.training_step_output) grad_norm_dict = result.get("grad_norm_dict", {}) else: # in manual optimization, there is no looping over optimizers result = self._run_optimization(batch_idx, split_batch) if result: - batch_outputs[0].append(result.training_step_output_for_epoch_end) + batch_outputs[0].append(result.training_step_output) - output = AttributeDict( + return AttributeDict( signal=0, grad_norm_dict=grad_norm_dict, - training_step_output_for_epoch_end=batch_outputs, + training_step_output=batch_outputs, ) - return output def _run_optimization(self, batch_idx, split_batch, opt_idx=0, optimizer=None): # TODO: In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py index 1936916fd2f1e..61fae95c70312 100644 --- a/tests/trainer/loops/test_evaluation_loop_flow.py +++ b/tests/trainer/loops/test_evaluation_loop_flow.py @@ -73,7 +73,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert out.signal == 0 assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) - train_step_out = out.training_step_output_for_epoch_end + train_step_out = out.training_step_output assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] assert isinstance(train_step_out.minimize, torch.Tensor) @@ -146,7 +146,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert out.signal == 0 assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) - train_step_out = out.training_step_output_for_epoch_end + train_step_out = out.training_step_output assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] assert isinstance(train_step_out.minimize, torch.Tensor) diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index cc878d46c1866..d9d8273eb6429 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -158,7 +158,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert out.signal == 0 assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) - train_step_out = out.training_step_output_for_epoch_end + train_step_out = out.training_step_output assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] assert isinstance(train_step_out.minimize, torch.Tensor) @@ -238,7 +238,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert out.signal == 0 assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) - train_step_out = out.training_step_output_for_epoch_end + train_step_out = out.training_step_output assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] assert isinstance(train_step_out.minimize, torch.Tensor) @@ -323,7 +323,7 @@ def training_step(self, batch, batch_idx): for batch_idx, batch in enumerate(model.train_dataloader()): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) if not batch_idx % 2: - assert out.training_step_output_for_epoch_end == [[]] + assert out.training_step_output == [[]] assert out.signal == 0 @@ -368,5 +368,5 @@ def train_dataloader(self): for batch_idx, batch in enumerate(model.train_dataloader()): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) if not batch_idx % 2: - assert out.training_step_output_for_epoch_end == [[]] + assert out.training_step_output == [[]] assert out.signal == 0 From 82fd2d14640b222330899af47375fcac0cd06a9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 May 2021 17:54:35 +0200 Subject: [PATCH 217/455] fix min epochs check in done() --- pytorch_lightning/loops/epoch_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 884e1f6af8811..4bef4272fcd69 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -102,7 +102,7 @@ def done(self) -> bool: should_stop = False if self.trainer.should_stop: # early stopping - met_min_epochs = (self.current_epoch >= self.min_epochs - 1) if self.min_epochs else True + met_min_epochs = self.current_epoch >= self.min_epochs if self.min_epochs else True met_min_steps = self.global_step >= self.min_steps if self.min_steps else True if met_min_epochs and met_min_steps: # TODO: THIS is now in on_run_end, always run? From 86bb8126910da46a6df1f15a9a9f8616aba27e52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 May 2021 18:00:24 +0200 Subject: [PATCH 218/455] fix name in mock --- tests/trainer/loops/test_evaluation_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py index 331552e272e85..b45212add5208 100644 --- a/tests/trainer/loops/test_evaluation_loop.py +++ b/tests/trainer/loops/test_evaluation_loop.py @@ -21,7 +21,7 @@ from tests.helpers.runif import RunIf -@mock.patch("pytorch_lightning.loops.evaluation_dataloader_loop.EvaluationDataLoaderLoop.on_evaluation_epoch_end") +@mock.patch("pytorch_lightning.loops.dataloader.evaluation_dataloader_loop.EvaluationDataLoaderLoop.on_evaluation_epoch_end") def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir): """ Tests that `on_evaluation_epoch_end` is called From 121a3451bc526b355bda0ebfac6ed56989d774b2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 30 May 2021 18:00:24 +0200 Subject: [PATCH 219/455] Apply func improvements and tests --- pytorch_lightning/utilities/apply_func.py | 98 ++++++++++------------- tests/utilities/test_apply_func.py | 67 +++++++++++++++- 2 files changed, 110 insertions(+), 55 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index af125e5ee5a95..5655933f054ef 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -54,13 +54,18 @@ def from_numpy(value, device: torch.device = None): ] +def _is_namedtuple(obj): + # https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8 + return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") + + def apply_to_collection( data: Any, dtype: Union[type, tuple], function: Callable, *args, wrong_dtype: Optional[Union[type, tuple]] = None, - remove_none: bool = False, + include_none: bool = True, **kwargs ) -> Any: """ @@ -73,14 +78,12 @@ def apply_to_collection( *args: positional arguments (will be forwarded to calls of ``function``) wrong_dtype: the given function won't be applied if this type is specified and the given collections is of the :attr:`wrong_type` even if it is of type :attr`dtype` - remove_none: Whether to skip an element if the output of function is ``None`` - while applying onto the collection. + include_none: Whether to include an element if the output of ``function`` is ``None``. **kwargs: keyword arguments (will be forwarded to calls of ``function``) Returns: the resulting collection """ - elem_type = type(data) # Breaking condition if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): @@ -88,31 +91,23 @@ def apply_to_collection( # Recursively apply to collection items if isinstance(data, Mapping): - _out = {} + out = {} for k, v in data.items(): v = apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) - if remove_none and v is None: - continue - _out[k] = v - return elem_type(_out) - - if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple - _out = [] - for d in data: - v = apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) - if remove_none and v is None: - continue - _out.append(v) - return elem_type(*_out) - - if isinstance(data, Sequence) and not isinstance(data, str): - _out = [] + if include_none or v is not None: + out[k] = v + return out + + is_namedtuple = _is_namedtuple(data) + is_sequence = isinstance(data, Sequence) and not isinstance(data, str) + if is_namedtuple or is_sequence: + elem_type = type(data) + out = [] for d in data: v = apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) - if remove_none and v is None: - continue - _out.append(v) - return elem_type(_out) + if include_none or v is not None: + out.append(v) + return elem_type(*out) if is_namedtuple else elem_type(out) # data is neither of dtype, nor a collection return data @@ -120,7 +115,7 @@ def apply_to_collection( def apply_to_collections( data1: Any, - data2: Any, + data2: Optional[Any], dtype: Union[type, tuple], function: Callable, *args, @@ -131,7 +126,8 @@ def apply_to_collections( Recursively applies a function to all elements of a certain dtype. Args: - data: the collection to apply the function to + data: The first collection + data2: The second collection dtype: the given function will be applied to all elements of this dtype function: the function to apply *args: positional arguments (will be forwarded to calls of ``function``) @@ -142,35 +138,29 @@ def apply_to_collections( Returns: the resulting collection """ - elem_type_1 = type(data1) - - # Breaking condition - if isinstance(data1, dtype) and (wrong_dtype is None or not isinstance(data1, wrong_dtype)): + if isinstance(data1, dtype) and data2 is not None and (wrong_dtype is None or not isinstance(data1, wrong_dtype)): return function(data1, data2, *args, **kwargs) - # Recursively apply to collection items - if isinstance(data1, Mapping): - return elem_type_1({ - k1: apply_to_collections(v1, v2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) - for (k1, v1), (k1, v2) in zip(data1.items(), data2.items()) - }) - - if isinstance(data1, tuple) and hasattr(data1, '_fields'): # named tuple - return elem_type_1( - *( - apply_to_collections(d1, d2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) - for d1, d2 in zip(data1, data2) - ) - ) - - if isinstance(data1, Sequence) and not isinstance(data1, str): - return elem_type_1([ - apply_to_collections(d1, d2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) - for d1, d2 in zip(data1, data2) - ]) - - # data is neither of dtype, nor a collection - return data1 + if isinstance(data1, Mapping) and data2 is not None: + # use union because we want to fail if a key does not exist in both + zipped = {k: (data1[k], data2[k]) for k in data1.keys() | data2.keys()} + return { + k: apply_to_collections(*v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + for k, v in zipped.items() + } + + is_namedtuple = _is_namedtuple(data1) + is_sequence = isinstance(data1, Sequence) and not isinstance(data1, str) + if (is_namedtuple or is_sequence) and data2 is not None: + assert len(data1) == len(data2), 'Sequence collections have different sizes' + elem_type = type(data1) + out = [ + apply_to_collections(v1, v2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + for v1, v2 in zip(data1, data2) + ] + return elem_type(*out) if is_namedtuple else elem_type(out) + + return apply_to_collection(data1, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) class TransferableDataType(ABC): diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index a7eea3a749f26..a3cf292c9a837 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -15,9 +15,10 @@ from collections import namedtuple import numpy as np +import pytest import torch -from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections def test_recursive_application_to_collection(): @@ -76,3 +77,67 @@ def test_recursive_application_to_collection(): assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a tensor' assert reduced['g'] == expected_result['g'], 'Reduction of a number did not yield the desired result' + + +def test_apply_to_collection_include_none(): + to_reduce = [1, 2, 3.4, 5.6, 7] + + def fn(x): + if isinstance(x, float): + return x + + reduced = apply_to_collection(to_reduce, (int, float), fn) + assert reduced == [None, None, 3.4, 5.6, None] + + reduced = apply_to_collection(to_reduce, (int, float), fn, include_none=False) + assert reduced == [3.4, 5.6] + + +def test_apply_to_collections(): + to_reduce_1 = {'a': {'b': [1, 2]}, 'c': 5} + to_reduce_2 = {'a': {'b': [3, 4]}, 'c': 6} + + def fn(a, b): + return a + b + + # basic test + reduced = apply_to_collections(to_reduce_1, to_reduce_2, int, fn) + assert reduced == {'a': {'b': [4, 6]}, 'c': 11} + + with pytest.raises(KeyError): + # strict mode - if a key does not exist in both we fail + apply_to_collections({**to_reduce_2, 'd': 'foo'}, to_reduce_1, float, fn) + + # multiple dtypes + reduced = apply_to_collections(to_reduce_1, to_reduce_2, (list, int), fn) + assert reduced == {'a': {'b': [1, 2, 3, 4]}, 'c': 11} + + # wrong dtype + reduced = apply_to_collections(to_reduce_1, to_reduce_2, (list, int), fn, wrong_dtype=int) + assert reduced == {'a': {'b': [1, 2, 3, 4]}, 'c': 5} + + # list takes precedence because it is the type of data1 + reduced = apply_to_collections([1, 2, 3], [4], (int, list), fn) + assert reduced == [1, 2, 3, 4] + + # different sizes + with pytest.raises(AssertionError, match='Sequence collections have different sizes'): + apply_to_collections([[1, 2], [3]], [4], int, fn) + + def fn(a, b): + return a.keys() | b.keys() + + # base case + reduced = apply_to_collections(to_reduce_1, to_reduce_2, dict, fn) + assert reduced == {'a', 'c'} + + # type conversion + to_reduce = [(1, 2), (3, 4)] + reduced = apply_to_collections(to_reduce, to_reduce, int, lambda *x: sum(x)) + assert reduced == [(2, 4), (6, 8)] + + # named tuple + foo = namedtuple('Foo', ['bar']) + to_reduce = [foo(1), foo(2), foo(3)] + reduced = apply_to_collections(to_reduce, to_reduce, int, lambda *x: sum(x)) + assert reduced == [foo(2), foo(4), foo(6)] From aac7ce9f9dca805907f1eaf6b1fb647d2ab9b0a1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 30 May 2021 18:48:19 +0200 Subject: [PATCH 220/455] Typo --- tests/accelerators/test_multi_nodes_gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/accelerators/test_multi_nodes_gpu.py b/tests/accelerators/test_multi_nodes_gpu.py index cae257666e390..463307ead8717 100644 --- a/tests/accelerators/test_multi_nodes_gpu.py +++ b/tests/accelerators/test_multi_nodes_gpu.py @@ -125,5 +125,5 @@ def backward(self, loss, optimizer, optimizer_idx): } # we don't want to enable val metrics during steps because it is not something that users should do - # on purpose DO NOT allow step_b... it's silly to monitor val step metrics + # on purpose DO NOT allow b_step... it's silly to monitor val step metrics assert set(trainer.callback_metrics) == {'a', 'a2', 'b', 'a_epoch', 'b_epoch', 'a_step'} From 82028ff4a38c8890d53692a70b406cd51c29dc0f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 30 May 2021 18:56:27 +0200 Subject: [PATCH 221/455] Rename --- .../trainer/connectors/logger_connector/result.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index ce73e6780850f..7b8b40bd6459a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -501,7 +501,7 @@ def get_metrics(self, on_step: bool) -> Dict[str, Dict[str, torch.Tensor]]: # extract forward_cache or computed from the ResultMetric # ignore when the output of fn is None - value = apply_to_collection(result_metric, ResultMetric, fn, remove_none=True) + value = apply_to_collection(result_metric, ResultMetric, fn, include_none=False) # detect if the value is None. This can be nested. is_empty = True @@ -513,12 +513,8 @@ def is_empty_fn(v): is_empty = False # apply detection. - # todo: (tchaton) need to find a way to support NamedTuple - wrong_dtype = ( - Mapping, - Sequence, - ) - apply_to_collection(value, object, is_empty_fn, wrong_dtype=wrong_dtype) + # TODO(@tchaton): need to find a way to support NamedTuple + apply_to_collection(value, object, is_empty_fn, wrong_dtype=(Mapping, Sequence)) # skip is the value was actually empty. if is_empty: @@ -543,7 +539,7 @@ def is_empty_fn(v): # populate progress_bar metrics. By default, the value should be converted to a float. if prog_bar: - value = apply_to_collection(value, torch.Tensor, self._to_item, remove_none=True) + value = apply_to_collection(value, torch.Tensor, self._to_item, include_none=False) metrics[DefaultMetricsKeys.PBAR][name_forked] = value return metrics From eba64c0eaf747b8c6eab8fb629aa2e80c8c2ac0d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 30 May 2021 19:02:00 +0200 Subject: [PATCH 222/455] Remove print --- tests/trainer/logging_/test_eval_loop_logging.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 518e92623df1a..de29a0f44cf5d 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -420,9 +420,6 @@ def validation_step(self, batch, batch_idx): loss = self.loss(batch, output) self.log('val_loss', loss) - def on_validation_end(self) -> None: - print(self.trainer.result_collection) - max_epochs = 1 model = TestModel() model.validation_epoch_end = None From 7dfd0032272d95c6589c04f708286237a4c45c6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 May 2021 21:52:53 +0200 Subject: [PATCH 223/455] fix accumulated_batches_reached condition --- pytorch_lightning/loops/batch_loop.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 9e414e8cd862c..78757419d3d91 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -78,7 +78,7 @@ def run(self, batch, batch_idx, dataloader_idx): return output def reset(self) -> None: - self.iteration_count = 0 + # self.iteration_count = 0 self._hiddens = None # TODO: let loops track individual outputs @@ -355,7 +355,8 @@ def _track_gradient_norm(self): def _accumulated_batches_reached(self): # TODO: use progress tracking of batches instead of iteration count, because iteration count may reset - return (self.iteration_count + 1) % self.trainer.accumulate_grad_batches == 0 + # iteration count is required to be global here, not reset + return self.iteration_count % self.trainer.accumulate_grad_batches == 0 def _num_training_batches_reached(self, is_last_batch=False): # TODO: use progress tracking of batches instead of iteration count, because iteration count may reset From 2c6019ee1e5d37330b63aff37cc93794f7af3a10 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 30 May 2021 19:53:42 +0000 Subject: [PATCH 224/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/trainer/loops/test_evaluation_loop.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py index b45212add5208..1b94923dd2ddf 100644 --- a/tests/trainer/loops/test_evaluation_loop.py +++ b/tests/trainer/loops/test_evaluation_loop.py @@ -21,7 +21,9 @@ from tests.helpers.runif import RunIf -@mock.patch("pytorch_lightning.loops.dataloader.evaluation_dataloader_loop.EvaluationDataLoaderLoop.on_evaluation_epoch_end") +@mock.patch( + "pytorch_lightning.loops.dataloader.evaluation_dataloader_loop.EvaluationDataLoaderLoop.on_evaluation_epoch_end" +) def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir): """ Tests that `on_evaluation_epoch_end` is called From 04c4bdbec08ec2f1aec1c52d2704e14bcf1acdd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 May 2021 22:30:56 +0200 Subject: [PATCH 225/455] remove obsolete __getstate__ --- pytorch_lightning/loops/training_loop.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index c2fb3bdf8b233..a935cf73ac65e 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -174,12 +174,6 @@ def on_run_end(self): self.trainer.call_hook('on_epoch_end') return self.epoch_output - def __getstate__(self): - # avoid pickling errors "cannot pickle generator object" - self._train_dataloader = None - return self.__dict__ - - # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ From 8942d7a45495f45be6ca7d4c18b7170392876c47 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 01:13:49 +0200 Subject: [PATCH 226/455] Fix tests --- pytorch_lightning/core/lightning.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 416a018391a03..1a6f2514753a6 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -445,20 +445,18 @@ def __sync( device: torch.device = None, ) -> _METRIC: """Sync across workers when using distributed training""" - if not isinstance(value, (torch.Tensor, numbers.Number)): + if isinstance(value, torch.Tensor): + # TODO: Find a way to make the reduction only once, so we don't need to clone. + value = value.clone() + elif isinstance(value, numbers.Number): + value = torch.tensor(value, device=device, dtype=torch.float) + else: return value sync_fn = sync_fn or sync_ddp_if_available dist_available = torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed() if not sync_dist or not dist_available: return value - - # TODO: Find a way to make the reduction only once, so we don't need to clone. - if isinstance(value, torch.Tensor): - value = value.clone() - else: - value = torch.tensor(value, device=device, dtype=torch.float) - return sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op) def write_prediction( From 8b4939ed5a62a60cca678a762ee5a960019c4c32 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 01:20:42 +0200 Subject: [PATCH 227/455] flake8 --- tests/trainer/logging_/test_logger_connector.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 2bf5e7bf15eb9..7b016fa71d931 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -11,11 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Tests to ensure that the training loop works with a dict (1.0) -""" -from copy import deepcopy -from typing import Any, Callable from unittest import mock import pytest From 46751b8fc007a0abf3e470c0d473dd768be39ffd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 31 May 2021 01:30:54 +0200 Subject: [PATCH 228/455] fix reload dataloaders --- .../dataloader/evaluation_dataloader_loop.py | 31 ++++++++++++++- pytorch_lightning/trainer/trainer.py | 38 ++++++++++++------- 2 files changed, 54 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py b/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py index f7e4dfaca9544..3d9efc0d59715 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py @@ -44,7 +44,7 @@ def reset(self) -> None: self.iteration_count = 0 # prepare dataloaders - self._dataloaders, self._max_batches = self.get_evaluation_dataloaders() + self._dataloaders, self._max_batches = self.get_eval_dataloaders(), self.get_max_batches() # bookkeeping self.outputs = [] @@ -98,6 +98,35 @@ def on_run_end(self) -> Any: # HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ + def get_max_batches(self): + # select dataloaders + if self.trainer.testing: + max_batches = self.trainer.num_test_batches + else: + if self.trainer.sanity_checking: + self.trainer.num_sanity_val_batches = [ + min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches + ] + max_batches = self.trainer.num_sanity_val_batches + else: + max_batches = self.trainer.num_val_batches + return max_batches + + def get_eval_dataloaders(self): + model = self.trainer.lightning_module + + # select dataloaders + if self.trainer.testing: + # self.trainer.reset_test_dataloader(model) + dataloaders = self.trainer.test_dataloaders + else: + # val + # if self.trainer.val_dataloaders is None or self.trainer.reload_dataloaders_every_epoch: + # self.trainer.reset_val_dataloader(model) + dataloaders = self.trainer.val_dataloaders + return dataloaders + + # TODO: remove this method, got split into two above def get_evaluation_dataloaders(self) -> Tuple[Optional[List[DataLoader]], List[Union[int, float]]]: model = self.trainer.lightning_module diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index db17520c3e239..c241f36e4069f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1064,6 +1064,13 @@ def _run_evaluatin_old_loop(self) -> _EVALUATE_OUTPUT: if self.evaluation_loop.should_skip_evaluation(max_batches): return [], [] + # enable eval mode + no grads + self.evaluation_loop.on_evaluation_model_eval() + # ref model + model = self.lightning_module + model.zero_grad() + torch.set_grad_enabled(False) + # hook self.evaluation_loop.on_evaluation_start() @@ -1140,23 +1147,26 @@ def _run_evaluation(self) -> _EVALUATE_OUTPUT: ) self.validating = True - # TODO: move this check inside new loop - # prepare dataloaders - dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders() + if NEW_LOOP: + # # TODO: move this check inside new loop + # # prepare dataloaders - # TODO: move this check inside new loop - # check if we want to skip this evaluation - if self.evaluation_loop.should_skip_evaluation(max_batches): - return [], [] + dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders() - # enable eval mode + no grads - self.evaluation_loop.on_evaluation_model_eval() - # ref model - model = self.lightning_module - model.zero_grad() - torch.set_grad_enabled(False) + # max_batches = self.evaluation_loop.get_max_batches() + # + # # TODO: move this check inside new loop + # # check if we want to skip this evaluation + if self.evaluation_loop.should_skip_evaluation(max_batches): + return [], [] + + # enable eval mode + no grads + self.evaluation_loop.on_evaluation_model_eval() + # ref model + model = self.lightning_module + model.zero_grad() + torch.set_grad_enabled(False) - if NEW_LOOP: eval_loop_results = self.evaluation_loop.run() else: eval_loop_results = self._run_evaluatin_old_loop() From 60d867cb91ea0981d57d520835c227df5c5e5a92 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 01:31:09 +0200 Subject: [PATCH 229/455] Refactor __sync --- pytorch_lightning/core/lightning.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 1a6f2514753a6..ceba630b6e6aa 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -361,7 +361,7 @@ def log( sync_dist_group=sync_dist_group, device=self.device, ) - value = apply_to_collection(value, (torch.Tensor, float, int), sync_fn) + value = apply_to_collection(value, (torch.Tensor, numbers.Number), sync_fn) result_collection.log( self._current_fx_name, @@ -437,22 +437,16 @@ def log_dict( @staticmethod def __sync( - value: _METRIC, + value: Union[torch.Tensor, numbers.Number], sync_fn: Optional[Callable] = None, sync_dist: bool = False, sync_dist_op: Union[Any, str] = 'mean', sync_dist_group: Optional[Any] = None, device: torch.device = None, - ) -> _METRIC: + ) -> torch.Tensor: """Sync across workers when using distributed training""" - if isinstance(value, torch.Tensor): - # TODO: Find a way to make the reduction only once, so we don't need to clone. - value = value.clone() - elif isinstance(value, numbers.Number): + if isinstance(value, numbers.Number): value = torch.tensor(value, device=device, dtype=torch.float) - else: - return value - sync_fn = sync_fn or sync_ddp_if_available dist_available = torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed() if not sync_dist or not dist_available: From 9b4b58c66fbe06588cf40c7f58f36992fe4718c2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 01:39:48 +0200 Subject: [PATCH 230/455] Fix tests --- pytorch_lightning/trainer/training_loop.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 099e0971757d7..01dd191a76129 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -479,6 +479,11 @@ def run_training_epoch(self): if batch_output.signal == -1: break + # ----------------------------------------- + # SAVE METRICS TO LOGGERS AND PROGRESS_BAR + # ----------------------------------------- + self.trainer.logger_connector.update_train_step_metrics(batch_output) + # hook self.on_train_batch_end( epoch_output, @@ -488,11 +493,6 @@ def run_training_epoch(self): dataloader_idx, ) - # ----------------------------------------- - # SAVE METRICS TO LOGGERS AND PROGRESS_BAR - # ----------------------------------------- - self.trainer.logger_connector.update_train_step_metrics(batch_output) - # ----------------------------------------- # VALIDATE IF NEEDED # ----------------------------------------- From 994421c33941f4cf74cca90a736c79491a728258 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 01:49:54 +0200 Subject: [PATCH 231/455] lru cache --- pytorch_lightning/core/lightning.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ceba630b6e6aa..1f61372dcd105 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -15,6 +15,7 @@ import collections import copy +import functools import inspect import logging import numbers @@ -47,7 +48,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.types import _METRIC, EPOCH_OUTPUT, STEP_OUTPUT +from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache warning_cache = WarningCache() @@ -112,7 +113,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._automatic_optimization: bool = True self._truncated_bptt_steps: int = 0 self._param_requires_grad_state = dict() - self._metric_attributes: Optional[Dict[int, str]] = None def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: if use_pl_optimizer: @@ -344,14 +344,8 @@ def log( ) if metric_attribute is None and isinstance(value, Metric): - if self._metric_attributes is None: - # compute once - self._metric_attributes = { - module: name - for name, module in self.named_children() if isinstance(module, Metric) - } # try to find the passed metric in the LightningModule - metric_attribute = self._metric_attributes.get(value, None) + metric_attribute = self.__metric_attributes.get(value, None) sync_fn = partial( self.__sync, @@ -378,6 +372,11 @@ def log( lightning_attribute_name=metric_attribute, ) + @property + @functools.lru_cache(maxsize=1) + def __metric_attributes(self) -> Dict[Metric, str]: + return {module: name for name, module in self.named_children() if isinstance(module, Metric)} + def log_dict( self, dictionary: dict, From 6dd519b6b57fab8a342751c276ed29c00006a49b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 02:25:34 +0200 Subject: [PATCH 232/455] Resolve minimize --- .../connectors/logger_connector/result.py | 8 +++++++ pytorch_lightning/trainer/training_loop.py | 21 ++++--------------- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 7b8b40bd6459a..7c541a6a6c8fc 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -677,3 +677,11 @@ def re_assign_metric(item): item.value = metrics[lightning_attribute_name] apply_to_collection(dict(self.items()), ResultMetric, re_assign_metric) + + def __getstate__(self) -> dict: + d = self.__dict__.copy() + # can't deepcopy tensors with grad_fn + minimize = d.get('_minimize') + if minimize is not None: + d['_minimize'] = minimize.detach() + return d diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 01dd191a76129..82ad2d4522caf 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -286,24 +286,14 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): if training_step_output is None: return - # enable empty loss when using manual opt closure_loss = None - untouched_loss = None - + loss = None if self.trainer.lightning_module.automatic_optimization: # accumulate loss. if accumulate_grad_batches==1, no effect closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches - # the loss will get scaled for amp. avoid any modifications to it - untouched_loss = closure_loss.detach().clone() - - # result - result = AttributeDict( - closure_loss=closure_loss, - loss=untouched_loss, - training_step_output=training_step_output, - ) - return result + loss = closure_loss.detach().clone() + return AttributeDict(closure_loss=closure_loss, loss=loss, training_step_output=training_step_output) def _process_training_step_output(self, training_step_output): if training_step_output is None: @@ -373,10 +363,7 @@ def _prepare_outputs( for tbptt_output in batch_outputs: out = tbptt_output.extra - loss = tbptt_output.minimize - if isinstance(loss, torch.Tensor): - loss = loss.detach() - out['loss'] = loss + out['loss'] = tbptt_output.minimize.detach() processed_tbptt_outputs.append(out) # if there was only one tbptt step then we can collapse that dimension From 56b67beaa89f407316e715374a0050194f9be987 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 02:29:19 +0200 Subject: [PATCH 233/455] Revert changes --- pytorch_lightning/core/lightning.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 1f61372dcd105..95029a48975ea 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -15,7 +15,6 @@ import collections import copy -import functools import inspect import logging import numbers @@ -48,7 +47,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT +from pytorch_lightning.utilities.types import _METRIC, EPOCH_OUTPUT, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache warning_cache = WarningCache() @@ -113,6 +112,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._automatic_optimization: bool = True self._truncated_bptt_steps: int = 0 self._param_requires_grad_state = dict() + self._metric_attributes: Optional[Dict[int, str]] = None def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: if use_pl_optimizer: @@ -344,8 +344,14 @@ def log( ) if metric_attribute is None and isinstance(value, Metric): + if self._metric_attributes is None: + # compute once + self._metric_attributes = { + id(module): name + for name, module in self.named_children() if isinstance(module, Metric) + } # try to find the passed metric in the LightningModule - metric_attribute = self.__metric_attributes.get(value, None) + metric_attribute = self._metric_attributes.get(id(value)) sync_fn = partial( self.__sync, @@ -372,11 +378,6 @@ def log( lightning_attribute_name=metric_attribute, ) - @property - @functools.lru_cache(maxsize=1) - def __metric_attributes(self) -> Dict[Metric, str]: - return {module: name for name, module in self.named_children() if isinstance(module, Metric)} - def log_dict( self, dictionary: dict, From b5a4fb92c6dc1a8940a409183a51a44b2cec144d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 02:29:57 +0200 Subject: [PATCH 234/455] flake8 --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 95029a48975ea..3e2da52bfa292 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -47,7 +47,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.types import _METRIC, EPOCH_OUTPUT, STEP_OUTPUT +from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache warning_cache = WarningCache() From b050ff5c79f5861dbd866f460aecf226f23ce701 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 03:42:31 +0200 Subject: [PATCH 235/455] Reduce diff --- .../logger_connector/logger_connector.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 2cf9ec9e20339..78f1f9ceaea9f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -48,6 +48,16 @@ def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_ste self.trainer.log_every_n_steps = log_every_n_steps self.trainer.move_metrics_to_cpu = move_metrics_to_cpu + @property + def should_flush_logs(self): + should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0 + return should_flush or self.trainer.should_stop + + @property + def should_update_logs(self): + should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 + return should_log_every_n_steps or self.trainer.should_stop + def configure_logger(self, logger): if logger is True: version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id) @@ -64,16 +74,6 @@ def configure_logger(self, logger): else: self.trainer.logger = logger - @property - def should_flush_logs(self): - should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0 - return should_flush or self.trainer.should_stop - - @property - def should_update_logs(self): - should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 - return should_log_every_n_steps or self.trainer.should_stop - def log_metrics(self, metrics, grad_norm_dict, step=None): """Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, From a608562beaebdc46f972c7bc06aa8f2492cce2a5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 03:57:43 +0200 Subject: [PATCH 236/455] Move fx validator --- pytorch_lightning/core/lightning.py | 2 +- .../connectors/logger_connector/logger_connector.py | 13 ++++--------- .../trainer/connectors/logger_connector/result.py | 4 +++- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 3e2da52bfa292..44c6c9dc6c96a 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -335,7 +335,7 @@ def log( on_epoch = self.__auto_choose_log_on_epoch(on_epoch) assert self._current_fx_name is not None - self.trainer.logger_connector.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) + result_collection.fx_validator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) # make sure user doesn't introduce logic for multi-dataloaders if "/dataloader_idx_" in name: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 78f1f9ceaea9f..d5fb9a027adac 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -20,7 +20,6 @@ from pytorch_lightning.core import memory from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger -from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.trainer.connectors.logger_connector.result import DefaultMetricsKeys from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import DeviceType @@ -34,7 +33,6 @@ def __init__(self, trainer, log_gpu_memory: Optional[str] = None): self.trainer = trainer self.log_gpu_memory = log_gpu_memory self.eval_loop_results = [] - self._fx_validator = FxValidator() self._val_log_step: int = 0 self._test_log_step: int = 0 self._progress_bar_metrics: Dict[str, float] = {} @@ -129,10 +127,9 @@ def add_to_eval_loop_results(self, dl_idx, has_been_initialized): callback_metrics = self.trainer.result_collection.metrics[DefaultMetricsKeys.CALLBACK] callback_metrics = deepcopy(callback_metrics) for key in list(callback_metrics.keys()): - if "dataloader_idx" in key: - if f"dataloader_idx_{dl_idx}" not in key: - # remove dl_idx from self.callback_metrics not belonging to this dataset. - del callback_metrics[key] + if "dataloader_idx" in key and f"dataloader_idx_{dl_idx}" not in key: + # remove callback metrics that don't belong to this dataloader + del callback_metrics[key] if has_been_initialized: self.eval_loop_results[dl_idx].update(callback_metrics) else: @@ -240,6 +237,7 @@ def on_train_split_start(self, batch_idx: int, split_batch: Any) -> None: self.trainer.result_collection.batch_idx = batch_idx def on_train_batch_end(self) -> None: + # TODO: why self.trainer.result_collection.batch_size = 1 def update_train_step_metrics(self, batch_output): @@ -307,6 +305,3 @@ def progress_bar_metrics(self) -> Dict[str, float]: metrics = self.trainer.result_collection.metrics[DefaultMetricsKeys.PBAR] self._progress_bar_metrics.update(metrics) return self._progress_bar_metrics - - def check_logging(self, fx_name: str, on_step: bool, on_epoch: bool) -> None: - self._fx_validator.check_logging(fx_name=fx_name, on_step=on_step, on_epoch=on_epoch) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 7c541a6a6c8fc..69bd092a6b07f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -20,6 +20,7 @@ from torch import Tensor from torchmetrics import Metric +from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.enums import LightningEnum @@ -172,7 +173,7 @@ class ResultCollection(dict): Example: - # the root_device need to be provided before calling the ``log`` function + # the root_device need to be provided before calling the ``log`` function result = ResultCollection(True, torch.device("cpu")) # arguments: hook_name, key, value, metadata @@ -216,6 +217,7 @@ def __init__(self, is_train: bool, root_device: Optional[torch.device] = None) - self._batch_size: Optional[int] = None self._batch_idx: Optional[int] = None self._root_device: Optional[torch.device] = root_device + self.fx_validator = FxValidator() @property def batch_size(self) -> int: From 339d91fc055fb8fd3cec6e03f01bf0ad4495dbc6 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 04:32:09 +0200 Subject: [PATCH 237/455] Minimize can be None in manual --- pytorch_lightning/trainer/training_loop.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 82ad2d4522caf..641ee14ca1de5 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -335,14 +335,14 @@ def _prepare_outputs( Extract required information from batch or epoch end results. Args: - outputs: A 3-dimensional list of ``Result`` objects with dimensions: - [optimizer outs][batch outs][tbptt steps]. + outputs: A 3-dimensional list of ``ResultCollection`` objects with dimensions: + ``[optimizer outs][batch outs][tbptt steps]``. batch_mode: If True, ignore the batch output dimension. Returns: - The cleaned outputs with ``Result`` objects converted to dictionaries. All list dimensions of size one will - be collapsed. + The cleaned outputs with ``ResultCollection`` objects converted to dictionaries. + All list dimensions of size one will be collapsed. """ processed_outputs = [] for opt_outputs in outputs: @@ -363,7 +363,8 @@ def _prepare_outputs( for tbptt_output in batch_outputs: out = tbptt_output.extra - out['loss'] = tbptt_output.minimize.detach() + if tbptt_output.minimize is not None: + out['loss'] = tbptt_output.minimize.detach() processed_tbptt_outputs.append(out) # if there was only one tbptt step then we can collapse that dimension From 061e407e00e77b6ac2e5a37607ad3dfb654fd5b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 31 May 2021 12:01:38 +0200 Subject: [PATCH 238/455] simplify --- .../loops/dataloader/evaluation_dataloader_loop.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py b/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py index 3d9efc0d59715..e83d2f43422df 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py @@ -113,18 +113,9 @@ def get_max_batches(self): return max_batches def get_eval_dataloaders(self): - model = self.trainer.lightning_module - - # select dataloaders if self.trainer.testing: - # self.trainer.reset_test_dataloader(model) - dataloaders = self.trainer.test_dataloaders - else: - # val - # if self.trainer.val_dataloaders is None or self.trainer.reload_dataloaders_every_epoch: - # self.trainer.reset_val_dataloader(model) - dataloaders = self.trainer.val_dataloaders - return dataloaders + return self.trainer.test_dataloaders + return self.trainer.val_dataloaders # TODO: remove this method, got split into two above def get_evaluation_dataloaders(self) -> Tuple[Optional[List[DataLoader]], List[Union[int, float]]]: From 6d14b4791ca6867440c8ca73ac0ad50170b47575 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 31 May 2021 10:14:19 +0000 Subject: [PATCH 239/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/training_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index a935cf73ac65e..7b6d4c9f4e195 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -174,6 +174,7 @@ def on_run_end(self): self.trainer.call_hook('on_epoch_end') return self.epoch_output + # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP # ------------------------------------------------------------------------------------------------------------ From 4621a525857b38cc6e5d86621e1843f9fb55a868 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 31 May 2021 13:09:28 +0200 Subject: [PATCH 240/455] cosmetic change --- pytorch_lightning/loops/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index a935cf73ac65e..078bf71427519 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -50,7 +50,7 @@ def connect(self, trainer: 'pl.Trainer', *args, **kwargs): @property def done(self): - max_steps_reached = (self.max_steps is not None and self.max_steps <= self.global_step) + max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch) def run(self, *args, **kwargs): From 5f37bc9effc4e28943020a4f18c35313af696a9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 31 May 2021 13:19:15 +0200 Subject: [PATCH 241/455] fix "on_train_start" hook call --- pytorch_lightning/loops/epoch_loop.py | 6 ++++-- pytorch_lightning/trainer/trainer.py | 7 ++++++- pytorch_lightning/trainer/training_loop.py | 4 ---- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 4bef4272fcd69..4d6b2437057ec 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -97,7 +97,7 @@ def connect(self, trainer: 'pl.Trainer', *args, **kwargs): @property def done(self) -> bool: # TODO: Move track steps inside training loop and move part of these condition inside training loop - stop_steps = self.max_steps and self.max_steps <= self.global_step + stop_steps = self.max_steps and self.global_step >= self.max_steps should_stop = False if self.trainer.should_stop: @@ -124,7 +124,9 @@ def reset(self) -> None: def on_run_start(self): # hook - self.trainer.call_hook("on_train_start") + # TODO: move it here, currently in Trainer._run_train_new_loop + # self.trainer.call_hook("on_train_start") + pass def on_advance_start(self): # equal to old on_train_epoch_start model = self.trainer.lightning_module diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c241f36e4069f..d28584ee5cb37 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -941,6 +941,7 @@ def _should_skip_training(self) -> bool: def _run_train_new_loop(self) -> None: self._pre_training_routine() + if not self.is_global_zero and self.progress_bar_callback is not None: self.progress_bar_callback.disable() @@ -955,9 +956,13 @@ def _run_train_new_loop(self) -> None: # reload data when needed model = self.lightning_module - # This might move somewhere else + # TODO: This might move somewhere else self.reset_train_val_dataloaders(model) + # hook + # TODO: move this inside loop together with skip condition below + self.call_hook("on_train_start") + try: # TODO: move skip condition into EpochLoop.done() if self._should_skip_training(): diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 7a12d7e766dae..901b40bc25c4f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -98,10 +98,6 @@ def should_skip_training(self) -> bool: should_by_epoch = self.max_epochs is not None and self.current_epoch >= self.max_epochs return should_by_max_steps or should_by_epoch or self.trainer.num_training_batches == 0 - def on_train_start(self): - # hook - self.trainer.call_hook("on_train_start") - def on_train_end(self): if self._teardown_already_run: return From 209c5c27b77bff2658a4407094da5d1e619df87a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 31 May 2021 15:05:21 +0200 Subject: [PATCH 242/455] simplify should_skip_training condition and hook calls --- pytorch_lightning/loops/epoch_loop.py | 16 ++++++++++------ pytorch_lightning/trainer/trainer.py | 8 +------- tests/callbacks/test_progress_bar.py | 3 ++- tests/models/test_restore.py | 2 +- 4 files changed, 14 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 4d6b2437057ec..349a7da88e9f6 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -97,7 +97,8 @@ def connect(self, trainer: 'pl.Trainer', *args, **kwargs): @property def done(self) -> bool: # TODO: Move track steps inside training loop and move part of these condition inside training loop - stop_steps = self.max_steps and self.global_step >= self.max_steps + stop_steps = self.max_steps is not None and self.global_step >= self.max_steps + stop_epochs = self.max_epochs is not None and self.current_epoch >= self.max_epochs should_stop = False if self.trainer.should_stop: @@ -116,17 +117,17 @@ def done(self) -> bool: ) self.trainer.should_stop = False - stop_epochs = self.current_epoch >= self.max_epochs if self.max_epochs is not None else False return stop_steps or should_stop or stop_epochs def reset(self) -> None: self.iteration_count = 0 + def run(self): + if not self._should_skip_training(): + return super().run() + def on_run_start(self): - # hook - # TODO: move it here, currently in Trainer._run_train_new_loop - # self.trainer.call_hook("on_train_start") - pass + self.trainer.call_hook("on_train_start") def on_advance_start(self): # equal to old on_train_epoch_start model = self.trainer.lightning_module @@ -226,6 +227,9 @@ def on_run_end(self): # reset bookkeeping self.trainer._running_stage = None + def _should_skip_training(self) -> bool: + return self.done or self.trainer.num_training_batches == 0 + def should_accumulate(self): return self.training_loop.batch_loop.should_accumulate() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d28584ee5cb37..da2cf556f946f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -934,6 +934,7 @@ def _run_train(self) -> None: else: self._run_train_old_loop() + # TODO: remove together with old loop def _should_skip_training(self) -> bool: should_by_max_steps = self.max_steps is not None and self.global_step >= self.max_steps should_by_epoch = self.max_epochs is not None and self.current_epoch >= self.max_epochs @@ -959,14 +960,7 @@ def _run_train_new_loop(self) -> None: # TODO: This might move somewhere else self.reset_train_val_dataloaders(model) - # hook - # TODO: move this inside loop together with skip condition below - self.call_hook("on_train_start") - try: - # TODO: move skip condition into EpochLoop.done() - if self._should_skip_training(): - return self.train_loop.run() except KeyboardInterrupt: rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...') diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index f4f8f34c1b4c1..a16cd3c93ab9c 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -350,7 +350,8 @@ def test_main_progress_bar_update_amount( checkpoint_callback=False, ) trainer.fit(model) - progress_bar.main_progress_bar.update.assert_has_calls([call(delta) for delta in train_deltas]) + if train_batches > 0: + progress_bar.main_progress_bar.update.assert_has_calls([call(delta) for delta in train_deltas]) if val_batches > 0: progress_bar.val_progress_bar.update.assert_has_calls([call(delta) for delta in val_deltas]) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 09ae795297eb5..e92b68aa9082d 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -143,7 +143,7 @@ def test_try_resume_from_non_existing_checkpoint(tmpdir): class CaptureCallbacksBeforeTraining(Callback): callbacks = [] - def on_train_start(self, trainer, pl_module): + def on_pretrain_routine_end(self, trainer, pl_module): self.callbacks = deepcopy(trainer.callbacks) From 86d30f678016315a912d5d065f9a03f978c74abe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 31 May 2021 16:59:10 +0200 Subject: [PATCH 243/455] fix test when on_train_start does not run anymore --- tests/models/test_restore.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index e92b68aa9082d..c3e78698fadbc 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -431,10 +431,10 @@ class CustomModel(CustomClassificationModelDP): def __init__(self): super().__init__() - self.on_train_start_called = False + self.on_pretrain_routine_end_called = False # set the epoch start hook so we can predict before the model does the full training - def on_train_start(self): + def on_pretrain_routine_end(self): assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0 # if model and state loaded correctly, predictions will be good even though we @@ -443,14 +443,14 @@ def on_train_start(self): dataloader = self.train_dataloader() tpipes.run_prediction_eval_model_template(self.trainer.lightning_module, dataloader=dataloader) - self.on_train_start_called = True + self.on_pretrain_routine_end_called = True # new model model = CustomModel() # fit new model which should load hpc weights new_trainer.fit(model, datamodule=dm) - assert model.on_train_start_called + assert model.on_pretrain_routine_end_called # test freeze on gpu model.freeze() From dfe25c100b2d7db91a7b629363dfdc5918679415 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 31 May 2021 18:49:20 +0200 Subject: [PATCH 244/455] fix hook calling in old loop --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index da2cf556f946f..4c385349f7d63 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -999,7 +999,7 @@ def _run_train_old_loop(self) -> None: self.train_loop.reset_train_val_dataloaders(model) # hook - self.train_loop.on_train_start() + self.call_hook("on_train_start") try: if self._should_skip_training(): From 92dabe3ec2148a39a89b5716dc945509db8aa56a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 31 May 2021 18:49:48 +0200 Subject: [PATCH 245/455] add back forgotten warning message --- pytorch_lightning/loops/batch_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 78757419d3d91..9aae8b70ff1f7 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -55,7 +55,8 @@ def done(self): def run(self, batch, batch_idx, dataloader_idx): if batch is None: - return AttributeDict(signal=0, grad_norm_dic={}) + self.warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") + return AttributeDict(signal=0, grad_norm_dic={}, training_step_output_for_epoch_end=[[]]) # hook response = self.trainer.call_hook("on_batch_start") From 064597133ad6f829ec0d37cd0e15a04c5181badb Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 20:33:50 +0200 Subject: [PATCH 246/455] DefaultMetricKeys -> MetricSource --- .../logger_connector/logger_connector.py | 34 +++++++++---------- .../connectors/logger_connector/result.py | 18 +++++----- tests/core/test_metric_result_integration.py | 16 ++++----- .../trainer/logging_/test_logger_connector.py | 14 ++++---- 4 files changed, 41 insertions(+), 41 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index d5fb9a027adac..0fad69a8214c7 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -20,7 +20,7 @@ from pytorch_lightning.core import memory from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger -from pytorch_lightning.trainer.connectors.logger_connector.result import DefaultMetricsKeys +from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import DeviceType from pytorch_lightning.utilities.metrics import metrics_to_scalars @@ -124,7 +124,7 @@ def add_to_eval_loop_results(self, dl_idx, has_been_initialized): if self.trainer.sanity_checking: return - callback_metrics = self.trainer.result_collection.metrics[DefaultMetricsKeys.CALLBACK] + callback_metrics = self.trainer.result_collection.metrics[MetricSource.CALLBACK] callback_metrics = deepcopy(callback_metrics) for key in list(callback_metrics.keys()): if "dataloader_idx" in key and f"dataloader_idx_{dl_idx}" not in key: @@ -145,13 +145,13 @@ def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: metrics = self.trainer.result_collection.metrics # update metrics - self._progress_bar_metrics.update(metrics[DefaultMetricsKeys.PBAR]) - self._callback_metrics.update(metrics[DefaultMetricsKeys.CALLBACK]) + self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) + self._callback_metrics.update(metrics[MetricSource.CALLBACK]) if not self.trainer.sanity_checking: # log all the metrics as a single dict - metrics_to_log = metrics[DefaultMetricsKeys.LOG] + metrics_to_log = metrics[MetricSource.LOG] if len(metrics_to_log) > 0: self.log_metrics(metrics_to_log, {}) @@ -209,14 +209,14 @@ def update_evaluation_step_metrics(self) -> None: metrics = self.trainer.result_collection.metrics # update metrics - self._progress_bar_metrics.update(metrics[DefaultMetricsKeys.PBAR]) - self._callback_metrics.update(metrics[DefaultMetricsKeys.CALLBACK]) + self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) + self._callback_metrics.update(metrics[MetricSource.CALLBACK]) if self.trainer.sanity_checking: return # logs user requested information to logger - batch_log_metrics = metrics[DefaultMetricsKeys.LOG] + batch_log_metrics = metrics[MetricSource.LOG] if len(batch_log_metrics) > 0: kwargs = dict() if "step" in batch_log_metrics else dict(step=self.evaluation_log_step) self.log_metrics(batch_log_metrics, {}, **kwargs) @@ -244,14 +244,14 @@ def update_train_step_metrics(self, batch_output): metrics = self.trainer.result_collection.metrics # update metrics - self._progress_bar_metrics.update(metrics[DefaultMetricsKeys.PBAR]) - self._callback_metrics.update(metrics[DefaultMetricsKeys.CALLBACK]) + self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) + self._callback_metrics.update(metrics[MetricSource.CALLBACK]) if self.trainer.train_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization: return # when metrics should be logged - batch_log_metrics = metrics[DefaultMetricsKeys.LOG] + batch_log_metrics = metrics[MetricSource.LOG] if self.should_update_logs or self.trainer.fast_dev_run is True: # logs user requested information to logger grad_norm_dict = batch_output.grad_norm_dict @@ -268,11 +268,11 @@ def update_train_epoch_metrics(self) -> None: metrics = self.trainer.result_collection.metrics - self._progress_bar_metrics.update(metrics[DefaultMetricsKeys.PBAR]) - self._callback_metrics.update(metrics[DefaultMetricsKeys.CALLBACK]) + self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) + self._callback_metrics.update(metrics[MetricSource.CALLBACK]) # add the metrics to the loggers - epoch_log_metrics = metrics[DefaultMetricsKeys.LOG] + epoch_log_metrics = metrics[MetricSource.LOG] if epoch_log_metrics and len(epoch_log_metrics) > 0: epoch_log_metrics["epoch"] = self.trainer.current_epoch self._logged_metrics.update(epoch_log_metrics) @@ -288,20 +288,20 @@ def update_train_epoch_metrics(self) -> None: @property def callback_metrics(self) -> Dict[str, float]: if self.trainer.result_collection: - metrics = self.trainer.result_collection.metrics[DefaultMetricsKeys.CALLBACK] + metrics = self.trainer.result_collection.metrics[MetricSource.CALLBACK] self._callback_metrics.update(metrics) return self._callback_metrics @property def logged_metrics(self) -> Dict[str, float]: if self.trainer.result_collection: - metrics = self.trainer.result_collection.metrics[DefaultMetricsKeys.LOG] + metrics = self.trainer.result_collection.metrics[MetricSource.LOG] self._logged_metrics.update(metrics) return self._logged_metrics @property def progress_bar_metrics(self) -> Dict[str, float]: if self.trainer.result_collection: - metrics = self.trainer.result_collection.metrics[DefaultMetricsKeys.PBAR] + metrics = self.trainer.result_collection.metrics[MetricSource.PBAR] self._progress_bar_metrics.update(metrics) return self._progress_bar_metrics diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 69bd092a6b07f..f80ba207240a1 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -28,7 +28,7 @@ from pytorch_lightning.utilities.types import _METRIC -class DefaultMetricsKeys(LightningEnum): +class MetricSource(LightningEnum): CALLBACK = "callback" PBAR = "pbar" LOG = "log" @@ -260,9 +260,9 @@ def metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: { - DefaultMetricsKeys.PBAR: {...}, - DefaultMetricsKeys.LOG: {...}, - DefaultMetricsKeys.CALLBACK: {...} + MetricSource.PBAR: {...}, + MetricSource.LOG: {...}, + MetricSource.CALLBACK: {...} } """ return self.get_epoch_metrics() if self.on_epoch_end_reached else self.get_batch_metrics() @@ -490,7 +490,7 @@ def _extract_metadata(self, key: str, result_metric, on_step: bool, suffix: str) return name, name_forked, logger, prog_bar, metric_on_epoch def get_metrics(self, on_step: bool) -> Dict[str, Dict[str, torch.Tensor]]: - metrics = {k: {} for k in DefaultMetricsKeys} + metrics = {k: {} for k in MetricSource} # either extract `forward_cache` or `computed` from `ResultMetric` objects fn = self._get_forward_cache if on_step else self._get_computed_cache @@ -529,20 +529,20 @@ def is_empty_fn(v): # populate logging metrics if logger: - metrics[DefaultMetricsKeys.LOG][name_forked] = value + metrics[MetricSource.LOG][name_forked] = value # populate callback metrics # callback metrics don't take `_step` forked metrics. if not self.is_train and (not metric_on_epoch or on_step): pass else: - metrics[DefaultMetricsKeys.CALLBACK][name] = value - metrics[DefaultMetricsKeys.CALLBACK][name_forked] = value + metrics[MetricSource.CALLBACK][name] = value + metrics[MetricSource.CALLBACK][name_forked] = value # populate progress_bar metrics. By default, the value should be converted to a float. if prog_bar: value = apply_to_collection(value, torch.Tensor, self._to_item, include_none=False) - metrics[DefaultMetricsKeys.PBAR][name_forked] = value + metrics[MetricSource.PBAR][name_forked] = value return metrics diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index c6c133a87b62c..728da426ba9ff 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -19,7 +19,7 @@ from torchmetrics import Metric import tests.helpers.utils as tutils -from pytorch_lightning.trainer.connectors.logger_connector.result import DefaultMetricsKeys, ResultCollection +from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource, ResultCollection from tests.helpers.runif import RunIf @@ -79,7 +79,7 @@ def _ddp_test_fn(rank, worldsize): result.log('h', 'b', metric_b, on_step=False, on_epoch=True, lightning_attribute_name="metric_b") result.log('h', 'c', metric_c, on_step=True, on_epoch=False, lightning_attribute_name="metric_c") - batch_log = result.get_batch_metrics()[DefaultMetricsKeys.LOG] + batch_log = result.get_batch_metrics()[MetricSource.LOG] batch_expected = {"a_step": i, "c": i} assert set(batch_log.keys()) == set(batch_expected.keys()) for k in batch_expected.keys(): @@ -87,7 +87,7 @@ def _ddp_test_fn(rank, worldsize): result.on_epoch_end_reached = True - epoch_log = result.get_epoch_metrics()[DefaultMetricsKeys.LOG] + epoch_log = result.get_epoch_metrics()[MetricSource.LOG] result.reset() # assert metric state reset to default values @@ -138,7 +138,7 @@ def test_result_metric_integration(): result.log('h', 'b', metric_b, on_step=False, on_epoch=True, lightning_attribute_name="metric_b") result.log('h', 'c', metric_c, on_step=True, on_epoch=False, lightning_attribute_name="metric_c") - batch_log = result.get_batch_metrics()[DefaultMetricsKeys.LOG] + batch_log = result.get_batch_metrics()[MetricSource.LOG] batch_expected = {"a_step": i, "c": i} assert set(batch_log.keys()) == set(batch_expected.keys()) for k in batch_expected.keys(): @@ -146,7 +146,7 @@ def test_result_metric_integration(): result.on_epoch_end_reached = True - epoch_log = result.get_epoch_metrics()[DefaultMetricsKeys.LOG] + epoch_log = result.get_epoch_metrics()[MetricSource.LOG] result.reset() # assert metric state reset to default values @@ -196,7 +196,7 @@ def test_result_collection_restoration(): result.log('training_step', 'b_1', b, on_step=False, on_epoch=True) result.log('training_step', 'c_1', [c, c], on_step=True, on_epoch=False) - batch_log = result.metrics[DefaultMetricsKeys.LOG] + batch_log = result.metrics[MetricSource.LOG] batch_expected = {"a_step": i, "c": i, "a_1_step": i, "c_1": [i, i]} assert set(batch_log.keys()) == set(batch_expected.keys()) @@ -221,8 +221,8 @@ def test_result_collection_restoration(): result.on_epoch_end_reached = True _result.on_epoch_end_reached = True - epoch_log = result.metrics[DefaultMetricsKeys.LOG] - _epoch_log = _result.metrics[DefaultMetricsKeys.LOG] + epoch_log = result.metrics[MetricSource.LOG] + _epoch_log = _result.metrics[MetricSource.LOG] assert epoch_log == _epoch_log diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 7b016fa71d931..eae2dc3e34937 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -22,7 +22,7 @@ from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator -from pytorch_lightning.trainer.connectors.logger_connector.result import DefaultMetricsKeys, ResultCollection +from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource, ResultCollection from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -478,7 +478,7 @@ def test_result_collection_on_tensor_with_mean_reduction(): 'loss_1_1_1_step': torch.tensor([9.]), 'loss_3_1_1': torch.tensor([9.]) } - assert batch_metrics[DefaultMetricsKeys.PBAR] == expected + assert batch_metrics[MetricSource.PBAR] == expected excepted = { 'loss_1_0_1_step': torch.tensor([9.]), @@ -486,7 +486,7 @@ def test_result_collection_on_tensor_with_mean_reduction(): 'loss_1_1_1_step': torch.tensor([9.]), 'loss_3_1_1': torch.tensor([9.]) } - assert batch_metrics[DefaultMetricsKeys.LOG] == excepted + assert batch_metrics[MetricSource.LOG] == excepted excepted = { 'loss_1_0_0': torch.tensor(9.), @@ -502,7 +502,7 @@ def test_result_collection_on_tensor_with_mean_reduction(): 'loss_1_1_1_step': torch.tensor(9.), 'loss_3_1_1': torch.tensor(9.) } - assert batch_metrics[DefaultMetricsKeys.CALLBACK] == excepted + assert batch_metrics[MetricSource.CALLBACK] == excepted result_collection.on_epoch_end_reached = True @@ -511,10 +511,10 @@ def test_result_collection_on_tensor_with_mean_reduction(): mean = (torch.tensor(excepted_values) * torch.tensor(excepted_batches)).sum() / sum(excepted_batches) expected = {'loss_1_1_0_epoch': mean, 'loss_2_1_0': mean, 'loss_1_1_1_epoch': mean, 'loss_2_1_1': mean} - assert epoch_metrics[DefaultMetricsKeys.PBAR] == expected + assert epoch_metrics[MetricSource.PBAR] == expected excepted = {'loss_1_0_1_epoch': mean, 'loss_2_0_1': mean, 'loss_1_1_1_epoch': mean, 'loss_2_1_1': mean} - assert epoch_metrics[DefaultMetricsKeys.LOG] == excepted + assert epoch_metrics[MetricSource.LOG] == excepted excepted = { 'loss_1_0_0': mean, @@ -530,4 +530,4 @@ def test_result_collection_on_tensor_with_mean_reduction(): 'loss_1_1_1_epoch': mean, 'loss_2_1_1': mean, } - assert epoch_metrics[DefaultMetricsKeys.CALLBACK] == excepted + assert epoch_metrics[MetricSource.CALLBACK] == excepted From cb003defad4b32cb93f30b4724cf28e24c3d12b8 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 20:34:15 +0200 Subject: [PATCH 247/455] Remove TODO --- pytorch_lightning/trainer/connectors/logger_connector/result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index f80ba207240a1..8cf41e42cde1e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -36,7 +36,7 @@ class MetricSource(LightningEnum): @dataclass class Metadata: - fx: str # TODO: distinction? + fx: str name: str prog_bar: bool = False logger: bool = True From 5ca5b977bf21b4eac03635b810e5439e38b2c6f8 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 20:44:20 +0200 Subject: [PATCH 248/455] forked_name --- .../trainer/connectors/logger_connector/result.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 8cf41e42cde1e..8096fb4b59fe4 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -52,16 +52,9 @@ class Metadata: def forked(self) -> bool: return self.on_step and self.on_epoch - @property - def forked_step_name(self) -> str: - if self.forked: - return self.name + "_step" - return self.name - - @property - def forked_epoch_name(self) -> str: + def forked_name(self, on_step: bool) -> str: if self.forked: - return self.name + "_epoch" + return f'{self.name}_{"step" if on_step else "epoch"}' return self.name @property @@ -468,7 +461,7 @@ def _extract_metadata(self, key: str, result_metric, on_step: bool, suffix: str) if isinstance(result_metric, ResultMetric): name = result_metric.meta.name - name_forked = result_metric.meta.forked_step_name if on_step else result_metric.meta.forked_epoch_name + name_forked = result_metric.meta.forked_name(on_step) logger = result_metric.meta.logger prog_bar = result_metric.meta.prog_bar metric_on_epoch = result_metric.meta.on_epoch From 0a6b18588d770346d436f31ad9aa3a0514e268ab Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 20:59:12 +0200 Subject: [PATCH 249/455] is_tensor_and --- .../connectors/logger_connector/result.py | 61 ++++++++++--------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 8096fb4b59fe4..4ca09abcdbf76 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -58,16 +58,16 @@ def forked_name(self, on_step: bool) -> str: return self.name @property - def is_tensor_and_mean_reduction(self) -> bool: - return self.is_tensor and self.reduce_fx == torch.mean + def is_mean_reduction(self) -> bool: + return self.reduce_fx == torch.mean @property - def is_tensor_and_max_reduction(self) -> bool: - return self.is_tensor and (self.reduce_fx in (torch.max, max)) + def is_max_reduction(self) -> bool: + return self.reduce_fx in (torch.max, max) @property - def is_tensor_and_min_reduction(self) -> bool: - return self.is_tensor and (self.reduce_fx in (torch.min, min)) + def is_min_reduction(self) -> bool: + return self.reduce_fx in (torch.min, min) class ResultMetric(Metric, DeviceDtypeModuleMixin): @@ -79,44 +79,45 @@ def __init__(self, metadata: Metadata) -> None: super().__init__(compute_on_step=metadata.is_tensor) self.meta = metadata if self.meta.is_tensor: - self.add_state("value", torch.tensor(.0)) - if self.meta.is_tensor_and_mean_reduction: - self.add_state("cumulated_batch_size", torch.tensor(.0)) + self.add_state("value", torch.tensor(0, dtype=torch.float)) + if self.meta.is_mean_reduction: + self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float)) + # TODO: self.value? def update(self, value: _METRIC, batch_size: Optional[int] = None) -> None: - if self.meta.is_tensor_and_mean_reduction: - self.value += value.float().mean() * batch_size - self.cumulated_batch_size += batch_size - - elif self.meta.is_tensor_and_max_reduction: - self.value = max(self.value, value.float().mean()) + # TODO: support for non-tensor, sync returns tensor always + if self.meta.is_tensor: + if self.meta.is_mean_reduction: + self.value += value.float().mean() * batch_size + self.cumulated_batch_size += batch_size - elif self.meta.is_tensor_and_min_reduction: - self.value = min(self.value, value.float().mean()) + elif self.meta.is_max_reduction: + self.value = max(self.value, value.float().mean()) + elif self.meta.is_min_reduction: + self.value = min(self.value, value.float().mean()) else: - self.value = value + self.value = value # noqa: attribute-defined-outside-init self._forward_cache = value._forward_cache def compute(self) -> torch.Tensor: if self.meta.is_tensor: - if self.meta.is_tensor_and_mean_reduction: + if self.meta.is_mean_reduction: return torch.sum(self.value) / torch.sum(self.cumulated_batch_size) - elif self.meta.is_tensor_and_max_reduction or self.meta.is_tensor_and_min_reduction: + elif self.meta.is_max_reduction or self.meta.is_min_reduction: return self.value - else: - raise MisconfigurationException("Only min, mean, max reduction are supported.") - else: - return self.value.compute() + raise MisconfigurationException( + f"Only [min, max, mean] reductions are supported. Found {self.meta.reduce_fx}" + ) + return self.value.compute() def __repr__(self) -> str: - if self.meta.is_tensor_and_mean_reduction: - attr = f"value={self.value}, cumulated_batch_size={self.cumulated_batch_size}" - else: - attr = f"value={getattr(self, 'value', None)}" - return f"{self.__class__.__name__}({attr})" + state = f"value={self.value}" + if self.meta.is_tensor and self.meta.is_mean_reduction: + state += f", cumulated_batch_size={self.cumulated_batch_size}" + return f"{self.__class__.__name__}({state})" - def reset(self): + def reset(self) -> None: if self.meta.is_tensor: super().reset() else: From 30eff02bcfa2afceb60fec6e472316b8dd2924e1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 21:27:38 +0200 Subject: [PATCH 250/455] Add missing variable --- pytorch_lightning/core/lightning.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 7736f3bd904e9..cb227a9dffa1d 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -332,6 +332,8 @@ def log( on_step = self.__auto_choose_log_on_step(on_step) on_epoch = self.__auto_choose_log_on_epoch(on_epoch) + result_collection = self.trainer.result_collection + assert result_collection is not None assert self._current_fx_name is not None result_collection.fx_validator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) @@ -359,8 +361,7 @@ def log( ) value = apply_to_collection(value, (torch.Tensor, numbers.Number), sync_fn) - assert self.trainer.result_collection is not None - self.trainer.result_collection.log( + result_collection.log( self._current_fx_name, name, value, From e4c7d2f7da5aa8fcde99a0e2a2d261941d47601b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 21:27:44 +0200 Subject: [PATCH 251/455] Refactor fwd --- .../connectors/logger_connector/result.py | 39 +++++-------------- 1 file changed, 9 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 4ca09abcdbf76..622e40823bf29 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -82,10 +82,10 @@ def __init__(self, metadata: Metadata) -> None: self.add_state("value", torch.tensor(0, dtype=torch.float)) if self.meta.is_mean_reduction: self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float)) - # TODO: self.value? + # TODO: self.value when not tensor? def update(self, value: _METRIC, batch_size: Optional[int] = None) -> None: - # TODO: support for non-tensor, sync returns tensor always + # TODO: support for non-tensor. sync returns tensor always if self.meta.is_tensor: if self.meta.is_mean_reduction: self.value += value.float().mean() * batch_size @@ -122,36 +122,15 @@ def reset(self) -> None: super().reset() else: self.value.reset() - self.meta.has_reset = True - def forward(self, *args, **kwargs): - """ - Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True. - """ - # todo (tchaton) Remove this override when merged to TorchMetrics. - # add current step - with torch.no_grad(): - self.update(*args, **kwargs) - - if self.compute_on_step: - self._to_sync = self.dist_sync_on_step - - # save context before switch - cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} - - # call reset, update, compute, on single batch - self.reset() - self.update(*args, **kwargs) - self._forward_cache = self.compute() - - # restore context - for attr, val in cache.items(): - setattr(self, attr, val) - self._to_sync = True - self._computed = None - - return self._forward_cache + def forward(self, value: _METRIC, *args, **kwargs) -> torch.Tensor: + """Overridden to avoid `self._forward_cache = None` after `update`""" + prev_fwd_cache = getattr(value, '_forward_cache', None) + out = super().forward(*args, **kwargs) + if out is None: + self._forward_cache = prev_fwd_cache + return out # placeholder for apply_to_collection From a9adba058bb4a44505f72b7900808c0f9bcc2958 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 31 May 2021 22:49:05 +0200 Subject: [PATCH 252/455] integrate #7772 --- pytorch_lightning/loops/training_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index bc9e9d925b14c..d5bfbf9055822 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -301,7 +301,8 @@ def _prepare_outputs( for tbptt_output in batch_outputs: out = tbptt_output.extra - out['loss'] = tbptt_output.minimize + if tbptt_output.minimize is not None: + out['loss'] = tbptt_output.minimize.detach() processed_tbptt_outputs.append(out) # if there was only one tbptt step then we can collapse that dimension From 3881b18a94da46740461cf73427092ed35e138d1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 00:42:59 +0200 Subject: [PATCH 253/455] Forgot to pass value --- pytorch_lightning/trainer/connectors/logger_connector/result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 622e40823bf29..faa1ec1b6b1ff 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -127,7 +127,7 @@ def reset(self) -> None: def forward(self, value: _METRIC, *args, **kwargs) -> torch.Tensor: """Overridden to avoid `self._forward_cache = None` after `update`""" prev_fwd_cache = getattr(value, '_forward_cache', None) - out = super().forward(*args, **kwargs) + out = super().forward(value, *args, **kwargs) if out is None: self._forward_cache = prev_fwd_cache return out From 041ec8e7ad1efb62065738422759d732c8c83c46 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 00:56:55 +0200 Subject: [PATCH 254/455] Docstrings --- .../connectors/logger_connector/result.py | 41 +++++++------------ 1 file changed, 14 insertions(+), 27 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index faa1ec1b6b1ff..a04a0b9fd165d 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -71,9 +71,7 @@ def is_min_reduction(self) -> bool: class ResultMetric(Metric, DeviceDtypeModuleMixin): - """ - This class is responsible to hold each single metric provided by ``LightningModule.log`` function. - """ + """Wraps the value provided to `:meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" def __init__(self, metadata: Metadata) -> None: super().__init__(compute_on_step=metadata.is_tensor) @@ -140,41 +138,30 @@ class ResultMeta(Dict): class ResultCollection(dict): """ - This class is used to capture all the logged values using LightningModule.log function. - - Here is how to use the ResultCollection object. + Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` Example: - # the root_device need to be provided before calling the ``log`` function + # `root_device` needs to be provided before logging result = ResultCollection(True, torch.device("cpu")) + # you can log to a specific collection. # arguments: hook_name, key, value, metadata - result.log('a0', 'a', torch.tensor(0.), on_step=True, on_epoch=True) - result.log('a1', 'a', torch.tensor(0.), on_step=True, on_epoch=True) - - for epoch in range(2): - - result.log('b0', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True) - result.log('b1', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True) - - for batch_idx, batch_size in enumerate(range(2)): + result.log('training_step', 'acc', torch.tensor(...), on_step=True, on_epoch=True) + result.log('validation_step', 'recall', torch.tensor(...), on_step=True, on_epoch=True) + for epoch in epochs: + for batch_idx, batch in enumerate(dataloader): # the batch_idx is used to reset the tensor metrics result.batch_idx = batch_idx + result.log('training_step', 'acc', torch.tensor(...), on_step=True, on_epoch=True) - result.log('c0', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) - result.log('c1', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) - - # used to indicate epoch end has been reached - result.on_epoch_end_reached = True - - result.log('d0', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) - result.log('d1', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) + result.on_epoch_end_reached = True # indicate epoch end has been reached + result.log('training_epoch_end', 'acc', torch.tensor(...), on_step=False, on_epoch=True) - # used to reset torchmetrics.Metric and set `on_epoch_end_reached` to False - result.reset_metrics() [Optional]: Reset only torchmetric.Metric object. - # result.reset() [Optional]: Reset the entire ResultCollection. + # Optionally: + result.reset_metrics() # reset the `torchmetrics.Metric` + result.reset() # reset the entire `ResultCollection` """ STEP_SUFFIX = "_step" From 1047836087e988defadc26bba3d5dc3b84d91f66 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 00:59:07 +0200 Subject: [PATCH 255/455] Revert if condition --- .../trainer/connectors/logger_connector/result.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index a04a0b9fd165d..f44ecd45b7110 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -493,9 +493,7 @@ def is_empty_fn(v): # populate callback metrics # callback metrics don't take `_step` forked metrics. - if not self.is_train and (not metric_on_epoch or on_step): - pass - else: + if self.is_train or metric_on_epoch and not on_step: metrics[MetricSource.CALLBACK][name] = value metrics[MetricSource.CALLBACK][name_forked] = value From b33f5dc5d9fc0ed6891b038e071febbc63a4bc6e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 01:00:23 +0200 Subject: [PATCH 256/455] is_train -> training --- .../trainer/connectors/logger_connector/result.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index f44ecd45b7110..3a5c5594d63dc 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -168,9 +168,9 @@ class ResultCollection(dict): EPOCH_SUFFIX = "_epoch" DATALOADER_SUFFIX = "/dataloader_idx_{}" - def __init__(self, is_train: bool, root_device: Optional[torch.device] = None) -> None: + def __init__(self, training: bool, root_device: Optional[torch.device] = None) -> None: super().__init__() - self.is_train = is_train + self.training = training self._on_epoch_end_reached = False self._minimize = None self._current_hook_name: Optional[str] = None @@ -493,7 +493,7 @@ def is_empty_fn(v): # populate callback metrics # callback metrics don't take `_step` forked metrics. - if self.is_train or metric_on_epoch and not on_step: + if self.training or metric_on_epoch and not on_step: metrics[MetricSource.CALLBACK][name] = value metrics[MetricSource.CALLBACK][name_forked] = value From 43777a6dddd79b0816d541ce1809c060fd70d9d2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 01:01:28 +0200 Subject: [PATCH 257/455] Unnecessary getter/setter --- .../trainer/connectors/logger_connector/result.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 3a5c5594d63dc..12e104748d900 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -142,7 +142,7 @@ class ResultCollection(dict): Example: - # `root_device` needs to be provided before logging + # `device` needs to be provided before logging result = ResultCollection(True, torch.device("cpu")) # you can log to a specific collection. @@ -176,7 +176,7 @@ def __init__(self, training: bool, root_device: Optional[torch.device] = None) - self._current_hook_name: Optional[str] = None self._batch_size: Optional[int] = None self._batch_idx: Optional[int] = None - self._root_device: Optional[torch.device] = root_device + self.root_device: Optional[torch.device] = root_device self.fx_validator = FxValidator() @property @@ -187,14 +187,6 @@ def batch_size(self) -> int: def batch_size(self, batch_size: int) -> None: self._batch_size = batch_size - @property - def root_device(self) -> Optional[torch.device]: - return self._root_device - - @root_device.setter - def root_device(self, root_device: torch.device) -> None: - self._root_device = root_device - @property def batch_idx(self) -> Optional[int]: return self._batch_idx From b7b963344477928ed7aa4bac50ee8fd5b874e094 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 01:02:19 +0200 Subject: [PATCH 258/455] Unnecessary getter/setter --- .../trainer/connectors/logger_connector/result.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 12e104748d900..9ed07de920757 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -175,7 +175,7 @@ def __init__(self, training: bool, root_device: Optional[torch.device] = None) - self._minimize = None self._current_hook_name: Optional[str] = None self._batch_size: Optional[int] = None - self._batch_idx: Optional[int] = None + self.batch_idx: Optional[int] = None self.root_device: Optional[torch.device] = root_device self.fx_validator = FxValidator() @@ -187,14 +187,6 @@ def batch_size(self) -> int: def batch_size(self, batch_size: int) -> None: self._batch_size = batch_size - @property - def batch_idx(self) -> Optional[int]: - return self._batch_idx - - @batch_idx.setter - def batch_idx(self, batch_idx: int) -> None: - self._batch_idx = batch_idx - @property def on_epoch_end_reached(self) -> bool: return self._on_epoch_end_reached @@ -202,7 +194,7 @@ def on_epoch_end_reached(self) -> bool: @on_epoch_end_reached.setter def on_epoch_end_reached(self, on_epoch_end_reached): self._on_epoch_end_reached = on_epoch_end_reached - self._batch_idx = None + self.batch_idx = None @property def metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: From 5768b558c21944ca4b4c13e04fb1c90a4ff97ff4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 01:05:13 +0200 Subject: [PATCH 259/455] Linter --- pytorch_lightning/trainer/connectors/logger_connector/result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 9ed07de920757..e2fbc6c0801bf 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -359,7 +359,7 @@ def fn(v: Union[torch.Tensor, Metric]) -> ResultMetric: def should_reset_tensors(self, hook_name: str) -> bool: # reset tensor metrics only when hook_name changed and starting a new iteration over dataloader. - return (self._current_hook_name != hook_name and self._batch_idx in (None, 0)) + return self._current_hook_name != hook_name and self.batch_idx in (None, 0) def update_metrics( self, hook_name: str, key: str, value: Union[Dict, torch.Tensor], batch_size: torch.Tensor From 08fac3d4c9e4bfe8e16745b5bca042851e79c330 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 01:06:48 +0200 Subject: [PATCH 260/455] root_device -> device --- .../logger_connector/logger_connector.py | 10 ++++------ .../connectors/logger_connector/result.py | 16 ++++++++-------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 0fad69a8214c7..4dd7ee9c48165 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -192,9 +192,8 @@ def increment_evaluation_log_step(self) -> None: elif self.trainer.state.stage is RunningStage.TESTING: self._test_log_step += 1 - def on_evaluation_start(self): - root_device = self.trainer.lightning_module.device - self.trainer.result_collection.root_device = root_device + def on_evaluation_start(self) -> None: + self.trainer.result_collection.device = self.trainer.lightning_module.device def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: model = self.trainer.lightning_module @@ -228,9 +227,8 @@ def update_evaluation_step_metrics(self) -> None: Train metric updates """ - def on_train_start(self): - root_device = self.trainer.lightning_module.device - self.trainer.result_collection.root_device = root_device + def on_train_start(self) -> None: + self.trainer.result_collection.device = self.trainer.lightning_module.device_device def on_train_split_start(self, batch_idx: int, split_batch: Any) -> None: self.trainer.result_collection.extract_batch_size(split_batch) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index e2fbc6c0801bf..216c1090f25eb 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -168,7 +168,7 @@ class ResultCollection(dict): EPOCH_SUFFIX = "_epoch" DATALOADER_SUFFIX = "/dataloader_idx_{}" - def __init__(self, training: bool, root_device: Optional[torch.device] = None) -> None: + def __init__(self, training: bool, device: Optional[torch.device] = None) -> None: super().__init__() self.training = training self._on_epoch_end_reached = False @@ -176,7 +176,7 @@ def __init__(self, training: bool, root_device: Optional[torch.device] = None) - self._current_hook_name: Optional[str] = None self._batch_size: Optional[int] = None self.batch_idx: Optional[int] = None - self.root_device: Optional[torch.device] = root_device + self.device: Optional[torch.device] = device self.fx_validator = FxValidator() @property @@ -325,7 +325,7 @@ def log( self.instance_result_metric(key, meta, value) # compute batch_size - batch_size = torch.tensor(batch_size or self.batch_size, device=self.root_device) + batch_size = torch.tensor(batch_size or self.batch_size, device=self.device) # update the ResultMetric self.update_metrics(hook_name, key, value, batch_size) @@ -338,12 +338,12 @@ def instance_result_metric(self, key: str, meta: Metadata, value: Union[Dict, to def fn(v: Union[torch.Tensor, Metric]) -> ResultMetric: # This local function is used to `ResultMetric`. # The `Metadata` is_tensor is modified on the fly - assert self.root_device is not None + assert self.device is not None nonlocal meta meta = deepcopy(meta) meta.is_tensor = torch.is_tensor(v) metric = ResultMetric(meta) - return metric.to(self.root_device) + return metric.to(self.device) # store a mapping between storage key and collection of `ResultMetric` self[key] = apply_to_collection(value, (torch.Tensor, Metric), fn) @@ -372,7 +372,7 @@ def update_metrics( # this function is used to call the forward function of ResultMetric object. def fn(result_metric, v): assert isinstance(v, (torch.Tensor, Metric)) - result_metric(v.to(self.root_device), batch_size.to(self.root_device)) + result_metric(v.to(self.device), batch_size.to(self.device)) result_metric.meta.has_reset = False apply_to_collections(self[key], value, ResultMetric, fn) @@ -599,8 +599,8 @@ def to_result_metric(item: ResultMeta) -> Dict[str, Any]: result_metric = ResultMetric(item["meta"]) # update its state result_metric.__dict__.update(item) - # move result_metric to root_device - return result_metric.to(self.root_device) + # move result_metric to device + return result_metric.to(self.device) # transform ResultMeta into ResultMetric state_dict = {k: apply_to_collection(v, ResultMeta, to_result_metric) for k, v in state_dict.items()} From 78861517d109c147b54661a06db82c81156b2e93 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 01:14:08 +0200 Subject: [PATCH 261/455] Docstrings --- .../connectors/logger_connector/result.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 216c1090f25eb..4c75d64381ccd 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -198,28 +198,19 @@ def on_epoch_end_reached(self, on_epoch_end_reached): @property def metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: - """ - This function returns either batch or epoch metrics depending on `on_epoch_end_reached` attribute. - The metrics are returned as: - - - { - MetricSource.PBAR: {...}, - MetricSource.LOG: {...}, - MetricSource.CALLBACK: {...} - } - """ + """This function returns either batch or epoch metrics depending on ``on_epoch_end_reached``.""" return self.get_epoch_metrics() if self.on_epoch_end_reached else self.get_batch_metrics() @property def minimize(self) -> Optional[Tensor]: + """ + The :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` loss + will be saved as the ``minimize`` attribute. + """ return self._minimize @minimize.setter def minimize(self, loss: Optional[torch.Tensor]) -> None: - """ - The `LightningModule.training_step` loss will be saved as the ResultCollection minimize attribute. - """ if loss is not None: if not isinstance(loss, Tensor): raise ValueError(f"`Result.minimize` must be a `torch.Tensor`, found: {loss}") From be94ec995ae61069467b39cbc5e5bf484f5748b8 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 01:45:06 +0200 Subject: [PATCH 262/455] Better exceptions --- pytorch_lightning/core/lightning.py | 16 +++++- .../logger_connector/logger_connector.py | 2 +- .../connectors/logger_connector/result.py | 51 +++++-------------- .../logging_/test_train_loop_logging.py | 42 +++++++++++++++ 4 files changed, 72 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index cb227a9dffa1d..1b85cf4028383 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -339,7 +339,10 @@ def log( # make sure user doesn't introduce logic for multi-dataloaders if "/dataloader_idx_" in name: - raise MisconfigurationException(f"Logged key: {name} should not contain information about dataloader_idx.") + raise MisconfigurationException( + f"You called `self.log` with the key `{name}`" + " but it should not contain information about `dataloader_idx`" + ) if metric_attribute is None and isinstance(value, Metric): if self._metric_attributes is None: @@ -348,8 +351,19 @@ def log( id(module): name for name, module in self.named_children() if isinstance(module, Metric) } + if not self._metric_attributes: + raise MisconfigurationException( + "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged." + " You can fix this by setting an attribute for the metric in your `LightningModule`." + ) # try to find the passed metric in the LightningModule metric_attribute = self._metric_attributes.get(id(value)) + if metric_attribute is None: + raise MisconfigurationException( + "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged." + f" You can fix this by calling `self.log({name}, ..., metric_attribute=name)` where `name` is one" + f" of {list(self._metric_attributes.values())}" + ) sync_fn = partial( self.__sync, diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 4dd7ee9c48165..68fd859f6d10d 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -228,7 +228,7 @@ def update_evaluation_step_metrics(self) -> None: """ def on_train_start(self) -> None: - self.trainer.result_collection.device = self.trainer.lightning_module.device_device + self.trainer.result_collection.device = self.trainer.lightning_module.device def on_train_split_start(self, batch_idx: int, split_batch: Any) -> None: self.trainer.result_collection.extract_batch_size(split_batch) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 4c75d64381ccd..4ccb4b9914fc8 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -80,10 +80,10 @@ def __init__(self, metadata: Metadata) -> None: self.add_state("value", torch.tensor(0, dtype=torch.float)) if self.meta.is_mean_reduction: self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float)) - # TODO: self.value when not tensor? + # FIXME: self.value when not tensor? def update(self, value: _METRIC, batch_size: Optional[int] = None) -> None: - # TODO: support for non-tensor. sync returns tensor always + # FIXME: support for non-tensor. sync returns tensor always if self.meta.is_tensor: if self.meta.is_mean_reduction: self.value += value.float().mean() * batch_size @@ -164,6 +164,7 @@ class ResultCollection(dict): result.reset() # reset the entire `ResultCollection` """ + # FIXME STEP_SUFFIX = "_step" EPOCH_SUFFIX = "_epoch" DATALOADER_SUFFIX = "/dataloader_idx_{}" @@ -181,6 +182,7 @@ def __init__(self, training: bool, device: Optional[torch.device] = None) -> Non @property def batch_size(self) -> int: + # FIXME return self._batch_size or 1 @batch_size.setter @@ -220,14 +222,15 @@ def minimize(self, loss: Optional[torch.Tensor]) -> None: @property def extra(self) -> Dict: + """ + Extras are any keys other than the loss returned by + :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` + """ return self.get('extra', {}) @extra.setter def extra(self, extra: Dict) -> None: - """ - The `LightningModule.training_step` extras will be saved as the ResultCollection extra key. - """ - + # FIXME: Should probably fail instead of detaching def detach_fn(v): return v.detach() @@ -247,31 +250,9 @@ def log( enable_graph: bool = False, dataloader_idx: Optional[int] = None, batch_size: Optional[int] = None, - lightning_attribute_name: Optional[str] = None, + lightning_attribute_name: str = None, ): - """ - This function is used to log metrics from with - :meth:`~pytorch_lightning.core.lightning.LightningModule.log` - - Args: - - hook_name: Current hook name - name: Key provided by the user on logging - value: Either a number, tensor or a collection of the previous. - prog_bar: Whether to add this value to the progress bar. - logger: Whether to log this value to the loggers - on_step: Whether to use this value during batch iteration. - on_epoch: Whether to use this value at the end of the batch iteration. - Automatic reduction will be performed. - reduce_fx: Which function to use for reduction. Currently support min, max and mean. - enable_graph: Whether to keep autograd graph when storing the value. - dataloader_idx: The current dataloader idx. This will be used to automatically - add `/dataloader_idx_{}` on the metrics. - batch_size: Current batch size. - lightning_attribute_name: When providing `nn.Metric` as a value, the ``metric_attribute`` - need to be provided to enable automatic saving / re-loading. - - """ + """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" # no metrics should be logged with graphs if not enable_graph and isinstance(value, torch.Tensor): value = value.detach() @@ -280,14 +261,13 @@ def log( if isinstance(value, torch.Tensor) and value.device.type == "xla": value = value.cpu() - if isinstance(value, Metric) and lightning_attribute_name is None: - raise MisconfigurationException( - "The LightningModule attribute name should be provided when using torchmetrics.Metric" + if on_step and self.on_epoch_end_reached: + raise RuntimeError( + "Logging `on_step` when `on_epoch_end_reached` isn't allowed. This shouldn't have happened." ) # storage key key = f"{hook_name}.{name}" - # add dataloader_suffix to both key and hook_name if dataloader_idx is not None: # use as ResultCollection key @@ -295,9 +275,6 @@ def log( # used to decide when to reset hook_name += f'.{dataloader_idx}' - if on_step and self.on_epoch_end_reached: - raise MisconfigurationException("Logging `on_step` after `on_epoch_end_reached` isn't authorized.") - if key not in self: # create metadata object if storage key doesn't exist in self meta = Metadata( diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index d4ba81ce656e6..28150b9203097 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -25,6 +25,8 @@ import pytorch_lightning as pl from pytorch_lightning import callbacks, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.metrics import Accuracy, Metric +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDictDataset from tests.helpers.runif import RunIf @@ -755,3 +757,43 @@ def validation_step(self, batch, batch_idx): assert trainer.callback_metrics["val_acc"] == 8 / 32. assert "train_loss" in trainer.callback_metrics + + +def test_logging_raises(tmpdir): + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + self.log('foo/dataloader_idx_0', -1) + + trainer = Trainer(default_root_dir=tmpdir) + model = TestModel() + with pytest.raises(MisconfigurationException, match='`self.log` with the key `foo/dataloader_idx_0`'): + trainer.fit(model) + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + self.log('foo', Accuracy()) + + trainer = Trainer(default_root_dir=tmpdir) + model = TestModel() + with pytest.raises(MisconfigurationException, match='fix this by setting an attribute for the metric in your'): + trainer.fit(model) + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.bar = Accuracy() + + def training_step(self, batch, batch_idx): + self.log('foo', Accuracy()) + + trainer = Trainer(default_root_dir=tmpdir) + model = TestModel() + with pytest.raises( + MisconfigurationException, + match=r"`self.log\(foo, ..., metric_attribute=name\)` where `name` is one of \['bar'\]" + ): + trainer.fit(model) From 76eb0ee49a6fe041100d696c474ba9af914d5882 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 01:59:55 +0200 Subject: [PATCH 263/455] flake8 and mypy --- pytorch_lightning/callbacks/pruning.py | 4 ++-- .../trainer/connectors/logger_connector/result.py | 2 +- pytorch_lightning/trainer/evaluation_loop.py | 2 ++ tests/trainer/logging_/test_train_loop_logging.py | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index e7da752d1c844..1d4510f4c4f14 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -259,13 +259,13 @@ def _create_pruning_fn(self, pruning_fn: str, **kwargs: Any) -> Union[Callable, def _wrap_pruning_fn(pruning_fn: Callable, **kwargs: Any) -> Callable: return partial(pruning_fn, **kwargs) - def make_pruning_permanent(self, pl_module: LightningModule) -> None: + def make_pruning_permanent(self, module: nn.Module) -> None: """ Removes pruning buffers from any pruned modules Adapted from https://github.com/pytorch/pytorch/blob/1.7.1/torch/nn/utils/prune.py#L1176-L1180 """ - for _, module in pl_module.named_modules(): + for _, module in module.named_modules(): for k in list(module._forward_pre_hooks): hook = module._forward_pre_hooks[k] if isinstance(hook, pytorch_prune.BasePruningMethod): diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 4ccb4b9914fc8..ef913991609c3 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -268,7 +268,7 @@ def log( # storage key key = f"{hook_name}.{name}" - # add dataloader_suffix to both key and hook_name + # add dataloader_suffix to both key and hook_name if dataloader_idx is not None: # use as ResultCollection key key += f'.{dataloader_idx}' diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 078776602d3a9..dfe0d8daa40d5 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -102,6 +102,7 @@ def on_evaluation_model_train(self) -> None: model_ref.on_validation_model_train() def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: + assert self.trainer.result_collection is not None self.trainer.result_collection.reset_metrics() if self.trainer.testing: @@ -213,6 +214,7 @@ def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: # set dataloader_idx to model and track batch_size + assert self.num_dataloaders is not None self.trainer.logger_connector.on_evaluation_batch_start(batch, batch_idx, dataloader_idx, self.num_dataloaders) if self.trainer.testing: diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 28150b9203097..3eb01f366569b 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -25,7 +25,7 @@ import pytorch_lightning as pl from pytorch_lightning import callbacks, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.metrics import Accuracy, Metric +from pytorch_lightning.metrics import Accuracy from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDictDataset from tests.helpers.runif import RunIf From ae932403021a38f7eb3389fda18db5c45dcae244 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 02:07:06 +0200 Subject: [PATCH 264/455] Clean up batch size? --- .../logger_connector/logger_connector.py | 4 ---- .../trainer/connectors/logger_connector/result.py | 15 +++------------ pytorch_lightning/trainer/training_loop.py | 3 --- 3 files changed, 3 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 68fd859f6d10d..e6851057b6aea 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -234,10 +234,6 @@ def on_train_split_start(self, batch_idx: int, split_batch: Any) -> None: self.trainer.result_collection.extract_batch_size(split_batch) self.trainer.result_collection.batch_idx = batch_idx - def on_train_batch_end(self) -> None: - # TODO: why - self.trainer.result_collection.batch_size = 1 - def update_train_step_metrics(self, batch_output): metrics = self.trainer.result_collection.metrics diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index ef913991609c3..849d075cc52b5 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -175,20 +175,11 @@ def __init__(self, training: bool, device: Optional[torch.device] = None) -> Non self._on_epoch_end_reached = False self._minimize = None self._current_hook_name: Optional[str] = None - self._batch_size: Optional[int] = None + self.batch_size: int = 1 self.batch_idx: Optional[int] = None self.device: Optional[torch.device] = device self.fx_validator = FxValidator() - @property - def batch_size(self) -> int: - # FIXME - return self._batch_size or 1 - - @batch_size.setter - def batch_size(self, batch_size: int) -> None: - self._batch_size = batch_size - @property def on_epoch_end_reached(self) -> bool: return self._on_epoch_end_reached @@ -514,9 +505,9 @@ def reset(self): def extract_batch_size(self, batch: Any) -> None: try: - self._batch_size = self._extract_batch_size(batch) + self.batch_size = self._extract_batch_size(batch) except RecursionError: - self._batch_size = 1 + self.batch_size = 1 def _extract_batch_size(self, batch: Any) -> int: """ diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 641ee14ca1de5..e300c08c61ac8 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -184,9 +184,6 @@ def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, # figure out what to track for epoch end self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs) - # reset batch logger internals - self.trainer.logger_connector.on_train_batch_end() - def reset_train_val_dataloaders(self, model) -> None: """ Resets train and val dataloaders if none are attached to the trainer. From 8fbbcff6dc2c84a51e2abc35449270a3e32dff1f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 02:20:02 +0200 Subject: [PATCH 265/455] Replace repr with str --- .../trainer/connectors/logger_connector/result.py | 8 ++------ tests/core/test_metric_result_integration.py | 8 ++++++++ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 849d075cc52b5..f88c0ca1744cc 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -530,12 +530,8 @@ def _extract_batch_size(self, batch: Any) -> int: size = 1 return size - def __repr__(self) -> str: - repr = f'{self.__class__.__name__}' + '{\n' - for k in sorted(self.keys()): - v = self[k] - repr += f" {k}: {v},\n" - return repr[:-1] + '\n}' + def __str__(self) -> str: + return f'{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})' def state_dict(self): diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 728da426ba9ff..ccfb2756df1d7 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -160,6 +160,14 @@ def test_result_metric_integration(): for k in epoch_expected.keys(): assert epoch_expected[k] == epoch_log[k] + assert str(result) == ( + "ResultCollection(True, cpu, {" + "'h.a': ResultMetric(value=DummyMetric()), " + "'h.b': ResultMetric(value=DummyMetric()), " + "'h.c': ResultMetric(value=DummyMetric())" + "})" + ) + def test_result_collection_restoration(): From df0ff74922ecfd189c8129aad53f8b31f66afa78 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 02:52:09 +0200 Subject: [PATCH 266/455] Refactor reset --- .../logger_connector/logger_connector.py | 2 +- .../connectors/logger_connector/result.py | 40 ++++++++----------- pytorch_lightning/trainer/evaluation_loop.py | 2 +- 3 files changed, 19 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index e6851057b6aea..ef63870e4b9e1 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -273,7 +273,7 @@ def update_train_epoch_metrics(self) -> None: self.log_metrics(epoch_log_metrics, {}) # reset result collection for next epoch - self.trainer.result_collection.reset_metrics() + self.trainer.result_collection.reset(metrics=True) """ Utilities and properties diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index f88c0ca1744cc..ac7074aaa7ec0 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -326,7 +326,7 @@ def update_metrics( if self.should_reset_tensors(hook_name): # when restarting an new epoch, reset the tensor hooks dynamically. - self._reset_metrics(hook_name, is_tensor=True) + self._reset(hook_name, metrics=False) # this function is used to call the forward function of ResultMetric object. def fn(result_metric, v): @@ -362,7 +362,7 @@ def valid_metrics(self) -> Generator: elif isinstance(item, ResultMetric) and item.meta.has_reset: continue - yield (key, item) + yield key, item def _extract_metadata(self, key: str, result_metric, on_step: bool, suffix: str) -> Tuple: """ @@ -477,29 +477,24 @@ def cpu(self) -> 'ResultCollection': """Move all data to CPU.""" return self.to(device="cpu") - def _reset_metrics(self, hook_name: str = None, is_tensor: Optional[bool] = None) -> None: - """Call at the end of epoch to reset all results provided as `Metric` or `tensor`""" + def _reset(self, fx: Optional[str] = None, metrics: bool = True) -> None: - def reset_fn(item: ResultMetric) -> None: - nonlocal hook_name - nonlocal is_tensor - if is_tensor is None or item.meta.is_tensor == is_tensor: - if isinstance(hook_name, str) and hook_name != item.meta.fx: - return + def fn(item: ResultMetric) -> None: + requested_type = metrics ^ item.meta.is_tensor # logical xor + same_fx = fx is None or fx == item.meta.fx + if requested_type and same_fx: item.reset() - apply_to_collection(dict(self.items()), ResultMetric, reset_fn) + apply_to_collection(self, ResultMetric, fn) - def reset_metrics(self): - self._reset_metrics(is_tensor=False) - self.on_epoch_end_reached = False - self._current_hook_name = None - - def reset(self): + def reset(self, metrics: bool = False) -> None: """ - This function is used to reset entirely the ResultCollection + Reset the result collection + + Args: + metrics: Whether to only reset the `torchmetrics.Metric` results """ - self._reset_metrics() + self._reset(metrics=metrics) self.on_epoch_end_reached = False self._current_hook_name = None @@ -530,9 +525,6 @@ def _extract_batch_size(self, batch: Any) -> int: size = 1 return size - def __str__(self) -> str: - return f'{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})' - def state_dict(self): def get_state_dict(item: ResultMetric) -> Dict[str, Any]: @@ -565,7 +557,6 @@ def to_result_metric(item: ResultMeta) -> Dict[str, Any]: self[k] = v if metrics: - # the metric reference are lost during serialization and # they need to be set back during loading @@ -577,6 +568,9 @@ def re_assign_metric(item): apply_to_collection(dict(self.items()), ResultMetric, re_assign_metric) + def __str__(self) -> str: + return f'{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})' + def __getstate__(self) -> dict: d = self.__dict__.copy() # can't deepcopy tensors with grad_fn diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index dfe0d8daa40d5..dba920167c6a0 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -103,7 +103,7 @@ def on_evaluation_model_train(self) -> None: def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: assert self.trainer.result_collection is not None - self.trainer.result_collection.reset_metrics() + self.trainer.result_collection.reset(metrics=True) if self.trainer.testing: self.trainer.call_hook('on_test_end', *args, **kwargs) From 20bcf0ccd336fc3737074a86a6c22bf338971b1c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 02:53:21 +0200 Subject: [PATCH 267/455] Docstring --- .../trainer/connectors/logger_connector/result.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index ac7074aaa7ec0..81748dd900a24 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -157,11 +157,7 @@ class ResultCollection(dict): result.log('training_step', 'acc', torch.tensor(...), on_step=True, on_epoch=True) result.on_epoch_end_reached = True # indicate epoch end has been reached - result.log('training_epoch_end', 'acc', torch.tensor(...), on_step=False, on_epoch=True) - - # Optionally: - result.reset_metrics() # reset the `torchmetrics.Metric` - result.reset() # reset the entire `ResultCollection` + result.log('training_epoch_end', 'acc', torch.tensor(...), on_step=False, on_epoch=True)` """ # FIXME From 6da9d8fa6ec3721035197c90095c9a8505d43d44 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 02:58:35 +0200 Subject: [PATCH 268/455] hook_name -> fx --- .../connectors/logger_connector/result.py | 37 +++++++++---------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 81748dd900a24..b5217c9f153d4 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -146,7 +146,7 @@ class ResultCollection(dict): result = ResultCollection(True, torch.device("cpu")) # you can log to a specific collection. - # arguments: hook_name, key, value, metadata + # arguments: fx, key, value, metadata result.log('training_step', 'acc', torch.tensor(...), on_step=True, on_epoch=True) result.log('validation_step', 'recall', torch.tensor(...), on_step=True, on_epoch=True) @@ -170,7 +170,7 @@ def __init__(self, training: bool, device: Optional[torch.device] = None) -> Non self.training = training self._on_epoch_end_reached = False self._minimize = None - self._current_hook_name: Optional[str] = None + self._current_fx: Optional[str] = None self.batch_size: int = 1 self.batch_idx: Optional[int] = None self.device: Optional[torch.device] = device @@ -226,7 +226,7 @@ def detach_fn(v): def log( self, - hook_name: str, + fx: str, name: str, value: Any, prog_bar: bool = False, @@ -254,18 +254,18 @@ def log( ) # storage key - key = f"{hook_name}.{name}" - # add dataloader_suffix to both key and hook_name + key = f"{fx}.{name}" + # add dataloader_suffix to both key and fx if dataloader_idx is not None: # use as ResultCollection key key += f'.{dataloader_idx}' # used to decide when to reset - hook_name += f'.{dataloader_idx}' + fx += f'.{dataloader_idx}' if key not in self: # create metadata object if storage key doesn't exist in self meta = Metadata( - fx=hook_name, + fx=fx, name=name, prog_bar=prog_bar, logger=logger, @@ -283,10 +283,10 @@ def log( batch_size = torch.tensor(batch_size or self.batch_size, device=self.device) # update the ResultMetric - self.update_metrics(hook_name, key, value, batch_size) + self.update_metrics(fx, key, value, batch_size) # save current_hook to know when to reset. - self._current_hook_name = hook_name + self._current_fx = fx def instance_result_metric(self, key: str, meta: Metadata, value: Union[Dict, torch.Tensor]) -> None: @@ -312,17 +312,14 @@ def fn(v: Union[torch.Tensor, Metric]) -> ResultMetric: self[key + '.on_epoch'] = meta.on_epoch self[key + '.dataloader_idx'] = meta.dataloader_idx - def should_reset_tensors(self, hook_name: str) -> bool: - # reset tensor metrics only when hook_name changed and starting a new iteration over dataloader. - return self._current_hook_name != hook_name and self.batch_idx in (None, 0) + def should_reset_tensors(self, fx: str) -> bool: + # reset tensor metrics only when the hook changed and reloading the dataloader + return self._current_fx != fx and self.batch_idx in (None, 0) - def update_metrics( - self, hook_name: str, key: str, value: Union[Dict, torch.Tensor], batch_size: torch.Tensor - ) -> None: - - if self.should_reset_tensors(hook_name): - # when restarting an new epoch, reset the tensor hooks dynamically. - self._reset(hook_name, metrics=False) + def update_metrics(self, fx: str, key: str, value: Union[Dict, torch.Tensor], batch_size: torch.Tensor) -> None: + if self.should_reset_tensors(fx): + # when restarting an new epoch, reset the tensors + self._reset(fx, metrics=False) # this function is used to call the forward function of ResultMetric object. def fn(result_metric, v): @@ -492,7 +489,7 @@ def reset(self, metrics: bool = False) -> None: """ self._reset(metrics=metrics) self.on_epoch_end_reached = False - self._current_hook_name = None + self._current_fx = None def extract_batch_size(self, batch: Any) -> None: try: From f10084e05fc7933a5038204297d2001b23bbb0ab Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 13:35:20 +0200 Subject: [PATCH 269/455] Fix reset --- .../trainer/connectors/logger_connector/result.py | 10 ++++++---- tests/core/test_metric_result_integration.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index b5217c9f153d4..292dfc74aaca8 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -470,22 +470,24 @@ def cpu(self) -> 'ResultCollection': """Move all data to CPU.""" return self.to(device="cpu") - def _reset(self, fx: Optional[str] = None, metrics: bool = True) -> None: + def _reset(self, fx: Optional[str] = None, metrics: Optional[bool] = None) -> None: def fn(item: ResultMetric) -> None: - requested_type = metrics ^ item.meta.is_tensor # logical xor + requested_type = metrics is None or metrics ^ item.meta.is_tensor same_fx = fx is None or fx == item.meta.fx if requested_type and same_fx: item.reset() apply_to_collection(self, ResultMetric, fn) - def reset(self, metrics: bool = False) -> None: + def reset(self, metrics: Optional[bool] = None) -> None: """ Reset the result collection Args: - metrics: Whether to only reset the `torchmetrics.Metric` results + metrics: If True, only ``torchmetrics.Metric`` results are reset, + if False, only ``torch.Tensors`` are reset, + if ``None``, both are. """ self._reset(metrics=metrics) self.on_epoch_end_reached = False diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index ccfb2756df1d7..8cdf410ee4019 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -251,7 +251,7 @@ def test_result_collection_restoration(): result.reset() # assert metric state reset to default values - assert metric_a.x == metric_a._defaults['x'], (metric_a.x, metric_a._defaults['x']) + assert metric_a.x == metric_a._defaults['x'] assert metric_b.x == metric_b._defaults['x'] assert metric_c.x == metric_c._defaults['x'] From 13785a0f85b3c7d985b1d8bd96eef8af326d3dab Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 13:35:36 +0200 Subject: [PATCH 270/455] Refactor --- .../connectors/logger_connector/result.py | 58 ++++++------------- 1 file changed, 18 insertions(+), 40 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 292dfc74aaca8..6becb140f1ae8 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -279,13 +279,10 @@ def log( # value can be provided as a nested collection. self.instance_result_metric(key, meta, value) - # compute batch_size - batch_size = torch.tensor(batch_size or self.batch_size, device=self.device) - - # update the ResultMetric - self.update_metrics(fx, key, value, batch_size) - - # save current_hook to know when to reset. + if self.should_reset_tensors(fx): + # when restarting an new epoch, reset the tensors + self._reset(fx, metrics=False) + self.update_metrics(key, value, batch_size) self._current_fx = fx def instance_result_metric(self, key: str, meta: Metadata, value: Union[Dict, torch.Tensor]) -> None: @@ -316,46 +313,32 @@ def should_reset_tensors(self, fx: str) -> bool: # reset tensor metrics only when the hook changed and reloading the dataloader return self._current_fx != fx and self.batch_idx in (None, 0) - def update_metrics(self, fx: str, key: str, value: Union[Dict, torch.Tensor], batch_size: torch.Tensor) -> None: - if self.should_reset_tensors(fx): - # when restarting an new epoch, reset the tensors - self._reset(fx, metrics=False) + def update_metrics(self, key: str, value: Union[Dict, torch.Tensor], batch_size: Optional[int]) -> None: + batch_size = torch.tensor(batch_size or self.batch_size, device=self.device) - # this function is used to call the forward function of ResultMetric object. def fn(result_metric, v): - assert isinstance(v, (torch.Tensor, Metric)) - result_metric(v.to(self.device), batch_size.to(self.device)) + # call the forward function of ResultMetric + result_metric(v.to(self.device), batch_size) result_metric.meta.has_reset = False apply_to_collections(self[key], value, ResultMetric, fn) @staticmethod def _get_forward_cache(result_metric: ResultMetric) -> Optional[torch.Tensor]: - # skip if meta `on_step` is False if not result_metric.meta.on_step: return - - # extract `ResultMetric` forward cache return result_metric._forward_cache.detach() @staticmethod def _to_item(t: torch.Tensor) -> float: return t.item() - def valid_metrics(self) -> Generator: - """ - This function is used to iterate over current valid metrics. - """ - for key, item in self.items(): - # skip when item is None, bool or extra arguments from training_step. - if item is None or isinstance(item, bool) or key == "extra": - continue - - # skip when the metrics hasn't been updated. - elif isinstance(item, ResultMetric) and item.meta.has_reset: - continue - - yield key, item + def valid_items(self) -> Generator: + """This function is used to iterate over current valid metrics.""" + return ((k, v) for k, v in self.items() if ( + v is not None and not isinstance(v, bool) and not k == "extra" + and not (isinstance(v, ResultMetric) and v.meta.has_reset) + )) def _extract_metadata(self, key: str, result_metric, on_step: bool, suffix: str) -> Tuple: """ @@ -395,7 +378,7 @@ def get_metrics(self, on_step: bool) -> Dict[str, Dict[str, torch.Tensor]]: suffix = self.STEP_SUFFIX if on_step else self.EPOCH_SUFFIX # iterate over all stored metrics. - for key, result_metric in self.valid_metrics(): + for key, result_metric in self.valid_items(): # extract forward_cache or computed from the ResultMetric # ignore when the output of fn is None @@ -443,22 +426,17 @@ def is_empty_fn(v): def get_batch_metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: return self.get_metrics(on_step=True) + def get_epoch_metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: + return self.get_metrics(on_step=False) + @staticmethod def _get_computed_cache(result_metric: ResultMetric) -> Optional[torch.Tensor]: - # skip if meta.on_epoch is False if not result_metric.meta.on_epoch: return - - # perform reduction is not done alrady if not result_metric._computed: result_metric.compute() - - # extract computed from ResultMetric. return result_metric._computed.detach() - def get_epoch_metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: - return self.get_metrics(on_step=False) - def to(self, *args, **kwargs) -> 'ResultCollection': """Move all data to the given device.""" for k, v in self.items(): From cce70232977febd65349b94dc5b8c940ec1cf050 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 14:03:48 +0200 Subject: [PATCH 271/455] Serialization updates --- .../connectors/logger_connector/result.py | 55 ++++++++----------- 1 file changed, 22 insertions(+), 33 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 6becb140f1ae8..1a320d950a146 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -131,9 +131,11 @@ def forward(self, value: _METRIC, *args, **kwargs) -> torch.Tensor: return out -# placeholder for apply_to_collection -class ResultMeta(Dict): - pass +class _SerializationHelper(dict): + """ + Since ``ResultCollection`` can hold ``ResultMetric`` values or dictionaries of them, we need + a class to differentiate between the cases after converting to state dict when saving its state. + """ class ResultCollection(dict): @@ -237,7 +239,7 @@ def log( enable_graph: bool = False, dataloader_idx: Optional[int] = None, batch_size: Optional[int] = None, - lightning_attribute_name: str = None, + lightning_attribute_name: Optional[str] = None, ): """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" # no metrics should be logged with graphs @@ -325,9 +327,8 @@ def fn(result_metric, v): @staticmethod def _get_forward_cache(result_metric: ResultMetric) -> Optional[torch.Tensor]: - if not result_metric.meta.on_step: - return - return result_metric._forward_cache.detach() + if result_metric.meta.on_step: + return result_metric._forward_cache.detach() @staticmethod def _to_item(t: torch.Tensor) -> float: @@ -500,46 +501,34 @@ def _extract_batch_size(self, batch: Any) -> int: def state_dict(self): - def get_state_dict(item: ResultMetric) -> Dict[str, Any]: - state = item.__getstate__() + def to_state_dict(item: ResultMetric) -> _SerializationHelper: + state = deepcopy(item.__getstate__()) # delete reference to TorchMetrics Metric - state = deepcopy(state) - if 'value' in state['_modules'] and isinstance(state['_modules']["value"], Metric): - del state['_modules']["value"] + state['_modules'].pop('value', None) + return _SerializationHelper(**state) - # ResultMeta is used as a placeholder for making re-loading simpler - return ResultMeta(**state) + return {k: apply_to_collection(v, ResultMetric, to_state_dict) for k, v in self.items()} - return {k: apply_to_collection(v, ResultMetric, get_state_dict) for k, v in self.items()} + def load_from_state_dict(self, state_dict: Dict[str, Any], metrics: Optional[Dict[str, Metric]] = None) -> None: - def load_from_state_dict(self, state_dict: Dict[str, Any], metrics: Dict[str, Metric]): - - def to_result_metric(item: ResultMeta) -> Dict[str, Any]: - # create a new ResultMetric + def to_result_metric(item: _SerializationHelper) -> ResultMetric: result_metric = ResultMetric(item["meta"]) - # update its state result_metric.__dict__.update(item) - # move result_metric to device return result_metric.to(self.device) - # transform ResultMeta into ResultMetric - state_dict = {k: apply_to_collection(v, ResultMeta, to_result_metric) for k, v in state_dict.items()} - - # add the state_dict as new key-value into self + state_dict = {k: apply_to_collection(v, _SerializationHelper, to_result_metric) for k, v in state_dict.items()} for k, v in state_dict.items(): self[k] = v if metrics: - # the metric reference are lost during serialization and - # they need to be set back during loading - def re_assign_metric(item): - nonlocal metrics - lightning_attribute_name = item.meta.lightning_attribute_name - if isinstance(lightning_attribute_name, str) and lightning_attribute_name in metrics: - item.value = metrics[lightning_attribute_name] + def re_assign_metric(item: ResultMetric) -> None: + # metric references are lost during serialization and need to be set back during loading + name = item.meta.lightning_attribute_name + if isinstance(name, str) and name in metrics: + item.value = metrics[name] - apply_to_collection(dict(self.items()), ResultMetric, re_assign_metric) + apply_to_collection(self, ResultMetric, re_assign_metric) def __str__(self) -> str: return f'{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})' From f51a2aa98c00e8bced1bb941a94d3fea379a3dc1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 16:07:33 +0200 Subject: [PATCH 272/455] Move is_tensor to ResultMetric --- .../connectors/logger_connector/result.py | 31 ++++++++----------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 1a320d950a146..d695fe5b97b9a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -44,7 +44,6 @@ class Metadata: on_epoch: bool = True reduce_fx: Callable = torch.mean dataloader_idx: Optional[int] = None - is_tensor: bool = True lightning_attribute_name: Optional[str] = None has_reset: bool = False @@ -73,10 +72,11 @@ def is_min_reduction(self) -> bool: class ResultMetric(Metric, DeviceDtypeModuleMixin): """Wraps the value provided to `:meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" - def __init__(self, metadata: Metadata) -> None: - super().__init__(compute_on_step=metadata.is_tensor) + def __init__(self, metadata: Metadata, is_tensor: bool) -> None: + super().__init__(compute_on_step=is_tensor) + self.is_tensor = is_tensor self.meta = metadata - if self.meta.is_tensor: + if is_tensor: self.add_state("value", torch.tensor(0, dtype=torch.float)) if self.meta.is_mean_reduction: self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float)) @@ -84,7 +84,7 @@ def __init__(self, metadata: Metadata) -> None: def update(self, value: _METRIC, batch_size: Optional[int] = None) -> None: # FIXME: support for non-tensor. sync returns tensor always - if self.meta.is_tensor: + if self.is_tensor: if self.meta.is_mean_reduction: self.value += value.float().mean() * batch_size self.cumulated_batch_size += batch_size @@ -99,7 +99,7 @@ def update(self, value: _METRIC, batch_size: Optional[int] = None) -> None: self._forward_cache = value._forward_cache def compute(self) -> torch.Tensor: - if self.meta.is_tensor: + if self.is_tensor: if self.meta.is_mean_reduction: return torch.sum(self.value) / torch.sum(self.cumulated_batch_size) elif self.meta.is_max_reduction or self.meta.is_min_reduction: @@ -111,12 +111,12 @@ def compute(self) -> torch.Tensor: def __repr__(self) -> str: state = f"value={self.value}" - if self.meta.is_tensor and self.meta.is_mean_reduction: + if self.is_tensor and self.meta.is_mean_reduction: state += f", cumulated_batch_size={self.cumulated_batch_size}" return f"{self.__class__.__name__}({state})" def reset(self) -> None: - if self.meta.is_tensor: + if self.is_tensor: super().reset() else: self.value.reset() @@ -279,7 +279,7 @@ def log( ) # create one ResultMetric object per value. # value can be provided as a nested collection. - self.instance_result_metric(key, meta, value) + self.to_result_metric(key, meta, value) if self.should_reset_tensors(fx): # when restarting an new epoch, reset the tensors @@ -287,21 +287,16 @@ def log( self.update_metrics(key, value, batch_size) self._current_fx = fx - def instance_result_metric(self, key: str, meta: Metadata, value: Union[Dict, torch.Tensor]) -> None: + def to_result_metric(self, key: str, meta: Metadata, value: Union[Dict, torch.Tensor]) -> None: def fn(v: Union[torch.Tensor, Metric]) -> ResultMetric: - # This local function is used to `ResultMetric`. - # The `Metadata` is_tensor is modified on the fly - assert self.device is not None - nonlocal meta - meta = deepcopy(meta) - meta.is_tensor = torch.is_tensor(v) - metric = ResultMetric(meta) + metric = ResultMetric(meta, isinstance(v, torch.Tensor)) return metric.to(self.device) # store a mapping between storage key and collection of `ResultMetric` self[key] = apply_to_collection(value, (torch.Tensor, Metric), fn) + # FIXME # when the value was a nested collection, store some metadata # to facilate access for later metrics gathering if not isinstance(self[key], ResultMetric): @@ -452,7 +447,7 @@ def cpu(self) -> 'ResultCollection': def _reset(self, fx: Optional[str] = None, metrics: Optional[bool] = None) -> None: def fn(item: ResultMetric) -> None: - requested_type = metrics is None or metrics ^ item.meta.is_tensor + requested_type = metrics is None or metrics ^ item.is_tensor same_fx = fx is None or fx == item.meta.fx if requested_type and same_fx: item.reset() From 5700df4c30217924f0ff11875f32c39f5d7b6e39 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 17:02:33 +0200 Subject: [PATCH 273/455] Fixes --- .../connectors/logger_connector/result.py | 2 +- tests/core/test_metric_result_integration.py | 16 +++++----------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index d695fe5b97b9a..c4b21be361a83 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -507,7 +507,7 @@ def to_state_dict(item: ResultMetric) -> _SerializationHelper: def load_from_state_dict(self, state_dict: Dict[str, Any], metrics: Optional[Dict[str, Metric]] = None) -> None: def to_result_metric(item: _SerializationHelper) -> ResultMetric: - result_metric = ResultMetric(item["meta"]) + result_metric = ResultMetric(item["meta"], item["is_tensor"]) result_metric.__dict__.update(item) return result_metric.to(self.device) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 8cdf410ee4019..afb721e3277c3 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -202,14 +202,11 @@ def test_result_collection_restoration(): ) result.log('training_step', 'a_1', a, on_step=True, on_epoch=True) result.log('training_step', 'b_1', b, on_step=False, on_epoch=True) - result.log('training_step', 'c_1', [c, c], on_step=True, on_epoch=False) + result.log('training_step', 'c_1', {'1': c, '2': c}, on_step=True, on_epoch=False) batch_log = result.metrics[MetricSource.LOG] - batch_expected = {"a_step": i, "c": i, "a_1_step": i, "c_1": [i, i]} - - assert set(batch_log.keys()) == set(batch_expected.keys()) - for k in batch_expected.keys(): - assert batch_expected[k] == batch_log[k] + assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"} + assert set(batch_log['c_1']) == {'1', '2'} _result = deepcopy(result) state_dict = result.state_dict() @@ -231,13 +228,10 @@ def test_result_collection_restoration(): epoch_log = result.metrics[MetricSource.LOG] _epoch_log = _result.metrics[MetricSource.LOG] - assert epoch_log == _epoch_log - epoch_expected = {'a_1_epoch', 'a_epoch', 'b', 'b_1'} - - assert set(epoch_log.keys()) == epoch_expected, epoch_log.keys() - for k in list(epoch_expected): + assert set(epoch_log) == {'a_1_epoch', 'a_epoch', 'b', 'b_1'} + for k in epoch_log: if k in {'a_epoch', 'b'}: assert epoch_log[k] == cumulative_sum else: From bf8bc526f3e46f98ce799b66c45267bd5ca550ee Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 17:06:42 +0200 Subject: [PATCH 274/455] Rename metric attribute --- pytorch_lightning/core/lightning.py | 2 +- .../connectors/logger_connector/result.py | 19 +++++++++--------- tests/core/test_metric_result_integration.py | 20 +++++++++---------- 3 files changed, 20 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index d60e8f5d21587..e0d2ac47279e1 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -390,7 +390,7 @@ def log( enable_graph=enable_graph, dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), batch_size=batch_size, - lightning_attribute_name=metric_attribute, + metric_attribute=metric_attribute, ) def log_dict( diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index c4b21be361a83..e97c125f1554d 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -44,7 +44,7 @@ class Metadata: on_epoch: bool = True reduce_fx: Callable = torch.mean dataloader_idx: Optional[int] = None - lightning_attribute_name: Optional[str] = None + metric_attribute: Optional[str] = None has_reset: bool = False @property @@ -109,12 +109,6 @@ def compute(self) -> torch.Tensor: ) return self.value.compute() - def __repr__(self) -> str: - state = f"value={self.value}" - if self.is_tensor and self.meta.is_mean_reduction: - state += f", cumulated_batch_size={self.cumulated_batch_size}" - return f"{self.__class__.__name__}({state})" - def reset(self) -> None: if self.is_tensor: super().reset() @@ -130,6 +124,11 @@ def forward(self, value: _METRIC, *args, **kwargs) -> torch.Tensor: self._forward_cache = prev_fwd_cache return out + def __repr__(self) -> str: + state = f"value={self.value}" + if self.is_tensor and self.meta.is_mean_reduction: + state += f", cumulated_batch_size={self.cumulated_batch_size}" + return f"{self.__class__.__name__}({state})" class _SerializationHelper(dict): """ @@ -239,7 +238,7 @@ def log( enable_graph: bool = False, dataloader_idx: Optional[int] = None, batch_size: Optional[int] = None, - lightning_attribute_name: Optional[str] = None, + metric_attribute: Optional[str] = None, ): """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" # no metrics should be logged with graphs @@ -275,7 +274,7 @@ def log( on_epoch=on_epoch, reduce_fx=reduce_fx, dataloader_idx=dataloader_idx, - lightning_attribute_name=lightning_attribute_name, + metric_attribute=metric_attribute, ) # create one ResultMetric object per value. # value can be provided as a nested collection. @@ -519,7 +518,7 @@ def to_result_metric(item: _SerializationHelper) -> ResultMetric: def re_assign_metric(item: ResultMetric) -> None: # metric references are lost during serialization and need to be set back during loading - name = item.meta.lightning_attribute_name + name = item.meta.metric_attribute if isinstance(name, str) and name in metrics: item.value = metrics[name] diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index afb721e3277c3..858c8744557be 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -75,9 +75,9 @@ def _ddp_test_fn(rank, worldsize): cumulative_sum += i - result.log('h', 'a', metric_a, on_step=True, on_epoch=True, lightning_attribute_name="metric_a") - result.log('h', 'b', metric_b, on_step=False, on_epoch=True, lightning_attribute_name="metric_b") - result.log('h', 'c', metric_c, on_step=True, on_epoch=False, lightning_attribute_name="metric_c") + result.log('h', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a") + result.log('h', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") + result.log('h', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") batch_log = result.get_batch_metrics()[MetricSource.LOG] batch_expected = {"a_step": i, "c": i} @@ -134,9 +134,9 @@ def test_result_metric_integration(): cumulative_sum += i - result.log('h', 'a', metric_a, on_step=True, on_epoch=True, lightning_attribute_name="metric_a") - result.log('h', 'b', metric_b, on_step=False, on_epoch=True, lightning_attribute_name="metric_b") - result.log('h', 'c', metric_c, on_step=True, on_epoch=False, lightning_attribute_name="metric_c") + result.log('h', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a") + result.log('h', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") + result.log('h', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") batch_log = result.get_batch_metrics()[MetricSource.LOG] batch_expected = {"a_step": i, "c": i} @@ -193,12 +193,12 @@ def test_result_collection_restoration(): cumulative_sum += i - result.log('training_step', 'a', metric_a, on_step=True, on_epoch=True, lightning_attribute_name="metric_a") + result.log('training_step', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a") result.log( - 'training_step', 'b', metric_b, on_step=False, on_epoch=True, lightning_attribute_name="metric_b" + 'training_step', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b" ) result.log( - 'training_step', 'c', metric_c, on_step=True, on_epoch=False, lightning_attribute_name="metric_c" + 'training_step', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c" ) result.log('training_step', 'a_1', a, on_step=True, on_epoch=True) result.log('training_step', 'b_1', b, on_step=False, on_epoch=True) @@ -238,7 +238,7 @@ def test_result_collection_restoration(): assert epoch_log[k] == 1 result.log( - 'train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True, lightning_attribute_name="metric_a_end" + 'train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True, metric_attribute="metric_a_end" ) _result.reset() From 89149ac342864ae3c288455f437dc1a9f8c58d3e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 17:43:04 +0200 Subject: [PATCH 275/455] Comment getstate --- .../trainer/connectors/logger_connector/result.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index e97c125f1554d..89ba0caf0e4ae 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Generator, Mapping, Sequence -from copy import deepcopy from dataclasses import dataclass from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union @@ -130,6 +129,14 @@ def __repr__(self) -> str: state += f", cumulated_batch_size={self.cumulated_batch_size}" return f"{self.__class__.__name__}({state})" + # FIXME: necessary? tests pass? + # def __getstate__(self) -> dict: + # d = super().__getstate__() + # # delete reference to TorchMetrics Metric + # d['_modules'].pop('value', None) + # return d + + class _SerializationHelper(dict): """ Since ``ResultCollection`` can hold ``ResultMetric`` values or dictionaries of them, we need @@ -496,10 +503,7 @@ def _extract_batch_size(self, batch: Any) -> int: def state_dict(self): def to_state_dict(item: ResultMetric) -> _SerializationHelper: - state = deepcopy(item.__getstate__()) - # delete reference to TorchMetrics Metric - state['_modules'].pop('value', None) - return _SerializationHelper(**state) + return _SerializationHelper(**item.__getstate__()) return {k: apply_to_collection(v, ResultMetric, to_state_dict) for k, v in self.items()} From d7823e35e9036cfbdada4d89ab8e730cd815284f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 1 Jun 2021 15:44:09 +0000 Subject: [PATCH 276/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../connectors/logger_connector/result.py | 20 +++++++------------ tests/core/test_metric_result_integration.py | 12 +++-------- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 89ba0caf0e4ae..014152371cb02 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Generator, Mapping, Sequence +from collections.abc import Generator from dataclasses import dataclass from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union @@ -387,20 +387,14 @@ def get_metrics(self, on_step: bool) -> Dict[str, Dict[str, torch.Tensor]]: value = apply_to_collection(result_metric, ResultMetric, fn, include_none=False) # detect if the value is None. This can be nested. - is_empty = True + is_none = False - def is_empty_fn(v): - nonlocal is_empty - # update is_empty if any value is not None. - if v is not None: - is_empty = False + def any_none(_): + nonlocal is_none + is_none = True - # apply detection. - # TODO(@tchaton): need to find a way to support NamedTuple - apply_to_collection(value, object, is_empty_fn, wrong_dtype=(Mapping, Sequence)) - - # skip is the value was actually empty. - if is_empty: + apply_to_collection(value, type(None), any_none) + if is_none: continue # extract metadata diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 858c8744557be..f7c35f0c9aa65 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -194,12 +194,8 @@ def test_result_collection_restoration(): cumulative_sum += i result.log('training_step', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a") - result.log( - 'training_step', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b" - ) - result.log( - 'training_step', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c" - ) + result.log('training_step', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") + result.log('training_step', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") result.log('training_step', 'a_1', a, on_step=True, on_epoch=True) result.log('training_step', 'b_1', b, on_step=False, on_epoch=True) result.log('training_step', 'c_1', {'1': c, '2': c}, on_step=True, on_epoch=False) @@ -237,9 +233,7 @@ def test_result_collection_restoration(): else: assert epoch_log[k] == 1 - result.log( - 'train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True, metric_attribute="metric_a_end" - ) + result.log('train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True, metric_attribute="metric_a_end") _result.reset() result.reset() From 37e75897ac4f16d32a53b603333c8e4360257a1b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 19:04:41 +0200 Subject: [PATCH 277/455] Leave out serialization for a different PR --- .../connectors/logger_connector/result.py | 34 -------- tests/core/test_metric_result_integration.py | 77 ------------------- 2 files changed, 111 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 014152371cb02..edf99f448e7a4 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -129,12 +129,6 @@ def __repr__(self) -> str: state += f", cumulated_batch_size={self.cumulated_batch_size}" return f"{self.__class__.__name__}({state})" - # FIXME: necessary? tests pass? - # def __getstate__(self) -> dict: - # d = super().__getstate__() - # # delete reference to TorchMetrics Metric - # d['_modules'].pop('value', None) - # return d class _SerializationHelper(dict): @@ -494,34 +488,6 @@ def _extract_batch_size(self, batch: Any) -> int: size = 1 return size - def state_dict(self): - - def to_state_dict(item: ResultMetric) -> _SerializationHelper: - return _SerializationHelper(**item.__getstate__()) - - return {k: apply_to_collection(v, ResultMetric, to_state_dict) for k, v in self.items()} - - def load_from_state_dict(self, state_dict: Dict[str, Any], metrics: Optional[Dict[str, Metric]] = None) -> None: - - def to_result_metric(item: _SerializationHelper) -> ResultMetric: - result_metric = ResultMetric(item["meta"], item["is_tensor"]) - result_metric.__dict__.update(item) - return result_metric.to(self.device) - - state_dict = {k: apply_to_collection(v, _SerializationHelper, to_result_metric) for k, v in state_dict.items()} - for k, v in state_dict.items(): - self[k] = v - - if metrics: - - def re_assign_metric(item: ResultMetric) -> None: - # metric references are lost during serialization and need to be set back during loading - name = item.meta.metric_attribute - if isinstance(name, str) and name in metrics: - item.value = metrics[name] - - apply_to_collection(self, ResultMetric, re_assign_metric) - def __str__(self) -> str: return f'{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})' diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index f7c35f0c9aa65..b9c0f3c50caef 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy - import torch import torch.distributed as dist import torch.multiprocessing as mp @@ -169,81 +167,6 @@ def test_result_metric_integration(): ) -def test_result_collection_restoration(): - - _result = None - metric_a = DummyMetric() - metric_b = DummyMetric() - metric_c = DummyMetric() - - result = ResultCollection(True, torch.device("cpu")) - - for _ in range(2): - - result.on_epoch_end_reached = False - cumulative_sum = 0 - - for i in range(3): - - result.batch_idx = i - - a = metric_a(i) - b = metric_b(i) - c = metric_c(i) - - cumulative_sum += i - - result.log('training_step', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a") - result.log('training_step', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") - result.log('training_step', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") - result.log('training_step', 'a_1', a, on_step=True, on_epoch=True) - result.log('training_step', 'b_1', b, on_step=False, on_epoch=True) - result.log('training_step', 'c_1', {'1': c, '2': c}, on_step=True, on_epoch=False) - - batch_log = result.metrics[MetricSource.LOG] - assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"} - assert set(batch_log['c_1']) == {'1', '2'} - - _result = deepcopy(result) - state_dict = result.state_dict() - - result = ResultCollection(True, torch.device("cpu")) - result.load_from_state_dict( - state_dict, { - "metric_a": metric_a, - "metric_b": metric_b, - "metric_c": metric_c, - "metric_a_end": metric_a - } - ) - - assert _result.items() == result.items() - - result.on_epoch_end_reached = True - _result.on_epoch_end_reached = True - - epoch_log = result.metrics[MetricSource.LOG] - _epoch_log = _result.metrics[MetricSource.LOG] - assert epoch_log == _epoch_log - - assert set(epoch_log) == {'a_1_epoch', 'a_epoch', 'b', 'b_1'} - for k in epoch_log: - if k in {'a_epoch', 'b'}: - assert epoch_log[k] == cumulative_sum - else: - assert epoch_log[k] == 1 - - result.log('train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True, metric_attribute="metric_a_end") - - _result.reset() - result.reset() - - # assert metric state reset to default values - assert metric_a.x == metric_a._defaults['x'] - assert metric_b.x == metric_b._defaults['x'] - assert metric_c.x == metric_c._defaults['x'] - - def test_result_collection_simple_loop(): result = ResultCollection(True, torch.device("cpu")) From c6e3e267b784cf1151370f9ed41eb8cb6c47d9ec Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 19:05:37 +0200 Subject: [PATCH 278/455] Encapsulate dict collections --- .../connectors/logger_connector/result.py | 107 +++++++----------- 1 file changed, 41 insertions(+), 66 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index edf99f448e7a4..24ce510b41664 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -130,17 +130,24 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({state})" - -class _SerializationHelper(dict): +class ResultMetricCollection(dict): """ - Since ``ResultCollection`` can hold ``ResultMetric`` values or dictionaries of them, we need - a class to differentiate between the cases after converting to state dict when saving its state. + Dict wrapper for easy access to metadata. + + All of the leaf items should be instances of + :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` + with the same metadata. """ + def __init__(self, metadata: Metadata) -> None: + super().__init__() + self.meta = metadata + class ResultCollection(dict): """ - Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` + Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` or + :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetricCollection` Example: @@ -162,9 +169,6 @@ class ResultCollection(dict): result.log('training_epoch_end', 'acc', torch.tensor(...), on_step=False, on_epoch=True)` """ - # FIXME - STEP_SUFFIX = "_step" - EPOCH_SUFFIX = "_epoch" DATALOADER_SUFFIX = "/dataloader_idx_{}" def __init__(self, training: bool, device: Optional[torch.device] = None) -> None: @@ -215,6 +219,7 @@ def extra(self) -> Dict: Extras are any keys other than the loss returned by :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` """ + # FIXME: add underscore to key return self.get('extra', {}) @extra.setter @@ -293,18 +298,15 @@ def fn(v: Union[torch.Tensor, Metric]) -> ResultMetric: metric = ResultMetric(meta, isinstance(v, torch.Tensor)) return metric.to(self.device) - # store a mapping between storage key and collection of `ResultMetric` - self[key] = apply_to_collection(value, (torch.Tensor, Metric), fn) + if isinstance(value, dict): + rmc = ResultMetricCollection(meta) + rmc.update(value) + apply_to_collection(rmc, (torch.Tensor, Metric), fn) + value = rmc + else: + value = fn(value) - # FIXME - # when the value was a nested collection, store some metadata - # to facilate access for later metrics gathering - if not isinstance(self[key], ResultMetric): - self[key + '.forked'] = meta.forked - self[key + '.logger'] = meta.logger - self[key + '.prog_bar'] = meta.prog_bar - self[key + '.on_epoch'] = meta.on_epoch - self[key + '.dataloader_idx'] = meta.dataloader_idx + self[key] = value def should_reset_tensors(self, fx: str) -> bool: # reset tensor metrics only when the hook changed and reloading the dataloader @@ -331,38 +333,18 @@ def _to_item(t: torch.Tensor) -> float: def valid_items(self) -> Generator: """This function is used to iterate over current valid metrics.""" - return ((k, v) for k, v in self.items() if ( - v is not None and not isinstance(v, bool) and not k == "extra" - and not (isinstance(v, ResultMetric) and v.meta.has_reset) - )) - - def _extract_metadata(self, key: str, result_metric, on_step: bool, suffix: str) -> Tuple: - """ - This function is used to extract the metadata for `ResultMetric` and `nested ResultMetrics`. - """ - - if isinstance(result_metric, ResultMetric): - name = result_metric.meta.name - name_forked = result_metric.meta.forked_name(on_step) - logger = result_metric.meta.logger - prog_bar = result_metric.meta.prog_bar - metric_on_epoch = result_metric.meta.on_epoch - dataloader_idx = result_metric.meta.dataloader_idx - else: - name = key.split('.')[-1] - name_forked = name + suffix if self[key + '.forked'] else name - logger = self[key + '.logger'] - prog_bar = self[key + '.prog_bar'] - metric_on_epoch = self[key + '.on_epoch'] - dataloader_idx = self[key + '.dataloader_idx'] - - # add dataloader_suffix is provided. - if dataloader_idx is not None: - dataloader_suffix = self.DATALOADER_SUFFIX.format(dataloader_idx) + return ((k, v) for k, v in self.items() + if (not k == "extra" and not (isinstance(v, ResultMetric) and v.meta.has_reset))) + + def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]: + name = result_metric.meta.name + forked_name = result_metric.meta.forked_name(on_step) + dl_idx = result_metric.meta.dataloader_idx + if dl_idx is not None: + dataloader_suffix = self.DATALOADER_SUFFIX.format(dl_idx) name += dataloader_suffix - name_forked += dataloader_suffix - - return name, name_forked, logger, prog_bar, metric_on_epoch + forked_name += dataloader_suffix + return name, forked_name def get_metrics(self, on_step: bool) -> Dict[str, Dict[str, torch.Tensor]]: metrics = {k: {} for k in MetricSource} @@ -370,9 +352,6 @@ def get_metrics(self, on_step: bool) -> Dict[str, Dict[str, torch.Tensor]]: # either extract `forward_cache` or `computed` from `ResultMetric` objects fn = self._get_forward_cache if on_step else self._get_computed_cache - # select suffix - suffix = self.STEP_SUFFIX if on_step else self.EPOCH_SUFFIX - # iterate over all stored metrics. for key, result_metric in self.valid_items(): @@ -391,25 +370,21 @@ def any_none(_): if is_none: continue - # extract metadata - name, name_forked, logger, prog_bar, metric_on_epoch = self._extract_metadata( - key, result_metric, on_step, suffix - ) + name, forked_name = self._forked_name(result_metric, on_step) # populate logging metrics - if logger: - metrics[MetricSource.LOG][name_forked] = value + if result_metric.meta.logger: + metrics[MetricSource.LOG][forked_name] = value - # populate callback metrics - # callback metrics don't take `_step` forked metrics. - if self.training or metric_on_epoch and not on_step: + # populate callback metrics. callback metrics don't take `_step` forked metrics + if self.training or result_metric.meta.on_epoch and not on_step: metrics[MetricSource.CALLBACK][name] = value - metrics[MetricSource.CALLBACK][name_forked] = value + metrics[MetricSource.CALLBACK][forked_name] = value - # populate progress_bar metrics. By default, the value should be converted to a float. - if prog_bar: + # populate progress_bar metrics. values should be converted to float + if result_metric.meta.prog_bar: value = apply_to_collection(value, torch.Tensor, self._to_item, include_none=False) - metrics[MetricSource.PBAR][name_forked] = value + metrics[MetricSource.PBAR][forked_name] = value return metrics From a4a690333eebeaf87ff686227964f43b4fd9a39a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 19:19:29 +0200 Subject: [PATCH 279/455] "extra" changes --- .../connectors/logger_connector/result.py | 19 +++++++++++-------- pytorch_lightning/trainer/training_loop.py | 4 ++-- .../logging_/test_train_loop_logging.py | 11 +++++++++++ 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 24ce510b41664..379b61895e931 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -219,17 +219,20 @@ def extra(self) -> Dict: Extras are any keys other than the loss returned by :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` """ - # FIXME: add underscore to key - return self.get('extra', {}) + return self.get('_extra', {}) @extra.setter def extra(self, extra: Dict) -> None: - # FIXME: Should probably fail instead of detaching - def detach_fn(v): - return v.detach() - extra = apply_to_collection(extra, torch.Tensor, detach_fn) - self['extra'] = extra + def check_fn(v): + if v.grad_fn is not None: + raise MisconfigurationException( + 'You passed a tensor with `grad_fn` when calling `self.log()`.' + f' The extra values are {extra}' + ) + + apply_to_collection(extra, torch.Tensor, check_fn) + self['_extra'] = extra def log( self, @@ -334,7 +337,7 @@ def _to_item(t: torch.Tensor) -> float: def valid_items(self) -> Generator: """This function is used to iterate over current valid metrics.""" return ((k, v) for k, v in self.items() - if (not k == "extra" and not (isinstance(v, ResultMetric) and v.meta.has_reset))) + if not k == "_extra" and not (isinstance(v, ResultMetric) and v.meta.has_reset)) def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]: name = result_metric.meta.name diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e300c08c61ac8..e5a339011d857 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -300,7 +300,7 @@ def _process_training_step_output(self, training_step_output): loss = None hiddens = None - result["extra"] = {} + result.extra = {} # handle dict return if isinstance(training_step_output, dict): @@ -308,7 +308,7 @@ def _process_training_step_output(self, training_step_output): hiddens = training_step_output.pop("hiddens", None) if hiddens is not None: hiddens = hiddens.detach() - result["extra"] = training_step_output + result.extra = training_step_output # handle scalar return elif isinstance(training_step_output, torch.Tensor): diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 8fee5af1544d7..2e2991a341148 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -811,3 +811,14 @@ def training_step(self, batch, batch_idx): match=r"`self.log\(foo, ..., metric_attribute=name\)` where `name` is one of \['bar'\]" ): trainer.fit(model) + + class TestModel(BoringModel): + + def training_step(self, *args): + loss = super().training_step(*args)['loss'] + return {"loss": loss, 'foo': loss} + + trainer = Trainer(default_root_dir=tmpdir) + model = TestModel() + with pytest.raises(MisconfigurationException, match='You passed a tensor with `grad_fn`'): + trainer.fit(model) From f9304fc048e0f23e4d4456a234d800b91eb3dda1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 19:31:27 +0200 Subject: [PATCH 280/455] Fix typing --- .../connectors/logger_connector/result.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 379b61895e931..5d41da303cfff 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -16,7 +16,6 @@ from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union import torch -from torch import Tensor from torchmetrics import Metric from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator @@ -24,7 +23,10 @@ from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import _METRIC + +# re-define the ones from pytorch_lightning.utilities.types without the `Number` type +_METRIC = Union[Metric, torch.Tensor] +_METRIC_COLLECTION = Union[_METRIC, Dict[str, _METRIC]] class MetricSource(LightningEnum): @@ -79,10 +81,8 @@ def __init__(self, metadata: Metadata, is_tensor: bool) -> None: self.add_state("value", torch.tensor(0, dtype=torch.float)) if self.meta.is_mean_reduction: self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float)) - # FIXME: self.value when not tensor? def update(self, value: _METRIC, batch_size: Optional[int] = None) -> None: - # FIXME: support for non-tensor. sync returns tensor always if self.is_tensor: if self.meta.is_mean_reduction: self.value += value.float().mean() * batch_size @@ -192,12 +192,12 @@ def on_epoch_end_reached(self, on_epoch_end_reached): self.batch_idx = None @property - def metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: + def metrics(self) -> Dict[str, _METRIC_COLLECTION]: """This function returns either batch or epoch metrics depending on ``on_epoch_end_reached``.""" return self.get_epoch_metrics() if self.on_epoch_end_reached else self.get_batch_metrics() @property - def minimize(self) -> Optional[Tensor]: + def minimize(self) -> Optional[torch.Tensor]: """ The :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` loss will be saved as the ``minimize`` attribute. @@ -207,14 +207,14 @@ def minimize(self) -> Optional[Tensor]: @minimize.setter def minimize(self, loss: Optional[torch.Tensor]) -> None: if loss is not None: - if not isinstance(loss, Tensor): + if not isinstance(loss, torch.Tensor): raise ValueError(f"`Result.minimize` must be a `torch.Tensor`, found: {loss}") if loss.grad_fn is None: raise RuntimeError("`Result.minimize` must have a `grad_fn`") self._minimize = loss @property - def extra(self) -> Dict: + def extra(self) -> Dict[str, Any]: """ Extras are any keys other than the loss returned by :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` @@ -222,7 +222,7 @@ def extra(self) -> Dict: return self.get('_extra', {}) @extra.setter - def extra(self, extra: Dict) -> None: + def extra(self, extra: Dict[str, Any]) -> None: def check_fn(v): if v.grad_fn is not None: @@ -238,7 +238,7 @@ def log( self, fx: str, name: str, - value: Any, + value: _METRIC_COLLECTION, prog_bar: bool = False, logger: bool = True, on_step: bool = False, @@ -248,7 +248,7 @@ def log( dataloader_idx: Optional[int] = None, batch_size: Optional[int] = None, metric_attribute: Optional[str] = None, - ): + ) -> None: """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" # no metrics should be logged with graphs if not enable_graph and isinstance(value, torch.Tensor): @@ -295,9 +295,9 @@ def log( self.update_metrics(key, value, batch_size) self._current_fx = fx - def to_result_metric(self, key: str, meta: Metadata, value: Union[Dict, torch.Tensor]) -> None: + def to_result_metric(self, key: str, meta: Metadata, value: _METRIC_COLLECTION) -> None: - def fn(v: Union[torch.Tensor, Metric]) -> ResultMetric: + def fn(v: _METRIC) -> ResultMetric: metric = ResultMetric(meta, isinstance(v, torch.Tensor)) return metric.to(self.device) @@ -315,7 +315,7 @@ def should_reset_tensors(self, fx: str) -> bool: # reset tensor metrics only when the hook changed and reloading the dataloader return self._current_fx != fx and self.batch_idx in (None, 0) - def update_metrics(self, key: str, value: Union[Dict, torch.Tensor], batch_size: Optional[int]) -> None: + def update_metrics(self, key: str, value: _METRIC_COLLECTION, batch_size: Optional[int]) -> None: batch_size = torch.tensor(batch_size or self.batch_size, device=self.device) def fn(result_metric, v): @@ -349,7 +349,7 @@ def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, forked_name += dataloader_suffix return name, forked_name - def get_metrics(self, on_step: bool) -> Dict[str, Dict[str, torch.Tensor]]: + def get_metrics(self, on_step: bool) -> Dict[str, _METRIC_COLLECTION]: metrics = {k: {} for k in MetricSource} # either extract `forward_cache` or `computed` from `ResultMetric` objects @@ -391,10 +391,10 @@ def any_none(_): return metrics - def get_batch_metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: + def get_batch_metrics(self) -> Dict[str, _METRIC_COLLECTION]: return self.get_metrics(on_step=True) - def get_epoch_metrics(self) -> Dict[str, Dict[str, torch.Tensor]]: + def get_epoch_metrics(self) -> Dict[str, _METRIC_COLLECTION]: return self.get_metrics(on_step=False) @staticmethod From 42373071c7d3ff01983ca2921a5267881e4dbef7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 19:39:24 +0200 Subject: [PATCH 281/455] Enable mypy --- setup.cfg | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.cfg b/setup.cfg index 5a68adb27b443..f12f3e2a03fe0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -163,6 +163,8 @@ ignore_errors = True # whitelist [mypy-pytorch_lightning.trainer.evaluation_loop] ignore_errors = False +[mypy-pytorch_lightning.trainer.connectors.logger_connector] +ignore_errors = False # todo: add proper typing to this module... [mypy-pytorch_lightning.distributed.*] From 24eb614709dff55b7fb79ba52d851586e31357fd Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 19:46:33 +0200 Subject: [PATCH 282/455] Legacy code --- tests/base/model_train_steps.py | 31 ++----------------------------- 1 file changed, 2 insertions(+), 29 deletions(-) diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index c24cf5ded575a..2ef83ffd5a2de 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC -from collections import OrderedDict class TrainingStepVariations(ABC): @@ -31,18 +30,7 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): # calculate loss loss_train = self.loss(y, y_hat) - log_train = loss_train - - # alternate between tensors and scalars for "log" and "progress_bar" - if batch_idx % 2 == 0: - log_train = log_train.item() - - output = OrderedDict({ - 'loss': loss_train, - 'progress_bar': dict(some_val=log_train * log_train), - 'log': dict(train_some_val=log_train * log_train), - }) - return output + return {'loss': loss_train} def training_step__multiple_dataloaders(self, batch, batch_idx, optimizer_idx=None): """Training step for multiple train loaders""" @@ -61,19 +49,4 @@ def training_step__multiple_dataloaders(self, batch, batch_idx, optimizer_idx=No # calculate loss loss_val = self.loss(y, y_hat) - log_val = loss_val - - # alternate between tensors and scalars for "log" and "progress_bar" - if batch_idx % 2 == 0: - log_val = log_val.item() - - output = OrderedDict({ - 'loss': loss_val, - 'progress_bar': { - 'some_val': log_val * log_val - }, - 'log': { - 'train_some_val': log_val * log_val - }, - }) - return output + return {'loss': loss_val} From 092fec014c34b5433a27260955611c0ab4052799 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 20:00:17 +0200 Subject: [PATCH 283/455] Stricter type checking --- pytorch_lightning/core/lightning.py | 13 +++++++++---- tests/trainer/logging_/test_train_loop_logging.py | 5 +++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e0d2ac47279e1..9c0d9b3c5bf3a 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -328,8 +328,13 @@ def log( ' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`' ) - # check for none values - apply_to_collection(value, type(None), partial(self.__check_none, name, value)) + # check for invalid values + apply_to_collection( + value, + object, + partial(self.__check_allowed, name, value), + wrong_dtype=(numbers.Number, Metric, Tensor, dict) + ) # set the default depending on the fx_name on_step = self.__auto_choose_log_on_step(on_step) @@ -469,8 +474,8 @@ def __sync( return sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op) @staticmethod - def __check_none(name: str, value: Any, _) -> Any: - raise ValueError(f'`self.log({name}, {value})` was called, but `None` values cannot be logged') + def __check_allowed(name: str, value: Any, v) -> Any: + raise ValueError(f'`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged') def write_prediction( self, name: str, value: Union[torch.Tensor, List[torch.Tensor]], filename: str = 'predictions.pt' diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 2e2991a341148..8278a6f48c76e 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -17,6 +17,7 @@ import collections import itertools +from re import escape import numpy as np import pytest @@ -759,7 +760,7 @@ def validation_step(self, batch, batch_idx): assert "train_loss" in trainer.callback_metrics -@pytest.mark.parametrize('value', [None, {'a': {'b': None}}]) +@pytest.mark.parametrize('value', [None, {'a': {'b': None}}, 'foo', [1, 2, 3], (1, 2, 3), [[1, 2], 3]]) def test_log_none_raises(tmpdir, value): class TestModel(BoringModel): @@ -769,7 +770,7 @@ def training_step(self, *args): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) model = TestModel() - with pytest.raises(ValueError, match=rf"self.log\(foo, {value}\)` was called"): + with pytest.raises(ValueError, match=rf"self.log\(foo, {escape(str(value))}\)` was called"): trainer.fit(model) From 935fe5b41aee2f9909a32127417e4df76d34c6ae Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 20:07:16 +0200 Subject: [PATCH 284/455] Minor changes --- pytorch_lightning/core/lightning.py | 2 +- .../trainer/connectors/logger_connector/result.py | 10 +++------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 9c0d9b3c5bf3a..a1f0a4c56238e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -474,7 +474,7 @@ def __sync( return sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op) @staticmethod - def __check_allowed(name: str, value: Any, v) -> Any: + def __check_allowed(name: str, value: Any, v: Any) -> None: raise ValueError(f'`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged') def write_prediction( diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 5d41da303cfff..606a329ec7aca 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -267,13 +267,10 @@ def log( key = f"{fx}.{name}" # add dataloader_suffix to both key and fx if dataloader_idx is not None: - # use as ResultCollection key key += f'.{dataloader_idx}' - # used to decide when to reset fx += f'.{dataloader_idx}' if key not in self: - # create metadata object if storage key doesn't exist in self meta = Metadata( fx=fx, name=name, @@ -285,9 +282,7 @@ def log( dataloader_idx=dataloader_idx, metric_attribute=metric_attribute, ) - # create one ResultMetric object per value. - # value can be provided as a nested collection. - self.to_result_metric(key, meta, value) + self.register_key(key, meta, value) if self.should_reset_tensors(fx): # when restarting an new epoch, reset the tensors @@ -295,7 +290,8 @@ def log( self.update_metrics(key, value, batch_size) self._current_fx = fx - def to_result_metric(self, key: str, meta: Metadata, value: _METRIC_COLLECTION) -> None: + def register_key(self, key: str, meta: Metadata, value: _METRIC_COLLECTION) -> None: + """Create one ResultMetric object per value. Value can be provided as a nested collection""" def fn(v: _METRIC) -> ResultMetric: metric = ResultMetric(meta, isinstance(v, torch.Tensor)) From eac0b822509ebda9c0a3016c8fd0446f8460187f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 1 Jun 2021 22:21:44 +0200 Subject: [PATCH 285/455] Performance optimizations --- .../connectors/logger_connector/result.py | 42 ++++++++++++------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 606a329ec7aca..61cdcc393cdbc 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -74,7 +74,7 @@ class ResultMetric(Metric, DeviceDtypeModuleMixin): """Wraps the value provided to `:meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" def __init__(self, metadata: Metadata, is_tensor: bool) -> None: - super().__init__(compute_on_step=is_tensor) + super().__init__() self.is_tensor = is_tensor self.meta = metadata if is_tensor: @@ -82,6 +82,10 @@ def __init__(self, metadata: Metadata, is_tensor: bool) -> None: if self.meta.is_mean_reduction: self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float)) + def __setattr__(self, key, value): + # performance: skip the `torch.nn.Module.__setattr__` checks + object.__setattr__(self, key, value) + def update(self, value: _METRIC, batch_size: Optional[int] = None) -> None: if self.is_tensor: if self.meta.is_mean_reduction: @@ -115,13 +119,11 @@ def reset(self) -> None: self.value.reset() self.meta.has_reset = True - def forward(self, value: _METRIC, *args, **kwargs) -> torch.Tensor: - """Overridden to avoid `self._forward_cache = None` after `update`""" - prev_fwd_cache = getattr(value, '_forward_cache', None) - out = super().forward(value, *args, **kwargs) - if out is None: - self._forward_cache = prev_fwd_cache - return out + def forward(self, value: _METRIC, *args, **kwargs) -> None: + # performance: skip the `torch.no_grad` context manager by calling `update` directly + # as `backward` shouldn't be run on metrics + self._forward_cache = value._forward_cache if isinstance(value, Metric) else value + self.update(value, *args, **kwargs) def __repr__(self) -> str: state = f"value={self.value}" @@ -177,11 +179,20 @@ def __init__(self, training: bool, device: Optional[torch.device] = None) -> Non self._on_epoch_end_reached = False self._minimize = None self._current_fx: Optional[str] = None - self.batch_size: int = 1 + self._batch_size = torch.tensor(1, device=device) self.batch_idx: Optional[int] = None self.device: Optional[torch.device] = device self.fx_validator = FxValidator() + @property + def batch_size(self) -> torch.Tensor: + # performance: cache the `batch_size` tensor instead of re-creating it + return self._batch_size + + @batch_size.setter + def batch_size(self, value: int) -> None: + self._batch_size = torch.tensor(value, device=self.device) + @property def on_epoch_end_reached(self) -> bool: return self._on_epoch_end_reached @@ -287,7 +298,11 @@ def log( if self.should_reset_tensors(fx): # when restarting an new epoch, reset the tensors self._reset(fx, metrics=False) - self.update_metrics(key, value, batch_size) + + if batch_size is not None: + self.batch_size = batch_size + + self.update_metrics(key, value) self._current_fx = fx def register_key(self, key: str, meta: Metadata, value: _METRIC_COLLECTION) -> None: @@ -311,12 +326,11 @@ def should_reset_tensors(self, fx: str) -> bool: # reset tensor metrics only when the hook changed and reloading the dataloader return self._current_fx != fx and self.batch_idx in (None, 0) - def update_metrics(self, key: str, value: _METRIC_COLLECTION, batch_size: Optional[int]) -> None: - batch_size = torch.tensor(batch_size or self.batch_size, device=self.device) + def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None: def fn(result_metric, v): - # call the forward function of ResultMetric - result_metric(v.to(self.device), batch_size) + # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl` + result_metric.forward(v.to(self.device), self.batch_size) result_metric.meta.has_reset = False apply_to_collections(self[key], value, ResultMetric, fn) From e789cef0cdb7a94d6cf848c562439df4336cbaa4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 2 Jun 2021 15:20:26 +0200 Subject: [PATCH 286/455] Fixes and remove support for nested dict --- pytorch_lightning/core/lightning.py | 7 +++ .../connectors/logger_connector/result.py | 44 +++++++------------ pytorch_lightning/utilities/types.py | 1 - .../logging_/test_train_loop_logging.py | 7 ++- 4 files changed, 30 insertions(+), 29 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a1f0a4c56238e..8c5ecf2515fae 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -329,6 +329,7 @@ def log( ) # check for invalid values + apply_to_collection(value, dict, partial(self.__check_not_nested, name)) apply_to_collection( value, object, @@ -473,6 +474,12 @@ def __sync( return value return sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op) + @staticmethod + def __check_not_nested(name: str, value: dict) -> None: + if any(isinstance(v, dict) for v in value.values()): + raise ValueError(f'`self.log({name}, {value})` was called, but nested dictionaries cannot be logged') + return value + @staticmethod def __check_allowed(name: str, value: Any, v: Any) -> None: raise ValueError(f'`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged') diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 61cdcc393cdbc..763eb6df319a9 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -13,6 +13,7 @@ # limitations under the License. from collections.abc import Generator from dataclasses import dataclass +from functools import partial from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union import torch @@ -141,8 +142,8 @@ class ResultMetricCollection(dict): with the same metadata. """ - def __init__(self, metadata: Metadata) -> None: - super().__init__() + def __init__(self, *args, metadata: Optional[Metadata] = None) -> None: + super().__init__(*args) self.meta = metadata @@ -270,6 +271,7 @@ def log( value = value.cpu() if on_step and self.on_epoch_end_reached: + # `FxValidator` should avoid this ever happening. Either a bug there or a bug in the logic order. raise RuntimeError( "Logging `on_step` when `on_epoch_end_reached` isn't allowed. This shouldn't have happened." ) @@ -312,14 +314,9 @@ def fn(v: _METRIC) -> ResultMetric: metric = ResultMetric(meta, isinstance(v, torch.Tensor)) return metric.to(self.device) + value = apply_to_collection(value, (torch.Tensor, Metric), fn) if isinstance(value, dict): - rmc = ResultMetricCollection(meta) - rmc.update(value) - apply_to_collection(rmc, (torch.Tensor, Metric), fn) - value = rmc - else: - value = fn(value) - + value = ResultMetricCollection(value, metadata=meta) self[key] = value def should_reset_tensors(self, fx: str) -> bool: @@ -336,9 +333,13 @@ def fn(result_metric, v): apply_to_collections(self[key], value, ResultMetric, fn) @staticmethod - def _get_forward_cache(result_metric: ResultMetric) -> Optional[torch.Tensor]: - if result_metric.meta.on_step: + def _get_cache(on_step: bool, result_metric: ResultMetric) -> Optional[torch.Tensor]: + if on_step and result_metric.meta.on_step: return result_metric._forward_cache.detach() + elif not on_step and result_metric.meta.on_epoch: + if not result_metric._computed: + result_metric.compute() + return result_metric._computed.detach() @staticmethod def _to_item(t: torch.Tensor) -> float: @@ -362,15 +363,12 @@ def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, def get_metrics(self, on_step: bool) -> Dict[str, _METRIC_COLLECTION]: metrics = {k: {} for k in MetricSource} - # either extract `forward_cache` or `computed` from `ResultMetric` objects - fn = self._get_forward_cache if on_step else self._get_computed_cache - - # iterate over all stored metrics. for key, result_metric in self.valid_items(): - # extract forward_cache or computed from the ResultMetric - # ignore when the output of fn is None - value = apply_to_collection(result_metric, ResultMetric, fn, include_none=False) + # extract forward_cache or computed from the ResultMetric. ignore when the output is None + value = apply_to_collection( + result_metric, ResultMetric, partial(self._get_cache, on_step), include_none=False + ) # detect if the value is None. This can be nested. is_none = False @@ -394,7 +392,7 @@ def any_none(_): metrics[MetricSource.CALLBACK][name] = value metrics[MetricSource.CALLBACK][forked_name] = value - # populate progress_bar metrics. values should be converted to float + # populate progress_bar metrics. convert tensors to numbers if result_metric.meta.prog_bar: value = apply_to_collection(value, torch.Tensor, self._to_item, include_none=False) metrics[MetricSource.PBAR][forked_name] = value @@ -407,14 +405,6 @@ def get_batch_metrics(self) -> Dict[str, _METRIC_COLLECTION]: def get_epoch_metrics(self) -> Dict[str, _METRIC_COLLECTION]: return self.get_metrics(on_step=False) - @staticmethod - def _get_computed_cache(result_metric: ResultMetric) -> Optional[torch.Tensor]: - if not result_metric.meta.on_epoch: - return - if not result_metric._computed: - result_metric.compute() - return result_metric._computed.detach() - def to(self, *args, **kwargs) -> 'ResultCollection': """Move all data to the given device.""" for k, v in self.items(): diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 4a98956b71c57..945ee3b74218f 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -23,7 +23,6 @@ from torchmetrics import Metric _METRIC = Union[Metric, torch.Tensor, Number] -# real type is `Union[_METRIC, Dict[str, '_METRIC_COLLECTION']]` but Sphinx fails with `RecursionError` _METRIC_COLLECTION = Union[_METRIC, Dict[str, _METRIC]] STEP_OUTPUT = Union[torch.Tensor, Dict[str, Any]] EPOCH_OUTPUT = List[STEP_OUTPUT] diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 8278a6f48c76e..c6ef47d9e8162 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -760,7 +760,12 @@ def validation_step(self, batch, batch_idx): assert "train_loss" in trainer.callback_metrics -@pytest.mark.parametrize('value', [None, {'a': {'b': None}}, 'foo', [1, 2, 3], (1, 2, 3), [[1, 2], 3]]) +@pytest.mark.parametrize( + 'value', + [None, dict(a=None), + dict(a=dict(b=None)), + dict(a=dict(b=1)), 'foo', [1, 2, 3], (1, 2, 3), [[1, 2], 3]] +) def test_log_none_raises(tmpdir, value): class TestModel(BoringModel): From 1f792124e2707a6261f6e7ea4ef2bfd4845b8c81 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 2 Jun 2021 16:40:50 +0200 Subject: [PATCH 287/455] Make sure _forward_cache gets float values --- .../trainer/connectors/logger_connector/result.py | 12 ++++++++---- tests/callbacks/test_progress_bar.py | 7 ++++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 763eb6df319a9..d630c4d85aeda 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -90,14 +90,14 @@ def __setattr__(self, key, value): def update(self, value: _METRIC, batch_size: Optional[int] = None) -> None: if self.is_tensor: if self.meta.is_mean_reduction: - self.value += value.float().mean() * batch_size + self.value += value.mean() * batch_size self.cumulated_batch_size += batch_size elif self.meta.is_max_reduction: - self.value = max(self.value, value.float().mean()) + self.value = max(self.value, value.mean()) elif self.meta.is_min_reduction: - self.value = min(self.value, value.float().mean()) + self.value = min(self.value, value.mean()) else: self.value = value # noqa: attribute-defined-outside-init self._forward_cache = value._forward_cache @@ -121,9 +121,13 @@ def reset(self) -> None: self.meta.has_reset = True def forward(self, value: _METRIC, *args, **kwargs) -> None: + if isinstance(value, Metric): + self._forward_cache = value._forward_cache + else: + value = value.float() + self._forward_cache = value # performance: skip the `torch.no_grad` context manager by calling `update` directly # as `backward` shouldn't be run on metrics - self._forward_cache = value._forward_cache if isinstance(value, Metric) else value self.update(value, *args, **kwargs) def __repr__(self) -> str: diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 965f74f802f05..96b46f9bf104c 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -384,8 +384,9 @@ def test_tensor_to_float_conversion(tmpdir): class TestModel(BoringModel): def training_step(self, batch, batch_idx): - self.log('foo', torch.tensor(0.123), prog_bar=True) - self.log('bar', {"baz": torch.tensor([1])}, prog_bar=True) + self.log('a', torch.tensor(0.123), prog_bar=True) + self.log('b', {"b1": torch.tensor([1])}, prog_bar=True) + self.log('c', {"c1": 2}, prog_bar=True) return super().training_step(batch, batch_idx) trainer = Trainer( @@ -399,7 +400,7 @@ def training_step(self, batch, batch_idx): pbar = trainer.progress_bar_callback.main_progress_bar actual = str(pbar.postfix) - assert actual.endswith("foo=0.123, bar={'baz': 1.0}") + assert actual.endswith("a=0.123, b={'b1': 1.0}, c={'c1': 2.0}"), actual @pytest.mark.parametrize( From b24b4108668fa92e0e78019ba480f7d36c80491d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 2 Jun 2021 17:08:20 +0200 Subject: [PATCH 288/455] Update legacy test --- tests/helpers/advanced_models.py | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/tests/helpers/advanced_models.py b/tests/helpers/advanced_models.py index 2b0146e1ee099..8f3b9663aa2d7 100644 --- a/tests/helpers/advanced_models.py +++ b/tests/helpers/advanced_models.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import OrderedDict import numpy as np import torch @@ -122,13 +121,8 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): # adversarial loss is binary cross-entropy g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid) - tqdm_dict = {'g_loss': g_loss} - output = OrderedDict({ - 'loss': g_loss, - 'progress_bar': tqdm_dict, - 'log': tqdm_dict, - }) - return output + self.log('g_loss', g_loss, prog_bar=True, logger=True) + return g_loss # train discriminator if optimizer_idx == 1: @@ -148,13 +142,8 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): # discriminator loss is the average of these d_loss = (real_loss + fake_loss) / 2 - tqdm_dict = {'d_loss': d_loss} - output = OrderedDict({ - 'loss': d_loss, - 'progress_bar': tqdm_dict, - 'log': tqdm_dict, - }) - return output + self.log('d_loss', d_loss, prog_bar=True, logger=True) + return d_loss def configure_optimizers(self): lr = self.learning_rate From 92c5ac0df355d7deba1e5020a4313c559bbf8481 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 2 Jun 2021 18:50:12 +0200 Subject: [PATCH 289/455] Sync improvements --- pytorch_lightning/core/lightning.py | 5 ++--- pytorch_lightning/plugins/training_type/ddp.py | 4 ++-- pytorch_lightning/plugins/training_type/ddp_spawn.py | 4 ++-- pytorch_lightning/utilities/distributed.py | 2 +- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 8c5ecf2515fae..9b4b1647960ca 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -43,7 +43,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin -from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed +from pytorch_lightning.utilities.distributed import sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature @@ -469,8 +469,7 @@ def __sync( if isinstance(value, numbers.Number): value = torch.tensor(value, device=device, dtype=torch.float) sync_fn = sync_fn or sync_ddp_if_available - dist_available = torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed() - if not sync_dist or not dist_available: + if not sync_dist: return value return sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index e65a6512d3846..4990f95f14ac0 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -328,7 +328,7 @@ def model_to_device(self): torch.cuda.set_device(self.root_device) self.model.to(self.root_device) - def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): + def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor: """ Reduces a tensor from several distributed processes to one aggregated tensor. @@ -342,7 +342,7 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ reduced value, except when the input was not a tensor the output remains is unchanged """ if isinstance(tensor, torch.Tensor): - tensor = sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean")) + tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor def training_step(self, *args, **kwargs): diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index df9f0ee158ba3..8d2cc217835fb 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -319,7 +319,7 @@ def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, opti if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync: prepare_for_backward(self.model, closure_loss) - def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): + def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor: """ Reduces a tensor from several distributed processes to one aggregated tensor. @@ -333,7 +333,7 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ reduced value, except when the input was not a tensor the output remains is unchanged """ if isinstance(tensor, torch.Tensor): - tensor = sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean")) + tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor def training_step(self, *args, **kwargs): diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index a54d00a983d9e..a507afa6bc895 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -129,7 +129,7 @@ def sync_ddp_if_available( Return: reduced value """ - if torch.distributed.is_available() and torch.distributed.is_initialized(): + if torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed(): return sync_ddp(result, group=group, reduce_op=reduce_op) return result From e628b1fc2526ea0a0b69855c8bfb9d75855d9a58 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 2 Jun 2021 18:52:16 +0200 Subject: [PATCH 290/455] Convert to float outside of sync --- pytorch_lightning/core/lightning.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 9b4b1647960ca..a4cd92c98ce7a 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -374,15 +374,7 @@ def log( f" of {list(self._metric_attributes.values())}" ) - sync_fn = partial( - self.__sync, - sync_fn=self.trainer.training_type_plugin.reduce, - sync_dist=sync_dist, - sync_dist_op=sync_dist_op, - sync_dist_group=sync_dist_group, - device=self.device, - ) - value = apply_to_collection(value, (torch.Tensor, numbers.Number), sync_fn) + value = apply_to_collection(value, numbers.Number, self.__to_float) result_collection.log( self._current_fx_name, @@ -458,16 +450,13 @@ def log_dict( @staticmethod def __sync( - value: Union[torch.Tensor, numbers.Number], + value: torch.Tensor, sync_fn: Optional[Callable] = None, sync_dist: bool = False, sync_dist_op: Union[Any, str] = 'mean', sync_dist_group: Optional[Any] = None, - device: torch.device = None, ) -> torch.Tensor: """Sync across workers when using distributed training""" - if isinstance(value, numbers.Number): - value = torch.tensor(value, device=device, dtype=torch.float) sync_fn = sync_fn or sync_ddp_if_available if not sync_dist: return value @@ -483,6 +472,9 @@ def __check_not_nested(name: str, value: dict) -> None: def __check_allowed(name: str, value: Any, v: Any) -> None: raise ValueError(f'`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged') + def __to_float(self, value: numbers.Number) -> torch.Tensor: + return torch.tensor(value, device=self.device, dtype=torch.float) + def write_prediction( self, name: str, value: Union[torch.Tensor, List[torch.Tensor]], filename: str = 'predictions.pt' ): From 4efc01af3ca8e5f3704cf5fbffb6fa1071e6e967 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 11:19:16 +0200 Subject: [PATCH 291/455] Fix empty collections --- .../trainer/connectors/logger_connector/result.py | 14 +++++++------- tests/callbacks/test_progress_bar.py | 9 ++++++--- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index d630c4d85aeda..a7c5cc0ec1b84 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -374,15 +374,15 @@ def get_metrics(self, on_step: bool) -> Dict[str, _METRIC_COLLECTION]: result_metric, ResultMetric, partial(self._get_cache, on_step), include_none=False ) - # detect if the value is None. This can be nested. - is_none = False + # check if the collection is empty + has_tensor = False - def any_none(_): - nonlocal is_none - is_none = True + def any_tensor(_): + nonlocal has_tensor + has_tensor = True - apply_to_collection(value, type(None), any_none) - if is_none: + apply_to_collection(value, torch.Tensor, any_tensor) + if not has_tensor: continue name, forked_name = self._forked_name(result_metric, on_step) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 96b46f9bf104c..1a46870baaf8c 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -384,9 +384,9 @@ def test_tensor_to_float_conversion(tmpdir): class TestModel(BoringModel): def training_step(self, batch, batch_idx): - self.log('a', torch.tensor(0.123), prog_bar=True) - self.log('b', {"b1": torch.tensor([1])}, prog_bar=True) - self.log('c', {"c1": 2}, prog_bar=True) + self.log('a', torch.tensor(0.123), prog_bar=True, on_epoch=False) + self.log('b', {"b1": torch.tensor([1])}, prog_bar=True, on_epoch=False) + self.log('c', {"c1": 2}, prog_bar=True, on_epoch=False) return super().training_step(batch, batch_idx) trainer = Trainer( @@ -398,6 +398,9 @@ def training_step(self, batch, batch_idx): ) trainer.fit(TestModel()) + torch.testing.assert_allclose(trainer.progress_bar_metrics['a'], 0.123) + assert trainer.progress_bar_metrics['b'] == {'b1': 1.0} + assert trainer.progress_bar_metrics['c'] == {'c1': 2.0} pbar = trainer.progress_bar_callback.main_progress_bar actual = str(pbar.postfix) assert actual.endswith("a=0.123, b={'b1': 1.0}, c={'c1': 2.0}"), actual From 8509c4ae410f06020e5900368eb9b7f66b7632f4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 11:21:41 +0200 Subject: [PATCH 292/455] Install latest torchmetrics --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 964bb493a2637..5b2e377397791 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ tqdm>=4.41.0 PyYAML>=5.1,<=5.4.1 fsspec[http]>=2021.4.0 tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file into descriptor pool!' -torchmetrics>=0.2.0 +torchmetrics>=0.3.2 pyDeprecate==0.3.0 packaging typing-extensions # TypedDict support for python<3.8 From f8b824fb9b00a51bc90d79a6753b5585994f06b4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 13:04:06 +0200 Subject: [PATCH 293/455] Update CODEOWNERS --- .github/CODEOWNERS | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 39f38bf266af0..e337621f0768e 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -33,6 +33,9 @@ /pytorch_lightning/tuner @SkafteNicki @borda @awaelchli /pytorch_lightning/utilities @borda @tchaton @SeanNaren @carmocca +# Specifics +/pytorch_lightning/trainer/connectors/logger_connector @tchaton @carmocca + # Metrics /pytorch_lightning/metrics/ @SkafteNicki @ananyahjha93 @justusschock /tests/metrics/ @SkafteNicki @ananyahjha93 @justusschock From 5ef05e3af06f7ea7757da7290efc84e348376a2e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 12:01:09 +0200 Subject: [PATCH 294/455] Keep no_grad ctx manager if enable graph --- .../trainer/connectors/logger_connector/result.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index a7c5cc0ec1b84..ed385b5381646 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -45,6 +45,7 @@ class Metadata: on_step: bool = False on_epoch: bool = True reduce_fx: Callable = torch.mean + enable_graph: bool = False dataloader_idx: Optional[int] = None metric_attribute: Optional[str] = None has_reset: bool = False @@ -87,7 +88,7 @@ def __setattr__(self, key, value): # performance: skip the `torch.nn.Module.__setattr__` checks object.__setattr__(self, key, value) - def update(self, value: _METRIC, batch_size: Optional[int] = None) -> None: + def update(self, value: _METRIC, batch_size: torch.Tensor) -> None: if self.is_tensor: if self.meta.is_mean_reduction: self.value += value.mean() * batch_size @@ -120,15 +121,18 @@ def reset(self) -> None: self.value.reset() self.meta.has_reset = True - def forward(self, value: _METRIC, *args, **kwargs) -> None: + def forward(self, value: _METRIC, batch_size: torch.Tensor) -> None: if isinstance(value, Metric): self._forward_cache = value._forward_cache else: value = value.float() self._forward_cache = value - # performance: skip the `torch.no_grad` context manager by calling `update` directly - # as `backward` shouldn't be run on metrics - self.update(value, *args, **kwargs) + if self.meta.enable_graph: + with torch.no_grad(): + self.update(value, batch_size) + else: + # performance: skip the `torch.no_grad` context manager by calling `update` directly + self.update(value, batch_size) def __repr__(self) -> str: state = f"value={self.value}" @@ -296,6 +300,7 @@ def log( on_step=on_step, on_epoch=on_epoch, reduce_fx=reduce_fx, + enable_graph=enable_graph, dataloader_idx=dataloader_idx, metric_attribute=metric_attribute, ) From 33b06d09c7183c02182be08226913059426ee09d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 13:03:49 +0200 Subject: [PATCH 295/455] Move has_reset to ResultMetric --- .../trainer/connectors/logger_connector/result.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index ed385b5381646..aaa8457a02d6d 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -48,7 +48,6 @@ class Metadata: enable_graph: bool = False dataloader_idx: Optional[int] = None metric_attribute: Optional[str] = None - has_reset: bool = False @property def forked(self) -> bool: @@ -79,6 +78,7 @@ def __init__(self, metadata: Metadata, is_tensor: bool) -> None: super().__init__() self.is_tensor = is_tensor self.meta = metadata + self.has_reset = False if is_tensor: self.add_state("value", torch.tensor(0, dtype=torch.float)) if self.meta.is_mean_reduction: @@ -119,7 +119,7 @@ def reset(self) -> None: super().reset() else: self.value.reset() - self.meta.has_reset = True + self.has_reset = True def forward(self, value: _METRIC, batch_size: torch.Tensor) -> None: if isinstance(value, Metric): @@ -337,7 +337,7 @@ def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None: def fn(result_metric, v): # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl` result_metric.forward(v.to(self.device), self.batch_size) - result_metric.meta.has_reset = False + result_metric.has_reset = False apply_to_collections(self[key], value, ResultMetric, fn) @@ -357,7 +357,7 @@ def _to_item(t: torch.Tensor) -> float: def valid_items(self) -> Generator: """This function is used to iterate over current valid metrics.""" return ((k, v) for k, v in self.items() - if not k == "_extra" and not (isinstance(v, ResultMetric) and v.meta.has_reset)) + if not k == "_extra" and not (isinstance(v, ResultMetric) and v.has_reset)) def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]: name = result_metric.meta.name From 41708751bb16e032afd22357d113bd65c9cb9b82 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 13:02:24 +0200 Subject: [PATCH 296/455] Sync dataclass --- pytorch_lightning/core/lightning.py | 18 +--- .../connectors/logger_connector/result.py | 100 ++++++++++++++---- 2 files changed, 83 insertions(+), 35 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a4cd92c98ce7a..8dcdee832c4d6 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -389,6 +389,10 @@ def log( dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), batch_size=batch_size, metric_attribute=metric_attribute, + sync_dist=sync_dist, + sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp_if_available, + sync_dist_op=sync_dist_op, + sync_dist_group=sync_dist_group, ) def log_dict( @@ -448,20 +452,6 @@ def log_dict( add_dataloader_idx=add_dataloader_idx ) - @staticmethod - def __sync( - value: torch.Tensor, - sync_fn: Optional[Callable] = None, - sync_dist: bool = False, - sync_dist_op: Union[Any, str] = 'mean', - sync_dist_group: Optional[Any] = None, - ) -> torch.Tensor: - """Sync across workers when using distributed training""" - sync_fn = sync_fn or sync_ddp_if_available - if not sync_dist: - return value - return sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op) - @staticmethod def __check_not_nested(name: str, value: dict) -> None: if any(isinstance(v, dict) for v in value.values()): diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index aaa8457a02d6d..72e080371495c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Generator -from dataclasses import dataclass -from functools import partial +from dataclasses import dataclass, field +from functools import partial, wraps from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union import torch from torchmetrics import Metric from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.enums import LightningEnum @@ -36,6 +37,22 @@ class MetricSource(LightningEnum): LOG = "log" +@dataclass +class Sync: + fn: Callable + should: bool = False + op: Union[Any, str] = 'mean' + group: Optional[Any] = None + + @property + def __call__(self) -> Callable: + return partial(self.fn, reduce_op=self.op, group=self.group) if self.should else self.no_op + + @staticmethod + def no_op(value: Any, *_, **__) -> Any: + return value + + @dataclass class Metadata: fx: str @@ -48,6 +65,7 @@ class Metadata: enable_graph: bool = False dataloader_idx: Optional[int] = None metric_attribute: Optional[str] = None + sync: Sync = field(default_factory=Sync) @property def forked(self) -> bool: @@ -90,13 +108,16 @@ def __setattr__(self, key, value): def update(self, value: _METRIC, batch_size: torch.Tensor) -> None: if self.is_tensor: + # performance: no need to accumulate on values only logged on_step + if self.meta.on_step and not self.meta.on_epoch: + self.value = self.meta.sync(value) + return + # perform accumulation with reduction if self.meta.is_mean_reduction: self.value += value.mean() * batch_size self.cumulated_batch_size += batch_size - elif self.meta.is_max_reduction: self.value = max(self.value, value.mean()) - elif self.meta.is_min_reduction: self.value = min(self.value, value.mean()) else: @@ -105,15 +126,37 @@ def update(self, value: _METRIC, batch_size: torch.Tensor) -> None: def compute(self) -> torch.Tensor: if self.is_tensor: + value = self.meta.sync(self.value) if self.meta.is_mean_reduction: - return torch.sum(self.value) / torch.sum(self.cumulated_batch_size) + cumulated_batch_size = self.meta.sync(self.cumulated_batch_size) + # FIXME: might need sum + return value / cumulated_batch_size elif self.meta.is_max_reduction or self.meta.is_min_reduction: - return self.value + return value raise MisconfigurationException( f"Only [min, max, mean] reductions are supported. Found {self.meta.reduce_fx}" ) return self.value.compute() + def _wrap_compute(self, compute: Any) -> Any: + # Override to avoid syncing - we handle it ourselves. + @wraps(compute) + def wrapped_func(*args, **kwargs): + if not self._update_called: + rank_zero_warn( + f"The ``compute`` method of metric {self.__class__.__name__}" + " was called before the ``update`` method which may lead to errors," + " as metric states have not yet been updated.", UserWarning + ) + + # return cached value + if self._computed is not None: + return self._computed + self._computed = compute(*args, **kwargs) + return self._computed + + return wrapped_func + def reset(self) -> None: if self.is_tensor: super().reset() @@ -122,11 +165,12 @@ def reset(self) -> None: self.has_reset = True def forward(self, value: _METRIC, batch_size: torch.Tensor) -> None: - if isinstance(value, Metric): - self._forward_cache = value._forward_cache - else: + if self.is_tensor: value = value.float() self._forward_cache = value + else: + self._forward_cache = value._forward_cache + if self.meta.enable_graph: with torch.no_grad(): self.update(value, batch_size) @@ -265,6 +309,10 @@ def log( on_epoch: bool = True, reduce_fx: Callable = torch.mean, enable_graph: bool = False, + sync_dist: bool = False, + sync_dist_fn: Callable = Sync.no_op, + sync_dist_op: Union[Any, str] = 'mean', + sync_dist_group: Optional[Any] = None, dataloader_idx: Optional[int] = None, batch_size: Optional[int] = None, metric_attribute: Optional[str] = None, @@ -291,20 +339,30 @@ def log( key += f'.{dataloader_idx}' fx += f'.{dataloader_idx}' - if key not in self: - meta = Metadata( - fx=fx, - name=name, - prog_bar=prog_bar, - logger=logger, - on_step=on_step, - on_epoch=on_epoch, - reduce_fx=reduce_fx, - enable_graph=enable_graph, - dataloader_idx=dataloader_idx, - metric_attribute=metric_attribute, + meta = Metadata( + fx=fx, + name=name, + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, + reduce_fx=reduce_fx, + enable_graph=enable_graph, + dataloader_idx=dataloader_idx, + metric_attribute=metric_attribute, + sync=Sync( + should=sync_dist, + fn=sync_dist_fn, + op=sync_dist_op, + group=sync_dist_group, ) + ) + if key not in self: self.register_key(key, meta, value) + elif meta != self[key].meta: + raise MisconfigurationException( + f'You called `self.log({key}, ...)` twice with different arguments. This is not allowed' + ) if self.should_reset_tensors(fx): # when restarting an new epoch, reset the tensors From 86127466a82da85165fecd015e6e12d2f79c6390 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 13:13:21 +0200 Subject: [PATCH 297/455] Detach if not enable graph --- .../trainer/connectors/logger_connector/result.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 72e080371495c..c85243335294d 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -401,12 +401,16 @@ def fn(result_metric, v): @staticmethod def _get_cache(on_step: bool, result_metric: ResultMetric) -> Optional[torch.Tensor]: + cache = None if on_step and result_metric.meta.on_step: - return result_metric._forward_cache.detach() + cache = result_metric._forward_cache elif not on_step and result_metric.meta.on_epoch: if not result_metric._computed: result_metric.compute() - return result_metric._computed.detach() + cache = result_metric._computed + if cache is not None and not result_metric.meta.enable_graph: + return cache.detach() + return cache @staticmethod def _to_item(t: torch.Tensor) -> float: From 0f8e6202c47be127f2ae0aa060451d57409c6d19 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 13:41:35 +0200 Subject: [PATCH 298/455] Prepend dataclasses with underscore --- .../connectors/logger_connector/result.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index c85243335294d..6e929a2a4405e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -38,7 +38,7 @@ class MetricSource(LightningEnum): @dataclass -class Sync: +class _Sync: fn: Callable should: bool = False op: Union[Any, str] = 'mean' @@ -54,7 +54,7 @@ def no_op(value: Any, *_, **__) -> Any: @dataclass -class Metadata: +class _Metadata: fx: str name: str prog_bar: bool = False @@ -65,7 +65,7 @@ class Metadata: enable_graph: bool = False dataloader_idx: Optional[int] = None metric_attribute: Optional[str] = None - sync: Sync = field(default_factory=Sync) + sync: _Sync = field(default_factory=_Sync) @property def forked(self) -> bool: @@ -92,7 +92,7 @@ def is_min_reduction(self) -> bool: class ResultMetric(Metric, DeviceDtypeModuleMixin): """Wraps the value provided to `:meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" - def __init__(self, metadata: Metadata, is_tensor: bool) -> None: + def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: super().__init__() self.is_tensor = is_tensor self.meta = metadata @@ -194,7 +194,7 @@ class ResultMetricCollection(dict): with the same metadata. """ - def __init__(self, *args, metadata: Optional[Metadata] = None) -> None: + def __init__(self, *args, metadata: Optional[_Metadata] = None) -> None: super().__init__(*args) self.meta = metadata @@ -310,7 +310,7 @@ def log( reduce_fx: Callable = torch.mean, enable_graph: bool = False, sync_dist: bool = False, - sync_dist_fn: Callable = Sync.no_op, + sync_dist_fn: Callable = _Sync.no_op, sync_dist_op: Union[Any, str] = 'mean', sync_dist_group: Optional[Any] = None, dataloader_idx: Optional[int] = None, @@ -339,7 +339,7 @@ def log( key += f'.{dataloader_idx}' fx += f'.{dataloader_idx}' - meta = Metadata( + meta = _Metadata( fx=fx, name=name, prog_bar=prog_bar, @@ -350,7 +350,7 @@ def log( enable_graph=enable_graph, dataloader_idx=dataloader_idx, metric_attribute=metric_attribute, - sync=Sync( + sync=_Sync( should=sync_dist, fn=sync_dist_fn, op=sync_dist_op, @@ -374,7 +374,7 @@ def log( self.update_metrics(key, value) self._current_fx = fx - def register_key(self, key: str, meta: Metadata, value: _METRIC_COLLECTION) -> None: + def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None: """Create one ResultMetric object per value. Value can be provided as a nested collection""" def fn(v: _METRIC) -> ResultMetric: From 7008fe962f0555b9a99aab86bf0f8302cb1f958d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 13:46:44 +0200 Subject: [PATCH 299/455] Add test --- .../trainer/connectors/logger_connector/result.py | 2 +- tests/trainer/logging_/test_train_loop_logging.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 6e929a2a4405e..c77e004f1a708 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -361,7 +361,7 @@ def log( self.register_key(key, meta, value) elif meta != self[key].meta: raise MisconfigurationException( - f'You called `self.log({key}, ...)` twice with different arguments. This is not allowed' + f'You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed' ) if self.should_reset_tensors(fx): diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index c6ef47d9e8162..becc2bdad24e7 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -828,3 +828,15 @@ def training_step(self, *args): model = TestModel() with pytest.raises(MisconfigurationException, match='You passed a tensor with `grad_fn`'): trainer.fit(model) + + class TestModel(BoringModel): + + def training_step(self, *args): + self.log('foo', -1, prog_bar=False) + self.log('foo', -1, prog_bar=True) + return super().training_step(*args) + + trainer = Trainer(default_root_dir=tmpdir) + model = TestModel() + with pytest.raises(MisconfigurationException, match=r'self.log\(foo, ...\)` twice in `training_step`'): + trainer.fit(model) From 6db7d668bc80b66c5377ce5d838ea1d280b225b7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 13:54:47 +0200 Subject: [PATCH 300/455] Update code after torchmetrics==0.3.2 --- .../trainer/connectors/logger_connector/result.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index c77e004f1a708..9e43ff811a36d 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -108,6 +108,8 @@ def __setattr__(self, key, value): def update(self, value: _METRIC, batch_size: torch.Tensor) -> None: if self.is_tensor: + value = value.float() + self._forward_cache = value # performance: no need to accumulate on values only logged on_step if self.meta.on_step and not self.meta.on_epoch: self.value = self.meta.sync(value) @@ -165,12 +167,6 @@ def reset(self) -> None: self.has_reset = True def forward(self, value: _METRIC, batch_size: torch.Tensor) -> None: - if self.is_tensor: - value = value.float() - self._forward_cache = value - else: - self._forward_cache = value._forward_cache - if self.meta.enable_graph: with torch.no_grad(): self.update(value, batch_size) From b4582e33145ec678469ba356613bf969034730f4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 14:00:40 +0200 Subject: [PATCH 301/455] Merge two ifs --- .../trainer/connectors/logger_connector/result.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 9e43ff811a36d..173f04f8687e2 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -118,10 +118,8 @@ def update(self, value: _METRIC, batch_size: torch.Tensor) -> None: if self.meta.is_mean_reduction: self.value += value.mean() * batch_size self.cumulated_batch_size += batch_size - elif self.meta.is_max_reduction: - self.value = max(self.value, value.mean()) - elif self.meta.is_min_reduction: - self.value = min(self.value, value.mean()) + elif self.meta.is_max_reduction or self.meta.is_min_reduction: + self.value = self.meta.reduce_fx(self.value, value.mean()) else: self.value = value # noqa: attribute-defined-outside-init self._forward_cache = value._forward_cache From fa91cdfc2dfbeb60ddaf03203b8bec6d8599ef16 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 14:05:24 +0200 Subject: [PATCH 302/455] Reorder code --- .../connectors/logger_connector/result.py | 62 +++++++++---------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 173f04f8687e2..21b973dd2dbe2 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -102,10 +102,6 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: if self.meta.is_mean_reduction: self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float)) - def __setattr__(self, key, value): - # performance: skip the `torch.nn.Module.__setattr__` checks - object.__setattr__(self, key, value) - def update(self, value: _METRIC, batch_size: torch.Tensor) -> None: if self.is_tensor: value = value.float() @@ -138,6 +134,21 @@ def compute(self) -> torch.Tensor: ) return self.value.compute() + def reset(self) -> None: + if self.is_tensor: + super().reset() + else: + self.value.reset() + self.has_reset = True + + def forward(self, value: _METRIC, batch_size: torch.Tensor) -> None: + if self.meta.enable_graph: + with torch.no_grad(): + self.update(value, batch_size) + else: + # performance: skip the `torch.no_grad` context manager by calling `update` directly + self.update(value, batch_size) + def _wrap_compute(self, compute: Any) -> Any: # Override to avoid syncing - we handle it ourselves. @wraps(compute) @@ -157,20 +168,9 @@ def wrapped_func(*args, **kwargs): return wrapped_func - def reset(self) -> None: - if self.is_tensor: - super().reset() - else: - self.value.reset() - self.has_reset = True - - def forward(self, value: _METRIC, batch_size: torch.Tensor) -> None: - if self.meta.enable_graph: - with torch.no_grad(): - self.update(value, batch_size) - else: - # performance: skip the `torch.no_grad` context manager by calling `update` directly - self.update(value, batch_size) + def __setattr__(self, key, value): + # performance: skip the `torch.nn.Module.__setattr__` checks + object.__setattr__(self, key, value) def __repr__(self) -> str: state = f"value={self.value}" @@ -407,7 +407,7 @@ def _get_cache(on_step: bool, result_metric: ResultMetric) -> Optional[torch.Ten return cache @staticmethod - def _to_item(t: torch.Tensor) -> float: + def __to_item(t: torch.Tensor) -> float: return t.item() def valid_items(self) -> Generator: @@ -459,7 +459,7 @@ def any_tensor(_): # populate progress_bar metrics. convert tensors to numbers if result_metric.meta.prog_bar: - value = apply_to_collection(value, torch.Tensor, self._to_item, include_none=False) + value = apply_to_collection(value, torch.Tensor, self.__to_item, include_none=False) metrics[MetricSource.PBAR][forked_name] = value return metrics @@ -470,17 +470,6 @@ def get_batch_metrics(self) -> Dict[str, _METRIC_COLLECTION]: def get_epoch_metrics(self) -> Dict[str, _METRIC_COLLECTION]: return self.get_metrics(on_step=False) - def to(self, *args, **kwargs) -> 'ResultCollection': - """Move all data to the given device.""" - for k, v in self.items(): - if isinstance(v, (torch.Tensor, Metric)): - self[k] = v.to(*args, **kwargs) - return self - - def cpu(self) -> 'ResultCollection': - """Move all data to CPU.""" - return self.to(device="cpu") - def _reset(self, fx: Optional[str] = None, metrics: Optional[bool] = None) -> None: def fn(item: ResultMetric) -> None: @@ -531,6 +520,17 @@ def _extract_batch_size(self, batch: Any) -> int: size = 1 return size + def to(self, *args, **kwargs) -> 'ResultCollection': + """Move all data to the given device.""" + for k, v in self.items(): + if isinstance(v, (torch.Tensor, Metric)): + self[k] = v.to(*args, **kwargs) + return self + + def cpu(self) -> 'ResultCollection': + """Move all data to CPU.""" + return self.to(device="cpu") + def __str__(self) -> str: return f'{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})' From 03f769c783d26987da632061802136954b792c20 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 14:15:55 +0200 Subject: [PATCH 303/455] Error for custom reductions --- pytorch_lightning/core/lightning.py | 8 ++++---- .../trainer/connectors/logger_connector/result.py | 9 +++++++++ tests/trainer/logging_/test_train_loop_logging.py | 11 +++++++++++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 8dcdee832c4d6..09dc3625dbe50 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -296,13 +296,13 @@ def log( "validation_epoch_end*", "F", "T", "F", "T" Args: - name: key name - value: value name + name: key to log + value: value to log prog_bar: if True logs to the progress bar logger: if True logs to the logger on_step: if True logs at this step. None auto-logs at the training_step but not validation/test_step on_epoch: if True logs epoch accumulated metrics. None auto-logs at the val/test step but not training_step - reduce_fx: reduction function over step values for end of epoch. Torch.mean by default + reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default. enable_graph: if True, will not auto detach the graph sync_dist: if True, reduces the metric across GPUs/TPUs sync_dist_op: the op to sync across GPUs/TPUs @@ -425,7 +425,7 @@ def log_dict( logger: if True logs to the logger on_step: if True logs at this step. None auto-logs for training_step but not validation/test_step on_epoch: if True logs epoch accumulated metrics. None auto-logs for val/test step but not training_step - reduce_fx: reduction function over step values for end of epoch. Torch.mean by default + reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default. enable_graph: if True, will not auto detach the graph sync_dist: if True, reduces the metric across GPUs/TPUs sync_dist_op: the op to sync across GPUs/TPUs diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 21b973dd2dbe2..58c6609c22b39 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -88,6 +88,10 @@ def is_max_reduction(self) -> bool: def is_min_reduction(self) -> bool: return self.reduce_fx in (torch.min, min) + @property + def is_custom_reduction(self) -> bool: + return not (self.is_mean_reduction or self.is_max_reduction or self.is_min_reduction) + class ResultMetric(Metric, DeviceDtypeModuleMixin): """Wraps the value provided to `:meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" @@ -352,6 +356,11 @@ def log( ) ) if key not in self: + if meta.is_custom_reduction: + raise MisconfigurationException( + 'Only `self.log(..., reduce_fx={min,max,mean})` are currently supported.' + ' Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`' + ) self.register_key(key, meta, value) elif meta != self[key].meta: raise MisconfigurationException( diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index becc2bdad24e7..54d6359a4a959 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -840,3 +840,14 @@ def training_step(self, *args): model = TestModel() with pytest.raises(MisconfigurationException, match=r'self.log\(foo, ...\)` twice in `training_step`'): trainer.fit(model) + + class TestModel(BoringModel): + + def training_step(self, *args): + self.log('foo', -1, reduce_fx=torch.argmax) + return super().training_step(*args) + + trainer = Trainer(default_root_dir=tmpdir) + model = TestModel() + with pytest.raises(MisconfigurationException, match=r'reduce_fx={min,max,mean}\)` are currently supported'): + trainer.fit(model) From 7cf8b8b41b6b339bcbdd5bae01bf9bec4129beba Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 14:19:05 +0200 Subject: [PATCH 304/455] improve reduce_fx test --- tests/trainer/logging_/test_train_loop_logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 54d6359a4a959..f7485f8eb0d86 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -193,7 +193,7 @@ def training_epoch_end(self, outputs): assert set(trainer.callback_metrics) == (logged_metrics | pbar_metrics | {'a', 'b'}) - {'epoch'} -@pytest.mark.parametrize(['batches', 'fx', 'result'], [(1, min, 0), (2, max, 1), (11, max, 10)]) +@pytest.mark.parametrize(['batches', 'fx', 'result'], [(3, min, 0), (3, max, 2), (11, max, 10)]) def test__training_step__log_max_reduce_fx(tmpdir, batches, fx, result): """ Tests that log works correctly with different tensor types From 23a628dfe155001c5f552e082147731397e8f627 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 15:31:30 +0200 Subject: [PATCH 305/455] Minor logger connector refactoring --- .../logger_connector/logger_connector.py | 24 +++++++------------ .../connectors/logger_connector/result.py | 2 +- .../logging_/test_eval_loop_logging.py | 9 ++----- 3 files changed, 12 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index ef63870e4b9e1..bf2ba932ba453 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -20,7 +20,7 @@ from pytorch_lightning.core import memory from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger -from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource +from pytorch_lightning.trainer.connectors.logger_connector.result import _METRIC, MetricSource from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import DeviceType from pytorch_lightning.utilities.metrics import metrics_to_scalars @@ -36,8 +36,8 @@ def __init__(self, trainer, log_gpu_memory: Optional[str] = None): self._val_log_step: int = 0 self._test_log_step: int = 0 self._progress_bar_metrics: Dict[str, float] = {} - self._logged_metrics: Dict[str, float] = {} - self._callback_metrics: Dict[str, float] = {} + self._logged_metrics: Dict[str, _METRIC] = {} + self._callback_metrics: Dict[str, _METRIC] = {} def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool): # logging @@ -149,10 +149,9 @@ def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: self._callback_metrics.update(metrics[MetricSource.CALLBACK]) if not self.trainer.sanity_checking: - # log all the metrics as a single dict metrics_to_log = metrics[MetricSource.LOG] - if len(metrics_to_log) > 0: + if metrics_to_log: self.log_metrics(metrics_to_log, {}) self.prepare_eval_loop_results() @@ -216,9 +215,8 @@ def update_evaluation_step_metrics(self) -> None: # logs user requested information to logger batch_log_metrics = metrics[MetricSource.LOG] - if len(batch_log_metrics) > 0: - kwargs = dict() if "step" in batch_log_metrics else dict(step=self.evaluation_log_step) - self.log_metrics(batch_log_metrics, {}, **kwargs) + if batch_log_metrics: + self.log_metrics(batch_log_metrics, {}, step=self.evaluation_log_step) # increment the step even if nothing was logged self.increment_evaluation_log_step() @@ -248,10 +246,8 @@ def update_train_step_metrics(self, batch_output): batch_log_metrics = metrics[MetricSource.LOG] if self.should_update_logs or self.trainer.fast_dev_run is True: # logs user requested information to logger - grad_norm_dict = batch_output.grad_norm_dict - if grad_norm_dict is None: - grad_norm_dict = {} - if len(batch_log_metrics) > 0 or len(grad_norm_dict) > 0: + grad_norm_dict = batch_output.grad_norm_dict or {} + if batch_log_metrics or grad_norm_dict: self.log_metrics(batch_log_metrics, grad_norm_dict) def on_train_epoch_end(self): @@ -267,9 +263,7 @@ def update_train_epoch_metrics(self) -> None: # add the metrics to the loggers epoch_log_metrics = metrics[MetricSource.LOG] - if epoch_log_metrics and len(epoch_log_metrics) > 0: - epoch_log_metrics["epoch"] = self.trainer.current_epoch - self._logged_metrics.update(epoch_log_metrics) + if epoch_log_metrics: self.log_metrics(epoch_log_metrics, {}) # reset result collection for next epoch diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 58c6609c22b39..652f8536058a6 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -434,7 +434,7 @@ def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, forked_name += dataloader_suffix return name, forked_name - def get_metrics(self, on_step: bool) -> Dict[str, _METRIC_COLLECTION]: + def get_metrics(self, on_step: bool) -> Dict[str, Dict[str, _METRIC]]: metrics = {k: {} for k in MetricSource} for key, result_metric in self.valid_items(): diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index de29a0f44cf5d..b7950c9290ae3 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -60,8 +60,7 @@ def validation_step(self, batch, batch_idx): ) trainer.fit(model) - # make sure all the metrics are available for callbacks - expected_logged_metrics = { + assert set(trainer.logged_metrics) == { 'a2', 'a_step', 'a_epoch', @@ -69,14 +68,10 @@ def validation_step(self, batch, batch_idx): 'b_epoch', 'epoch', } - logged_metrics = set(trainer.logged_metrics.keys()) - assert expected_logged_metrics == logged_metrics # we don't want to enable val metrics during steps because it is not something that users should do # on purpose DO NOT allow b_step... it's silly to monitor val step metrics - callback_metrics = set(trainer.callback_metrics.keys()) - expected_cb_metrics = {'a', 'a2', 'b', 'a_epoch', 'b_epoch', 'a_step'} - assert expected_cb_metrics == callback_metrics + assert set(trainer.callback_metrics) == {'a', 'a2', 'b', 'a_epoch', 'b_epoch', 'a_step'} def test__validation_step__epoch_end__log(tmpdir): From 13d57aff2f41238e9fb76b60de00ed0f1b48ed6b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 16:00:34 +0200 Subject: [PATCH 306/455] Refactor eval loop results --- .../logger_connector/logger_connector.py | 32 ++++++++----------- .../connectors/logger_connector/result.py | 2 +- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index bf2ba932ba453..79201c7ef629f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from copy import deepcopy from pprint import pprint from typing import Any, Dict, Iterable, Optional @@ -120,31 +119,25 @@ def evaluation_epoch_end(self): model_ref._current_dataloader_idx = None self.trainer.result_collection.on_epoch_end_reached = True - def add_to_eval_loop_results(self, dl_idx, has_been_initialized): + def prepare_eval_loop_results(self, metrics: Dict[str, _METRIC]) -> None: if self.trainer.sanity_checking: return - callback_metrics = self.trainer.result_collection.metrics[MetricSource.CALLBACK] - callback_metrics = deepcopy(callback_metrics) - for key in list(callback_metrics.keys()): - if "dataloader_idx" in key and f"dataloader_idx_{dl_idx}" not in key: - # remove callback metrics that don't belong to this dataloader - del callback_metrics[key] - if has_been_initialized: - self.eval_loop_results[dl_idx].update(callback_metrics) - else: - self.eval_loop_results.append(callback_metrics) - - def prepare_eval_loop_results(self): num_dataloaders = self.trainer.evaluation_loop.num_dataloaders has_been_initialized = len(self.eval_loop_results) == num_dataloaders for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders): - self.add_to_eval_loop_results(dl_idx, has_been_initialized) + # remove callback metrics that don't belong to this dataloader + callback_metrics = { + k: v + for k, v in metrics.items() if "dataloader_idx" not in k or f"dataloader_idx_{dl_idx}" in k + } + if has_been_initialized: + self.eval_loop_results[dl_idx].update(callback_metrics) + else: + self.eval_loop_results.append(callback_metrics) def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: - metrics = self.trainer.result_collection.metrics - - # update metrics + metrics = self.trainer.result_collection.get_metrics(False) self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) self._callback_metrics.update(metrics[MetricSource.CALLBACK]) @@ -154,7 +147,8 @@ def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: if metrics_to_log: self.log_metrics(metrics_to_log, {}) - self.prepare_eval_loop_results() + # FIXME: use self.callback_metrics instead? + self.prepare_eval_loop_results(metrics[MetricSource.CALLBACK]) # log results of evaluation if ( diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 652f8536058a6..f0c3c9ab3af0a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -434,7 +434,7 @@ def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, forked_name += dataloader_suffix return name, forked_name - def get_metrics(self, on_step: bool) -> Dict[str, Dict[str, _METRIC]]: + def get_metrics(self, on_step: bool) -> Dict[MetricSource, Dict[str, _METRIC]]: metrics = {k: {} for k in MetricSource} for key, result_metric in self.valid_items(): From fcbf408dd607687b70c42556d52cddaeec656b7d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 16:03:50 +0200 Subject: [PATCH 307/455] Formatting --- .../trainer/connectors/logger_connector/logger_connector.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 79201c7ef629f..bdb75e0ecf125 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -228,8 +228,6 @@ def on_train_split_start(self, batch_idx: int, split_batch: Any) -> None: def update_train_step_metrics(self, batch_output): metrics = self.trainer.result_collection.metrics - - # update metrics self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) self._callback_metrics.update(metrics[MetricSource.CALLBACK]) @@ -249,9 +247,7 @@ def on_train_epoch_end(self): self.trainer.result_collection.on_epoch_end_reached = True def update_train_epoch_metrics(self) -> None: - metrics = self.trainer.result_collection.metrics - self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) self._callback_metrics.update(metrics[MetricSource.CALLBACK]) From 574e713d2b40e48745593a7eae3b288a63ea3301 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 16:11:37 +0200 Subject: [PATCH 308/455] Typing --- .../trainer/connectors/logger_connector/logger_connector.py | 4 ++-- tests/trainer/logging_/test_train_loop_logging.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index bdb75e0ecf125..522bb9805ed04 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -264,14 +264,14 @@ def update_train_epoch_metrics(self) -> None: """ @property - def callback_metrics(self) -> Dict[str, float]: + def callback_metrics(self) -> Dict[str, _METRIC]: if self.trainer.result_collection: metrics = self.trainer.result_collection.metrics[MetricSource.CALLBACK] self._callback_metrics.update(metrics) return self._callback_metrics @property - def logged_metrics(self) -> Dict[str, float]: + def logged_metrics(self) -> Dict[str, _METRIC]: if self.trainer.result_collection: metrics = self.trainer.result_collection.metrics[MetricSource.LOG] self._logged_metrics.update(metrics) diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index f7485f8eb0d86..f570d72b29c93 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -502,7 +502,7 @@ def get_expected_output(func_attr, original_values): # get creation attr func_attr = test_callback.funcs_attr[func_name] - # retrived orginal logged values + # retrieved original logged values values = test_callback.callback_funcs_called[func_name] if len(values) > 0: original_values = values[len(values) - 1] From e74a8add814da810c949ba2ae5b021a52a7df425 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 16:19:37 +0200 Subject: [PATCH 309/455] Formatting --- .../trainer/connectors/logger_connector/logger_connector.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 522bb9805ed04..b831904831bb7 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -199,8 +199,6 @@ def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: def update_evaluation_step_metrics(self) -> None: metrics = self.trainer.result_collection.metrics - - # update metrics self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) self._callback_metrics.update(metrics[MetricSource.CALLBACK]) From 585e62c5020f4caab00d04ce0e434dba69bbbe63 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 17:00:28 +0200 Subject: [PATCH 310/455] Move back call --- pytorch_lightning/trainer/training_loop.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e5a339011d857..ea55186449abf 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -464,11 +464,6 @@ def run_training_epoch(self): if batch_output.signal == -1: break - # ----------------------------------------- - # SAVE METRICS TO LOGGERS AND PROGRESS_BAR - # ----------------------------------------- - self.trainer.logger_connector.update_train_step_metrics(batch_output) - # hook self.on_train_batch_end( epoch_output, @@ -478,6 +473,11 @@ def run_training_epoch(self): dataloader_idx, ) + # ----------------------------------------- + # SAVE METRICS TO LOGGERS AND PROGRESS_BAR + # ----------------------------------------- + self.trainer.logger_connector.update_train_step_metrics(batch_output) + # ----------------------------------------- # VALIDATE IF NEEDED # ----------------------------------------- From 7cce418482eaa3c50d6e515a2ec391d4e15f8f21 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Jun 2021 17:17:54 +0200 Subject: [PATCH 311/455] Fix FIXME --- .../trainer/connectors/logger_connector/logger_connector.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index b831904831bb7..1ee86bafc832b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -139,7 +139,6 @@ def prepare_eval_loop_results(self, metrics: Dict[str, _METRIC]) -> None: def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: metrics = self.trainer.result_collection.get_metrics(False) self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) - self._callback_metrics.update(metrics[MetricSource.CALLBACK]) if not self.trainer.sanity_checking: # log all the metrics as a single dict @@ -147,8 +146,7 @@ def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: if metrics_to_log: self.log_metrics(metrics_to_log, {}) - # FIXME: use self.callback_metrics instead? - self.prepare_eval_loop_results(metrics[MetricSource.CALLBACK]) + self.prepare_eval_loop_results(self.callback_metrics) # log results of evaluation if ( From a87c68f28b787e690648b63f021eeaa2b3e92ea2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 4 Jun 2021 17:11:03 +0200 Subject: [PATCH 312/455] integrate #7631 (logger connector refactor) integrate training loop changes from #7631 finish integration of #7631 --- pytorch_lightning/loops/batch_loop.py | 72 ++++++------------- .../dataloader/evaluation_dataloader_loop.py | 16 +++-- pytorch_lightning/loops/epoch_loop.py | 6 +- pytorch_lightning/loops/evaluation_loop.py | 23 ++---- pytorch_lightning/loops/training_loop.py | 70 +++++++++--------- 5 files changed, 80 insertions(+), 107 deletions(-) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 9aae8b70ff1f7..72290be36fbf5 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -11,7 +11,6 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loops.base import Loop from pytorch_lightning.plugins import ParallelPlugin -from pytorch_lightning.trainer.connectors.logger_connector.result import Result from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -56,7 +55,7 @@ def done(self): def run(self, batch, batch_idx, dataloader_idx): if batch is None: self.warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") - return AttributeDict(signal=0, grad_norm_dic={}, training_step_output_for_epoch_end=[[]]) + return AttributeDict(signal=0, grad_norm_dic={}, training_step_output=[[]]) # hook response = self.trainer.call_hook("on_batch_start") @@ -70,13 +69,11 @@ def run(self, batch, batch_idx, dataloader_idx): super().run(batch, batch_idx, dataloader_idx) - output = AttributeDict( + return AttributeDict( signal=0, - # TODO: Properly aggregate grad_norm accross opt_idx and split_idx grad_norm_dict=self.grad_norm_dicts[-1], - training_step_output_for_epoch_end=self.batch_outputs, + training_step_output=self.batch_outputs, ) - return output def reset(self) -> None: # self.iteration_count = 0 @@ -101,13 +98,13 @@ def advance(self, batch, batch_idx, dataloader_idx): for opt_idx, optimizer in self.get_active_optimizers(batch_idx): result = self._run_optimization(batch_idx, split_idx, split_batch, opt_idx, optimizer) if result: - self.batch_outputs[opt_idx].append(result.training_step_output_for_epoch_end) + self.batch_outputs[opt_idx].append(result.training_step_output) grad_norm_dict = result.get("grad_norm_dict", {}) else: # in manual optimization, there is no looping over optimizers result = self._run_optimization(batch_idx, split_idx, split_batch) if result: - self.batch_outputs[0].append(result.training_step_output_for_epoch_end) + self.batch_outputs[0].append(result.training_step_output) # TODO: Properly aggregate grad_norm accross opt_idx and split_idx self.grad_norm_dicts.append(grad_norm_dict) @@ -194,16 +191,11 @@ def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) - if not opt_closure_result: return - # cache metrics - self.trainer.logger_connector.cache_training_step_metrics(opt_closure_result) - # check if loss or model weights are nan if self.trainer.terminate_on_nan: self._check_finite(opt_closure_result.loss) def on_after_backward(self, training_step_output, batch_idx, untouched_loss): - training_step_output.detach() - # insert after step hook self.trainer.call_hook("on_after_backward") @@ -225,55 +217,36 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): # manually capture logged metrics model_ref._current_fx_name = 'training_step' - model_ref._results = Result() with self.trainer.profiler.profile("training_step"): training_step_output = self.trainer.accelerator.training_step(step_kwargs) self.trainer.accelerator.post_training_step() - self.trainer.logger_connector.cache_logged_metrics() - self._check_training_step_output(training_step_output) training_step_output = self.trainer.call_hook("training_step_end", training_step_output) - training_step_output_for_epoch_end, training_step_output = self._process_training_step_output( - training_step_output, split_batch - ) - if training_step_output_for_epoch_end is None: + training_step_output = self._process_training_step_output(training_step_output) + if training_step_output is None: return - # enable empty loss when using manual opt closure_loss = None - untouched_loss = None - + loss = None if self.trainer.lightning_module.automatic_optimization: # accumulate loss. if accumulate_grad_batches==1, no effect closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches - # the loss will get scaled for amp. avoid any modifications to it - untouched_loss = closure_loss.detach().clone() - - # result - result = AttributeDict( - closure_loss=closure_loss, - loss=untouched_loss, - training_step_output=training_step_output, # Result object - training_step_output_for_epoch_end=training_step_output_for_epoch_end, # Result object - ) - return result + loss = closure_loss.detach().clone() + return AttributeDict(closure_loss=closure_loss, loss=loss, training_step_output=training_step_output) - def _process_training_step_output(self, training_step_output, split_batch): - training_step_output_for_epoch_end = training_step_output + def _process_training_step_output(self, training_step_output): + if training_step_output is None: + return None - # enable validation_step return None - if training_step_output_for_epoch_end is None: - return None, None - - result = self.trainer.lightning_module._results + result = self.trainer.result_collection loss = None hiddens = None - result["extra"] = {} + result.extra = {} # handle dict return if isinstance(training_step_output, dict): @@ -281,7 +254,7 @@ def _process_training_step_output(self, training_step_output, split_batch): hiddens = training_step_output.pop("hiddens", None) if hiddens is not None: hiddens = hiddens.detach() - result["extra"] = training_step_output + result.extra = training_step_output # handle scalar return elif isinstance(training_step_output, torch.Tensor): @@ -291,16 +264,10 @@ def _process_training_step_output(self, training_step_output, split_batch): result.minimize = loss self._hiddens = hiddens - # track batch for manual reduction with result - result.track_batch_size(len(split_batch)) - - # track metrics without grads for epoch reduction - training_step_output_for_epoch_end = copy(result) - training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() if self.trainer.move_metrics_to_cpu: - training_step_output_for_epoch_end = training_step_output_for_epoch_end.cpu() + result = result.cpu() - return training_step_output_for_epoch_end, result + return result def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): model_ref = self.trainer.lightning_module @@ -415,7 +382,8 @@ def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): model.toggle_optimizer(optimizer, opt_idx) # use to track metrics internally - self.trainer.logger_connector.on_train_split_start(split_idx, opt_idx, split_batch) + # TODO: pass split_idx here? + self.trainer.logger_connector.on_train_split_start(batch_idx=split_idx, split_batch=split_batch) @contextmanager def block_ddp_sync_behaviour(self, should_block_sync: bool = False): diff --git a/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py b/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py index e83d2f43422df..70030756b39c5 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py @@ -4,7 +4,7 @@ from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop from pytorch_lightning.loops.evaluation_loop import EvaluationLoop -from pytorch_lightning.trainer.connectors.logger_connector.result import Result +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT @@ -19,6 +19,9 @@ def __init__(self): self.outputs = [] self.evaluation_loop = EvaluationLoop() + self.validation_results = ResultCollection(False) + self.test_results = ResultCollection(False) + @property def num_dataloaders(self) -> int: return self._get_num_dataloaders(self.dataloaders) @@ -147,6 +150,9 @@ def should_skip_evaluation(self, max_batches: List[Union[int, float]]) -> bool: def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end() + + self.trainer.logger_connector.on_evaluation_start() + if self.trainer.testing: self.trainer.call_hook('on_test_start', *args, **kwargs) else: @@ -167,6 +173,9 @@ def on_evaluation_model_train(self) -> None: model_ref.on_validation_model_train() def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: + assert self.trainer.result_collection is not None + self.trainer.result_collection.reset(metrics=True) + if self.trainer.testing: self.trainer.call_hook('on_test_end', *args, **kwargs) else: @@ -226,13 +235,10 @@ def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: model._current_fx_name = 'validation_epoch_end' model.validation_epoch_end(outputs) - # capture logging - self.trainer.logger_connector.cache_logged_metrics() - def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, dataloader_idx: int) -> None: # Add step predictions to prediction collection to write later if output is not None and self.predictions is not None: - if isinstance(output, Result) and self.trainer.testing: + if isinstance(output, ResultCollection) and self.trainer.testing: self.predictions.add(output.pop('predictions', None)) # track debug metrics diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 349a7da88e9f6..60fda5447bacf 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -10,6 +10,7 @@ from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.training_loop import TrainingLoop +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -35,6 +36,8 @@ def __init__(self, min_epochs, max_epochs, min_steps, max_steps): self.training_loop = TrainingLoop(min_steps, max_steps) + self.train_results = ResultCollection(True) + @property def current_epoch(self) -> int: return self.iteration_count @@ -127,6 +130,7 @@ def run(self): return super().run() def on_run_start(self): + self.trainer.logger_connector.on_train_start() self.trainer.call_hook("on_train_start") def on_advance_start(self): # equal to old on_train_epoch_start @@ -171,7 +175,7 @@ def advance(self): # TODO(@carmocca): deprecate and rename so users don't get confused self.global_step -= 1 # log epoch metrics - self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output) + self.trainer.logger_connector.update_train_epoch_metrics() self.global_step += 1 def on_advance_end(self): diff --git a/pytorch_lightning/loops/evaluation_loop.py b/pytorch_lightning/loops/evaluation_loop.py index e3883e337ddc1..cd3b0983f2793 100644 --- a/pytorch_lightning/loops/evaluation_loop.py +++ b/pytorch_lightning/loops/evaluation_loop.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Iterator, Optional, Union from pytorch_lightning.loops.base import Loop -from pytorch_lightning.trainer.connectors.logger_connector.result import Result +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -56,7 +56,7 @@ def advance(self, dataloader_iter, dataloader_idx, dl_max_batches, num_dataloade self.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx) # log batch metrics - self.trainer.logger_connector.log_evaluation_step_metrics() + self.trainer.logger_connector.update_evaluation_step_metrics() # track epoch level outputs self.outputs = self.trainer._track_output_for_epoch_end(self.outputs, output) @@ -73,24 +73,15 @@ def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Op # configure step_kwargs step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) - model_ref = self.trainer.lightning_module - model_ref._results = Result() - if self.trainer.testing: - model_ref._current_fx_name = "test_step" + self.trainer.lightning_module._current_fx_name = "test_step" with self.trainer.profiler.profile("test_step"): output = self.trainer.accelerator.test_step(step_kwargs) else: - model_ref._current_fx_name = "validation_step" + self.trainer.lightning_module._current_fx_name = "validation_step" with self.trainer.profiler.profile("validation_step"): output = self.trainer.accelerator.validation_step(step_kwargs) - # capture any logged information - self.trainer.logger_connector.cache_logged_metrics() - # track batch size for weighted average - if isinstance(output, Result): - output.track_batch_size(batch) - return output def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: @@ -101,8 +92,8 @@ def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT return output def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - # set dataloader_idx to model and track batch_size - self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self.num_dataloaders) + assert self.num_dataloaders is not None + self.trainer.logger_connector.on_evaluation_batch_start(batch, batch_idx, dataloader_idx, self.num_dataloaders) if self.trainer.testing: self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) @@ -127,7 +118,7 @@ def on_evaluation_batch_end( def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, dataloader_idx: int) -> None: # Add step predictions to prediction collection to write later if output is not None and self.predictions is not None: - if isinstance(output, Result) and self.trainer.testing: + if isinstance(output, ResultCollection) and self.trainer.testing: self.predictions.add(output.pop('predictions', None)) # track debug metrics diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index d5bfbf9055822..0e4fbde1ac097 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -3,7 +3,7 @@ import pytorch_lightning as pl from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.batch_loop import BatchLoop -from pytorch_lightning.trainer.connectors.logger_connector.result import Result +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature @@ -99,16 +99,16 @@ def advance(self, dataloader_iter: Iterator, **kwargs): # hook self.on_train_batch_end( self.epoch_output, - batch_output.training_step_output_for_epoch_end, + batch_output.training_step_output, batch, self.iteration_count, self._dataloader_idx, ) # ----------------------------------------- - # SAVE METRICS TO LOGGERS + # SAVE METRICS TO LOGGERS AND PROGRESS_BAR # ----------------------------------------- - self.trainer.logger_connector.log_train_step_metrics(batch_output) + self.trainer.logger_connector.update_train_step_metrics(batch_output) def on_advance_end(self): # ----------------------------------------- @@ -166,9 +166,6 @@ def on_run_end(self): 'HINT: remove the return statement in training_epoch_end' ) - # capture logging - self.trainer.logger_connector.cache_logged_metrics() - # call train epoch end hooks self._on_train_epoch_end_hook(processed_outputs) self.trainer.call_hook('on_epoch_end') @@ -187,8 +184,7 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: # This implementation is copied from Trainer.call_hook hook_name = "on_train_epoch_end" - # set hook_name to model + reset Result obj - skip = self.trainer._reset_result_and_set_fx_name(hook_name) + self.trainer.lightning_module._current_fx_name = hook_name # always profile hooks with self.trainer.profiler.profile(hook_name): @@ -218,8 +214,7 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: accelerator_hook = getattr(self.trainer.accelerator, hook_name) accelerator_hook() - if not skip: - self.trainer._cache_logged_metrics() + self.trainer.lightning_module._current_fx_name = None def _num_training_batches_reached(self, is_last_batch=False): return self.batches_seen == self.trainer.num_training_batches or is_last_batch @@ -241,49 +236,55 @@ def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, # figure out what to track for epoch end self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs) - # reset batch logger internals - self.trainer.logger_connector.on_train_batch_end() - def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): + hook_overridden = self._should_add_batch_output_to_epoch_output() + if not hook_overridden: + return # track the outputs to reduce at the end of the epoch for opt_idx, opt_outputs in enumerate(batch_end_outputs): - sample_output = opt_outputs[-1] - - # decide if we need to reduce at the end of the epoch automatically - auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end - hook_overridden = ( - is_overridden("training_epoch_end", model=self.trainer.lightning_module) - or is_overridden("on_train_epoch_end", model=self.trainer.lightning_module) - ) - - # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end - if not (hook_overridden or auto_reduce_tng_result): - continue - # with 1 step (no tbptt) don't use a sequence at epoch end - if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result): + if ( + isinstance(opt_outputs, list) and len(opt_outputs) == 1 + and not isinstance(opt_outputs[0], ResultCollection) + ): opt_outputs = opt_outputs[0] epoch_output[opt_idx].append(opt_outputs) + def _should_add_batch_output_to_epoch_output(self) -> bool: + # We add to the epoch outputs if + # 1. The model defines training_epoch_end OR + # 2. The model overrides on_train_epoch_end which has `outputs` in the signature + # TODO: in v1.5 this only needs to check if training_epoch_end is overridden + lightning_module = self.trainer.lightning_module + if is_overridden("training_epoch_end", model=lightning_module): + return True + + if is_overridden("on_train_epoch_end", model=lightning_module): + model_hook_fx = getattr(lightning_module, "on_train_epoch_end") + if is_param_in_hook_signature(model_hook_fx, "outputs"): + return True + + return False + @staticmethod def _prepare_outputs( - outputs: List[List[List[Result]]], + outputs: List[List[List['ResultCollection']]], batch_mode: bool, ) -> Union[List[List[List[Dict]]], List[List[Dict]], List[Dict], Dict]: """ Extract required information from batch or epoch end results. Args: - outputs: A 3-dimensional list of ``Result`` objects with dimensions: - [optimizer outs][batch outs][tbptt steps]. + outputs: A 3-dimensional list of ``ResultCollection`` objects with dimensions: + ``[optimizer outs][batch outs][tbptt steps]``. batch_mode: If True, ignore the batch output dimension. Returns: - The cleaned outputs with ``Result`` objects converted to dictionaries. All list dimensions of size one will - be collapsed. + The cleaned outputs with ``ResultCollection`` objects converted to dictionaries. + All list dimensions of size one will be collapsed. """ processed_outputs = [] for opt_outputs in outputs: @@ -299,6 +300,9 @@ def _prepare_outputs( for batch_outputs in opt_outputs: processed_tbptt_outputs = [] + if isinstance(batch_outputs, ResultCollection): + batch_outputs = [batch_outputs] + for tbptt_output in batch_outputs: out = tbptt_output.extra if tbptt_output.minimize is not None: From d180bb2a6ac6820d6d541e5e2b4d762ad87e352b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Jun 2021 14:23:06 +0200 Subject: [PATCH 313/455] call logger connector on_train_split_start at start of train split --- pytorch_lightning/trainer/training_loop.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ea55186449abf..90f9bd52c1545 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -632,6 +632,9 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): for split_idx, split_batch in enumerate(splits): self.split_idx = split_idx + # let logger connector extract batch size + self.trainer.logger_connector.on_train_split_start(batch_idx, split_batch) + if self.trainer.lightning_module.automatic_optimization: for opt_idx, optimizer in self.get_active_optimizers(batch_idx): result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer) @@ -654,8 +657,8 @@ def _run_optimization(self, batch_idx, split_batch, opt_idx=0, optimizer=None): # TODO: In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change # opt_idx=0 to opt_idx=None in the signature here - # toggle model params + set info to logger_connector - self.run_train_split_start(batch_idx, split_batch, opt_idx, optimizer) + # toggle model params + self.run_optimization_start(opt_idx, optimizer) result = AttributeDict() closure = self.make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens, result) @@ -905,16 +908,13 @@ def save_loggers_on_train_batch_end(self): if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() - def run_train_split_start(self, batch_idx, split_batch, opt_idx, optimizer): + def run_optimization_start(self, opt_idx, optimizer): # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1: model = self.trainer.lightning_module model.toggle_optimizer(optimizer, opt_idx) - # use to track metrics internally - self.trainer.logger_connector.on_train_split_start(batch_idx, split_batch) - def update_running_loss(self, current_loss: torch.Tensor) -> None: if self.trainer.lightning_module.automatic_optimization: # track total loss for logging (avoid mem leaks) From a386347aa9dda058b5ed6c3bb1a13ece83ad6cd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Jun 2021 14:28:51 +0200 Subject: [PATCH 314/455] integrate d180bb2 --- pytorch_lightning/loops/batch_loop.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 72290be36fbf5..61f76d98b9db6 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -90,19 +90,22 @@ def advance(self, batch, batch_idx, dataloader_idx): split_idx, split_batch = self._remaining_splits.pop(0) self.split_idx = split_idx + # let logger connector extract current batch size + self.trainer.logger_connector.on_train_split_start(batch_idx, split_batch) + # TODO: this list needs to go outside this loop # batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] grad_norm_dict = {} if self.trainer.lightning_module.automatic_optimization: for opt_idx, optimizer in self.get_active_optimizers(batch_idx): - result = self._run_optimization(batch_idx, split_idx, split_batch, opt_idx, optimizer) + result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer) if result: self.batch_outputs[opt_idx].append(result.training_step_output) grad_norm_dict = result.get("grad_norm_dict", {}) else: # in manual optimization, there is no looping over optimizers - result = self._run_optimization(batch_idx, split_idx, split_batch) + result = self._run_optimization(batch_idx, split_batch) if result: self.batch_outputs[0].append(result.training_step_output) @@ -123,12 +126,12 @@ def optimizer_freq_cumsum(self): self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) return self._optimizer_freq_cumsum - def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimizer=None): + def _run_optimization(self, batch_idx, split_batch, opt_idx=0, optimizer=None): # TODO: In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change # opt_idx=0 to opt_idx=None in the signature here - # toggle model params + set info to logger_connector - self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) + # toggle model params + self.run_optimization_start(opt_idx, optimizer) result = AttributeDict() closure = self.make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens, result) @@ -371,20 +374,13 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens): return args - def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): - # set split_idx to trainer for tracking - self.trainer.split_idx = split_idx - + def run_optimization_start(self, opt_idx, optimizer): # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1: model = self.trainer.lightning_module model.toggle_optimizer(optimizer, opt_idx) - # use to track metrics internally - # TODO: pass split_idx here? - self.trainer.logger_connector.on_train_split_start(batch_idx=split_idx, split_batch=split_batch) - @contextmanager def block_ddp_sync_behaviour(self, should_block_sync: bool = False): """ From 42210fc4831d2fd4dbac8ed102fec6a5e2e9f3a6 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Jun 2021 14:36:04 +0200 Subject: [PATCH 315/455] Minor changes --- pytorch_lightning/trainer/properties.py | 8 +++----- tests/trainer/logging_/test_train_loop_logging.py | 7 +++---- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 49029432f481e..a679e2f99db1f 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -274,8 +274,8 @@ def progress_bar_dict(self) -> dict: ref_model = cast(LightningModule, ref_model) standard_metrics = ref_model.get_progress_bar_dict() - logged_metrics = self.progress_bar_metrics - duplicates = list(standard_metrics.keys() & logged_metrics.keys()) + pbar_metrics = self.progress_bar_metrics + duplicates = list(standard_metrics.keys() & pbar_metrics.keys()) if duplicates: rank_zero_warn( f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and" @@ -283,9 +283,7 @@ def progress_bar_dict(self) -> dict: f" If this is undesired, change the name or override `get_progress_bar_dict()`" f" in `LightingModule`.", UserWarning ) - all_metrics = dict(**standard_metrics) - all_metrics.update(**logged_metrics) - return all_metrics + return {**standard_metrics, **pbar_metrics} @property def disable_validation(self) -> bool: diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index f570d72b29c93..30c995a061d7a 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -633,8 +633,6 @@ def training_step(self, *args): return super().training_step(*args) def on_train_epoch_end(self, *_): - self.on_train_epoch_end_called = True - self.epoch_end_called = True self.log( 'foo_2', torch.tensor(self.current_epoch), @@ -643,11 +641,12 @@ def on_train_epoch_end(self, *_): sync_dist=True, sync_dist_op='sum' ) + self.on_train_epoch_end_called = True def on_epoch_end(self): - self.epoch_end_called = True assert self.trainer.progress_bar_dict["foo"] == self.current_epoch assert self.trainer.progress_bar_dict["foo_2"] == self.current_epoch + self.on_epoch_end_called = True trainer = Trainer( default_root_dir=tmpdir, @@ -661,8 +660,8 @@ def on_epoch_end(self): ) model = TestModel() trainer.fit(model) - assert model.epoch_end_called assert model.on_train_epoch_end_called + assert model.on_epoch_end_called def test_logging_in_callbacks_with_log_function(tmpdir): From e76010a533ba0c6c66729035072688ad1aaf7edb Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Jun 2021 15:59:25 +0200 Subject: [PATCH 316/455] Refactor loop logic into logger connector --- pytorch_lightning/core/lightning.py | 6 + .../logger_connector/logger_connector.py | 108 ++++++++++-------- .../connectors/logger_connector/result.py | 71 ++---------- pytorch_lightning/trainer/evaluation_loop.py | 18 ++- pytorch_lightning/trainer/trainer.py | 8 +- pytorch_lightning/trainer/training_loop.py | 8 +- tests/core/test_metric_result_integration.py | 98 ++++++---------- .../trainer/logging_/test_logger_connector.py | 6 +- 8 files changed, 132 insertions(+), 191 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 09dc3625dbe50..05dbf71831b10 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -376,6 +376,10 @@ def log( value = apply_to_collection(value, numbers.Number, self.__to_float) + if self.trainer.logger_connector.should_reset_tensors(self._current_fx_name): + # when restarting an new epoch, reset the tensors + result_collection.reset(metrics=False, fx=self._current_fx_name) + result_collection.log( self._current_fx_name, name, @@ -395,6 +399,8 @@ def log( sync_dist_group=sync_dist_group, ) + self.trainer.logger_connector._current_fx = self._current_fx_name + def log_dict( self, dictionary: Dict[str, _METRIC_COLLECTION], diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 1ee86bafc832b..77cee7fe93124 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -37,6 +37,10 @@ def __init__(self, trainer, log_gpu_memory: Optional[str] = None): self._progress_bar_metrics: Dict[str, float] = {} self._logged_metrics: Dict[str, _METRIC] = {} self._callback_metrics: Dict[str, _METRIC] = {} + self._epoch_end_reached = False + self._current_fx: Optional[str] = None + # FIXME: use _epoch_end_reached? + self._batch_idx: Optional[int] = None def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool): # logging @@ -113,12 +117,6 @@ def log_metrics(self, metrics, grad_norm_dict, step=None): Evaluation metric updates """ - def evaluation_epoch_end(self): - # reset dataloader idx - model_ref = self.trainer.lightning_module - model_ref._current_dataloader_idx = None - self.trainer.result_collection.on_epoch_end_reached = True - def prepare_eval_loop_results(self, metrics: Dict[str, _METRIC]) -> None: if self.trainer.sanity_checking: return @@ -137,16 +135,16 @@ def prepare_eval_loop_results(self, metrics: Dict[str, _METRIC]) -> None: self.eval_loop_results.append(callback_metrics) def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: - metrics = self.trainer.result_collection.get_metrics(False) - self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) + assert self._epoch_end_reached + metrics = self.metrics if not self.trainer.sanity_checking: # log all the metrics as a single dict - metrics_to_log = metrics[MetricSource.LOG] - if metrics_to_log: - self.log_metrics(metrics_to_log, {}) + log_metrics = metrics[MetricSource.LOG] + if log_metrics: + self.log_metrics(log_metrics, {}) - self.prepare_eval_loop_results(self.callback_metrics) + self.prepare_eval_loop_results(metrics[MetricSource.CALLBACK]) # log results of evaluation if ( @@ -183,9 +181,6 @@ def increment_evaluation_log_step(self) -> None: elif self.trainer.state.stage is RunningStage.TESTING: self._test_log_step += 1 - def on_evaluation_start(self) -> None: - self.trainer.result_collection.device = self.trainer.lightning_module.device - def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: model = self.trainer.lightning_module # set dataloader_idx only if multiple ones @@ -193,20 +188,17 @@ def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: # track batch_size self.trainer.result_collection.extract_batch_size(batch) - self.trainer.result_collection.batch_idx = batch_idx + self._batch_idx = batch_idx def update_evaluation_step_metrics(self) -> None: - metrics = self.trainer.result_collection.metrics - self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) - self._callback_metrics.update(metrics[MetricSource.CALLBACK]) - if self.trainer.sanity_checking: return # logs user requested information to logger - batch_log_metrics = metrics[MetricSource.LOG] - if batch_log_metrics: - self.log_metrics(batch_log_metrics, {}, step=self.evaluation_log_step) + assert not self._epoch_end_reached + metrics = self.metrics[MetricSource.LOG] + if metrics: + self.log_metrics(metrics, {}, step=self.evaluation_log_step) # increment the step even if nothing was logged self.increment_evaluation_log_step() @@ -215,42 +207,29 @@ def update_evaluation_step_metrics(self) -> None: Train metric updates """ - def on_train_start(self) -> None: - self.trainer.result_collection.device = self.trainer.lightning_module.device - def on_train_split_start(self, batch_idx: int, split_batch: Any) -> None: self.trainer.result_collection.extract_batch_size(split_batch) - self.trainer.result_collection.batch_idx = batch_idx + self._batch_idx = batch_idx def update_train_step_metrics(self, batch_output): - metrics = self.trainer.result_collection.metrics - self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) - self._callback_metrics.update(metrics[MetricSource.CALLBACK]) - if self.trainer.train_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization: return # when metrics should be logged - batch_log_metrics = metrics[MetricSource.LOG] + assert not self._epoch_end_reached + metrics = self.metrics[MetricSource.LOG] if self.should_update_logs or self.trainer.fast_dev_run is True: # logs user requested information to logger grad_norm_dict = batch_output.grad_norm_dict or {} - if batch_log_metrics or grad_norm_dict: - self.log_metrics(batch_log_metrics, grad_norm_dict) - - def on_train_epoch_end(self): - # inform cached logger connector epoch finished - self.trainer.result_collection.on_epoch_end_reached = True + if metrics or grad_norm_dict: + self.log_metrics(metrics, grad_norm_dict) def update_train_epoch_metrics(self) -> None: - metrics = self.trainer.result_collection.metrics - self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) - self._callback_metrics.update(metrics[MetricSource.CALLBACK]) - # add the metrics to the loggers - epoch_log_metrics = metrics[MetricSource.LOG] - if epoch_log_metrics: - self.log_metrics(epoch_log_metrics, {}) + assert self._epoch_end_reached + metrics = self.metrics[MetricSource.LOG] + if metrics: + self.log_metrics(metrics, {}) # reset result collection for next epoch self.trainer.result_collection.reset(metrics=True) @@ -259,23 +238,56 @@ def update_train_epoch_metrics(self) -> None: Utilities and properties """ + def on_epoch_start(self) -> None: + self._epoch_end_reached = False + + def on_epoch_end(self) -> None: + assert self._epoch_end_reached + metrics = self.metrics + self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) + self._callback_metrics.update(metrics[MetricSource.CALLBACK]) + self._logged_metrics.update(metrics[MetricSource.LOG]) + self._batch_idx = None + + def on_batch_end(self) -> None: + assert not self._epoch_end_reached + metrics = self.metrics + self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) + self._callback_metrics.update(metrics[MetricSource.CALLBACK]) + self._logged_metrics.update(metrics[MetricSource.LOG]) + + def should_reset_tensors(self, fx: str) -> bool: + # reset tensor metrics only when the hook changed and reloading the dataloader + return self._current_fx != fx and self._batch_idx in (None, 0) + + def reset(self, metrics: Optional[bool] = None) -> None: + self.trainer.result_collection.reset(metrics=metrics) + self._batch_idx = None + self._current_fx = None + + @property + def metrics(self) -> Dict[MetricSource, Dict[str, _METRIC]]: + """This function returns either batch or epoch metrics depending on ``_epoch_end_reached``.""" + on_step = not self._epoch_end_reached + return self.trainer.result_collection.metrics(on_step) + @property def callback_metrics(self) -> Dict[str, _METRIC]: if self.trainer.result_collection: - metrics = self.trainer.result_collection.metrics[MetricSource.CALLBACK] + metrics = self.metrics[MetricSource.CALLBACK] self._callback_metrics.update(metrics) return self._callback_metrics @property def logged_metrics(self) -> Dict[str, _METRIC]: if self.trainer.result_collection: - metrics = self.trainer.result_collection.metrics[MetricSource.LOG] + metrics = self.metrics[MetricSource.LOG] self._logged_metrics.update(metrics) return self._logged_metrics @property def progress_bar_metrics(self) -> Dict[str, float]: if self.trainer.result_collection: - metrics = self.trainer.result_collection.metrics[MetricSource.PBAR] + metrics = self.metrics[MetricSource.PBAR] self._progress_bar_metrics.update(metrics) return self._progress_bar_metrics diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index f0c3c9ab3af0a..7f79c2a050805 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -211,15 +211,6 @@ class ResultCollection(dict): # arguments: fx, key, value, metadata result.log('training_step', 'acc', torch.tensor(...), on_step=True, on_epoch=True) result.log('validation_step', 'recall', torch.tensor(...), on_step=True, on_epoch=True) - - for epoch in epochs: - for batch_idx, batch in enumerate(dataloader): - # the batch_idx is used to reset the tensor metrics - result.batch_idx = batch_idx - result.log('training_step', 'acc', torch.tensor(...), on_step=True, on_epoch=True) - - result.on_epoch_end_reached = True # indicate epoch end has been reached - result.log('training_epoch_end', 'acc', torch.tensor(...), on_step=False, on_epoch=True)` """ DATALOADER_SUFFIX = "/dataloader_idx_{}" @@ -227,11 +218,8 @@ class ResultCollection(dict): def __init__(self, training: bool, device: Optional[torch.device] = None) -> None: super().__init__() self.training = training - self._on_epoch_end_reached = False self._minimize = None - self._current_fx: Optional[str] = None self._batch_size = torch.tensor(1, device=device) - self.batch_idx: Optional[int] = None self.device: Optional[torch.device] = device self.fx_validator = FxValidator() @@ -244,20 +232,6 @@ def batch_size(self) -> torch.Tensor: def batch_size(self, value: int) -> None: self._batch_size = torch.tensor(value, device=self.device) - @property - def on_epoch_end_reached(self) -> bool: - return self._on_epoch_end_reached - - @on_epoch_end_reached.setter - def on_epoch_end_reached(self, on_epoch_end_reached): - self._on_epoch_end_reached = on_epoch_end_reached - self.batch_idx = None - - @property - def metrics(self) -> Dict[str, _METRIC_COLLECTION]: - """This function returns either batch or epoch metrics depending on ``on_epoch_end_reached``.""" - return self.get_epoch_metrics() if self.on_epoch_end_reached else self.get_batch_metrics() - @property def minimize(self) -> Optional[torch.Tensor]: """ @@ -324,12 +298,6 @@ def log( if isinstance(value, torch.Tensor) and value.device.type == "xla": value = value.cpu() - if on_step and self.on_epoch_end_reached: - # `FxValidator` should avoid this ever happening. Either a bug there or a bug in the logic order. - raise RuntimeError( - "Logging `on_step` when `on_epoch_end_reached` isn't allowed. This shouldn't have happened." - ) - # storage key key = f"{fx}.{name}" # add dataloader_suffix to both key and fx @@ -367,15 +335,10 @@ def log( f'You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed' ) - if self.should_reset_tensors(fx): - # when restarting an new epoch, reset the tensors - self._reset(fx, metrics=False) - if batch_size is not None: self.batch_size = batch_size self.update_metrics(key, value) - self._current_fx = fx def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None: """Create one ResultMetric object per value. Value can be provided as a nested collection""" @@ -389,10 +352,6 @@ def fn(v: _METRIC) -> ResultMetric: value = ResultMetricCollection(value, metadata=meta) self[key] = value - def should_reset_tensors(self, fx: str) -> bool: - # reset tensor metrics only when the hook changed and reloading the dataloader - return self._current_fx != fx and self.batch_idx in (None, 0) - def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None: def fn(result_metric, v): @@ -434,7 +393,7 @@ def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, forked_name += dataloader_suffix return name, forked_name - def get_metrics(self, on_step: bool) -> Dict[MetricSource, Dict[str, _METRIC]]: + def metrics(self, on_step: bool) -> Dict[MetricSource, Dict[str, _METRIC]]: metrics = {k: {} for k in MetricSource} for key, result_metric in self.valid_items(): @@ -473,13 +432,16 @@ def any_tensor(_): return metrics - def get_batch_metrics(self) -> Dict[str, _METRIC_COLLECTION]: - return self.get_metrics(on_step=True) - - def get_epoch_metrics(self) -> Dict[str, _METRIC_COLLECTION]: - return self.get_metrics(on_step=False) + def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> None: + """ + Reset the result collection - def _reset(self, fx: Optional[str] = None, metrics: Optional[bool] = None) -> None: + Args: + metrics: If True, only ``torchmetrics.Metric`` results are reset, + if False, only ``torch.Tensors`` are reset, + if ``None``, both are. + fx: Function to reset + """ def fn(item: ResultMetric) -> None: requested_type = metrics is None or metrics ^ item.is_tensor @@ -489,19 +451,6 @@ def fn(item: ResultMetric) -> None: apply_to_collection(self, ResultMetric, fn) - def reset(self, metrics: Optional[bool] = None) -> None: - """ - Reset the result collection - - Args: - metrics: If True, only ``torchmetrics.Metric`` results are reset, - if False, only ``torch.Tensors`` are reset, - if ``None``, both are. - """ - self._reset(metrics=metrics) - self.on_epoch_end_reached = False - self._current_fx = None - def extract_batch_size(self, batch: Any) -> None: try: self.batch_size = self._extract_batch_size(batch) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index dba920167c6a0..72d3bf777d86d 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -80,7 +80,7 @@ def should_skip_evaluation(self, max_batches: List[Union[int, float]]) -> bool: def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end() - self.trainer.logger_connector.on_evaluation_start() + self.trainer.result_collection.device = self.trainer.lightning_module.device if self.trainer.testing: self.trainer.call_hook('on_test_start', *args, **kwargs) @@ -102,8 +102,7 @@ def on_evaluation_model_train(self) -> None: model_ref.on_validation_model_train() def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: - assert self.trainer.result_collection is not None - self.trainer.result_collection.reset(metrics=True) + self.trainer.logger_connector.reset(metrics=True) if self.trainer.testing: self.trainer.call_hook('on_test_end', *args, **kwargs) @@ -134,6 +133,7 @@ def setup(self, max_batches: List[Union[int, float]], dataloaders: List[DataLoad self.num_dataloaders = self._get_num_dataloaders(dataloaders) def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: + self.trainer.logger_connector.on_epoch_start() self.trainer.call_hook('on_epoch_start', *args, **kwargs) if self.trainer.testing: @@ -196,12 +196,15 @@ def _should_track_batch_outputs_for_epoch_end(self) -> bool: return is_overridden('validation_epoch_end', model=model) def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: - # unset dataloder_idx in model - self.trainer.logger_connector.evaluation_epoch_end() + # inform logger the batch loop has finished + self.trainer.logger_connector._epoch_end_reached = True # call the model epoch end model = self.trainer.lightning_module + # unset dataloader_idx in model + model._current_dataloader_idx = None + if self.trainer.testing: if is_overridden('test_epoch_end', model=model): model._current_fx_name = 'test_epoch_end' @@ -234,6 +237,10 @@ def on_evaluation_batch_end( else: self.trainer.call_hook('on_validation_batch_end', output, batch, batch_idx, dataloader_idx) + # FIXME: missing hook? + # self.trainer.call_hook('on_batch_end') + self.trainer.logger_connector.on_batch_end() + # store predicitons if do_write_predictions and track eval loss history self.store_predictions(output, batch_idx, dataloader_idx) @@ -250,3 +257,4 @@ def on_evaluation_epoch_end(self) -> None: hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" self.trainer.call_hook(hook_name) self.trainer.call_hook('on_epoch_end') + self.trainer.logger_connector.on_epoch_end() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7861bb8005ade..e2b01e5c813c4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1010,12 +1010,12 @@ def _run_evaluation(self) -> _EVALUATE_OUTPUT: # hook self.evaluation_loop.on_evaluation_epoch_end() - # log epoch metrics - eval_loop_results = self.logger_connector.get_evaluate_epoch_results() - # hook self.evaluation_loop.on_evaluation_end() + # log epoch metrics + eval_loop_results = self.logger_connector.get_evaluate_epoch_results() + # save predictions to disk self.evaluation_loop.predictions.to_disk() @@ -1114,7 +1114,7 @@ def _run_sanity_check(self, ref_model): self.on_sanity_check_end() # reset validation metrics - self.result_collection.reset() + self.logger_connector.reset() # reset the seed to what it was before sanity check # prevents sanity check to affect random sampling in training diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ea55186449abf..ea38462f0b110 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -100,8 +100,7 @@ def should_skip_training(self) -> bool: return should_by_max_steps or should_by_epoch or self.trainer.num_training_batches == 0 def on_train_start(self): - # hook - self.trainer.logger_connector.on_train_start() + self.trainer.result_collection.device = self.trainer.lightning_module.device self.trainer.call_hook("on_train_start") def on_train_end(self): @@ -169,6 +168,7 @@ def on_train_epoch_start(self, epoch): self.accumulated_loss = TensorRunningAccum(window_length=self.trainer.accumulate_grad_batches) # hook + self.trainer.logger_connector.on_epoch_start() self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") @@ -180,6 +180,7 @@ def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, # hook self.trainer.call_hook('on_train_batch_end', processed_batch_end_outputs, batch, batch_idx, dataloader_idx) self.trainer.call_hook('on_batch_end') + self.trainer.logger_connector.on_batch_end() # figure out what to track for epoch end self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs) @@ -533,7 +534,7 @@ def run_training_epoch(self): def on_train_epoch_end(self, epoch_output: List[List[List['ResultCollection']]]) -> None: # inform logger the batch loop has finished - self.trainer.logger_connector.on_train_epoch_end() + self.trainer.logger_connector._epoch_end_reached = True # prepare epoch output processed_epoch_output = TrainLoop._prepare_outputs(epoch_output, batch_mode=False) @@ -556,6 +557,7 @@ def on_train_epoch_end(self, epoch_output: List[List[List['ResultCollection']]]) # call train epoch end hooks self._on_train_epoch_end_hook(processed_epoch_output) self.trainer.call_hook('on_epoch_end') + self.trainer.logger_connector.on_epoch_end() def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: # We cannot rely on Trainer.call_hook because the signatures might be different across diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index b9c0f3c50caef..d34e548836c00 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -55,18 +55,11 @@ def _ddp_test_fn(rank, worldsize): metric_b = metric_b.to(f"cuda:{rank}") metric_c = metric_c.to(f"cuda:{rank}") - # dist_sync_on_step is False by default result = ResultCollection(True, torch.device(f"cuda:{rank}")) for _ in range(3): cumulative_sum = 0 - - result.on_epoch_end_reached = False - for i in range(5): - - result.batch_idx = i - metric_a(i) metric_b(i) metric_c(i) @@ -77,15 +70,10 @@ def _ddp_test_fn(rank, worldsize): result.log('h', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") result.log('h', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") - batch_log = result.get_batch_metrics()[MetricSource.LOG] - batch_expected = {"a_step": i, "c": i} - assert set(batch_log.keys()) == set(batch_expected.keys()) - for k in batch_expected.keys(): - assert batch_expected[k] == batch_log[k] - - result.on_epoch_end_reached = True + batch_log = result.metrics(True)[MetricSource.LOG] + assert batch_log == {"a_step": i, "c": i} - epoch_log = result.get_epoch_metrics()[MetricSource.LOG] + epoch_log = result.metrics(False)[MetricSource.LOG] result.reset() # assert metric state reset to default values @@ -93,11 +81,7 @@ def _ddp_test_fn(rank, worldsize): assert metric_b.x == metric_b._defaults['x'] assert metric_c.x == metric_c._defaults['x'] - epoch_expected = {"b": cumulative_sum * worldsize, "a_epoch": cumulative_sum * worldsize} - - assert set(epoch_log.keys()) == set(epoch_expected.keys()) - for k in epoch_expected.keys(): - assert epoch_expected[k] == epoch_log[k] + assert epoch_log == {"b": cumulative_sum * worldsize, "a_epoch": cumulative_sum * worldsize} @RunIf(skip_windows=True, min_gpus=2) @@ -118,14 +102,7 @@ def test_result_metric_integration(): for _ in range(3): cumulative_sum = 0 - - result.on_epoch_end_reached = False - for i in range(5): - - # need to set batch_idx - result.batch_idx = i - metric_a(i) metric_b(i) metric_c(i) @@ -136,15 +113,10 @@ def test_result_metric_integration(): result.log('h', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") result.log('h', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") - batch_log = result.get_batch_metrics()[MetricSource.LOG] - batch_expected = {"a_step": i, "c": i} - assert set(batch_log.keys()) == set(batch_expected.keys()) - for k in batch_expected.keys(): - assert batch_expected[k] == batch_log[k] - - result.on_epoch_end_reached = True + batch_log = result.metrics(True)[MetricSource.LOG] + assert batch_log == {"a_step": i, "c": i} - epoch_log = result.get_epoch_metrics()[MetricSource.LOG] + epoch_log = result.metrics(False)[MetricSource.LOG] result.reset() # assert metric state reset to default values @@ -152,11 +124,7 @@ def test_result_metric_integration(): assert metric_b.x == metric_b._defaults['x'] assert metric_c.x == metric_c._defaults['x'] - epoch_expected = {"b": cumulative_sum, "a_epoch": cumulative_sum} - - assert set(epoch_log.keys()) == set(epoch_expected.keys()) - for k in epoch_expected.keys(): - assert epoch_expected[k] == epoch_log[k] + assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum} assert str(result) == ( "ResultCollection(True, cpu, {" @@ -168,31 +136,29 @@ def test_result_metric_integration(): def test_result_collection_simple_loop(): - result = ResultCollection(True, torch.device("cpu")) - - result.log('a0', 'a', torch.tensor(0.), on_step=True, on_epoch=True) - result.log('a1', 'a', torch.tensor(0.), on_step=True, on_epoch=True) - + current_fx_name = None + batch_idx = None + + def lightning_log(fx, *args, **kwargs): + nonlocal current_fx_name + if current_fx_name != fx and batch_idx in (None, 0): + result.reset(metrics=False, fx=fx) + result.log(fx, *args, **kwargs) + current_fx_name = fx + + lightning_log('a0', 'a', torch.tensor(0.), on_step=True, on_epoch=True) + lightning_log('a1', 'a', torch.tensor(0.), on_step=True, on_epoch=True) for epoch in range(2): - - result.on_epoch_end_reached = False - - result.log('b0', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True) - result.log('b1', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True) - - for batch_idx, batch_size in enumerate(range(2)): - - result.batch_idx = batch_idx - - result.log('c0', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) - result.log('c1', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) - result.log('c2', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) - - result.on_epoch_end_reached = True - - result.log('d0', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) - result.log('d1', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) + lightning_log('b0', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True) + lightning_log('b1', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True) + for batch_idx in range(2): + lightning_log('c0', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) + lightning_log('c1', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) + lightning_log('c2', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) + batch_idx = None + lightning_log('d0', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) + lightning_log('d1', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) assert result['a0.a'].value == torch.tensor(0.) assert result['a0.a'].cumulated_batch_size == torch.tensor(1.) @@ -204,11 +170,11 @@ def test_result_collection_simple_loop(): assert result['b1.a'].value == torch.tensor(1.) + epoch assert result['b1.a'].cumulated_batch_size == torch.tensor(1.) - assert result['c0.a'].value == torch.tensor(4.) + epoch * (batch_size + 1) + assert result['c0.a'].value == torch.tensor(4.) + epoch * 2 assert result['c0.a'].cumulated_batch_size == torch.tensor(2.) - assert result['c1.a'].value == torch.tensor(4.) + epoch * (batch_size + 1) + assert result['c1.a'].value == torch.tensor(4.) + epoch * 2 assert result['c1.a'].cumulated_batch_size == torch.tensor(2.) - assert result['c2.a'].value == torch.tensor(4.) + epoch * (batch_size + 1) + assert result['c2.a'].value == torch.tensor(4.) + epoch * 2 assert result['c2.a'].cumulated_batch_size == torch.tensor(2.) assert result['d0.a'].value == torch.tensor(3.) + epoch diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index eae2dc3e34937..da1cb501c7af6 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -470,7 +470,7 @@ def test_result_collection_on_tensor_with_mean_reduction(): assert result_collection["training_step.loss_1_0_0"].value == sum(total_value) assert result_collection["training_step.loss_1_0_0"].cumulated_batch_size == sum(excepted_batches) - batch_metrics = result_collection.get_batch_metrics() + batch_metrics = result_collection.metrics(True) expected = { 'loss_1_1_0_step': torch.tensor([9.]), @@ -504,9 +504,7 @@ def test_result_collection_on_tensor_with_mean_reduction(): } assert batch_metrics[MetricSource.CALLBACK] == excepted - result_collection.on_epoch_end_reached = True - - epoch_metrics = result_collection.get_epoch_metrics() + epoch_metrics = result_collection.metrics(False) mean = (torch.tensor(excepted_values) * torch.tensor(excepted_batches)).sum() / sum(excepted_batches) From d65e152c5753dc68f4b328bde10630ae24524a35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Jun 2021 16:19:23 +0200 Subject: [PATCH 317/455] deletions --- pytorch_lightning/loops/cache.py | 24 -- .../loops/dataloader/__init__.py | 0 .../loops/dataloader/dataloader_loop.py | 37 --- .../dataloader/evaluation_dataloader_loop.py | 250 ------------------ .../dataloader/prediction_dataloader_loop.py | 144 ---------- pytorch_lightning/loops/evaluation_loop.py | 137 ---------- pytorch_lightning/loops/prediction_loop.py | 91 ------- pytorch_lightning/trainer/trainer.py | 135 +--------- 8 files changed, 1 insertion(+), 817 deletions(-) delete mode 100644 pytorch_lightning/loops/cache.py delete mode 100644 pytorch_lightning/loops/dataloader/__init__.py delete mode 100644 pytorch_lightning/loops/dataloader/dataloader_loop.py delete mode 100644 pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py delete mode 100644 pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py delete mode 100644 pytorch_lightning/loops/evaluation_loop.py delete mode 100644 pytorch_lightning/loops/prediction_loop.py diff --git a/pytorch_lightning/loops/cache.py b/pytorch_lightning/loops/cache.py deleted file mode 100644 index da0080de8744f..0000000000000 --- a/pytorch_lightning/loops/cache.py +++ /dev/null @@ -1,24 +0,0 @@ -# from typing import Tuple -# -# -# class Cache: -# -# def __init__(self): -# self._store = ... -# -# def add(self, obj: object, **tags): -# pass -# -# def merge(self, cache: "Cache"): -# pass -# -# def filter_by(self, tags: Tuple[str]): -# pass -# -# -# -# self.cache = Cache() -# self.cache.add("abc", result, batch_idx=, opt_idx=..) -# self.cache.add("abc", result, batch_idx=) -# -# self.cache.group_by("abc", ("batch_idx", "opt_idx")) diff --git a/pytorch_lightning/loops/dataloader/__init__.py b/pytorch_lightning/loops/dataloader/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/pytorch_lightning/loops/dataloader/dataloader_loop.py b/pytorch_lightning/loops/dataloader/dataloader_loop.py deleted file mode 100644 index c6449cb6aaeeb..0000000000000 --- a/pytorch_lightning/loops/dataloader/dataloader_loop.py +++ /dev/null @@ -1,37 +0,0 @@ -from abc import abstractmethod -from typing import Sequence - -from torch.utils.data import DataLoader - -from pytorch_lightning.loops.base import Loop - - -# TODO: Handle max_batches also in base class here -class DataLoaderLoop(Loop): - - def __init__(self): - super().__init__() - - @property - @abstractmethod - def dataloaders(self) -> Sequence[DataLoader]: - pass - - @property - def current_dataloader_idx(self) -> int: - return self.iteration_count - - @property - def current_dataloader(self) -> DataLoader: - return self.dataloaders[self.current_dataloader_idx] - - @property - def num_dataloaders(self) -> int: - return len(self.dataloaders) - - @property - def done(self) -> bool: - return self.current_dataloader_idx >= self.num_dataloaders - - def reset(self) -> None: - self.iteration_count = 0 diff --git a/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py b/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py deleted file mode 100644 index 70030756b39c5..0000000000000 --- a/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py +++ /dev/null @@ -1,250 +0,0 @@ -from typing import Any, List, Optional, Sequence, Tuple, Union - -from torch.utils.data.dataloader import DataLoader - -from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop -from pytorch_lightning.loops.evaluation_loop import EvaluationLoop -from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT - - -class EvaluationDataLoaderLoop(DataLoaderLoop): - - def __init__(self): - super().__init__() - self._dataloaders: Optional[Union[DataLoader, Sequence[DataLoader]]] = None - self._max_batches: Optional[Union[int, Sequence[int]]] = None - self.outputs = [] - self.evaluation_loop = EvaluationLoop() - - self.validation_results = ResultCollection(False) - self.test_results = ResultCollection(False) - - @property - def num_dataloaders(self) -> int: - return self._get_num_dataloaders(self.dataloaders) - - @property - def dataloaders(self) -> Sequence[DataLoader]: - return self._dataloaders - - @property - def predictions(self): - # TODO: fixme - return self.evaluation_loop.predictions - - def connect(self, trainer, *args, **kwargs) -> None: - super().connect(trainer, *args, **kwargs) - self.evaluation_loop.connect(trainer, *args, **kwargs) - - @property - def done(self) -> bool: - return (self.current_dataloader_idx >= len(self.dataloaders)) or self.should_skip_evaluation(self._max_batches) - - def reset(self) -> None: - self.iteration_count = 0 - - # prepare dataloaders - self._dataloaders, self._max_batches = self.get_eval_dataloaders(), self.get_max_batches() - # bookkeeping - self.outputs = [] - - if isinstance(self._max_batches, int): - self._max_batches = [self._max_batches] * len(self._dataloaders) - - def advance(self, *args: Any, **kwargs: Any) -> None: - dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) - dataloader_iter = enumerate(dataloader) - dl_max_batches = self._max_batches[self.current_dataloader_idx] - - dl_outputs = self.evaluation_loop.run( - dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders - ) - - # store batch level output per dataloader - if self.should_track_batch_outputs_for_epoch_end: - self.outputs.append(dl_outputs) - - def on_run_start(self, *args: Any, **kwargs: Any) -> None: - # hook - self.on_evaluation_start() - self.on_evaluation_epoch_start() - - def on_run_end(self) -> Any: - outputs = self.outputs - - # free memory - self.outputs = [] - - # with a single dataloader don't pass a 2D list - if len(outputs) > 0 and self.num_dataloaders == 1: - outputs = outputs[0] - - # lightning module method - self.evaluation_epoch_end(outputs) - - # hook - self.on_evaluation_epoch_end() - - # log epoch metrics - eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results() - - # hook - self.on_evaluation_end() - - return eval_loop_results - - -# ------------------------------------------------------------------------------------------------------------ -# HELPER --- TO BE CLEANED UP -# ------------------------------------------------------------------------------------------------------------ - - def get_max_batches(self): - # select dataloaders - if self.trainer.testing: - max_batches = self.trainer.num_test_batches - else: - if self.trainer.sanity_checking: - self.trainer.num_sanity_val_batches = [ - min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches - ] - max_batches = self.trainer.num_sanity_val_batches - else: - max_batches = self.trainer.num_val_batches - return max_batches - - def get_eval_dataloaders(self): - if self.trainer.testing: - return self.trainer.test_dataloaders - return self.trainer.val_dataloaders - - # TODO: remove this method, got split into two above - def get_evaluation_dataloaders(self) -> Tuple[Optional[List[DataLoader]], List[Union[int, float]]]: - model = self.trainer.lightning_module - - # select dataloaders - if self.trainer.testing: - self.trainer.reset_test_dataloader(model) - - dataloaders = self.trainer.test_dataloaders - max_batches = self.trainer.num_test_batches - else: - # val - if self.trainer.val_dataloaders is None or self.trainer.reload_dataloaders_every_epoch: - self.trainer.reset_val_dataloader(model) - if self.trainer.sanity_checking: - self.trainer.num_sanity_val_batches = [ - min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches - ] - max_batches = self.trainer.num_sanity_val_batches - else: - max_batches = self.trainer.num_val_batches - dataloaders = self.trainer.val_dataloaders - return dataloaders, max_batches - - # TODO: this is currently also used in the new and old TrainingLoop - def should_skip_evaluation(self, max_batches: List[Union[int, float]]) -> bool: - return sum(max_batches) == 0 - - def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: - self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end() - - self.trainer.logger_connector.on_evaluation_start() - - if self.trainer.testing: - self.trainer.call_hook('on_test_start', *args, **kwargs) - else: - self.trainer.call_hook('on_validation_start', *args, **kwargs) - - def on_evaluation_model_eval(self) -> None: - model_ref = self.trainer.lightning_module - if self.trainer.testing: - model_ref.on_test_model_eval() - else: - model_ref.on_validation_model_eval() - - def on_evaluation_model_train(self) -> None: - model_ref = self.trainer.lightning_module - if self.trainer.testing: - model_ref.on_test_model_train() - else: - model_ref.on_validation_model_train() - - def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: - assert self.trainer.result_collection is not None - self.trainer.result_collection.reset(metrics=True) - - if self.trainer.testing: - self.trainer.call_hook('on_test_end', *args, **kwargs) - else: - self.trainer.call_hook('on_validation_end', *args, **kwargs) - - if self.trainer.state.fn != TrainerFn.FITTING: - # summarize profile results - self.trainer.profiler.describe() - - def reload_evaluation_dataloaders(self) -> None: - model = self.trainer.lightning_module - if self.trainer.testing: - self.trainer.reset_test_dataloader(model) - else: - self.trainer.reset_val_dataloader(model) - - def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: - self.trainer.call_hook('on_epoch_start', *args, **kwargs) - - if self.trainer.testing: - self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) - else: - self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) - - def _get_num_dataloaders(self, dataloaders: Optional[List[DataLoader]]) -> int: - # case where user does: - # return dl1, dl2 - if dataloaders is not None: - length = len(dataloaders) - if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): - length = len(dataloaders[0]) - return length - else: - return 0 - - def _should_track_batch_outputs_for_epoch_end(self) -> bool: - model = self.trainer.lightning_module - if self.trainer.testing: - return is_overridden('test_epoch_end', model=model) - else: - return is_overridden('validation_epoch_end', model=model) - - def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: - # unset dataloder_idx in model - self.trainer.logger_connector.evaluation_epoch_end() - - # call the model epoch end - model = self.trainer.lightning_module - - if self.trainer.testing: - if is_overridden('test_epoch_end', model=model): - model._current_fx_name = 'test_epoch_end' - model.test_epoch_end(outputs) - - else: - if is_overridden('validation_epoch_end', model=model): - model._current_fx_name = 'validation_epoch_end' - model.validation_epoch_end(outputs) - - def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, dataloader_idx: int) -> None: - # Add step predictions to prediction collection to write later - if output is not None and self.predictions is not None: - if isinstance(output, ResultCollection) and self.trainer.testing: - self.predictions.add(output.pop('predictions', None)) - - # track debug metrics - self.trainer.dev_debugger.track_eval_loss_history(batch_idx, dataloader_idx, output) - - def on_evaluation_epoch_end(self) -> None: - hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" - self.trainer.call_hook(hook_name) - self.trainer.call_hook('on_epoch_end') diff --git a/pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py b/pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py deleted file mode 100644 index 42b086a7205c5..0000000000000 --- a/pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py +++ /dev/null @@ -1,144 +0,0 @@ -from typing import Any, List, Optional, Sequence - -import torch -from torch.utils.data import DataLoader - -from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop -from pytorch_lightning.loops.prediction_loop import PredictionLoop -from pytorch_lightning.plugins import DDPSpawnPlugin -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import _PREDICT_OUTPUT - - -class PredictionDataLoaderLoop(DataLoaderLoop): - - def __init__(self): - super().__init__() - self.prediction_loop = PredictionLoop() - self._return_predictions = False - self.predictions = None - self.epoch_batch_indices = None - self._dataloaders = None - self._max_batches = None - - @property - def return_predictions(self) -> bool: - return self._return_predictions - - @return_predictions.setter - def return_predictions(self, return_predictions: Optional[bool] = None) -> None: - # ``DDPSpawnPlugin`` plugins and derivate don't support return predictions. - is_ddp_spawn = isinstance(self.trainer.training_type_plugin, DDPSpawnPlugin) - if return_predictions and is_ddp_spawn: - raise MisconfigurationException( - "`return_predictions` should be set to `False` when using the `DDPSpawnPlugin` or children class. " - f"Found {return_predictions} with training_type_plugin {type(self.trainer.training_type_plugin)}." - ) - # For non ``DDPSpawnPlugin`` plugin, the `return_predictions` is True by default unless user decide otherwise. - self._return_predictions = not is_ddp_spawn if return_predictions is None else return_predictions - - @property - def num_dataloaders(self) -> int: - return self._get_num_dataloaders(self.dataloaders) - - @property - def dataloaders(self) -> Sequence[DataLoader]: - return self._dataloaders - - @property - def done(self) -> bool: - return (self.current_dataloader_idx >= len(self.dataloaders)) or self.should_skip_predict(self._max_batches) - - def connect(self, trainer, *args, **kwargs) -> None: - super().connect(trainer, *args, **kwargs) - self.prediction_loop.connect(trainer, *args, **kwargs) - - def reset(self) -> None: - super().reset() - self._dataloaders, self._max_batches = self.get_predict_dataloaders() - - # convert max_batches to list - if isinstance(self._max_batches, int): - self._max_batches = [self._max_batches] * len(self.dataloaders) - - self.predictions = [] - self.epoch_batch_indices = [] - - def on_run_start(self) -> None: - self.on_predict_start() - - def advance(self, *args, **kwargs) -> None: - dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) - dataloader_iter = enumerate(dataloader) - dl_max_batches = self._max_batches[self.current_dataloader_idx] - - dl_predictions, dl_batch_indices = self.prediction_loop.run( - dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders, self.return_predictions - ) - self.predictions.append(dl_predictions) - self.epoch_batch_indices.append(dl_batch_indices) - - def on_run_end(self): - results = self.on_predict_epoch_end() - self.on_predict_end() - return results - - -# ------------------------------------------------------------------------------------------------------------ -# HELPER --- TO BE CLEANED UP -# ------------------------------------------------------------------------------------------------------------ - - def get_predict_dataloaders(self): - self.trainer.reset_predict_dataloader(self.trainer.lightning_module) - - dataloaders = self.trainer.predict_dataloaders - max_batches = self.trainer.num_predict_batches - - return dataloaders, max_batches - - def should_skip_predict(self, max_batches): - return sum(max_batches) == 0 - - def on_predict_start(self) -> None: - # enable eval mode + no grads - self.on_predict_model_eval() - self.trainer.lightning_module.zero_grad() - self._previous_grad_status = torch.is_grad_enabled() - torch.set_grad_enabled(False) - - # hook - self.trainer.call_hook("on_predict_start") - self.trainer.call_hook("on_predict_epoch_start") - - def on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: - self.trainer.profiler.describe() - - results = self.predictions - - self.trainer.call_hook("on_predict_epoch_end", results) - - if self.return_predictions: - return results[0] if self.num_dataloaders == 1 else results - - def on_predict_end(self): - # clear memory. the predictions are extracted in `on_predict_epoch_end`. - self.predictions = [] - self.epoch_batch_indices = [] - - # reset grad to its previous status. - torch.set_grad_enabled(self._previous_grad_status) - - # hook - self.trainer.call_hook("on_predict_end") - - def on_predict_model_eval(self): - model_ref = self.trainer.lightning_module - model_ref.on_predict_model_eval() - - def _get_num_dataloaders(self, dataloaders: List[DataLoader]) -> int: - # case where user does: - # return dl1, dl2 - length = len(dataloaders) - if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): - length = len(dataloaders[0]) - return length diff --git a/pytorch_lightning/loops/evaluation_loop.py b/pytorch_lightning/loops/evaluation_loop.py deleted file mode 100644 index cd3b0983f2793..0000000000000 --- a/pytorch_lightning/loops/evaluation_loop.py +++ /dev/null @@ -1,137 +0,0 @@ -from collections import OrderedDict -from typing import Any, Dict, Iterator, Optional, Union - -from pytorch_lightning.loops.base import Loop -from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.supporters import PredictionCollection -from pytorch_lightning.utilities.types import STEP_OUTPUT - - -class EvaluationLoop(Loop): - - def __init__(self): - super().__init__() - self.predictions: Optional[PredictionCollection] = None - self.dataloader: Optional[Iterator] = None - self.dl_max_batches: Optional[int] = None - self.dataloader_idx: Optional[int] = None - self.num_dataloaders: Optional[int] = None - self.outputs = [] - - def connect(self, trainer, *args, **kwargs): - super().connect(trainer, *args, **kwargs) - - @property - def done(self) -> bool: - return self.iteration_count >= self.dl_max_batches - - def reset(self) -> None: - self.iteration_count = 0 - self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) - self.dl_max_batches = None - self.dataloader_idx = None - self.num_dataloaders = None - self.outputs = [] - - def on_run_start(self, dataloader_iter, dataloader_idx, dl_max_batches, num_dataloaders) -> None: - self.dl_max_batches = dl_max_batches - self.dataloader_idx = dataloader_idx - self.num_dataloaders = num_dataloaders - - def advance(self, dataloader_iter, dataloader_idx, dl_max_batches, num_dataloaders) -> None: - batch_idx, batch = next(dataloader_iter) - - if batch is None: - raise StopIteration - - # hook - self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) - - # lightning module methods - with self.trainer.profiler.profile("evaluation_step_and_end"): - output = self.evaluation_step(batch, batch_idx, dataloader_idx) - output = self.evaluation_step_end(output) - - # hook + store predictions - self.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx) - - # log batch metrics - self.trainer.logger_connector.update_evaluation_step_metrics() - - # track epoch level outputs - self.outputs = self.trainer._track_output_for_epoch_end(self.outputs, output) - - def on_run_end(self) -> Any: - return self.outputs - - -# ------------------------------------------------------------------------------------------------------------ -# HELPER --- TO BE CLEANED UP -# ------------------------------------------------------------------------------------------------------------ - - def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]: - # configure step_kwargs - step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) - - if self.trainer.testing: - self.trainer.lightning_module._current_fx_name = "test_step" - with self.trainer.profiler.profile("test_step"): - output = self.trainer.accelerator.test_step(step_kwargs) - else: - self.trainer.lightning_module._current_fx_name = "validation_step" - with self.trainer.profiler.profile("validation_step"): - output = self.trainer.accelerator.validation_step(step_kwargs) - - return output - - def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: - if self.trainer.testing: - output = self.trainer.call_hook('test_step_end', *args, **kwargs) - else: - output = self.trainer.call_hook('validation_step_end', *args, **kwargs) - return output - - def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - assert self.num_dataloaders is not None - self.trainer.logger_connector.on_evaluation_batch_start(batch, batch_idx, dataloader_idx, self.num_dataloaders) - - if self.trainer.testing: - self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) - else: - self.trainer.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx) - - def on_evaluation_batch_end( - self, - output: Optional[STEP_OUTPUT], - batch: Any, - batch_idx: int, - dataloader_idx: int, - ) -> None: - if self.trainer.testing: - self.trainer.call_hook('on_test_batch_end', output, batch, batch_idx, dataloader_idx) - else: - self.trainer.call_hook('on_validation_batch_end', output, batch, batch_idx, dataloader_idx) - - # store predicitons if do_write_predictions and track eval loss history - self.store_predictions(output, batch_idx, dataloader_idx) - - def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, dataloader_idx: int) -> None: - # Add step predictions to prediction collection to write later - if output is not None and self.predictions is not None: - if isinstance(output, ResultCollection) and self.trainer.testing: - self.predictions.add(output.pop('predictions', None)) - - # track debug metrics - self.trainer.dev_debugger.track_eval_loss_history(batch_idx, dataloader_idx, output) - - def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Union[Any, int]]: - # make dataloader_idx arg in validation_step optional - step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)]) - - multiple_val_loaders = (not self.trainer.testing and self.num_dataloaders > 1) - multiple_test_loaders = (self.trainer.testing and self.num_dataloaders > 1) - - if multiple_test_loaders or multiple_val_loaders: - step_kwargs['dataloader_idx'] = dataloader_idx - - return step_kwargs diff --git a/pytorch_lightning/loops/prediction_loop.py b/pytorch_lightning/loops/prediction_loop.py deleted file mode 100644 index fe51cfd3e92d5..0000000000000 --- a/pytorch_lightning/loops/prediction_loop.py +++ /dev/null @@ -1,91 +0,0 @@ -from collections import OrderedDict -from typing import Any, List - -from pytorch_lightning.loops.base import Loop -from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper -from pytorch_lightning.utilities.warnings import WarningCache - - -class PredictionLoop(Loop): - - def __init__(self): - super().__init__() - self.warning_cache = WarningCache() - self.dl_max_batches = None - self.num_dataloaders = None - self.return_predictions = False - self.predictions: List[Any] = [] - self.current_batch_indices: [List[int]] = [] - self.all_batch_indices: [List[int]] = [] - - @property - def should_store_predictions(self) -> bool: - any_pred = any(cb.interval.on_epoch for cb in self.trainer.prediction_writer_callbacks) - return self.return_predictions or any_pred - - @property - def done(self) -> bool: - return self.iteration_count >= self.dl_max_batches - - def reset(self) -> None: - self.iteration_count = 0 - self.all_batch_indices: List[int] = [] - self.predictions: List[Any] = [] - - def on_run_start( - self, dataloader_iter, dataloader_idx, dl_max_batches, num_dataloaders, return_predictions=False - ) -> None: - self.dl_max_batches = dl_max_batches - self.num_dataloaders = num_dataloaders - self.return_predictions = return_predictions - - def advance(self, dataloader_iter, dataloader_idx, dl_max_batches, *args, **kwargs) -> None: - batch_idx, batch = next(dataloader_iter) - if batch is None: - raise StopIteration - - with self.trainer.profiler.profile("predict_step"): - self.predict_step(batch, batch_idx, dataloader_idx) - - def on_run_end(self) -> Any: - return self.predictions, self.all_batch_indices - - -# ------------------------------------------------------------------------------------------------------------ -# HELPER --- TO BE CLEANED UP -# ------------------------------------------------------------------------------------------------------------ - - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - # configure step_kwargs - step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) - - # extract batch_indices and store them - self._store_batch_indices(dataloader_idx) - - model_ref = self.trainer.lightning_module - - self.trainer.call_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx) - - model_ref._current_fx_name = "predict_step" - predictions = self.trainer.accelerator.predict_step(step_kwargs) - - if predictions is None: - self.warning_cache.warn("predict returned None if it was on purpose, ignore this warning...") - - self.trainer.call_hook("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx) - - if self.should_store_predictions: - self.predictions.append(predictions) - - def _build_kwargs(self, batch, batch_idx, dataloader_idx): - step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)]) - if self.num_dataloaders: - step_kwargs['dataloader_idx'] = dataloader_idx - return step_kwargs - - def _store_batch_indices(self, dataloader_idx: int) -> None: - batch_sampler = self.trainer.predict_dataloaders[dataloader_idx].batch_sampler - if isinstance(batch_sampler, IndexBatchSamplerWrapper): - self.current_batch_indices = batch_sampler.batch_indices - if self.should_store_predictions: - self.all_batch_indices.append(batch_sampler.batch_indices) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7e890db4332ec..97e6a4cb0995e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -59,7 +59,6 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin -from pytorch_lightning.trainer.predict_loop import PredictLoop from pytorch_lightning.trainer.properties import TrainerProperties from pytorch_lightning.trainer.states import TrainerFn, TrainerState, TrainerStatus from pytorch_lightning.trainer.training_loop import TrainLoop @@ -83,12 +82,6 @@ NEW_LOOP = True -if NEW_LOOP: - from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationDataLoaderLoop - from pytorch_lightning.loops.dataloader.prediction_dataloader_loop import PredictionDataLoaderLoop -else: - from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop - class Trainer( TrainerProperties, @@ -346,16 +339,10 @@ def __init__( if NEW_LOOP: self.train_loop = EpochLoop(min_epochs, max_epochs, min_steps, max_steps) - self.evaluation_loop = EvaluationDataLoaderLoop() - self.predict_loop = PredictionDataLoaderLoop() self.train_loop.connect(self) - self.evaluation_loop.connect(self) - self.predict_loop.connect(self) else: # old loops: self.train_loop = TrainLoop(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps) - self.evaluation_loop = EvaluationLoop(self) - self.predict_loop = PredictLoop(self) # training state if weights_summary is not None and weights_summary not in ModelSummary.MODES: @@ -1055,88 +1042,7 @@ def _run_train_old_loop(self) -> None: self.state.stage = None raise - def _run_evaluatin_old_loop(self) -> _EVALUATE_OUTPUT: - # prepare dataloaders - dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders() - - # check if we want to skip this evaluation - if self.evaluation_loop.should_skip_evaluation(max_batches): - return [], [] - - # enable eval mode + no grads - self.evaluation_loop.on_evaluation_model_eval() - # ref model - model = self.lightning_module - model.zero_grad() - torch.set_grad_enabled(False) - - # hook - self.evaluation_loop.on_evaluation_start() - - # set up the eval loop - self.evaluation_loop.setup(max_batches, dataloaders) - - # hook - self.evaluation_loop.on_evaluation_epoch_start() - - # run validation/testing - for dataloader_idx, dataloader in enumerate(dataloaders): - # bookkeeping - dl_outputs = [] - dataloader = self.accelerator.process_dataloader(dataloader) - dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx] - - for batch_idx, batch in enumerate(dataloader): - if batch is None: - continue - - # stop short when running on limited batches - if batch_idx >= dl_max_batches: - break - - # hook - self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) - - # lightning module methods - with self.profiler.profile("evaluation_step_and_end"): - output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx) - output = self.evaluation_loop.evaluation_step_end(output) - - # hook + store predictions - self.evaluation_loop.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx) - - # log batch metrics - self.logger_connector.update_evaluation_step_metrics() - - # track epoch level outputs - dl_outputs = self._track_output_for_epoch_end(dl_outputs, output) - - # store batch level output per dataloader - if self.evaluation_loop.should_track_batch_outputs_for_epoch_end: - self.evaluation_loop.outputs.append(dl_outputs) - - outputs = self.evaluation_loop.outputs - # reset outputs - self.evaluation_loop.outputs = [] - - # with a single dataloader don't pass a 2D list - if len(outputs) > 0 and self.evaluation_loop.num_dataloaders == 1: - outputs = outputs[0] - - # lightning module method - self.evaluation_loop.evaluation_epoch_end(outputs) - - # hook - self.evaluation_loop.on_evaluation_epoch_end() - - # log epoch metrics - eval_loop_results = self.logger_connector.get_evaluate_epoch_results() - - # hook - self.evaluation_loop.on_evaluation_end() - - return eval_loop_results def _run_evaluation(self) -> _EVALUATE_OUTPUT: if not (self.evaluating or self.sanity_checking): @@ -1212,49 +1118,10 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT: return eval_loop_results - def _run_predict_old_loop(self) -> Optional[_PREDICT_OUTPUT]: - # prepare dataloaders - dataloaders, max_batches = self.predict_loop.get_predict_dataloaders() - - # check if we want to skip this evaluation - if self.predict_loop.should_skip_predict(max_batches): - return [] - # set up the eval loop - self.predict_loop.setup(max_batches, dataloaders) - - # call hook - self.predict_loop.on_predict_start() - - # run validation/testing - for dataloader_idx, dataloader in enumerate(dataloaders): - dataloader = self.accelerator.process_dataloader(dataloader) - dl_max_batches = self.predict_loop.max_batches[dataloader_idx] - for batch_idx, batch in enumerate(dataloader): - if batch is None: - continue - - # stop short when running on limited batches - if batch_idx >= dl_max_batches: - break - - # lightning module methods - with self.profiler.profile("predict_step"): - self.predict_loop.predict_step(batch, batch_idx, dataloader_idx) - - # call hook - results = self.predict_loop.on_predict_epoch_end() - - # call hook - self.predict_loop.on_predict_end() - - return results def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: - if NEW_LOOP: - return self.predict_loop.run() - else: - return self._run_predict_old_loop() + pass def _run_sanity_check(self, ref_model): using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model) From 4a1218c74a863a34704bdf2ea6be92b876bae025 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Jun 2021 16:21:19 +0200 Subject: [PATCH 318/455] restore predict --- pytorch_lightning/trainer/trainer.py | 37 ++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 97e6a4cb0995e..7826d37016466 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1118,10 +1118,43 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT: return eval_loop_results + def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: + # prepare dataloaders + dataloaders, max_batches = self.predict_loop.get_predict_dataloaders() + # check if we want to skip this evaluation + if self.predict_loop.should_skip_predict(max_batches): + return [] - def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: - pass + # set up the eval loop + self.predict_loop.setup(max_batches, dataloaders) + + # call hook + self.predict_loop.on_predict_start() + + # run validation/testing + for dataloader_idx, dataloader in enumerate(dataloaders): + dataloader = self.accelerator.process_dataloader(dataloader) + dl_max_batches = self.predict_loop.max_batches[dataloader_idx] + for batch_idx, batch in enumerate(dataloader): + if batch is None: + continue + + # stop short when running on limited batches + if batch_idx >= dl_max_batches: + break + + # lightning module methods + with self.profiler.profile("predict_step"): + self.predict_loop.predict_step(batch, batch_idx, dataloader_idx) + + # call hook + results = self.predict_loop.on_predict_epoch_end() + + # call hook + self.predict_loop.on_predict_end() + + return results def _run_sanity_check(self, ref_model): using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model) From 6acd71cb29e695d5d23cb62da3708692d5aaaa97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Jun 2021 16:22:42 +0200 Subject: [PATCH 319/455] remove evaluation --- pytorch_lightning/trainer/trainer.py | 108 ++++++++++++++++++++------- 1 file changed, 82 insertions(+), 26 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7826d37016466..fe30d0a172b90 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1042,8 +1042,6 @@ def _run_train_old_loop(self) -> None: self.state.stage = None raise - - def _run_evaluation(self) -> _EVALUATE_OUTPUT: if not (self.evaluating or self.sanity_checking): rank_zero_warn( @@ -1052,29 +1050,85 @@ def _run_evaluation(self) -> _EVALUATE_OUTPUT: ) self.validating = True - if NEW_LOOP: - # # TODO: move this check inside new loop - # # prepare dataloaders - - dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders() - - # max_batches = self.evaluation_loop.get_max_batches() - # - # # TODO: move this check inside new loop - # # check if we want to skip this evaluation - if self.evaluation_loop.should_skip_evaluation(max_batches): - return [], [] - - # enable eval mode + no grads - self.evaluation_loop.on_evaluation_model_eval() - # ref model - model = self.lightning_module - model.zero_grad() - torch.set_grad_enabled(False) - - eval_loop_results = self.evaluation_loop.run() - else: - eval_loop_results = self._run_evaluatin_old_loop() + # prepare dataloaders + dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders() + + # check if we want to skip this evaluation + if self.evaluation_loop.should_skip_evaluation(max_batches): + return [], [] + + # enable eval mode + no grads + self.evaluation_loop.on_evaluation_model_eval() + # ref model + model = self.lightning_module + model.zero_grad() + torch.set_grad_enabled(False) + + # hook + self.evaluation_loop.on_evaluation_start() + + # set up the eval loop + self.evaluation_loop.setup(max_batches, dataloaders) + + # hook + self.evaluation_loop.on_evaluation_epoch_start() + + # run validation/testing + for dataloader_idx, dataloader in enumerate(dataloaders): + # bookkeeping + dl_outputs = [] + dataloader = self.accelerator.process_dataloader(dataloader) + dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx] + + for batch_idx, batch in enumerate(dataloader): + if batch is None: + continue + + # stop short when running on limited batches + if batch_idx >= dl_max_batches: + break + + # hook + self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) + + # lightning module methods + with self.profiler.profile("evaluation_step_and_end"): + output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx) + output = self.evaluation_loop.evaluation_step_end(output) + + # hook + store predictions + self.evaluation_loop.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx) + + # log batch metrics + self.logger_connector.log_evaluation_step_metrics() + + # track epoch level outputs + dl_outputs = self._track_output_for_epoch_end(dl_outputs, output) + + # store batch level output per dataloader + if self.evaluation_loop.should_track_batch_outputs_for_epoch_end: + self.evaluation_loop.outputs.append(dl_outputs) + + outputs = self.evaluation_loop.outputs + + # reset outputs + self.evaluation_loop.outputs = [] + + # with a single dataloader don't pass a 2D list + if len(outputs) > 0 and self.evaluation_loop.num_dataloaders == 1: + outputs = outputs[0] + + # lightning module method + self.evaluation_loop.evaluation_epoch_end(outputs) + + # hook + self.evaluation_loop.on_evaluation_epoch_end() + + # log epoch metrics + eval_loop_results = self.logger_connector.get_evaluate_epoch_results() + + # hook + self.evaluation_loop.on_evaluation_end() # save predictions to disk self.evaluation_loop.predictions.to_disk() @@ -1082,11 +1136,13 @@ def _run_evaluation(self) -> _EVALUATE_OUTPUT: # enable train mode again self.evaluation_loop.on_evaluation_model_train() + # reset cached results + self.logger_connector.reset() + torch.set_grad_enabled(True) return eval_loop_results - # TODO: move inside evaluation loop def _track_output_for_epoch_end(self, outputs, output): if output is not None: if isinstance(output, ResultCollection): From 0778b366424215d323af8326807fa095029c8e65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Jun 2021 16:29:14 +0200 Subject: [PATCH 320/455] x --- pytorch_lightning/trainer/trainer.py | 1 - tests/trainer/loops/test_evaluation_loop.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fe30d0a172b90..232b01a665d71 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -29,7 +29,6 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.loops.dataloader.prediction_dataloader_loop import PredictionDataLoaderLoop from pytorch_lightning.loops.epoch_loop import EpochLoop from pytorch_lightning.plugins import Plugin from pytorch_lightning.plugins.environments import ClusterEnvironment diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py index 1b94923dd2ddf..073d8db45e548 100644 --- a/tests/trainer/loops/test_evaluation_loop.py +++ b/tests/trainer/loops/test_evaluation_loop.py @@ -22,7 +22,7 @@ @mock.patch( - "pytorch_lightning.loops.dataloader.evaluation_dataloader_loop.EvaluationDataLoaderLoop.on_evaluation_epoch_end" + "pytorch_lightning.trainer.evaluation_loop.EvaluationLoop.on_evaluation_epoch_end" ) def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir): """ From 0ef000fc442cc946af9460be59309913e316eb2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Jun 2021 16:31:24 +0200 Subject: [PATCH 321/455] loop --- pytorch_lightning/trainer/trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 232b01a665d71..b5d65203837b4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -55,9 +55,11 @@ from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin from pytorch_lightning.trainer.deprecated_api import DeprecatedTrainerAttributes +from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin +from pytorch_lightning.trainer.predict_loop import PredictLoop from pytorch_lightning.trainer.properties import TrainerProperties from pytorch_lightning.trainer.states import TrainerFn, TrainerState, TrainerStatus from pytorch_lightning.trainer.training_loop import TrainLoop @@ -343,6 +345,9 @@ def __init__( # old loops: self.train_loop = TrainLoop(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps) + self.evaluation_loop = EvaluationLoop(self) + self.predict_loop = PredictLoop(self) + # training state if weights_summary is not None and weights_summary not in ModelSummary.MODES: raise MisconfigurationException( From d421b56b5997fbb383f0caa6de9f5b1dd5a6dbff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Jun 2021 16:40:02 +0200 Subject: [PATCH 322/455] logger --- pytorch_lightning/trainer/trainer.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b5d65203837b4..2ea5655214ab3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1104,7 +1104,7 @@ def _run_evaluation(self) -> _EVALUATE_OUTPUT: self.evaluation_loop.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx) # log batch metrics - self.logger_connector.log_evaluation_step_metrics() + self.logger_connector.update_evaluation_step_metrics() # track epoch level outputs dl_outputs = self._track_output_for_epoch_end(dl_outputs, output) @@ -1128,21 +1128,18 @@ def _run_evaluation(self) -> _EVALUATE_OUTPUT: # hook self.evaluation_loop.on_evaluation_epoch_end() - # log epoch metrics - eval_loop_results = self.logger_connector.get_evaluate_epoch_results() - # hook self.evaluation_loop.on_evaluation_end() + # log epoch metrics + eval_loop_results = self.logger_connector.get_evaluate_epoch_results() + # save predictions to disk self.evaluation_loop.predictions.to_disk() # enable train mode again self.evaluation_loop.on_evaluation_model_train() - # reset cached results - self.logger_connector.reset() - torch.set_grad_enabled(True) return eval_loop_results From fa9ff739795f56feb8dbfec9c74e8c0f21bda838 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Jun 2021 18:47:49 +0200 Subject: [PATCH 323/455] Refactor test --- .../logger_connector/logger_connector.py | 1 - .../logging_/test_train_loop_logging.py | 176 ++++++------------ 2 files changed, 57 insertions(+), 120 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 77cee7fe93124..8054b363309d9 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -39,7 +39,6 @@ def __init__(self, trainer, log_gpu_memory: Optional[str] = None): self._callback_metrics: Dict[str, _METRIC] = {} self._epoch_end_reached = False self._current_fx: Optional[str] = None - # FIXME: use _epoch_end_reached? self._batch_idx: Optional[int] = None def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool): diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 30c995a061d7a..c8651a9b55619 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -344,75 +344,38 @@ def test_log_works_in_train_callback(tmpdir): class TestCallback(callbacks.Callback): - # helpers - count = 1 + count = 0 choices = [False, True] - # used to compute expected values - callback_funcs_called = collections.defaultdict(dict) - funcs_called_count = collections.defaultdict(int) - funcs_attr = {} - - def make_logging( - self, pl_module: pl.LightningModule, func_name, func_idx, on_steps=[], on_epochs=[], prob_bars=[] - ): - self.funcs_called_count[func_name] += 1 - iterate = list(itertools.product(*[on_steps, on_epochs, prob_bars])) - value = self.count * func_idx - - current_epoch = pl_module.trainer.current_epoch - - for idx, (on_step, on_epoch, prog_bar) in enumerate(iterate): - # run logging - custom_func_name = f"{func_idx}_{idx}_{func_name}" - pl_module.log(custom_func_name, value, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar) - - self.callback_funcs_called[custom_func_name].setdefault(current_epoch, []) - self.callback_funcs_called[custom_func_name][current_epoch].append(value) - - forked = on_step and on_epoch - - self.funcs_attr[custom_func_name] = { - "on_step": on_step, - "on_epoch": on_epoch, - "prog_bar": prog_bar, - "forked": forked, - "func_name": func_name - } - - if on_step and on_epoch: - self.funcs_attr[f"{custom_func_name}_step"] = { - "on_step": True, - "on_epoch": False, - "prog_bar": prog_bar, - "forked": False, - "func_name": func_name - } - - self.funcs_attr[f"{custom_func_name}_epoch"] = { - "on_step": False, - "on_epoch": True, - "prog_bar": prog_bar, - "forked": False, - "func_name": func_name - } + # used to compute expected values + logged_values = collections.defaultdict(list) + call_counter = collections.Counter() + logged_arguments = {} + + def make_logging(self, pl_module, func_name, on_steps, on_epochs, prob_bars): + self.call_counter.update([func_name]) + + for idx, (on_step, on_epoch, prog_bar) in enumerate(itertools.product(on_steps, on_epochs, prob_bars)): + fx = f"{func_name}_{idx}" + pl_module.log(fx, self.count, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar) + self.logged_values[fx].append(self.count) + self.logged_arguments[fx] = {"on_step": on_step, "on_epoch": on_epoch, "prog_bar": prog_bar} self.count += 1 def on_train_start(self, trainer, pl_module): self.make_logging( - pl_module, 'on_train_start', 1, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices + pl_module, 'on_train_start', on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices ) def on_epoch_start(self, trainer, pl_module): self.make_logging( - pl_module, 'on_epoch_start', 2, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices + pl_module, 'on_epoch_start', on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices ) def on_train_epoch_start(self, trainer, pl_module): self.make_logging( pl_module, 'on_train_epoch_start', - 3, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices @@ -420,101 +383,76 @@ def on_train_epoch_start(self, trainer, pl_module): def on_batch_end(self, trainer, pl_module): self.make_logging( - pl_module, 'on_batch_end', 6, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices + pl_module, 'on_batch_end', on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices ) def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self.make_logging( - pl_module, - 'on_train_batch_end', - 7, - on_steps=self.choices, - on_epochs=self.choices, - prob_bars=self.choices + pl_module, 'on_train_batch_end', on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices ) def on_train_epoch_end(self, trainer, pl_module): self.make_logging( - pl_module, 'on_train_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices + pl_module, 'on_train_epoch_end', on_steps=[False], on_epochs=self.choices, prob_bars=self.choices ) def on_epoch_end(self, trainer, pl_module): self.make_logging( - pl_module, 'on_epoch_end', 9, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices + pl_module, 'on_epoch_end', on_steps=[False], on_epochs=self.choices, prob_bars=self.choices ) class TestModel(BoringModel): - - manual_loss = [] + seen_losses = [] def training_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - self.manual_loss.append(loss) + loss = super().training_step(batch, batch_idx)['loss'] + self.seen_losses.append(loss) self.log('train_loss', loss, prog_bar=True) return {"loss": loss} - max_epochs = 2 - limit_train_batches = 2 model = TestModel() - test_callback = TestCallback() - + cb = TestCallback() trainer = Trainer( default_root_dir=tmpdir, - limit_train_batches=limit_train_batches, + limit_train_batches=2, limit_val_batches=0, - limit_test_batches=0, - val_check_interval=0., num_sanity_val_steps=0, - max_epochs=max_epochs, - callbacks=[test_callback] + max_epochs=2, + callbacks=[cb] ) trainer.fit(model) - assert test_callback.funcs_called_count["on_train_start"] == 1 - assert test_callback.funcs_called_count["on_epoch_start"] == 2 - assert test_callback.funcs_called_count["on_train_epoch_start"] == 2 - assert test_callback.funcs_called_count["on_batch_end"] == 4 - assert test_callback.funcs_called_count["on_epoch_end"] == 2 - assert test_callback.funcs_called_count["on_train_batch_end"] == 4 - assert test_callback.funcs_called_count["on_epoch_end"] == 2 - assert test_callback.funcs_called_count["on_train_epoch_end"] == 2 - - # function used to describe expected return logic - def get_expected_output(func_attr, original_values): - if func_attr["on_epoch"] and not func_attr["on_step"]: - # Apply mean on values - expected_output = np.mean(original_values) - else: - # Keep the latest value - expected_output = np.max(original_values) - return expected_output - # Make sure the func_name output equals the average from all logged values when on_epoch true - # pop extra keys - assert trainer.progress_bar_dict["train_loss"] == model.manual_loss[-1] - assert trainer.callback_metrics["train_loss"] == model.manual_loss[-1] - trainer.callback_metrics.pop("train_loss") - - for func_name, output_value in trainer.callback_metrics.items(): - if torch.is_tensor(output_value): - output_value = output_value.item() - # get creation attr - func_attr = test_callback.funcs_attr[func_name] - - # retrieved original logged values - values = test_callback.callback_funcs_called[func_name] - if len(values) > 0: - original_values = values[len(values) - 1] - # compute expected output and compare to actual one - expected_output = get_expected_output(func_attr, original_values) - assert float(output_value) == float(expected_output) - - for func_name, func_attr in test_callback.funcs_attr.items(): - if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]) and not func_attr["forked"]: - assert func_name in trainer.logger_connector.progress_bar_metrics - else: - assert func_name not in trainer.logger_connector.progress_bar_metrics + assert trainer.progress_bar_dict["train_loss"] == model.seen_losses[-1] + assert trainer.callback_metrics["train_loss"] == model.seen_losses[-1] + + assert cb.call_counter == { + 'on_train_batch_end': 4, + 'on_batch_end': 4, + 'on_epoch_start': 2, + 'on_train_epoch_start': 2, + 'on_train_epoch_end': 2, + 'on_epoch_end': 2, + 'on_train_start': 1 + } + + def get_expected(on_epoch, values): + reduction = np.mean if on_epoch else np.max + return reduction(values) + + for fx, value in trainer.callback_metrics.items(): + actual = value.item() + if fx not in cb.logged_arguments: + continue + on_epoch = cb.logged_arguments[fx]['on_epoch'] + values = cb.logged_values[fx] + expected = get_expected(on_epoch, values) + assert actual == expected + + for fx, attrs in cb.logged_arguments.items(): + should_include = attrs["prog_bar"] and attrs["on_step"] ^ attrs["on_epoch"] + is_included = fx in trainer.logger_connector.progress_bar_metrics + assert is_included if should_include else not is_included def test_logging_sync_dist_true_cpu(tmpdir): From a1f839f4062285c2f95f5c4f05b859c355b2cdbd Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Jun 2021 18:59:23 +0200 Subject: [PATCH 324/455] Tighter fx validator --- .../logger_connector/fx_validator.py | 28 ++++++++-------- .../logging_/test_train_loop_logging.py | 32 +++++++------------ 2 files changed, 26 insertions(+), 34 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index 7ab288e6041fd..8d079f8b4a637 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -29,26 +29,26 @@ class FxValidator: on_fit_end=None, on_sanity_check_start=None, on_sanity_check_end=None, - on_train_start=dict(on_step=(False, True), on_epoch=(False, True)), + on_train_start=dict(on_step=(False, ), on_epoch=(True, )), on_train_end=None, - on_validation_start=dict(on_step=(False, True), on_epoch=(False, True)), + on_validation_start=dict(on_step=(False, ), on_epoch=(True, )), on_validation_end=None, - on_test_start=dict(on_step=(False, True), on_epoch=(False, True)), + on_test_start=dict(on_step=(False, ), on_epoch=(True, )), on_test_end=None, on_predict_start=None, on_predict_end=None, on_pretrain_routine_start=None, on_pretrain_routine_end=None, - on_train_epoch_start=dict(on_step=(False, True), on_epoch=(False, True)), - on_train_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)), - on_validation_epoch_start=dict(on_step=(False, True), on_epoch=(False, True)), - on_validation_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)), - on_test_epoch_start=dict(on_step=(False, True), on_epoch=(False, True)), - on_test_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)), + on_train_epoch_start=dict(on_step=(False, True), on_epoch=(True, )), + on_train_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), + on_validation_epoch_start=dict(on_step=(False, True), on_epoch=(True, )), + on_validation_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), + on_test_epoch_start=dict(on_step=(False, True), on_epoch=(True, )), + on_test_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), on_predict_epoch_start=None, on_predict_epoch_end=None, - on_epoch_start=dict(on_step=(False, True), on_epoch=(False, True)), - on_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)), + on_epoch_start=dict(on_step=(False, True), on_epoch=(True, )), + on_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), on_batch_start=dict(on_step=(False, True), on_epoch=(False, True)), on_batch_end=dict(on_step=(False, True), on_epoch=(False, True)), on_train_batch_start=dict(on_step=(False, True), on_epoch=(False, True)), @@ -72,9 +72,9 @@ class FxValidator: training_step_end=dict(on_step=(False, True), on_epoch=(False, True)), validation_step_end=dict(on_step=(False, True), on_epoch=(False, True)), test_step_end=dict(on_step=(False, True), on_epoch=(False, True)), - training_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)), - validation_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)), - test_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)), + training_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), + validation_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), + test_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), on_before_batch_transfer=None, transfer_batch_to_device=None, on_after_batch_transfer=None, diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index c8651a9b55619..ef06d854341ab 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -362,44 +362,36 @@ def make_logging(self, pl_module, func_name, on_steps, on_epochs, prob_bars): self.logged_arguments[fx] = {"on_step": on_step, "on_epoch": on_epoch, "prog_bar": prog_bar} self.count += 1 - def on_train_start(self, trainer, pl_module): - self.make_logging( - pl_module, 'on_train_start', on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices - ) + def on_train_start(self, _, pl_module): + self.make_logging(pl_module, 'on_train_start', on_steps=[False], on_epochs=[True], prob_bars=self.choices) - def on_epoch_start(self, trainer, pl_module): + def on_epoch_start(self, _, pl_module): self.make_logging( - pl_module, 'on_epoch_start', on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices + pl_module, 'on_epoch_start', on_steps=self.choices, on_epochs=[True], prob_bars=self.choices ) - def on_train_epoch_start(self, trainer, pl_module): + def on_train_epoch_start(self, _, pl_module): self.make_logging( - pl_module, - 'on_train_epoch_start', - on_steps=self.choices, - on_epochs=self.choices, - prob_bars=self.choices + pl_module, 'on_train_epoch_start', on_steps=self.choices, on_epochs=[True], prob_bars=self.choices ) - def on_batch_end(self, trainer, pl_module): + def on_batch_end(self, _, pl_module): self.make_logging( pl_module, 'on_batch_end', on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices ) - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, _, pl_module, *__): self.make_logging( pl_module, 'on_train_batch_end', on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices ) - def on_train_epoch_end(self, trainer, pl_module): + def on_train_epoch_end(self, _, pl_module): self.make_logging( - pl_module, 'on_train_epoch_end', on_steps=[False], on_epochs=self.choices, prob_bars=self.choices + pl_module, 'on_train_epoch_end', on_steps=[False], on_epochs=[True], prob_bars=self.choices ) - def on_epoch_end(self, trainer, pl_module): - self.make_logging( - pl_module, 'on_epoch_end', on_steps=[False], on_epochs=self.choices, prob_bars=self.choices - ) + def on_epoch_end(self, _, pl_module): + self.make_logging(pl_module, 'on_epoch_end', on_steps=[False], on_epochs=[True], prob_bars=self.choices) class TestModel(BoringModel): seen_losses = [] From a58787995df3acab3b930f3fab900ea7cbf3758e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Jun 2021 19:11:43 +0200 Subject: [PATCH 325/455] Add back split idx --- .../logger_connector/logger_connector.py | 14 +++++++++++--- pytorch_lightning/trainer/training_loop.py | 2 +- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 8054b363309d9..0086fdb9cd9d8 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -40,6 +40,7 @@ def __init__(self, trainer, log_gpu_memory: Optional[str] = None): self._epoch_end_reached = False self._current_fx: Optional[str] = None self._batch_idx: Optional[int] = None + self._split_idx: Optional[int] = None def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool): # logging @@ -206,9 +207,10 @@ def update_evaluation_step_metrics(self) -> None: Train metric updates """ - def on_train_split_start(self, batch_idx: int, split_batch: Any) -> None: + def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None: self.trainer.result_collection.extract_batch_size(split_batch) self._batch_idx = batch_idx + self._split_idx = split_idx def update_train_step_metrics(self, batch_output): if self.trainer.train_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization: @@ -247,6 +249,7 @@ def on_epoch_end(self) -> None: self._callback_metrics.update(metrics[MetricSource.CALLBACK]) self._logged_metrics.update(metrics[MetricSource.LOG]) self._batch_idx = None + self._split_idx = None def on_batch_end(self) -> None: assert not self._epoch_end_reached @@ -256,12 +259,17 @@ def on_batch_end(self) -> None: self._logged_metrics.update(metrics[MetricSource.LOG]) def should_reset_tensors(self, fx: str) -> bool: - # reset tensor metrics only when the hook changed and reloading the dataloader - return self._current_fx != fx and self._batch_idx in (None, 0) + is_different_fx = self._current_fx != fx + if self._split_idx is None: + is_first_batch = self._batch_idx in (None, 0) + else: + is_first_batch = self._batch_idx + self._split_idx == 0 + return is_different_fx and is_first_batch def reset(self, metrics: Optional[bool] = None) -> None: self.trainer.result_collection.reset(metrics=metrics) self._batch_idx = None + self._split_idx = None self._current_fx = None @property diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 3694e9557ca06..358cd24e554c2 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -644,7 +644,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): self.split_idx = split_idx # let logger connector extract batch size - self.trainer.logger_connector.on_train_split_start(batch_idx, split_batch) + self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch) if self.trainer.lightning_module.automatic_optimization: for opt_idx, optimizer in self.get_active_optimizers(batch_idx): From 631da985a5a6f00172a882e907af62ce276d0c7d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Jun 2021 19:28:08 +0200 Subject: [PATCH 326/455] Typing --- .../logger_connector/logger_connector.py | 33 +++++++++++-------- .../connectors/logger_connector/result.py | 10 +++--- pytorch_lightning/trainer/evaluation_loop.py | 1 + pytorch_lightning/utilities/types.py | 4 +-- 4 files changed, 27 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 0086fdb9cd9d8..234965ea0c1fe 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -13,22 +13,23 @@ # limitations under the License. import os from pprint import pprint -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable, Mapping, Optional import torch +import pytorch_lightning as pl from pytorch_lightning.core import memory -from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger +from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger from pytorch_lightning.trainer.connectors.logger_connector.result import _METRIC, MetricSource from pytorch_lightning.trainer.states import RunningStage, TrainerFn -from pytorch_lightning.utilities import DeviceType +from pytorch_lightning.utilities import AttributeDict, DeviceType from pytorch_lightning.utilities.metrics import metrics_to_scalars from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT class LoggerConnector: - def __init__(self, trainer, log_gpu_memory: Optional[str] = None): + def __init__(self, trainer: 'pl.Trainer', log_gpu_memory: Optional[str] = None) -> None: self.trainer = trainer self.log_gpu_memory = log_gpu_memory self.eval_loop_results = [] @@ -42,24 +43,26 @@ def __init__(self, trainer, log_gpu_memory: Optional[str] = None): self._batch_idx: Optional[int] = None self._split_idx: Optional[int] = None - def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool): - # logging + def on_trainer_init( + self, logger: LightningLoggerBase, flush_logs_every_n_steps: int, log_every_n_steps: int, + move_metrics_to_cpu: bool + ) -> None: self.configure_logger(logger) self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps self.trainer.log_every_n_steps = log_every_n_steps self.trainer.move_metrics_to_cpu = move_metrics_to_cpu @property - def should_flush_logs(self): + def should_flush_logs(self) -> bool: should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0 return should_flush or self.trainer.should_stop @property - def should_update_logs(self): + def should_update_logs(self) -> bool: should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 return should_log_every_n_steps or self.trainer.should_stop - def configure_logger(self, logger): + def configure_logger(self, logger: LightningLoggerBase) -> None: if logger is True: version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id) @@ -75,14 +78,16 @@ def configure_logger(self, logger): else: self.trainer.logger = logger - def log_metrics(self, metrics, grad_norm_dict, step=None): + def log_metrics( + self, metrics: Dict[str, _METRIC], grad_norm_dict: Dict[str, float], step: Optional[int] = None + ) -> None: """Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, uses metrics["step"] as a step Args: - metrics (dict): Metric values - grad_norm_dict (dict): Gradient norms + metrics: Metric values + grad_norm_dict: Gradient norms step (int): Step for which metrics should be logged. Default value is `self.global_step` during training or the total validation / test log step count during validation and testing. """ @@ -117,7 +122,7 @@ def log_metrics(self, metrics, grad_norm_dict, step=None): Evaluation metric updates """ - def prepare_eval_loop_results(self, metrics: Dict[str, _METRIC]) -> None: + def prepare_eval_loop_results(self, metrics: Mapping[str, _METRIC]) -> None: if self.trainer.sanity_checking: return @@ -212,7 +217,7 @@ def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) self._batch_idx = batch_idx self._split_idx = split_idx - def update_train_step_metrics(self, batch_output): + def update_train_step_metrics(self, batch_output: AttributeDict) -> None: if self.trainer.train_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization: return diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 7f79c2a050805..4c2b6bedfd996 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -14,7 +14,7 @@ from collections.abc import Generator from dataclasses import dataclass, field from functools import partial, wraps -from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union import torch from torchmetrics import Metric @@ -28,7 +28,7 @@ # re-define the ones from pytorch_lightning.utilities.types without the `Number` type _METRIC = Union[Metric, torch.Tensor] -_METRIC_COLLECTION = Union[_METRIC, Dict[str, _METRIC]] +_METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]] class MetricSource(LightningEnum): @@ -172,7 +172,7 @@ def wrapped_func(*args, **kwargs): return wrapped_func - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: Any) -> None: # performance: skip the `torch.nn.Module.__setattr__` checks object.__setattr__(self, key, value) @@ -220,7 +220,7 @@ def __init__(self, training: bool, device: Optional[torch.device] = None) -> Non self.training = training self._minimize = None self._batch_size = torch.tensor(1, device=device) - self.device: Optional[torch.device] = device + self.device: Optional[Union[str, torch.device]] = device self.fx_validator = FxValidator() @property @@ -258,7 +258,7 @@ def extra(self) -> Dict[str, Any]: return self.get('_extra', {}) @extra.setter - def extra(self, extra: Dict[str, Any]) -> None: + def extra(self, extra: Mapping[str, Any]) -> None: def check_fn(v): if v.grad_fn is not None: diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 72d3bf777d86d..d925ea612c319 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -80,6 +80,7 @@ def should_skip_evaluation(self, max_batches: List[Union[int, float]]) -> bool: def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end() + assert self.trainer.result_collection is not None self.trainer.result_collection.device = self.trainer.lightning_module.device if self.trainer.testing: diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 945ee3b74218f..a04f7ba87284d 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -17,13 +17,13 @@ - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no trailing `_`) """ from numbers import Number -from typing import Any, Dict, Iterator, List, Union +from typing import Any, Dict, Iterator, List, Mapping, Union import torch from torchmetrics import Metric _METRIC = Union[Metric, torch.Tensor, Number] -_METRIC_COLLECTION = Union[_METRIC, Dict[str, _METRIC]] +_METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]] STEP_OUTPUT = Union[torch.Tensor, Dict[str, Any]] EPOCH_OUTPUT = List[STEP_OUTPUT] _EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader From e098631a305aded8a9f57861fe973eede235cb36 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 7 Jun 2021 18:37:46 +0100 Subject: [PATCH 327/455] update --- tests/trainer/logging_/test_train_loop_logging.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index ef06d854341ab..c4dd93394ff5b 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -409,7 +409,7 @@ def training_step(self, batch, batch_idx): limit_train_batches=2, limit_val_batches=0, num_sanity_val_steps=0, - max_epochs=2, + max_epochs=1, callbacks=[cb] ) trainer.fit(model) @@ -419,12 +419,12 @@ def training_step(self, batch, batch_idx): assert trainer.callback_metrics["train_loss"] == model.seen_losses[-1] assert cb.call_counter == { - 'on_train_batch_end': 4, - 'on_batch_end': 4, - 'on_epoch_start': 2, - 'on_train_epoch_start': 2, - 'on_train_epoch_end': 2, - 'on_epoch_end': 2, + 'on_train_batch_end': 2, + 'on_batch_end': 2, + 'on_epoch_start': 1, + 'on_train_epoch_start': 1, + 'on_train_epoch_end': 1, + 'on_epoch_end': 1, 'on_train_start': 1 } From 2f13234f1e9484f98ff1532244ed51d2cf9fc207 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Jun 2021 20:06:31 +0200 Subject: [PATCH 328/455] Conflict --- tests/trainer/logging_/test_train_loop_logging.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index c4dd93394ff5b..88ad0883d86dd 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -419,13 +419,13 @@ def training_step(self, batch, batch_idx): assert trainer.callback_metrics["train_loss"] == model.seen_losses[-1] assert cb.call_counter == { - 'on_train_batch_end': 2, - 'on_batch_end': 2, + 'on_train_start': 1, 'on_epoch_start': 1, 'on_train_epoch_start': 1, + 'on_train_batch_end': 2, + 'on_batch_end': 2, 'on_train_epoch_end': 1, - 'on_epoch_end': 1, - 'on_train_start': 1 + 'on_epoch_end': 1 } def get_expected(on_epoch, values): From 502dcbd12c51e024857e21bfe9d7e661aaed3b26 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Jun 2021 20:15:24 +0200 Subject: [PATCH 329/455] Fix tests --- .../connectors/logger_connector/logger_connector.py | 8 ++++++-- pytorch_lightning/trainer/evaluation_loop.py | 2 +- pytorch_lightning/trainer/training_loop.py | 2 +- tests/callbacks/test_early_stopping.py | 4 +++- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 234965ea0c1fe..ce14c782cc2a3 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -247,14 +247,18 @@ def update_train_epoch_metrics(self) -> None: def on_epoch_start(self) -> None: self._epoch_end_reached = False + def epoch_end_reached(self): + self.trainer.logger_connector._epoch_end_reached = True + self.trainer.logger_connector._batch_idx = None + self.trainer.logger_connector._split_idx = None + def on_epoch_end(self) -> None: assert self._epoch_end_reached metrics = self.metrics self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) self._callback_metrics.update(metrics[MetricSource.CALLBACK]) self._logged_metrics.update(metrics[MetricSource.LOG]) - self._batch_idx = None - self._split_idx = None + self._current_fx = None def on_batch_end(self) -> None: assert not self._epoch_end_reached diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index d925ea612c319..49564c8ba9f88 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -198,7 +198,7 @@ def _should_track_batch_outputs_for_epoch_end(self) -> bool: def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: # inform logger the batch loop has finished - self.trainer.logger_connector._epoch_end_reached = True + self.trainer.logger_connector.epoch_end_reached() # call the model epoch end model = self.trainer.lightning_module diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 358cd24e554c2..ef9872ccb98c0 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -544,7 +544,7 @@ def run_training_epoch(self): def on_train_epoch_end(self, epoch_output: List[List[List['ResultCollection']]]) -> None: # inform logger the batch loop has finished - self.trainer.logger_connector._epoch_end_reached = True + self.trainer.logger_connector.epoch_end_reached() # prepare epoch output processed_epoch_output = TrainLoop._prepare_outputs(epoch_output, batch_mode=False) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 7d303e6ed00d6..e62ddb90ff5ac 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -123,7 +123,7 @@ def test_early_stopping_patience(tmpdir, loss_values: list, patience: int, expec """Test to ensure that early stopping is not triggered before patience is exhausted.""" class ModelOverrideValidationReturn(BoringModel): - validation_return_values = torch.Tensor(loss_values) + validation_return_values = torch.tensor(loss_values) def validation_epoch_end(self, outputs): loss = self.validation_return_values[self.current_epoch] @@ -137,6 +137,7 @@ def validation_epoch_end(self, outputs): val_check_interval=1.0, num_sanity_val_steps=0, max_epochs=10, + progress_bar_refresh_rate=0, ) trainer.fit(model) assert trainer.current_epoch == expected_stop_epoch @@ -176,6 +177,7 @@ def training_epoch_end(self, outputs): callbacks=[early_stop_callback], num_sanity_val_steps=0, max_epochs=10, + progress_bar_refresh_rate=0, ) trainer.fit(model) assert trainer.current_epoch == expected_stop_epoch From c716736dc1ddb18707a69e7f9def1830bf795fda Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 7 Jun 2021 19:16:27 +0100 Subject: [PATCH 330/455] resolve grad_norm --- pytorch_lightning/core/lightning.py | 17 +++++++++++++++++ .../logger_connector/logger_connector.py | 5 ++--- pytorch_lightning/trainer/training_loop.py | 15 ++++++--------- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 7ef73405a4b81..bb99bd518a2c4 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1530,6 +1530,23 @@ def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): """ optimizer.zero_grad() + def log_grad_norm(self, grad_norm_dict: Dict[str, torch.Tensor]) -> None: + """Override this method to change the default behaviour of ``log_grad_norm``. + + Args: + grad_norm_dict: Dictionary containing current grad norm metrics + + Examples:: + + # DEFAULT + def log_grad_norm(self, grad_norm_dict): + print(grad_norm_dict) + + """ + self._current_fx_name = "on_after_backward" + self.log_dict(grad_norm_dict, on_step=True, on_epoch=True, prog_bar=True, logger=True) + self._current_fx_name = None + def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list: r""" When using truncated backpropagation through time, each batch must be split along the diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 234965ea0c1fe..4be823f7c8bfc 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -226,9 +226,8 @@ def update_train_step_metrics(self, batch_output: AttributeDict) -> None: metrics = self.metrics[MetricSource.LOG] if self.should_update_logs or self.trainer.fast_dev_run is True: # logs user requested information to logger - grad_norm_dict = batch_output.grad_norm_dict or {} - if metrics or grad_norm_dict: - self.log_metrics(metrics, grad_norm_dict) + if metrics: + self.log_metrics(metrics, {}) def update_train_epoch_metrics(self) -> None: # add the metrics to the loggers diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 358cd24e554c2..9db9d02d0af16 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -608,8 +608,7 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: self.trainer.lightning_module._current_fx_name = None def run_training_batch(self, batch, batch_idx, dataloader_idx): - # track grad norms - grad_norm_dict = {} + model_ref = self.trainer.lightning_module # bookkeeping self._hiddens = None @@ -623,19 +622,18 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): self.warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") return AttributeDict( signal=0, - grad_norm_dict={}, training_step_output=batch_outputs, ) # hook response = self.trainer.call_hook("on_batch_start") if response == -1: - return AttributeDict(signal=-1, grad_norm_dict={}) + return AttributeDict(signal=-1) # hook response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx) if response == -1: - return AttributeDict(signal=-1, grad_norm_dict={}) + return AttributeDict(signal=-1) # lightning module hook splits = self._tbptt_split_batch(batch) @@ -646,12 +644,11 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): # let logger connector extract batch size self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch) - if self.trainer.lightning_module.automatic_optimization: + if model_ref.automatic_optimization: for opt_idx, optimizer in self.get_active_optimizers(batch_idx): result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer) if result: batch_outputs[opt_idx].append(result.training_step_output) - grad_norm_dict = result.get("grad_norm_dict", {}) else: # in manual optimization, there is no looping over optimizers result = self._run_optimization(batch_idx, split_batch) @@ -660,7 +657,6 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): return AttributeDict( signal=0, - grad_norm_dict=grad_norm_dict, training_step_output=batch_outputs, ) @@ -815,7 +811,8 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs): if not self.should_accumulate(): # track gradients - result.grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer) + grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer) + self.trainer.lightning_module.log_grad_norm(grad_norm_dict) def update_lr_schedulers(self, interval: str) -> None: if interval == "step": From 26f5e037dd7684157d2a2aa6df7f89f8a388d8e1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 7 Jun 2021 19:20:08 +0100 Subject: [PATCH 331/455] update --- tests/models/test_grad_norm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py index de6f7c3c2f1db..14a966d211e99 100644 --- a/tests/models/test_grad_norm.py +++ b/tests/models/test_grad_norm.py @@ -108,5 +108,5 @@ def test_grad_tracking_interval(tmpdir, log_every_n_steps): if grad_norm_dict: grad_norm_dicts.append(grad_norm_dict) - assert len(grad_norm_dicts) == expected - assert all(grad_norm_dicts[0].keys() == g.keys() for g in grad_norm_dicts) + assert len(grad_norm_dicts) == expected + 1 + assert all(grad_norm_dicts[0].keys() == g.keys() for g in grad_norm_dicts[:-1]) From badd645c372850bed5f7e8fcc3836acadf5c4e44 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 7 Jun 2021 19:27:36 +0100 Subject: [PATCH 332/455] move to train loop --- pytorch_lightning/core/lightning.py | 2 -- pytorch_lightning/trainer/training_loop.py | 5 ++++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index bb99bd518a2c4..7d86517413840 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1543,9 +1543,7 @@ def log_grad_norm(self, grad_norm_dict): print(grad_norm_dict) """ - self._current_fx_name = "on_after_backward" self.log_dict(grad_norm_dict, on_step=True, on_epoch=True, prog_bar=True, logger=True) - self._current_fx_name = None def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list: r""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c2c0e10d06112..7c85877bb71c0 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -812,7 +812,10 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs): if not self.should_accumulate(): # track gradients grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer) - self.trainer.lightning_module.log_grad_norm(grad_norm_dict) + if grad_norm_dict: + self.trainer.lightning_module._current_fx_name = "on_after_backward" + self.trainer.lightning_module.log_grad_norm(grad_norm_dict) + self.trainer.lightning_module._current_fx_name = None def update_lr_schedulers(self, interval: str) -> None: if interval == "step": From aac11a0dc147102dc4ebaf208f871cadd07a9b42 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Jun 2021 20:30:44 +0200 Subject: [PATCH 333/455] Bye grad_norm_dict parameter --- .../logger_connector/logger_connector.py | 16 +++++----------- .../trainer/logging_/test_distributed_logging.py | 2 +- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index ab8a8e69c500d..22b4acd0f717e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -78,16 +78,13 @@ def configure_logger(self, logger: LightningLoggerBase) -> None: else: self.trainer.logger = logger - def log_metrics( - self, metrics: Dict[str, _METRIC], grad_norm_dict: Dict[str, float], step: Optional[int] = None - ) -> None: + def log_metrics(self, metrics: Dict[str, _METRIC], step: Optional[int] = None) -> None: """Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, uses metrics["step"] as a step Args: metrics: Metric values - grad_norm_dict: Gradient norms step (int): Step for which metrics should be logged. Default value is `self.global_step` during training or the total validation / test log step count during validation and testing. """ @@ -96,9 +93,6 @@ def log_metrics( mem_map = memory.get_memory_profile(self.log_gpu_memory) metrics.update(mem_map) - # add norms - metrics.update(grad_norm_dict) - # turn all tensors to scalars scalar_metrics = metrics_to_scalars(metrics) @@ -147,7 +141,7 @@ def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: # log all the metrics as a single dict log_metrics = metrics[MetricSource.LOG] if log_metrics: - self.log_metrics(log_metrics, {}) + self.log_metrics(log_metrics) self.prepare_eval_loop_results(metrics[MetricSource.CALLBACK]) @@ -203,7 +197,7 @@ def update_evaluation_step_metrics(self) -> None: assert not self._epoch_end_reached metrics = self.metrics[MetricSource.LOG] if metrics: - self.log_metrics(metrics, {}, step=self.evaluation_log_step) + self.log_metrics(metrics, step=self.evaluation_log_step) # increment the step even if nothing was logged self.increment_evaluation_log_step() @@ -227,14 +221,14 @@ def update_train_step_metrics(self, batch_output: AttributeDict) -> None: if self.should_update_logs or self.trainer.fast_dev_run is True: # logs user requested information to logger if metrics: - self.log_metrics(metrics, {}) + self.log_metrics(metrics) def update_train_epoch_metrics(self) -> None: # add the metrics to the loggers assert self._epoch_end_reached metrics = self.metrics[MetricSource.LOG] if metrics: - self.log_metrics(metrics, {}) + self.log_metrics(metrics) # reset result collection for next epoch self.trainer.result_collection.reset(metrics=True) diff --git a/tests/trainer/logging_/test_distributed_logging.py b/tests/trainer/logging_/test_distributed_logging.py index 5832f387cc63d..4094fd90021af 100644 --- a/tests/trainer/logging_/test_distributed_logging.py +++ b/tests/trainer/logging_/test_distributed_logging.py @@ -24,7 +24,7 @@ class TestModel(BoringModel): def on_pretrain_routine_end(self) -> None: with mock.patch('pytorch_lightning.loggers.base.LightningLoggerBase.agg_and_log_metrics') as m: - self.trainer.logger_connector.log_metrics({'a': 2}, {}) + self.trainer.logger_connector.log_metrics({'a': 2}) logged_times = m.call_count expected = int(self.trainer.is_global_zero) msg = f'actual logger called from non-global zero, logged_times: {logged_times}, expected: {expected}' From 919cbfb485630b35de1780003c67608be7159893 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Jun 2021 20:39:53 +0200 Subject: [PATCH 334/455] Fix sync test --- .../trainer/connectors/logger_connector/result.py | 2 +- tests/core/test_results.py | 7 +++++-- tests/models/test_tpu.py | 12 +++++------- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 4c2b6bedfd996..2a0388eb7c11f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -45,7 +45,7 @@ class _Sync: group: Optional[Any] = None @property - def __call__(self) -> Callable: + def __call__(self) -> Any: return partial(self.fn, reduce_op=self.op, group=self.group) if self.should else self.no_op @staticmethod diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 9fce99ffbecc9..5fffb64331ae4 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -20,7 +20,9 @@ import torch.multiprocessing as mp import tests.helpers.utils as tutils -from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning import Trainer +from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync +from pytorch_lightning.utilities.distributed import sync_ddp_if_available from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf @@ -37,7 +39,8 @@ def _setup_ddp(rank, worldsize): def _ddp_test_fn(rank, worldsize): _setup_ddp(rank, worldsize) tensor = torch.tensor([1.0]) - actual = LightningModule._LightningModule__sync(tensor, sync_dist=True, sync_dist_op=torch.distributed.ReduceOp.SUM) + sync = _Sync(sync_ddp_if_available, should=True, op=torch.distributed.ReduceOp.SUM) + actual = sync(tensor) assert actual.item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors" diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 8e3bab7350f7f..2e7db175801b9 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -21,10 +21,11 @@ import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils -from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning import Trainer from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.plugins import TPUSpawnPlugin +from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -424,12 +425,9 @@ def test_tpu_sync_dist(): """Test tpu spawn sync dist operation """ def test_sync_dist(_): - value = LightningModule._LightningModule__sync( - torch.tensor([1.0]), - sync_fn=TPUSpawnPlugin().reduce, - sync_dist=True, - sync_dist_op=torch.distributed.ReduceOp.SUM - ) + sync = _Sync(TPUSpawnPlugin().reduce, should=True, op=torch.distributed.ReduceOp.SUM) + value = torch.tensor([1.0]) + value = sync(value), assert value.item() == 8 xmp.spawn(test_sync_dist, nprocs=8, start_method='fork') From 1c75341cb24c12122548af5b0751bdccc59b1500 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 7 Jun 2021 19:41:06 +0100 Subject: [PATCH 335/455] update --- pytorch_lightning/core/lightning.py | 2 +- tests/models/test_grad_norm.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 7d86517413840..cef9aed87bb20 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -341,7 +341,7 @@ def log( on_step = self.__auto_choose_log_on_step(on_step) on_epoch = self.__auto_choose_log_on_epoch(on_epoch) - result_collection = self.trainer.result_collection + result_collection: 'ResultCollection' = self.trainer.result_collection # noqa F821 assert result_collection is not None assert self._current_fx_name is not None result_collection.fx_validator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py index 14a966d211e99..294afd419a288 100644 --- a/tests/models/test_grad_norm.py +++ b/tests/models/test_grad_norm.py @@ -108,5 +108,9 @@ def test_grad_tracking_interval(tmpdir, log_every_n_steps): if grad_norm_dict: grad_norm_dicts.append(grad_norm_dict) + # logging on n steps + 1 epochs assert len(grad_norm_dicts) == expected + 1 + # check all metrics derived from steps have the same keys assert all(grad_norm_dicts[0].keys() == g.keys() for g in grad_norm_dicts[:-1]) + epoch_end_keys = [k.replace("step", "epoch") for k in grad_norm_dicts[0].keys()] + assert epoch_end_keys == list(grad_norm_dicts[-1].keys()) From 23a75100377d5a48c2ab06c686e5c9b31110b457 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Jun 2021 21:01:40 +0200 Subject: [PATCH 336/455] Fix bug when validation is run mid epoch --- .../trainer/connectors/logger_connector/logger_connector.py | 4 +++- .../trainer/connectors/logger_connector/result.py | 1 - pytorch_lightning/trainer/evaluation_loop.py | 6 +++++- pytorch_lightning/trainer/training_loop.py | 1 + tests/models/test_grad_norm.py | 4 ++-- 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 22b4acd0f717e..1c848cdeccce2 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -160,7 +160,6 @@ def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: print('-' * 80) results = self.eval_loop_results - # clear mem self.eval_loop_results = [] return results @@ -240,6 +239,9 @@ def update_train_epoch_metrics(self) -> None: def on_epoch_start(self) -> None: self._epoch_end_reached = False + def on_batch_start(self) -> None: + self._epoch_end_reached = False + def epoch_end_reached(self): self.trainer.logger_connector._epoch_end_reached = True self.trainer.logger_connector._batch_idx = None diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 2a0388eb7c11f..1f7f2d286630f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -129,7 +129,6 @@ def compute(self) -> torch.Tensor: value = self.meta.sync(self.value) if self.meta.is_mean_reduction: cumulated_batch_size = self.meta.sync(self.cumulated_batch_size) - # FIXME: might need sum return value / cumulated_batch_size elif self.meta.is_max_reduction or self.meta.is_min_reduction: return value diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 49564c8ba9f88..dc2b1f4d5e29c 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -217,6 +217,10 @@ def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: model.validation_epoch_end(outputs) def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + self.trainer.logger_connector.on_batch_start() + # FIXME(@carmocca): missing hook? + # self.trainer.call_hook('on_batch_start') + # set dataloader_idx to model and track batch_size assert self.num_dataloaders is not None self.trainer.logger_connector.on_evaluation_batch_start(batch, batch_idx, dataloader_idx, self.num_dataloaders) @@ -238,7 +242,7 @@ def on_evaluation_batch_end( else: self.trainer.call_hook('on_validation_batch_end', output, batch, batch_idx, dataloader_idx) - # FIXME: missing hook? + # FIXME(@carmocca): missing hook? # self.trainer.call_hook('on_batch_end') self.trainer.logger_connector.on_batch_end() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 7c85877bb71c0..14c26fc647838 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -626,6 +626,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): ) # hook + self.trainer.logger_connector.on_batch_start() response = self.trainer.call_hook("on_batch_start") if response == -1: return AttributeDict(signal=-1) diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py index 294afd419a288..384e643e184fe 100644 --- a/tests/models/test_grad_norm.py +++ b/tests/models/test_grad_norm.py @@ -112,5 +112,5 @@ def test_grad_tracking_interval(tmpdir, log_every_n_steps): assert len(grad_norm_dicts) == expected + 1 # check all metrics derived from steps have the same keys assert all(grad_norm_dicts[0].keys() == g.keys() for g in grad_norm_dicts[:-1]) - epoch_end_keys = [k.replace("step", "epoch") for k in grad_norm_dicts[0].keys()] - assert epoch_end_keys == list(grad_norm_dicts[-1].keys()) + epoch_end_keys = [k.replace("step", "epoch") for k in grad_norm_dicts[0]] + assert epoch_end_keys == list(grad_norm_dicts[-1]) From 9df0da2a6d627d1c1129a25275c80d318a635432 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Jun 2021 21:03:08 +0200 Subject: [PATCH 337/455] fix grad_norm_dict test --- tests/trainer/loops/test_training_loop_flow_scalar.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index d9d8273eb6429..9a47932b68cad 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -156,7 +156,6 @@ def backward(self, loss, optimizer, optimizer_idx): batch_idx, batch = 0, next(iter(model.train_dataloader())) out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) train_step_out = out.training_step_output assert len(train_step_out) == 1 @@ -236,7 +235,6 @@ def backward(self, loss, optimizer, optimizer_idx): batch_idx, batch = 0, next(iter(model.train_dataloader())) out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) train_step_out = out.training_step_output assert len(train_step_out) == 1 From 0485a98e4cdbf8d043d5d381e1124d67810e87fc Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Jun 2021 21:19:38 +0200 Subject: [PATCH 338/455] Fix fx_validator test --- tests/trainer/logging_/test_logger_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index da1cb501c7af6..1a3f95fea6367 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -117,7 +117,8 @@ def test_fx_validator(tmpdir): # This summarizes where and what is currently possible to log using `self.log` is_stage = "train" in func_name or "test" in func_name or "validation" in func_name is_start = "start" in func_name or "batch" in func_name - on_step = is_stage and is_start + is_epoch = "epoch" in func_name + on_step = is_stage and not is_start and not is_epoch on_epoch = True # creating allowed condition allowed = ( From e0702aa5ca859a1fc8b4d1027cec7c1ff23aa88f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Jun 2021 21:22:19 +0200 Subject: [PATCH 339/455] fix grad_norm_dict test --- tests/trainer/loops/test_evaluation_loop_flow.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py index 61fae95c70312..a7520c90dc869 100644 --- a/tests/trainer/loops/test_evaluation_loop_flow.py +++ b/tests/trainer/loops/test_evaluation_loop_flow.py @@ -71,7 +71,6 @@ def backward(self, loss, optimizer, optimizer_idx): batch_idx, batch = 0, next(iter(model.train_dataloader())) out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) train_step_out = out.training_step_output assert len(train_step_out) == 1 @@ -144,7 +143,6 @@ def backward(self, loss, optimizer, optimizer_idx): batch_idx, batch = 0, next(iter(model.train_dataloader())) out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) train_step_out = out.training_step_output assert len(train_step_out) == 1 From 32ca719bd98e8ab5f7b6a5441665a8a62c43a8c5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Jun 2021 21:27:57 +0200 Subject: [PATCH 340/455] Fix order bug --- pytorch_lightning/trainer/evaluation_loop.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index dc2b1f4d5e29c..977ae10dc5464 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -103,8 +103,6 @@ def on_evaluation_model_train(self) -> None: model_ref.on_validation_model_train() def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: - self.trainer.logger_connector.reset(metrics=True) - if self.trainer.testing: self.trainer.call_hook('on_test_end', *args, **kwargs) else: @@ -114,6 +112,8 @@ def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: # summarize profile results self.trainer.profiler.describe() + self.trainer.logger_connector.reset(metrics=True) + def reload_evaluation_dataloaders(self) -> None: model = self.trainer.lightning_module if self.trainer.testing: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9a5d9380fc23d..fc858e2fcb564 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1009,12 +1009,12 @@ def _run_evaluation(self) -> _EVALUATE_OUTPUT: # hook self.evaluation_loop.on_evaluation_epoch_end() - # hook - self.evaluation_loop.on_evaluation_end() - # log epoch metrics eval_loop_results = self.logger_connector.get_evaluate_epoch_results() + # hook + self.evaluation_loop.on_evaluation_end() + # save predictions to disk self.evaluation_loop.predictions.to_disk() From c68825c7e9d0876df7c1078eb1c3bc8254b58476 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Jun 2021 21:32:18 +0200 Subject: [PATCH 341/455] Detach tensors in test --- .../trainer/connectors/logger_connector/result.py | 5 +---- tests/trainer/logging_/test_train_loop_logging.py | 2 +- tests/trainer/optimization/test_manual_optimization.py | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 1f7f2d286630f..f2a0750b9e570 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -261,10 +261,7 @@ def extra(self, extra: Mapping[str, Any]) -> None: def check_fn(v): if v.grad_fn is not None: - raise MisconfigurationException( - 'You passed a tensor with `grad_fn` when calling `self.log()`.' - f' The extra values are {extra}' - ) + raise MisconfigurationException(f'You returned a tensor with `grad_fn`. The extra values are {extra}') apply_to_collection(extra, torch.Tensor, check_fn) self['_extra'] = extra diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 88ad0883d86dd..b588e17d83f2d 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -755,7 +755,7 @@ def training_step(self, *args): trainer = Trainer(default_root_dir=tmpdir) model = TestModel() - with pytest.raises(MisconfigurationException, match='You passed a tensor with `grad_fn`'): + with pytest.raises(MisconfigurationException, match='You returned a tensor with `grad_fn`'): trainer.fit(model) class TestModel(BoringModel): diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 8589473564f00..9abc9b47ab82a 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -474,7 +474,7 @@ def training_step(self, batch, batch_idx): opt_b.step() opt_b.zero_grad() - return {'loss1': loss_1, 'loss2': loss_2} + return {'loss1': loss_1.detach(), 'loss2': loss_2.detach()} def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer From 37ed74d144d789b06049f44d37005103790fac46 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 7 Jun 2021 19:33:10 +0000 Subject: [PATCH 342/455] resolve some tests --- pytorch_lightning/core/lightning.py | 2 +- .../connectors/logger_connector/fx_validator.py | 10 +++++----- pytorch_lightning/trainer/trainer.py | 7 ++++++- pytorch_lightning/trainer/training_loop.py | 2 +- tests/trainer/logging_/test_eval_loop_logging.py | 10 +++++----- tests/trainer/loops/test_evaluation_loop.py | 2 +- tests/trainer/loops/test_evaluation_loop_flow.py | 2 -- tests/trainer/optimization/test_manual_optimization.py | 2 +- 8 files changed, 20 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index cef9aed87bb20..dd15cc2d9f14b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1540,7 +1540,7 @@ def log_grad_norm(self, grad_norm_dict: Dict[str, torch.Tensor]) -> None: # DEFAULT def log_grad_norm(self, grad_norm_dict): - print(grad_norm_dict) + self.log_dict(grad_norm_dict, on_step=False, on_epoch=True, prog_bar=False, logger=True) """ self.log_dict(grad_norm_dict, on_step=True, on_epoch=True, prog_bar=True, logger=True) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index 8d079f8b4a637..20b9ddb5ca17d 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -29,11 +29,11 @@ class FxValidator: on_fit_end=None, on_sanity_check_start=None, on_sanity_check_end=None, - on_train_start=dict(on_step=(False, ), on_epoch=(True, )), + on_train_start=dict(on_step=(True, False), on_epoch=(False, True)), on_train_end=None, - on_validation_start=dict(on_step=(False, ), on_epoch=(True, )), + on_validation_start=dict(on_step=(False, True), on_epoch=(False, True)), on_validation_end=None, - on_test_start=dict(on_step=(False, ), on_epoch=(True, )), + on_test_start=dict(on_step=(False, True), on_epoch=(False, True)), on_test_end=None, on_predict_start=None, on_predict_end=None, @@ -43,8 +43,8 @@ class FxValidator: on_train_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), on_validation_epoch_start=dict(on_step=(False, True), on_epoch=(True, )), on_validation_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), - on_test_epoch_start=dict(on_step=(False, True), on_epoch=(True, )), - on_test_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), + on_test_epoch_start=dict(on_step=(False, True), on_epoch=(False, True, )), + on_test_epoch_end=dict(on_step=(False, ), on_epoch=(True,)), on_predict_epoch_start=None, on_predict_epoch_end=None, on_epoch_start=dict(on_step=(False, True), on_epoch=(True, )), diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9a5d9380fc23d..f32ba9c87266d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1200,6 +1200,11 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: # If making changes to this function, ensure that those changes are also made to # TrainLoop._on_train_epoch_end_hook if self.lightning_module: + # restore current_fx when nested context + if self.lightning_module._current_fx_name is not None: + current_fx_name = self.lightning_module._current_fx_name + else: + current_fx_name = None self.lightning_module._current_fx_name = hook_name # always profile hooks @@ -1227,7 +1232,7 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: output = accelerator_output if output is None else output if self.lightning_module: - self.lightning_module._current_fx_name = None + self.lightning_module._current_fx_name = None if (current_fx_name == hook_name) else current_fx_name return output diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 14c26fc647838..1b2b2d3c1e5ea 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -814,9 +814,9 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs): # track gradients grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer) if grad_norm_dict: + import pdb; pdb.set_trace() self.trainer.lightning_module._current_fx_name = "on_after_backward" self.trainer.lightning_module.log_grad_norm(grad_norm_dict) - self.trainer.lightning_module._current_fx_name = None def update_lr_schedulers(self, interval: str) -> None: if interval == "step": diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index b7950c9290ae3..8859a60dc4496 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -364,7 +364,7 @@ def on_epoch_start(self, trainer, pl_module): 'on_epoch_start', 2, on_steps=self.choices, - on_epochs=self.choices, + on_epochs=[True], prob_bars=self.choices ) @@ -374,7 +374,7 @@ def on_validation_epoch_start(self, trainer, pl_module): 'on_validation_epoch_start', 3, on_steps=self.choices, - on_epochs=self.choices, + on_epochs=[True], prob_bars=self.choices ) @@ -395,7 +395,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, def on_epoch_end(self, trainer, pl_module): if trainer.validating: self.make_logging( - pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices + pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=[True], prob_bars=self.choices ) def on_validation_epoch_end(self, trainer, pl_module): @@ -404,7 +404,7 @@ def on_validation_epoch_end(self, trainer, pl_module): 'on_validation_epoch_end', 9, on_steps=[False], - on_epochs=self.choices, + on_epochs=[True], prob_bars=self.choices ) @@ -582,7 +582,7 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal def on_test_epoch_end(self, trainer, pl_module): self.make_logging( - pl_module, 'on_test_epoch_end', 7, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices + pl_module, 'on_test_epoch_end', 7, on_steps=[False], on_epochs=[True], prob_bars=self.choices ) max_epochs = 2 diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py index 6e601b577d648..33927194f1616 100644 --- a/tests/trainer/loops/test_evaluation_loop.py +++ b/tests/trainer/loops/test_evaluation_loop.py @@ -68,7 +68,7 @@ def on_validation_end(self): ) trainer.fit(LessBoringModel()) - assert order == ["log_epoch_metrics", "on_validation_end"] + assert order == ["on_validation_end", "log_epoch_metrics"] @RunIf(min_gpus=1) diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py index 61fae95c70312..a7520c90dc869 100644 --- a/tests/trainer/loops/test_evaluation_loop_flow.py +++ b/tests/trainer/loops/test_evaluation_loop_flow.py @@ -71,7 +71,6 @@ def backward(self, loss, optimizer, optimizer_idx): batch_idx, batch = 0, next(iter(model.train_dataloader())) out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) train_step_out = out.training_step_output assert len(train_step_out) == 1 @@ -144,7 +143,6 @@ def backward(self, loss, optimizer, optimizer_idx): batch_idx, batch = 0, next(iter(model.train_dataloader())) out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert len(out.grad_norm_dict) == 0 and isinstance(out.grad_norm_dict, dict) train_step_out = out.training_step_output assert len(train_step_out) == 1 diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 8589473564f00..9abc9b47ab82a 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -474,7 +474,7 @@ def training_step(self, batch, batch_idx): opt_b.step() opt_b.zero_grad() - return {'loss1': loss_1, 'loss2': loss_2} + return {'loss1': loss_1.detach(), 'loss2': loss_2.detach()} def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer From cabd48b36903b6cb92fa3d47675e2fec5c2e8ec4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Jun 2021 19:34:24 +0000 Subject: [PATCH 343/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../connectors/logger_connector/fx_validator.py | 7 +++++-- pytorch_lightning/trainer/training_loop.py | 3 ++- tests/trainer/logging_/test_eval_loop_logging.py | 14 ++------------ 3 files changed, 9 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index 20b9ddb5ca17d..a185bf65fe015 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -43,8 +43,11 @@ class FxValidator: on_train_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), on_validation_epoch_start=dict(on_step=(False, True), on_epoch=(True, )), on_validation_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), - on_test_epoch_start=dict(on_step=(False, True), on_epoch=(False, True, )), - on_test_epoch_end=dict(on_step=(False, ), on_epoch=(True,)), + on_test_epoch_start=dict(on_step=(False, True), on_epoch=( + False, + True, + )), + on_test_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), on_predict_epoch_start=None, on_predict_epoch_end=None, on_epoch_start=dict(on_step=(False, True), on_epoch=(True, )), diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1b2b2d3c1e5ea..02cef346aad92 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -814,7 +814,8 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs): # track gradients grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer) if grad_norm_dict: - import pdb; pdb.set_trace() + import pdb + pdb.set_trace() self.trainer.lightning_module._current_fx_name = "on_after_backward" self.trainer.lightning_module.log_grad_norm(grad_norm_dict) diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 8859a60dc4496..a4809375a8d40 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -360,12 +360,7 @@ def on_validation_start(self, trainer, pl_module): def on_epoch_start(self, trainer, pl_module): if trainer.validating: self.make_logging( - pl_module, - 'on_epoch_start', - 2, - on_steps=self.choices, - on_epochs=[True], - prob_bars=self.choices + pl_module, 'on_epoch_start', 2, on_steps=self.choices, on_epochs=[True], prob_bars=self.choices ) def on_validation_epoch_start(self, trainer, pl_module): @@ -400,12 +395,7 @@ def on_epoch_end(self, trainer, pl_module): def on_validation_epoch_end(self, trainer, pl_module): self.make_logging( - pl_module, - 'on_validation_epoch_end', - 9, - on_steps=[False], - on_epochs=[True], - prob_bars=self.choices + pl_module, 'on_validation_epoch_end', 9, on_steps=[False], on_epochs=[True], prob_bars=self.choices ) class TestModel(BoringModel): From bc106c333bfa1507608c5a89863992ce143f655a Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 7 Jun 2021 19:41:56 +0000 Subject: [PATCH 344/455] remove pdb --- pytorch_lightning/trainer/training_loop.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1b2b2d3c1e5ea..9b2efa9decc28 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -814,7 +814,6 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs): # track gradients grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer) if grad_norm_dict: - import pdb; pdb.set_trace() self.trainer.lightning_module._current_fx_name = "on_after_backward" self.trainer.lightning_module.log_grad_norm(grad_norm_dict) From 98298e197fcff06daa53491389fc1808328ea0f6 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 7 Jun 2021 20:45:29 +0100 Subject: [PATCH 345/455] resolve flake8 --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4d4317d553da2..c1e4237bfe09b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1200,7 +1200,7 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: # If making changes to this function, ensure that those changes are also made to # TrainLoop._on_train_epoch_end_hook if self.lightning_module: - # restore current_fx when nested context + # restore current_fx when nested context if self.lightning_module._current_fx_name is not None: current_fx_name = self.lightning_module._current_fx_name else: From d8da3cd01fe176f65d4a8e1bc2ea775c4307f2f0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Jun 2021 22:01:24 +0200 Subject: [PATCH 346/455] Update test --- .../logging_/test_eval_loop_logging.py | 318 +++++------------- 1 file changed, 89 insertions(+), 229 deletions(-) diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index b7950c9290ae3..f81481e1ee695 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -300,192 +300,103 @@ def test_log_works_in_val_callback(tmpdir): class TestCallback(callbacks.Callback): - # helpers - count = 1 + count = 0 choices = [False, True] - # used to compute expected values - callback_funcs_called = collections.defaultdict(list) - funcs_called_count = collections.defaultdict(int) - funcs_attr = {} - def make_logging(self, pl_module, func_name, func_idx, on_steps=[], on_epochs=[], prob_bars=[]): - self.funcs_called_count[func_name] += 1 - product = [on_steps, on_epochs, prob_bars] - for idx, (on_step, on_epoch, prog_bar) in enumerate(list(itertools.product(*product))): - # run logging - custom_func_name = f"{func_idx}_{idx}_{func_name}" - pl_module.log( - custom_func_name, self.count * func_idx, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar - ) - # catch information for verification - self.callback_funcs_called[func_name].append([self.count * func_idx]) - self.funcs_attr[custom_func_name] = { - "on_step": on_step, - "on_epoch": on_epoch, - "prog_bar": prog_bar, - "forked": on_step and on_epoch, - "func_name": func_name, - "training": self.log.__self__.trainer.training - } + # used to compute expected values + logged_values = collections.defaultdict(list) + call_counter = collections.Counter() + logged_arguments = {} - if on_step and on_epoch: - self.funcs_attr[f"{custom_func_name}_step"] = { - "on_step": True, - "on_epoch": False, - "prog_bar": prog_bar, - "forked": False, - "func_name": func_name, - "training": self.log.__self__.trainer.training - } + def make_logging(self, pl_module, func_name, on_steps, on_epochs, prob_bars): + self.call_counter.update([func_name]) - self.funcs_attr[f"{custom_func_name}_epoch"] = { - "on_step": False, - "on_epoch": True, - "prog_bar": prog_bar, - "forked": False, - "func_name": func_name, - "training": self.log.__self__.trainer.training - } + for idx, (on_step, on_epoch, prog_bar) in enumerate(itertools.product(on_steps, on_epochs, prob_bars)): + fx = f"{func_name}_{idx}" + pl_module.log(fx, self.count, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar) + self.logged_values[fx].append(self.count) + self.logged_arguments[fx] = {"on_step": on_step, "on_epoch": on_epoch, "prog_bar": prog_bar} + self.count += 1 - def on_validation_start(self, trainer, pl_module): + def on_validation_start(self, _, pl_module): self.make_logging( - pl_module, - 'on_validation_start', - 1, - on_steps=self.choices, - on_epochs=self.choices, - prob_bars=self.choices + pl_module, 'on_validation_start', on_steps=[False], on_epochs=[True], prob_bars=self.choices ) def on_epoch_start(self, trainer, pl_module): if trainer.validating: self.make_logging( - pl_module, - 'on_epoch_start', - 2, - on_steps=self.choices, - on_epochs=self.choices, - prob_bars=self.choices + pl_module, 'on_epoch_start', on_steps=[False], on_epochs=[True], prob_bars=self.choices ) - def on_validation_epoch_start(self, trainer, pl_module): + def on_validation_epoch_start(self, _, pl_module): self.make_logging( - pl_module, - 'on_validation_epoch_start', - 3, - on_steps=self.choices, - on_epochs=self.choices, - prob_bars=self.choices + pl_module, 'on_validation_epoch_start', on_steps=[False], on_epochs=[True], prob_bars=self.choices ) - def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_validation_batch_end(self, _, pl_module, *__): self.make_logging( pl_module, 'on_validation_batch_end', - 7, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices ) - # used to make sure aggregation works fine. - # we should obtain func[value * c for c in range(1, max_epochs * limit_validation_batches)]) - # with func = np.mean if on_epoch else func = np.max - self.count += 1 def on_epoch_end(self, trainer, pl_module): if trainer.validating: - self.make_logging( - pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices - ) + self.make_logging(pl_module, 'on_epoch_end', on_steps=[False], on_epochs=[True], prob_bars=self.choices) - def on_validation_epoch_end(self, trainer, pl_module): + def on_validation_epoch_end(self, _, pl_module): self.make_logging( - pl_module, - 'on_validation_epoch_end', - 9, - on_steps=[False], - on_epochs=self.choices, - prob_bars=self.choices + pl_module, 'on_validation_epoch_end', on_steps=[False], on_epochs=[True], prob_bars=self.choices ) class TestModel(BoringModel): def validation_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) + loss = super().validation_step(batch, batch_idx)['x'] self.log('val_loss', loss) - max_epochs = 1 model = TestModel() model.validation_epoch_end = None - test_callback = TestCallback() - + cb = TestCallback() trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=1, limit_val_batches=4, - limit_test_batches=0, - val_check_interval=0., num_sanity_val_steps=0, - max_epochs=max_epochs, - callbacks=[test_callback], + max_epochs=1, + callbacks=[cb], ) trainer.fit(model) - assert test_callback.funcs_called_count["on_epoch_start"] == 1 - # assert test_callback.funcs_called_count["on_batch_start"] == 1 - assert test_callback.funcs_called_count["on_validation_start"] == 1 - assert test_callback.funcs_called_count["on_validation_epoch_start"] == 1 - # assert test_callback.funcs_called_count["on_validation_batch_start"] == 4 - assert test_callback.funcs_called_count["on_epoch_end"] == 1 - assert test_callback.funcs_called_count["on_validation_batch_end"] == 4 - assert test_callback.funcs_called_count["on_validation_epoch_end"] == 1 - - # Make sure the func_name exists within callback_metrics. If not, we missed some - callback_metrics_keys = [*trainer.callback_metrics.keys()] - for func_name in test_callback.callback_funcs_called.keys(): - is_in = False - for callback_metrics_key in callback_metrics_keys: - if func_name in callback_metrics_key: - is_in = True - assert is_in, (func_name, callback_metrics_keys) + assert cb.call_counter == { + 'on_validation_batch_end': 4, + 'on_validation_start': 1, + 'on_epoch_start': 1, + 'on_validation_epoch_start': 1, + 'on_validation_epoch_end': 1, + 'on_epoch_end': 1 + } - # function used to describe expected return logic - def get_expected_output(func_attr, original_values): - if func_attr["on_step"] and not func_attr["on_epoch"]: - # Keep the latest value - expected_output = np.max(original_values) - else: - # Apply mean on values - expected_output = np.mean(original_values) - return expected_output + def get_expected(on_epoch, values): + reduction = np.mean if on_epoch else np.max + return reduction(values) - # Make sure the func_name output equals the average from all logged values when on_epoch true - # pop extra keys - trainer.callback_metrics.pop("val_loss") - for func_name, output_value in trainer.callback_metrics.items(): - # not sure how to handle this now - if "epoch_0" in func_name: - func_name = '/'.join(func_name.split('/')[:-1]) + for fx, value in trainer.callback_metrics.items(): + actual = value.item() + if fx not in cb.logged_arguments: continue + on_epoch = cb.logged_arguments[fx]['on_epoch'] + values = cb.logged_values[fx] + expected = get_expected(on_epoch, values) + assert actual == expected - if torch.is_tensor(output_value): - output_value = output_value.item() - # get creation attr - func_attr = test_callback.funcs_attr[func_name] - - # retrived orginal logged values - original_values = test_callback.callback_funcs_called[func_attr["func_name"]] - - # compute expected output and compare to actual one - expected_output = get_expected_output(func_attr, original_values) - assert float(output_value) == float(expected_output) - - for func_name, func_attr in test_callback.funcs_attr.items(): - if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]) and not func_attr["forked"]: - assert func_name in trainer.progress_bar_metrics - else: - assert func_name not in trainer.progress_bar_metrics + for fx, attrs in cb.logged_arguments.items(): + should_include = attrs["prog_bar"] and attrs["on_step"] ^ attrs["on_epoch"] + is_included = fx in trainer.logger_connector.progress_bar_metrics + assert is_included if should_include else not is_included def test_log_works_in_test_callback(tmpdir): @@ -496,7 +407,7 @@ def test_log_works_in_test_callback(tmpdir): class TestCallback(callbacks.Callback): # helpers - count = 1 + count = 0 choices = [False, True] # used to compute expected values @@ -504,19 +415,15 @@ class TestCallback(callbacks.Callback): funcs_called_count = collections.defaultdict(int) funcs_attr = {} - def make_logging(self, pl_module, func_name, func_idx, on_steps=[], on_epochs=[], prob_bars=[]): + def make_logging(self, pl_module, func_name, on_steps, on_epochs, prob_bars): original_func_name = func_name[:] self.funcs_called_count[original_func_name] += 1 - product = [on_steps, on_epochs, prob_bars] - for idx, t in enumerate(list(itertools.product(*product))): - # run logging + + for idx, (on_step, on_epoch, prog_bar) in enumerate(itertools.product(on_steps, on_epochs, prob_bars)): func_name = original_func_name[:] - on_step, on_epoch, prog_bar = t - custom_func_name = f"{func_idx}_{idx}_{func_name}" + custom_func_name = f"{idx}_{func_name}" - pl_module.log( - custom_func_name, self.count * func_idx, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar - ) + pl_module.log(custom_func_name, self.count, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar) num_dl_ext = '' if pl_module._current_dataloader_idx is not None: @@ -525,12 +432,11 @@ def make_logging(self, pl_module, func_name, func_idx, on_steps=[], on_epochs=[] func_name += num_dl_ext # catch information for verification - self.callback_funcs_called[func_name].append([self.count * func_idx]) + self.callback_funcs_called[func_name].append([self.count]) self.funcs_attr[custom_func_name + num_dl_ext] = { "on_step": on_step, "on_epoch": on_epoch, "prog_bar": prog_bar, - "forked": on_step and on_epoch, "func_name": func_name } if on_step and on_epoch: @@ -538,7 +444,6 @@ def make_logging(self, pl_module, func_name, func_idx, on_steps=[], on_epochs=[] "on_step": True, "on_epoch": False, "prog_bar": prog_bar, - "forked": False, "func_name": func_name } @@ -546,131 +451,86 @@ def make_logging(self, pl_module, func_name, func_idx, on_steps=[], on_epochs=[] "on_step": False, "on_epoch": True, "prog_bar": prog_bar, - "forked": False, "func_name": func_name } - def on_test_start(self, trainer, pl_module): - self.make_logging( - pl_module, 'on_test_start', 1, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices - ) + def on_test_start(self, _, pl_module): + self.make_logging(pl_module, 'on_test_start', on_steps=[False], on_epochs=[True], prob_bars=self.choices) - def on_test_epoch_start(self, trainer, pl_module): + def on_test_epoch_start(self, _, pl_module): self.make_logging( - pl_module, - 'on_test_epoch_start', - 3, - on_steps=self.choices, - on_epochs=self.choices, - prob_bars=self.choices + pl_module, 'on_test_epoch_start', on_steps=[False], on_epochs=[True], prob_bars=self.choices ) - def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_test_batch_end(self, _, pl_module, *__): self.make_logging( - pl_module, - 'on_test_batch_end', - 5, - on_steps=self.choices, - on_epochs=self.choices, - prob_bars=self.choices + pl_module, 'on_test_batch_end', on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices ) - # used to make sure aggregation works fine. - # we should obtain func[value * c for c in range(1, max_epochs * limit_test_batches)]) - # with func = np.mean if on_epoch else func = np.max - self.count += 1 - - def on_test_epoch_end(self, trainer, pl_module): + def on_test_epoch_end(self, _, pl_module): self.make_logging( - pl_module, 'on_test_epoch_end', 7, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices + pl_module, 'on_test_epoch_end', on_steps=[False], on_epochs=[True], prob_bars=self.choices ) - max_epochs = 2 num_dataloaders = 2 class TestModel(BoringModel): - - manual_mean = collections.defaultdict(list) + seen_losses = {i: [] for i in range(num_dataloaders)} def test_step(self, batch, batch_idx, dataloader_idx=None): - output = self.layer(batch) - loss = self.loss(batch, output) + loss = super().test_step(batch, batch_idx)['y'] self.log('test_loss', loss) - self.manual_mean[str(dataloader_idx)].append(loss) + self.seen_losses[dataloader_idx].append(loss) def test_dataloader(self): return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)] model = TestModel() model.test_epoch_end = None - test_callback = TestCallback() - + cb = TestCallback() trainer = Trainer( default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=0, limit_test_batches=2, - val_check_interval=0., num_sanity_val_steps=0, - max_epochs=max_epochs, - callbacks=[test_callback], + max_epochs=2, + callbacks=[cb], ) trainer.test(model) - assert test_callback.funcs_called_count["on_test_start"] == 1 - assert test_callback.funcs_called_count["on_test_epoch_start"] == 1 - assert test_callback.funcs_called_count["on_test_batch_end"] == 4 - assert test_callback.funcs_called_count["on_test_epoch_end"] == 1 - - # Make sure the func_name exists within callback_metrics. If not, we missed some - callback_metrics_keys = [*trainer.callback_metrics.keys()] + assert cb.funcs_called_count["on_test_start"] == 1 + assert cb.funcs_called_count["on_test_epoch_start"] == 1 + assert cb.funcs_called_count["on_test_batch_end"] == 4 + assert cb.funcs_called_count["on_test_epoch_end"] == 1 - for func_name in test_callback.callback_funcs_called.keys(): + callback_metrics_keys = list(trainer.callback_metrics) + for func_name in cb.callback_funcs_called.keys(): is_in = False for callback_metrics_key in callback_metrics_keys: if func_name in callback_metrics_key: is_in = True assert is_in, (func_name, callback_metrics_keys) - # function used to describe expected return logic - def get_expected_output(func_attr, original_values): - if func_attr["on_step"] and not func_attr["on_epoch"]: - expected_output = np.max(original_values) - else: - expected_output = np.mean(original_values) - return expected_output + def get_expected(on_epoch, values): + reduction = np.mean if on_epoch else np.max + return reduction(values) # Make sure the func_name output equals the average from all logged values when on_epoch true for dl_idx in range(num_dataloaders): key = f"test_loss/dataloader_idx_{dl_idx}" assert key in trainer.callback_metrics - assert torch.stack(model.manual_mean[str(dl_idx)]).mean() == trainer.callback_metrics[key] - trainer.callback_metrics.pop(key) + assert torch.stack(model.seen_losses[dl_idx]).mean() == trainer.callback_metrics.pop(key) for func_name, output_value in trainer.callback_metrics.items(): - # not sure how to handle this now - if "epoch_1" in func_name: - func_name = '/'.join(func_name.split('/')[:-1]) - continue - - if torch.is_tensor(output_value): - output_value = output_value.item() - - # get func attr - func_attr = test_callback.funcs_attr[func_name] - - # retrived orginal logged values - original_values = test_callback.callback_funcs_called[func_attr["func_name"]] - - # compute expected output and compare to actual one - expected_output = get_expected_output(func_attr, original_values) - assert float(output_value) == float(expected_output) - - for func_name, func_attr in test_callback.funcs_attr.items(): - if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]) and not func_attr["forked"]: - assert func_name in trainer.logger_connector.progress_bar_metrics - else: - assert func_name not in trainer.logger_connector.progress_bar_metrics + output_value = output_value.item() + func_attr = cb.funcs_attr[func_name] + original_values = cb.callback_funcs_called[func_attr["func_name"]] + expected_output = get_expected(func_attr['on_epoch'], original_values) + assert output_value == expected_output + + for fx, attrs in cb.funcs_attr.items(): + should_include = attrs["prog_bar"] and attrs["on_step"] ^ attrs["on_epoch"] + is_included = fx in trainer.logger_connector.progress_bar_metrics + assert is_included if should_include else not is_included @mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics") From 49558bf827bc0073f6a2db5026675940b6b2251c Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 7 Jun 2021 20:02:54 +0000 Subject: [PATCH 347/455] more tests --- pytorch_lightning/trainer/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4d4317d553da2..f32ba9c87266d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1009,12 +1009,12 @@ def _run_evaluation(self) -> _EVALUATE_OUTPUT: # hook self.evaluation_loop.on_evaluation_epoch_end() - # log epoch metrics - eval_loop_results = self.logger_connector.get_evaluate_epoch_results() - # hook self.evaluation_loop.on_evaluation_end() + # log epoch metrics + eval_loop_results = self.logger_connector.get_evaluate_epoch_results() + # save predictions to disk self.evaluation_loop.predictions.to_disk() From f27d0a38562a5ca8b77afbfc5f2b987ad5157f40 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Jun 2021 22:10:00 +0200 Subject: [PATCH 348/455] Revert last thomas' changes --- .../connectors/logger_connector/fx_validator.py | 11 ++++------- pytorch_lightning/trainer/trainer.py | 6 +++--- tests/trainer/loops/test_evaluation_loop.py | 2 +- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index a185bf65fe015..8d079f8b4a637 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -29,11 +29,11 @@ class FxValidator: on_fit_end=None, on_sanity_check_start=None, on_sanity_check_end=None, - on_train_start=dict(on_step=(True, False), on_epoch=(False, True)), + on_train_start=dict(on_step=(False, ), on_epoch=(True, )), on_train_end=None, - on_validation_start=dict(on_step=(False, True), on_epoch=(False, True)), + on_validation_start=dict(on_step=(False, ), on_epoch=(True, )), on_validation_end=None, - on_test_start=dict(on_step=(False, True), on_epoch=(False, True)), + on_test_start=dict(on_step=(False, ), on_epoch=(True, )), on_test_end=None, on_predict_start=None, on_predict_end=None, @@ -43,10 +43,7 @@ class FxValidator: on_train_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), on_validation_epoch_start=dict(on_step=(False, True), on_epoch=(True, )), on_validation_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), - on_test_epoch_start=dict(on_step=(False, True), on_epoch=( - False, - True, - )), + on_test_epoch_start=dict(on_step=(False, True), on_epoch=(True, )), on_test_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), on_predict_epoch_start=None, on_predict_epoch_end=None, diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index bd704029d1c21..c1e4237bfe09b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1009,12 +1009,12 @@ def _run_evaluation(self) -> _EVALUATE_OUTPUT: # hook self.evaluation_loop.on_evaluation_epoch_end() - # hook - self.evaluation_loop.on_evaluation_end() - # log epoch metrics eval_loop_results = self.logger_connector.get_evaluate_epoch_results() + # hook + self.evaluation_loop.on_evaluation_end() + # save predictions to disk self.evaluation_loop.predictions.to_disk() diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py index 33927194f1616..6e601b577d648 100644 --- a/tests/trainer/loops/test_evaluation_loop.py +++ b/tests/trainer/loops/test_evaluation_loop.py @@ -68,7 +68,7 @@ def on_validation_end(self): ) trainer.fit(LessBoringModel()) - assert order == ["on_validation_end", "log_epoch_metrics"] + assert order == ["log_epoch_metrics", "on_validation_end"] @RunIf(min_gpus=1) From f4444c468b8db11060239179fb415e0fa6c552de Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 7 Jun 2021 20:22:58 +0000 Subject: [PATCH 349/455] resolve 1 test --- tests/trainer/logging_/test_logger_connector.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 1a3f95fea6367..2c93054497d6c 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -17,7 +17,7 @@ import torch from torch.utils.data import DataLoader from torchmetrics import Accuracy, AveragePrecision - +from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning import LightningModule, seed_everything from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.trainer import Trainer @@ -356,15 +356,15 @@ def _assert_epoch_end(self, stage): acc.reset.assert_called_once() ap.reset.assert_called_once() - def on_train_end(self): - self._assert_epoch_end('train') + def teardown(self, stage): + if stage == TrainerFn.FITTING: + self._assert_epoch_end('train') + self._assert_epoch_end('val') - def on_validation_end(self): - if not self.trainer.sanity_checking: + elif stage == TrainerFn.VALIDATING: self._assert_epoch_end('val') - def on_test_end(self): - if not self.trainer.sanity_checking: + elif stage == TrainerFn.TESTING: self._assert_epoch_end('test') def _assert_called(model, stage): From d144771a3efea1fae21d9a00c7ff1d568e9c483c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Jun 2021 20:24:09 +0000 Subject: [PATCH 350/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/trainer/logging_/test_logger_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 2c93054497d6c..ce21291289880 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -17,12 +17,13 @@ import torch from torch.utils.data import DataLoader from torchmetrics import Accuracy, AveragePrecision -from pytorch_lightning.trainer.states import TrainerFn + from pytorch_lightning import LightningModule, seed_everything from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource, ResultCollection +from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf From ffff6f45b309a98b94a941793f835dea33358c42 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Jun 2021 22:24:50 +0200 Subject: [PATCH 351/455] Refactor context restoration --- pytorch_lightning/trainer/trainer.py | 9 +++------ tests/plugins/test_single_device_plugin.py | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c1e4237bfe09b..84d5d9d0ab706 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1200,11 +1200,7 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: # If making changes to this function, ensure that those changes are also made to # TrainLoop._on_train_epoch_end_hook if self.lightning_module: - # restore current_fx when nested context - if self.lightning_module._current_fx_name is not None: - current_fx_name = self.lightning_module._current_fx_name - else: - current_fx_name = None + prev_fx_name = self.lightning_module._current_fx_name self.lightning_module._current_fx_name = hook_name # always profile hooks @@ -1232,7 +1228,8 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: output = accelerator_output if output is None else output if self.lightning_module: - self.lightning_module._current_fx_name = None if (current_fx_name == hook_name) else current_fx_name + # restore current_fx when nested context + self.lightning_module._current_fx_name = prev_fx_name return output diff --git a/tests/plugins/test_single_device_plugin.py b/tests/plugins/test_single_device_plugin.py index a398d960daf91..2e4834233537e 100644 --- a/tests/plugins/test_single_device_plugin.py +++ b/tests/plugins/test_single_device_plugin.py @@ -38,7 +38,7 @@ def on_train_start(self) -> None: @RunIf(skip_windows=True, min_gpus=1) def test_single_gpu(): - """Tests if device is set correctely when training and after teardown for single GPU plugin.""" + """Tests if device is set correctly when training and after teardown for single GPU plugin.""" trainer = Trainer(gpus=1, fast_dev_run=True) # assert training type plugin attributes for device setting assert isinstance(trainer.training_type_plugin, SingleDevicePlugin) From 9d7202876dad8348710075376908b21eaaf56e93 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Jun 2021 20:45:45 +0000 Subject: [PATCH 352/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/trainer/loops/test_evaluation_loop.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py index 073d8db45e548..6e601b577d648 100644 --- a/tests/trainer/loops/test_evaluation_loop.py +++ b/tests/trainer/loops/test_evaluation_loop.py @@ -21,9 +21,7 @@ from tests.helpers.runif import RunIf -@mock.patch( - "pytorch_lightning.trainer.evaluation_loop.EvaluationLoop.on_evaluation_epoch_end" -) +@mock.patch("pytorch_lightning.trainer.evaluation_loop.EvaluationLoop.on_evaluation_epoch_end") def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir): """ Tests that `on_evaluation_epoch_end` is called From d3302ed3d99341b3c0283f4815467c614224629f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Jun 2021 22:51:19 +0200 Subject: [PATCH 353/455] update --- pytorch_lightning/callbacks/prediction_writer.py | 2 +- pytorch_lightning/trainer/evaluation_loop.py | 13 ++++++++++++- pytorch_lightning/trainer/trainer.py | 13 ------------- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/callbacks/prediction_writer.py b/pytorch_lightning/callbacks/prediction_writer.py index 85a1408104ccd..f277503b2217d 100644 --- a/pytorch_lightning/callbacks/prediction_writer.py +++ b/pytorch_lightning/callbacks/prediction_writer.py @@ -109,7 +109,7 @@ def on_predict_batch_end( if not self.interval.on_batch: return is_distributed = trainer.accelerator_connector.is_distributed - batch_indices = trainer.predict_loop.prediction_loop.current_batch_indices if is_distributed else None + batch_indices = trainer.predict_loop.current_batch_indices if is_distributed else None self.write_on_batch_end(trainer, pl_module, outputs, batch_indices, batch, batch_idx, dataloader_idx) def on_predict_epoch_end( diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 6eca3b0cf582d..977ae10dc5464 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -38,7 +38,18 @@ def __init__(self, trainer: 'pl.Trainer'): self.test_results = ResultCollection(False) def on_trainer_init(self) -> None: - pass + self.trainer.num_sanity_val_batches = [] + self.trainer.num_test_batches = [] + self.trainer.num_val_batches = [] + self.trainer.test_dataloaders = None + self.trainer.val_dataloaders = None + + # .validate() and .test() set this when they load a checkpoint + self.trainer.validated_ckpt_path = None + self.trainer.tested_ckpt_path = None + + # when true, print evaluation results in .validate() and .test() + self.trainer.verbose_evaluate = True def get_evaluation_dataloaders(self) -> Tuple[Optional[List[DataLoader]], List[Union[int, float]]]: model = self.trainer.lightning_module diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c386230995e30..80ad6ea33f8ef 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -433,19 +433,6 @@ def _setup_on_init( else: self.num_sanity_val_steps = num_sanity_val_steps - self.num_sanity_val_batches = [] - self.num_test_batches = [] - self.num_val_batches = [] - self.test_dataloaders = None - self.val_dataloaders = None - - # .validate() and .test() set this when they load a checkpoint - self.validated_ckpt_path = None - self.tested_ckpt_path = None - - # when true, print evaluation results in .validate() and .test() - self.verbose_evaluate = True - self.num_predict_batches = [] self.predicted_ckpt_path = None From 601aa9592d3d46094530dc3de72c74ad008153a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Jun 2021 22:51:59 +0200 Subject: [PATCH 354/455] x --- pytorch_lightning/callbacks/prediction_writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/prediction_writer.py b/pytorch_lightning/callbacks/prediction_writer.py index f277503b2217d..cbcff74ff0278 100644 --- a/pytorch_lightning/callbacks/prediction_writer.py +++ b/pytorch_lightning/callbacks/prediction_writer.py @@ -109,7 +109,7 @@ def on_predict_batch_end( if not self.interval.on_batch: return is_distributed = trainer.accelerator_connector.is_distributed - batch_indices = trainer.predict_loop.current_batch_indices if is_distributed else None + batch_indices = trainer.predict_loop.batch_indices if is_distributed else None self.write_on_batch_end(trainer, pl_module, outputs, batch_indices, batch, batch_idx, dataloader_idx) def on_predict_epoch_end( From e272385fccf85186a5f7c6e310a5331de5edd8c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Jun 2021 22:54:57 +0200 Subject: [PATCH 355/455] on trainer init --- pytorch_lightning/trainer/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 80ad6ea33f8ef..7ff798f154a0e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -389,6 +389,9 @@ def __init__( truncated_bptt_steps, terminate_on_nan, ) + + self.evaluation_loop.on_trainer_init() + self.predict_loop.on_trainer_init() self._setup_on_init(num_sanity_val_steps) # configure tuner @@ -433,9 +436,6 @@ def _setup_on_init( else: self.num_sanity_val_steps = num_sanity_val_steps - self.num_predict_batches = [] - self.predicted_ckpt_path = None - def _setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): # clean hparams if hasattr(model, "hparams"): From 1763d8f59b9ad4a5f54f55d0c9b05e532c442e7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Jun 2021 22:57:30 +0200 Subject: [PATCH 356/455] test --- tests/trainer/test_trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a9f95c65ce228..219d911465ff3 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1069,9 +1069,9 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches): assert trainer.num_sanity_val_steps == num_sanity_val_steps with patch.object( - trainer.evaluation_loop.evaluation_loop, + trainer.evaluation_loop, "evaluation_step", - wraps=trainer.evaluation_loop.evaluation_loop.evaluation_step + wraps=trainer.evaluation_loop.evaluation_step ) as mocked: val_dataloaders = model.val_dataloader__multiple_mixed_length() trainer.fit(model, val_dataloaders=val_dataloaders) @@ -1099,9 +1099,9 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): assert trainer.num_sanity_val_steps == float("inf") with patch.object( - trainer.evaluation_loop.evaluation_loop, + trainer.evaluation_loop, "evaluation_step", - wraps=trainer.evaluation_loop.evaluation_loop.evaluation_step + wraps=trainer.evaluation_loop.evaluation_step ) as mocked: val_dataloaders = model.val_dataloader__multiple() trainer.fit(model, val_dataloaders=val_dataloaders) From d7184982ebef41d0678596afce8c4419395a2e23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Jun 2021 22:58:30 +0200 Subject: [PATCH 357/455] update trainer --- pytorch_lightning/trainer/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7ff798f154a0e..2b398eab41d2c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1114,12 +1114,12 @@ def _run_evaluation(self) -> _EVALUATE_OUTPUT: # hook self.evaluation_loop.on_evaluation_epoch_end() - # hook - self.evaluation_loop.on_evaluation_end() - # log epoch metrics eval_loop_results = self.logger_connector.get_evaluate_epoch_results() + # hook + self.evaluation_loop.on_evaluation_end() + # save predictions to disk self.evaluation_loop.predictions.to_disk() From 6d98a078b56cc67873f975cc493bbc19415a1985 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Jun 2021 23:16:37 +0200 Subject: [PATCH 358/455] integrate latest changes from logger connector refactor poc --- pytorch_lightning/loops/batch_loop.py | 30 +++++++++++++++++------- pytorch_lightning/loops/epoch_loop.py | 3 ++- pytorch_lightning/loops/training_loop.py | 9 +++---- 3 files changed, 28 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 61f76d98b9db6..21dcfaf5571fc 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -2,7 +2,7 @@ from contextlib import contextmanager from copy import copy from functools import partial, update_wrapper -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple, Mapping import numpy as np import torch @@ -55,23 +55,23 @@ def done(self): def run(self, batch, batch_idx, dataloader_idx): if batch is None: self.warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") - return AttributeDict(signal=0, grad_norm_dic={}, training_step_output=[[]]) + return AttributeDict(signal=0, training_step_output=[[]]) # hook + self.trainer.logger_connector.on_batch_start() response = self.trainer.call_hook("on_batch_start") if response == -1: - return AttributeDict(signal=-1, grad_norm_dic={}) + return AttributeDict(signal=-1) # hook response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx) if response == -1: - return AttributeDict(signal=-1, grad_norm_dic={}) + return AttributeDict(signal=-1) super().run(batch, batch_idx, dataloader_idx) return AttributeDict( signal=0, - grad_norm_dict=self.grad_norm_dicts[-1], training_step_output=self.batch_outputs, ) @@ -91,7 +91,7 @@ def advance(self, batch, batch_idx, dataloader_idx): self.split_idx = split_idx # let logger connector extract current batch size - self.trainer.logger_connector.on_train_split_start(batch_idx, split_batch) + self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch) # TODO: this list needs to go outside this loop # batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] @@ -102,14 +102,13 @@ def advance(self, batch, batch_idx, dataloader_idx): result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer) if result: self.batch_outputs[opt_idx].append(result.training_step_output) - grad_norm_dict = result.get("grad_norm_dict", {}) else: # in manual optimization, there is no looping over optimizers result = self._run_optimization(batch_idx, split_batch) if result: self.batch_outputs[0].append(result.training_step_output) - # TODO: Properly aggregate grad_norm accross opt_idx and split_idx + # TODO: needed? self.grad_norm_dicts.append(grad_norm_dict) @@ -210,6 +209,16 @@ def _check_training_step_output(self, training_step_output): if training_step_output.grad_fn is None: # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor") + elif self.trainer.lightning_module.automatic_optimization: + if not any(( + isinstance(training_step_output, torch.Tensor), + (isinstance(training_step_output, Mapping) + and 'loss' in training_step_output), training_step_output is None + )): + raise MisconfigurationException( + "In automatic optimization, `training_step` must either return a Tensor, " + "a dict with key 'loss' or None (where the step will be skipped)." + ) def training_step(self, split_batch, batch_idx, opt_idx, hiddens): # give the PL module a result for logging @@ -460,7 +469,10 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs): if not self.should_accumulate(): # track gradients - result.grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer) + grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer) + if grad_norm_dict: + self.trainer.lightning_module._current_fx_name = "on_after_backward" + self.trainer.lightning_module.log_grad_norm(grad_norm_dict) def update_running_loss(self, current_loss: torch.Tensor) -> None: if self.trainer.lightning_module.automatic_optimization: diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 60fda5447bacf..a84225becb9d4 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -130,7 +130,7 @@ def run(self): return super().run() def on_run_start(self): - self.trainer.logger_connector.on_train_start() + self.trainer.result_collection.device = self.trainer.lightning_module.device self.trainer.call_hook("on_train_start") def on_advance_start(self): # equal to old on_train_epoch_start @@ -154,6 +154,7 @@ def on_advance_start(self): # equal to old on_train_epoch_start ) # hook + self.trainer.logger_connector.on_epoch_start() self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index 0e4fbde1ac097..cec18fc8dda53 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -144,7 +144,7 @@ def on_run_end(self): return # inform logger the batch loop has finished - self.trainer.logger_connector.on_train_epoch_end() + self.trainer.logger_connector.epoch_end_reached() # prepare epoch output processed_outputs = self._prepare_outputs(self.epoch_output, batch_mode=False) @@ -169,6 +169,7 @@ def on_run_end(self): # call train epoch end hooks self._on_train_epoch_end_hook(processed_outputs) self.trainer.call_hook('on_epoch_end') + self.trainer.logger_connector.on_epoch_end() return self.epoch_output @@ -208,9 +209,8 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: else: model_ref.on_train_epoch_end() - # if the PL module doesn't have the hook then call the accelerator - # used to auto-reduce things for the user with Results obj - elif hasattr(self.trainer.accelerator, hook_name): + # call the accelerator hook + if hasattr(self.trainer.accelerator, hook_name): accelerator_hook = getattr(self.trainer.accelerator, hook_name) accelerator_hook() @@ -232,6 +232,7 @@ def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, # hook self.trainer.call_hook('on_train_batch_end', processed_batch_end_outputs, batch, batch_idx, dataloader_idx) self.trainer.call_hook('on_batch_end') + self.trainer.logger_connector.on_batch_end() # figure out what to track for epoch end self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs) From 7ca1049d679d697a6cd18480375bda7a61021fd3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Jun 2021 21:18:18 +0000 Subject: [PATCH 359/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/batch_loop.py | 8 ++++---- tests/trainer/test_trainer.py | 8 ++------ 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 21dcfaf5571fc..c8c185e1d1966 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -2,7 +2,7 @@ from contextlib import contextmanager from copy import copy from functools import partial, update_wrapper -from typing import Any, Callable, List, Optional, Tuple, Mapping +from typing import Any, Callable, List, Mapping, Optional, Tuple import numpy as np import torch @@ -211,9 +211,9 @@ def _check_training_step_output(self, training_step_output): raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor") elif self.trainer.lightning_module.automatic_optimization: if not any(( - isinstance(training_step_output, torch.Tensor), - (isinstance(training_step_output, Mapping) - and 'loss' in training_step_output), training_step_output is None + isinstance(training_step_output, torch.Tensor), + (isinstance(training_step_output, Mapping) + and 'loss' in training_step_output), training_step_output is None )): raise MisconfigurationException( "In automatic optimization, `training_step` must either return a Tensor, " diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 219d911465ff3..aa4a78bf67912 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1069,9 +1069,7 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches): assert trainer.num_sanity_val_steps == num_sanity_val_steps with patch.object( - trainer.evaluation_loop, - "evaluation_step", - wraps=trainer.evaluation_loop.evaluation_step + trainer.evaluation_loop, "evaluation_step", wraps=trainer.evaluation_loop.evaluation_step ) as mocked: val_dataloaders = model.val_dataloader__multiple_mixed_length() trainer.fit(model, val_dataloaders=val_dataloaders) @@ -1099,9 +1097,7 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): assert trainer.num_sanity_val_steps == float("inf") with patch.object( - trainer.evaluation_loop, - "evaluation_step", - wraps=trainer.evaluation_loop.evaluation_step + trainer.evaluation_loop, "evaluation_step", wraps=trainer.evaluation_loop.evaluation_step ) as mocked: val_dataloaders = model.val_dataloader__multiple() trainer.fit(model, val_dataloaders=val_dataloaders) From 515ad9fc2c98210f1b11b468e393b707ef188350 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 00:26:17 +0200 Subject: [PATCH 360/455] Minor changes --- .../trainer/connectors/logger_connector/logger_connector.py | 6 ++---- pytorch_lightning/utilities/metrics.py | 1 - 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 1c848cdeccce2..34dec3460e935 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -217,10 +217,8 @@ def update_train_step_metrics(self, batch_output: AttributeDict) -> None: # when metrics should be logged assert not self._epoch_end_reached metrics = self.metrics[MetricSource.LOG] - if self.should_update_logs or self.trainer.fast_dev_run is True: - # logs user requested information to logger - if metrics: - self.log_metrics(metrics) + if self.should_update_logs or self.trainer.fast_dev_run is True and metrics: + self.log_metrics(metrics) def update_train_epoch_metrics(self) -> None: # add the metrics to the loggers diff --git a/pytorch_lightning/utilities/metrics.py b/pytorch_lightning/utilities/metrics.py index bd57470dc270e..7ee6caf21743b 100644 --- a/pytorch_lightning/utilities/metrics.py +++ b/pytorch_lightning/utilities/metrics.py @@ -21,7 +21,6 @@ def metrics_to_scalars(metrics: dict) -> dict: """ Recursively walk through a dictionary of metrics and convert single-item tensors to scalar values. """ - # TODO: this is duplicated in MetricsHolder. should be unified new_metrics = {} for k, v in metrics.items(): if isinstance(v, torch.Tensor): From b03591c75707a4f9b599c3a17340d0c570c0798a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 8 Jun 2021 00:26:29 +0200 Subject: [PATCH 361/455] update changelog --- CHANGELOG.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 347e1a5d901fe..0bfb95fcfdec2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -96,8 +96,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Simplified logic for updating the learning rate for schedulers ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682)) * Removed the `on_epoch` guard from the "should stop" validation check ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701)) ... - * Refactored internal loop interface; added new classes `EpochLoop`, `TrainingLoop`, `BatchLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700)) - + * Refactored internal loop interface; added new classes `EpochLoop`, `TrainingLoop`, `BatchLoop` ([#7871](https://github.com/PyTorchLightning/pytorch-lightning/pull/7871)) - Refactored logging * Renamed and moved `core/step_result.py` to `trainer/connectors/logger_connector/result.py` ([#7736](https://github.com/PyTorchLightning/pytorch-lightning/pull/7736)) From 0aa8428a79b81d3f429b41e3dd92abbfbc6a5a5e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 00:42:32 +0200 Subject: [PATCH 362/455] Remove unused argument --- .../trainer/connectors/logger_connector/logger_connector.py | 4 ++-- pytorch_lightning/trainer/training_loop.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 34dec3460e935..ac2ecb0c71e03 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -22,7 +22,7 @@ from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger from pytorch_lightning.trainer.connectors.logger_connector.result import _METRIC, MetricSource from pytorch_lightning.trainer.states import RunningStage, TrainerFn -from pytorch_lightning.utilities import AttributeDict, DeviceType +from pytorch_lightning.utilities import DeviceType from pytorch_lightning.utilities.metrics import metrics_to_scalars from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT @@ -210,7 +210,7 @@ def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) self._batch_idx = batch_idx self._split_idx = split_idx - def update_train_step_metrics(self, batch_output: AttributeDict) -> None: + def update_train_step_metrics(self) -> None: if self.trainer.train_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization: return diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 9b2efa9decc28..7fd0366af45ac 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -487,7 +487,7 @@ def run_training_epoch(self): # ----------------------------------------- # SAVE METRICS TO LOGGERS AND PROGRESS_BAR # ----------------------------------------- - self.trainer.logger_connector.update_train_step_metrics(batch_output) + self.trainer.logger_connector.update_train_step_metrics() # ----------------------------------------- # VALIDATE IF NEEDED From 24b41e390b24d0e747112dd7d936bc7ac5d00c0a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 01:12:13 +0200 Subject: [PATCH 363/455] Update CHANGELOG --- CHANGELOG.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b2d79ba0fa4a..664bb31d163b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -98,7 +98,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Refactored logging * Renamed and moved `core/step_result.py` to `trainer/connectors/logger_connector/result.py` ([#7736](https://github.com/PyTorchLightning/pytorch-lightning/pull/7736)) - + * Allow passing `self.log(batch_size=...)` ([#7631](https://github.com/PyTorchLightning/pytorch-lightning/pull/7631)) + * Allow passing `self.log(metric_attribute='your_metric')` to properly serialize the state of any `torchmetrics.Metric`s in your model ([#7631](https://github.com/PyTorchLightning/pytorch-lightning/pull/7631)) + * Remove `EpochResultStore` and `HookResultStore` in favor of `ResultCollection` ([#7631](https://github.com/PyTorchLightning/pytorch-lightning/pull/7631)) + * Dramatically simplify the `LoggerConnector` ([#7631](https://github.com/PyTorchLightning/pytorch-lightning/pull/7631)) + * `trainer.{logged,progress_bar,callback}_metrics` are now updated on-demand ([#7631](https://github.com/PyTorchLightning/pytorch-lightning/pull/7631)) + * Remove `MetricsHolder` ([#7631](https://github.com/PyTorchLightning/pytorch-lightning/pull/7631)) + * Completely overhaul the `Result` object in favor of `ResultMetric` ([#7631](https://github.com/PyTorchLightning/pytorch-lightning/pull/7631)) + * Each of the training loops now keeps its own metrics ([#7631](https://github.com/PyTorchLightning/pytorch-lightning/pull/7631)) + * Improve epoch-level reduction time and overall memory usage ([#7631](https://github.com/PyTorchLightning/pytorch-lightning/pull/7631)) - Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/)) From 6d71e6afd73944c84e39db12e83a38a92d17940b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 01:18:21 +0200 Subject: [PATCH 364/455] Copy call_hook changes --- pytorch_lightning/trainer/training_loop.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 7fd0366af45ac..372d3c9de20bb 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -576,7 +576,8 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: # This implementation is copied from Trainer.call_hook hook_name = "on_train_epoch_end" - self.trainer.lightning_module._current_fx_name = hook_name + prev_fx_name = self.lightning_module._current_fx_name + self.lightning_module._current_fx_name = hook_name # always profile hooks with self.trainer.profiler.profile(hook_name): @@ -605,7 +606,8 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: accelerator_hook = getattr(self.trainer.accelerator, hook_name) accelerator_hook() - self.trainer.lightning_module._current_fx_name = None + # restore current_fx when nested context + self.lightning_module._current_fx_name = prev_fx_name def run_training_batch(self, batch, batch_idx, dataloader_idx): model_ref = self.trainer.lightning_module From 44ad4ac2b3105a8900930b5227866535e62c6915 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 01:25:04 +0200 Subject: [PATCH 365/455] Docs --- docs/source/extensions/logging.rst | 4 ++++ pytorch_lightning/core/lightning.py | 5 +++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/source/extensions/logging.rst b/docs/source/extensions/logging.rst index 107eca2dd9d74..12760f0ee6898 100644 --- a/docs/source/extensions/logging.rst +++ b/docs/source/extensions/logging.rst @@ -68,6 +68,10 @@ except functions with `batch_start` in their names. def training_step(self, batch, batch_idx): self.log('my_metric', x) + # or a dict + def training_step(self, batch, batch_idx): + self.log('performance', {'acc': acc, 'recall': recall}) + Depending on where log is called from, Lightning auto-determines the correct logging mode for you. \ But of course you can override the default behavior by manually setting the :func:`~~pytorch_lightning.core.lightning.LightningModule.log` parameters. diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index dd15cc2d9f14b..91aac91726d4b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -297,7 +297,7 @@ def log( Args: name: key to log - value: value to log + value: value to log. Can be a ``float``, ``Tensor``, ``Metric``, or a dictionary of the former. prog_bar: if True logs to the progress bar logger: if True logs to the logger on_step: if True logs at this step. None auto-logs at the training_step but not validation/test_step @@ -426,7 +426,8 @@ def log_dict( self.log_dict(values) Args: - dictionary: key value pairs (str, tensors) + dictionary: key value pairs. + The values can be a ``float``, ``Tensor``, ``Metric``, or a dictionary of the former. prog_bar: if True logs to the progress base logger: if True logs to the logger on_step: if True logs at this step. None auto-logs for training_step but not validation/test_step From 2c74018d6a671152297cbd9aeb7722d1643e7782 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 01:58:45 +0200 Subject: [PATCH 366/455] Fix ref --- pytorch_lightning/trainer/training_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 372d3c9de20bb..7947684fcbd03 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -576,8 +576,8 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: # This implementation is copied from Trainer.call_hook hook_name = "on_train_epoch_end" - prev_fx_name = self.lightning_module._current_fx_name - self.lightning_module._current_fx_name = hook_name + prev_fx_name = self.trainer.lightning_module._current_fx_name + self.trainer.lightning_module._current_fx_name = hook_name # always profile hooks with self.trainer.profiler.profile(hook_name): @@ -607,7 +607,7 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: accelerator_hook() # restore current_fx when nested context - self.lightning_module._current_fx_name = prev_fx_name + self.trainer.lightning_module._current_fx_name = prev_fx_name def run_training_batch(self, batch, batch_idx, dataloader_idx): model_ref = self.trainer.lightning_module From 9747023611ee57debb5b29180d3987d9f308889f Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 8 Jun 2021 09:51:22 +0000 Subject: [PATCH 367/455] move to cpu --- .../trainer/connectors/logger_connector/logger_connector.py | 6 ++++++ .../trainer/connectors/logger_connector/result.py | 5 +++++ pytorch_lightning/trainer/trainer.py | 1 + tests/plugins/test_single_device_plugin.py | 2 ++ 4 files changed, 14 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index ac2ecb0c71e03..8f4dede939479 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -230,6 +230,12 @@ def update_train_epoch_metrics(self) -> None: # reset result collection for next epoch self.trainer.result_collection.reset(metrics=True) + def teardown(self): + self.trainer.train_loop.train_results.cpu() + self.trainer.evaluation_loop.validation_results.cpu() + self.trainer.evaluation_loop.test_results.cpu() + import pdb; pdb.set_trace() + """ Utilities and properties """ diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index f2a0750b9e570..4178066a01c78 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -479,6 +479,11 @@ def to(self, *args, **kwargs) -> 'ResultCollection': for k, v in self.items(): if isinstance(v, (torch.Tensor, Metric)): self[k] = v.to(*args, **kwargs) + state = self.__getstate__() + for k, v in state.items(): + if isinstance(v, (torch.Tensor, Metric)): + state[k] = v.to(*args, **kwargs) + self.__dict__.update(state) return self def cpu(self) -> 'ResultCollection': diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 84d5d9d0ab706..b6088f045566a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -806,6 +806,7 @@ def _pre_dispatch(self): def _post_dispatch(self): self.accelerator.post_dispatch(self) self.accelerator.teardown() + self.logger_connector.teardown() def _dispatch(self): if self.evaluating: diff --git a/tests/plugins/test_single_device_plugin.py b/tests/plugins/test_single_device_plugin.py index 2e4834233537e..b30bb6da20189 100644 --- a/tests/plugins/test_single_device_plugin.py +++ b/tests/plugins/test_single_device_plugin.py @@ -39,6 +39,7 @@ def on_train_start(self) -> None: @RunIf(skip_windows=True, min_gpus=1) def test_single_gpu(): """Tests if device is set correctly when training and after teardown for single GPU plugin.""" + torch.cuda.empty_cache() trainer = Trainer(gpus=1, fast_dev_run=True) # assert training type plugin attributes for device setting assert isinstance(trainer.training_type_plugin, SingleDevicePlugin) @@ -51,6 +52,7 @@ def test_single_gpu(): trainer.fit(model) # assert after training, model is moved to CPU and memory is deallocated + import pdb; pdb.set_trace() assert model.device == torch.device("cpu") cuda_memory = torch.cuda.memory_allocated() assert cuda_memory < model.start_cuda_memory From d9ae37a16788a98f6acb0bfe22717ab8e494c138 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 11:50:46 +0200 Subject: [PATCH 368/455] Bad merge --- pytorch_lightning/trainer/training_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 7947684fcbd03..2223da4095377 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -286,10 +286,10 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): training_step_output = self.trainer.accelerator.training_step(step_kwargs) self.trainer.accelerator.post_training_step() - self._check_training_step_output(training_step_output) - training_step_output = self.trainer.call_hook("training_step_end", training_step_output) + self._check_training_step_output(training_step_output) + training_step_output = self._process_training_step_output(training_step_output) if training_step_output is None: return From bad51c6712ba4e481a1b86427f4ecd4cd9c95c93 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Jun 2021 09:52:27 +0000 Subject: [PATCH 369/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../trainer/connectors/logger_connector/logger_connector.py | 3 ++- tests/plugins/test_single_device_plugin.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 8f4dede939479..ef98e2b2dc81a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -234,7 +234,8 @@ def teardown(self): self.trainer.train_loop.train_results.cpu() self.trainer.evaluation_loop.validation_results.cpu() self.trainer.evaluation_loop.test_results.cpu() - import pdb; pdb.set_trace() + import pdb + pdb.set_trace() """ Utilities and properties diff --git a/tests/plugins/test_single_device_plugin.py b/tests/plugins/test_single_device_plugin.py index b30bb6da20189..33b96570896a0 100644 --- a/tests/plugins/test_single_device_plugin.py +++ b/tests/plugins/test_single_device_plugin.py @@ -52,7 +52,8 @@ def test_single_gpu(): trainer.fit(model) # assert after training, model is moved to CPU and memory is deallocated - import pdb; pdb.set_trace() + import pdb + pdb.set_trace() assert model.device == torch.device("cpu") cuda_memory = torch.cuda.memory_allocated() assert cuda_memory < model.start_cuda_memory From 273bc92988709e561900e05890c8d0adb8a2300b Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 8 Jun 2021 09:52:54 +0000 Subject: [PATCH 370/455] remove pdb --- .../trainer/connectors/logger_connector/logger_connector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 8f4dede939479..2a37ae4e5c033 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -234,7 +234,6 @@ def teardown(self): self.trainer.train_loop.train_results.cpu() self.trainer.evaluation_loop.validation_results.cpu() self.trainer.evaluation_loop.test_results.cpu() - import pdb; pdb.set_trace() """ Utilities and properties From f214632eca30841cc5418a562e60163c6766eccc Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 8 Jun 2021 09:53:25 +0000 Subject: [PATCH 371/455] remove pdb --- tests/plugins/test_single_device_plugin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/plugins/test_single_device_plugin.py b/tests/plugins/test_single_device_plugin.py index b30bb6da20189..7cf123ae9aa0d 100644 --- a/tests/plugins/test_single_device_plugin.py +++ b/tests/plugins/test_single_device_plugin.py @@ -52,7 +52,6 @@ def test_single_gpu(): trainer.fit(model) # assert after training, model is moved to CPU and memory is deallocated - import pdb; pdb.set_trace() assert model.device == torch.device("cpu") cuda_memory = torch.cuda.memory_allocated() assert cuda_memory < model.start_cuda_memory From 99543a758c188e18ffb9ea1e3eb4507d07cd49a4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 12:07:56 +0200 Subject: [PATCH 372/455] Refactor to --- .../connectors/logger_connector/result.py | 19 +++++++++++-------- tests/plugins/test_single_device_plugin.py | 1 - 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 4178066a01c78..014e9d9fcb17a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -476,14 +476,17 @@ def _extract_batch_size(self, batch: Any) -> int: def to(self, *args, **kwargs) -> 'ResultCollection': """Move all data to the given device.""" - for k, v in self.items(): - if isinstance(v, (torch.Tensor, Metric)): - self[k] = v.to(*args, **kwargs) - state = self.__getstate__() - for k, v in state.items(): - if isinstance(v, (torch.Tensor, Metric)): - state[k] = v.to(*args, **kwargs) - self.__dict__.update(state) + + def to_(item: Union[torch.Tensor, Metric], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Metric]: + return item.to(*args, **kwargs) + + apply_to_collection(self, (torch.Tensor, Metric), to_, *args, **kwargs) + + if self.minimize is not None: + self.minimize = self.minimize.to(*args, **kwargs) + self._batch_size = self._batch_size.to(*args, **kwargs) + if 'device' in kwargs: + self.device = kwargs['device'] return self def cpu(self) -> 'ResultCollection': diff --git a/tests/plugins/test_single_device_plugin.py b/tests/plugins/test_single_device_plugin.py index 7cf123ae9aa0d..2e4834233537e 100644 --- a/tests/plugins/test_single_device_plugin.py +++ b/tests/plugins/test_single_device_plugin.py @@ -39,7 +39,6 @@ def on_train_start(self) -> None: @RunIf(skip_windows=True, min_gpus=1) def test_single_gpu(): """Tests if device is set correctly when training and after teardown for single GPU plugin.""" - torch.cuda.empty_cache() trainer = Trainer(gpus=1, fast_dev_run=True) # assert training type plugin attributes for device setting assert isinstance(trainer.training_type_plugin, SingleDevicePlugin) From 738c810952a35ac58bf0f5c51017616988f89bf3 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 12:18:09 +0200 Subject: [PATCH 373/455] Avoid partial --- .../trainer/connectors/logger_connector/result.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 014e9d9fcb17a..40481dd9afb68 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -358,7 +358,7 @@ def fn(result_metric, v): apply_to_collections(self[key], value, ResultMetric, fn) @staticmethod - def _get_cache(on_step: bool, result_metric: ResultMetric) -> Optional[torch.Tensor]: + def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Tensor]: cache = None if on_step and result_metric.meta.on_step: cache = result_metric._forward_cache @@ -395,9 +395,7 @@ def metrics(self, on_step: bool) -> Dict[MetricSource, Dict[str, _METRIC]]: for key, result_metric in self.valid_items(): # extract forward_cache or computed from the ResultMetric. ignore when the output is None - value = apply_to_collection( - result_metric, ResultMetric, partial(self._get_cache, on_step), include_none=False - ) + value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False) # check if the collection is empty has_tensor = False From 6a7637d2a1307ccff9a411fe2a8b798af1b99e36 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 13:03:12 +0200 Subject: [PATCH 374/455] trigger ci From aff9e3dcf6b91e844003d5644672747162e5ef8a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 13:33:24 +0200 Subject: [PATCH 375/455] Bad merge --- pytorch_lightning/core/lightning.py | 30 ++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 91aac91726d4b..759f28911c749 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -472,6 +472,21 @@ def __check_allowed(name: str, value: Any, v: Any) -> None: def __to_float(self, value: numbers.Number) -> torch.Tensor: return torch.tensor(value, device=self.device, dtype=torch.float) + def log_grad_norm(self, grad_norm_dict: Dict[str, torch.Tensor]) -> None: + """Override this method to change the default behaviour of ``log_grad_norm``. + + Args: + grad_norm_dict: Dictionary containing current grad norm metrics + + Examples:: + + # DEFAULT + def log_grad_norm(self, grad_norm_dict): + self.log_dict(grad_norm_dict, on_step=False, on_epoch=True, prog_bar=False, logger=True) + + """ + self.log_dict(grad_norm_dict, on_step=True, on_epoch=True, prog_bar=True, logger=True) + def write_prediction( self, name: str, value: Union[torch.Tensor, List[torch.Tensor]], filename: str = 'predictions.pt' ): @@ -1531,21 +1546,6 @@ def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): """ optimizer.zero_grad() - def log_grad_norm(self, grad_norm_dict: Dict[str, torch.Tensor]) -> None: - """Override this method to change the default behaviour of ``log_grad_norm``. - - Args: - grad_norm_dict: Dictionary containing current grad norm metrics - - Examples:: - - # DEFAULT - def log_grad_norm(self, grad_norm_dict): - self.log_dict(grad_norm_dict, on_step=False, on_epoch=True, prog_bar=False, logger=True) - - """ - self.log_dict(grad_norm_dict, on_step=True, on_epoch=True, prog_bar=True, logger=True) - def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list: r""" When using truncated backpropagation through time, each batch must be split along the From 461332bc11458e6891d0c1618f8683bd604390eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 8 Jun 2021 14:06:14 +0200 Subject: [PATCH 376/455] integrate latest logger connector changes --- pytorch_lightning/loops/batch_loop.py | 21 ++++++--------------- pytorch_lightning/loops/training_loop.py | 7 ++++--- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index c8c185e1d1966..c8a0e9f748572 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -70,10 +70,7 @@ def run(self, batch, batch_idx, dataloader_idx): super().run(batch, batch_idx, dataloader_idx) - return AttributeDict( - signal=0, - training_step_output=self.batch_outputs, - ) + return AttributeDict(signal=0, training_step_output=self.batch_outputs) def reset(self) -> None: # self.iteration_count = 0 @@ -233,10 +230,10 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): training_step_output = self.trainer.accelerator.training_step(step_kwargs) self.trainer.accelerator.post_training_step() - self._check_training_step_output(training_step_output) - training_step_output = self.trainer.call_hook("training_step_end", training_step_output) + self._check_training_step_output(training_step_output) + training_step_output = self._process_training_step_output(training_step_output) if training_step_output is None: return @@ -317,7 +314,9 @@ def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): def track_and_norm_grad(self, optimizer) -> dict: # track gradient norms - grad_norm_dict = self._track_gradient_norm() + grad_norm_dict = {} + if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 and float(self.trainer.track_grad_norm) > 0: + grad_norm_dict = grad_norm(self.trainer.lightning_module, self.trainer.track_grad_norm) # clip gradients self.trainer.accelerator.clip_gradients( @@ -325,14 +324,6 @@ def track_and_norm_grad(self, optimizer) -> dict: ) return grad_norm_dict - def _track_gradient_norm(self): - grad_norm_dict = {} - if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0: - if float(self.trainer.track_grad_norm) > 0: - model = self.trainer.lightning_module - grad_norm_dict = grad_norm(model, self.trainer.track_grad_norm) - return grad_norm_dict - def _accumulated_batches_reached(self): # TODO: use progress tracking of batches instead of iteration count, because iteration count may reset # iteration count is required to be global here, not reset diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index cec18fc8dda53..c7f34ffd6a702 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -108,7 +108,7 @@ def advance(self, dataloader_iter: Iterator, **kwargs): # ----------------------------------------- # SAVE METRICS TO LOGGERS AND PROGRESS_BAR # ----------------------------------------- - self.trainer.logger_connector.update_train_step_metrics(batch_output) + self.trainer.logger_connector.update_train_step_metrics() def on_advance_end(self): # ----------------------------------------- @@ -184,7 +184,7 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: # This implementation is copied from Trainer.call_hook hook_name = "on_train_epoch_end" - + prev_fx_name = self.trainer.lightning_module._current_fx_name self.trainer.lightning_module._current_fx_name = hook_name # always profile hooks @@ -214,7 +214,8 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: accelerator_hook = getattr(self.trainer.accelerator, hook_name) accelerator_hook() - self.trainer.lightning_module._current_fx_name = None + # restore current_fx when nested context + self.trainer.lightning_module._current_fx_name = prev_fx_name def _num_training_batches_reached(self, is_last_batch=False): return self.batches_seen == self.trainer.num_training_batches or is_last_batch From 417ad319e16d31a01b92997c166234be59511fc0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Jun 2021 12:10:35 +0000 Subject: [PATCH 377/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/batch_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index c8a0e9f748572..cd54b46ec436b 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -315,7 +315,8 @@ def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): def track_and_norm_grad(self, optimizer) -> dict: # track gradient norms grad_norm_dict = {} - if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 and float(self.trainer.track_grad_norm) > 0: + if (self.trainer.global_step + + 1) % self.trainer.log_every_n_steps == 0 and float(self.trainer.track_grad_norm) > 0: grad_norm_dict = grad_norm(self.trainer.lightning_module, self.trainer.track_grad_norm) # clip gradients From 9321b11789c6c21ba4c7a28a02c61ac461322b84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 8 Jun 2021 14:31:02 +0200 Subject: [PATCH 378/455] remove grad norm dicts list --- pytorch_lightning/loops/batch_loop.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index cd54b46ec436b..1cdc574ee1fd5 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -78,7 +78,6 @@ def reset(self) -> None: self._hiddens = None # TODO: let loops track individual outputs self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] - self.grad_norm_dicts = [] def on_run_start(self, batch, batch_idx, dataloader_idx): self._remaining_splits = list(enumerate(self.tbptt_split_batch(batch))) @@ -105,9 +104,6 @@ def advance(self, batch, batch_idx, dataloader_idx): if result: self.batch_outputs[0].append(result.training_step_output) - # TODO: needed? - self.grad_norm_dicts.append(grad_norm_dict) - # ------------------------------------------------------------------------------------------------------------ # HELPER --- TO BE CLEANED UP From e75a958806f4e8ff04b9071b4cefe7afbe12332f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 14:39:41 +0200 Subject: [PATCH 379/455] Diff --- CHANGELOG.md | 8 ++++---- pytorch_lightning/core/lightning.py | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3199c3e94cab5..65c3b3742dbc7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -104,15 +104,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Refactored logging * Renamed and moved `core/step_result.py` to `trainer/connectors/logger_connector/result.py` ([#7736](https://github.com/PyTorchLightning/pytorch-lightning/pull/7736)) + * Dramatically simplify the `LoggerConnector` ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882)) + * `trainer.{logged,progress_bar,callback}_metrics` are now updated on-demand ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882)) + * Completely overhaul the `Result` object in favor of `ResultMetric` ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882)) + * Improve epoch-level reduction time and overall memory usage ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882)) * Allow passing `self.log(batch_size=...)` ([#7631](https://github.com/PyTorchLightning/pytorch-lightning/pull/7631)) * Allow passing `self.log(metric_attribute='your_metric')` to properly serialize the state of any `torchmetrics.Metric`s in your model ([#7631](https://github.com/PyTorchLightning/pytorch-lightning/pull/7631)) * Remove `EpochResultStore` and `HookResultStore` in favor of `ResultCollection` ([#7631](https://github.com/PyTorchLightning/pytorch-lightning/pull/7631)) - * Dramatically simplify the `LoggerConnector` ([#7631](https://github.com/PyTorchLightning/pytorch-lightning/pull/7631)) - * `trainer.{logged,progress_bar,callback}_metrics` are now updated on-demand ([#7631](https://github.com/PyTorchLightning/pytorch-lightning/pull/7631)) * Remove `MetricsHolder` ([#7631](https://github.com/PyTorchLightning/pytorch-lightning/pull/7631)) - * Completely overhaul the `Result` object in favor of `ResultMetric` ([#7631](https://github.com/PyTorchLightning/pytorch-lightning/pull/7631)) * Each of the training loops now keeps its own metrics ([#7631](https://github.com/PyTorchLightning/pytorch-lightning/pull/7631)) - * Improve epoch-level reduction time and overall memory usage ([#7631](https://github.com/PyTorchLightning/pytorch-lightning/pull/7631)) - Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 759f28911c749..794e21fec6675 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -478,12 +478,11 @@ def log_grad_norm(self, grad_norm_dict: Dict[str, torch.Tensor]) -> None: Args: grad_norm_dict: Dictionary containing current grad norm metrics - Examples:: + Example:: # DEFAULT def log_grad_norm(self, grad_norm_dict): self.log_dict(grad_norm_dict, on_step=False, on_epoch=True, prog_bar=False, logger=True) - """ self.log_dict(grad_norm_dict, on_step=True, on_epoch=True, prog_bar=True, logger=True) From 2e4bb2498e61945393835f5e06f3b88b30c04653 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 16:50:11 +0200 Subject: [PATCH 380/455] Bad merge --- pytorch_lightning/core/lightning.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 110bae7bd4175..46dd8c43492da 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -455,6 +455,7 @@ def log_dict( add_dataloader_idx=add_dataloader_idx ) + @staticmethod def __check_not_nested(value: dict, name: str) -> None: # self-imposed restriction. for simplicity if any(isinstance(v, dict) for v in value.values()): From f5154ae62239f48fbb5045d504a55cca8b5c7c05 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 17:27:40 +0200 Subject: [PATCH 381/455] Reuse metrics_to_scalars --- .../connectors/logger_connector/result.py | 8 ++---- pytorch_lightning/utilities/metrics.py | 25 ++++++++----------- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 40481dd9afb68..03a7c78e11175 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -25,6 +25,7 @@ from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.metrics import metrics_to_scalars # re-define the ones from pytorch_lightning.utilities.types without the `Number` type _METRIC = Union[Metric, torch.Tensor] @@ -370,10 +371,6 @@ def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Ten return cache.detach() return cache - @staticmethod - def __to_item(t: torch.Tensor) -> float: - return t.item() - def valid_items(self) -> Generator: """This function is used to iterate over current valid metrics.""" return ((k, v) for k, v in self.items() @@ -421,8 +418,7 @@ def any_tensor(_): # populate progress_bar metrics. convert tensors to numbers if result_metric.meta.prog_bar: - value = apply_to_collection(value, torch.Tensor, self.__to_item, include_none=False) - metrics[MetricSource.PBAR][forked_name] = value + metrics[MetricSource.PBAR][forked_name] = metrics_to_scalars(value) return metrics diff --git a/pytorch_lightning/utilities/metrics.py b/pytorch_lightning/utilities/metrics.py index 7ee6caf21743b..8433e9e370640 100644 --- a/pytorch_lightning/utilities/metrics.py +++ b/pytorch_lightning/utilities/metrics.py @@ -12,28 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. """Helper functions to operate on metric values. """ +import numbers import torch +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException def metrics_to_scalars(metrics: dict) -> dict: """ Recursively walk through a dictionary of metrics and convert single-item tensors to scalar values. """ - new_metrics = {} - for k, v in metrics.items(): - if isinstance(v, torch.Tensor): - if v.numel() != 1: - raise MisconfigurationException( - f"The metric `{k}` does not contain a single element" - f" thus it cannot be converted to float. Found `{v}`" - ) - v = v.item() + def to_item(value: torch.Tensor) -> numbers.Number: + if value.numel() != 1: + raise MisconfigurationException( + f"The metric `{value}` does not contain a single element" + f" thus it cannot be converted to float." + ) + return value.item() - if isinstance(v, dict): - v = metrics_to_scalars(v) - - new_metrics[k] = v - - return new_metrics + return apply_to_collection(metrics, torch.Tensor, to_item) From 558cdf4be243c3cac91f844550f549a12b017fcf Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 17:28:16 +0200 Subject: [PATCH 382/455] Use active loop --- pytorch_lightning/core/lightning.py | 10 ++++----- .../logger_connector/logger_connector.py | 22 +++++++++---------- pytorch_lightning/trainer/evaluation_loop.py | 15 +++++++++---- pytorch_lightning/trainer/properties.py | 18 +++++++++------ pytorch_lightning/trainer/training_loop.py | 16 +++++++------- 5 files changed, 46 insertions(+), 35 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 46dd8c43492da..2ceef175917c1 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -337,10 +337,10 @@ def log( on_step = self.__auto_choose_log_on_step(on_step) on_epoch = self.__auto_choose_log_on_epoch(on_epoch) - result_collection: 'ResultCollection' = self.trainer.result_collection # noqa F821 - assert result_collection is not None + results = self.trainer.results + assert results is not None assert self._current_fx_name is not None - result_collection.fx_validator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) + results.fx_validator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) # make sure user doesn't introduce logic for multi-dataloaders if "/dataloader_idx_" in name: @@ -374,9 +374,9 @@ def log( if self.trainer.logger_connector.should_reset_tensors(self._current_fx_name): # when restarting an new epoch, reset the tensors - result_collection.reset(metrics=False, fx=self._current_fx_name) + results.reset(metrics=False, fx=self._current_fx_name) - result_collection.log( + results.log( self._current_fx_name, name, value, diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 22ae8a6af5423..73ef1432e21a2 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -188,7 +188,7 @@ def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None # track batch_size - self.trainer.result_collection.extract_batch_size(batch) + self.trainer.results.extract_batch_size(batch) self._batch_idx = batch_idx def update_evaluation_step_metrics(self) -> None: @@ -209,7 +209,7 @@ def update_evaluation_step_metrics(self) -> None: """ def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None: - self.trainer.result_collection.extract_batch_size(split_batch) + self.trainer.results.extract_batch_size(split_batch) self._batch_idx = batch_idx self._split_idx = split_idx @@ -231,12 +231,12 @@ def update_train_epoch_metrics(self) -> None: self.log_metrics(metrics) # reset result collection for next epoch - self.trainer.result_collection.reset(metrics=True) + self.trainer.results.reset(metrics=True) def teardown(self): - self.trainer.train_loop.train_results.cpu() - self.trainer.evaluation_loop.validation_results.cpu() - self.trainer.evaluation_loop.test_results.cpu() + self.trainer.train_loop.results.cpu() + self.trainer.evaluation_loop._val_results.cpu() + self.trainer.evaluation_loop._test_results.cpu() """ Utilities and properties @@ -277,7 +277,7 @@ def should_reset_tensors(self, fx: str) -> bool: return is_different_fx and is_first_batch def reset(self, metrics: Optional[bool] = None) -> None: - self.trainer.result_collection.reset(metrics=metrics) + self.trainer.results.reset(metrics=metrics) self._batch_idx = None self._split_idx = None self._current_fx = None @@ -286,25 +286,25 @@ def reset(self, metrics: Optional[bool] = None) -> None: def metrics(self) -> Dict[MetricSource, Dict[str, _METRIC]]: """This function returns either batch or epoch metrics depending on ``_epoch_end_reached``.""" on_step = not self._epoch_end_reached - return self.trainer.result_collection.metrics(on_step) + return self.trainer.results.metrics(on_step) @property def callback_metrics(self) -> Dict[str, _METRIC]: - if self.trainer.result_collection: + if self.trainer.results: metrics = self.metrics[MetricSource.CALLBACK] self._callback_metrics.update(metrics) return self._callback_metrics @property def logged_metrics(self) -> Dict[str, _METRIC]: - if self.trainer.result_collection: + if self.trainer.results: metrics = self.metrics[MetricSource.LOG] self._logged_metrics.update(metrics) return self._logged_metrics @property def progress_bar_metrics(self) -> Dict[str, float]: - if self.trainer.result_collection: + if self.trainer.results: metrics = self.metrics[MetricSource.PBAR] self._progress_bar_metrics.update(metrics) return self._progress_bar_metrics diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 977ae10dc5464..2d303c54511ec 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -34,8 +34,15 @@ def __init__(self, trainer: 'pl.Trainer'): self.max_batches: Optional[List[Union[int, float]]] = None self.warning_cache = WarningCache() self.num_dataloaders: Optional[int] = None - self.validation_results = ResultCollection(False) - self.test_results = ResultCollection(False) + self._val_results = ResultCollection(False) + self._test_results = ResultCollection(False) + + @property + def results(self) -> Optional[ResultCollection]: + if self.trainer.validating or self.trainer.sanity_checking: + return self._val_results + elif self.trainer.testing: + return self._test_results def on_trainer_init(self) -> None: self.trainer.num_sanity_val_batches = [] @@ -80,8 +87,8 @@ def should_skip_evaluation(self, max_batches: List[Union[int, float]]) -> bool: def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end() - assert self.trainer.result_collection is not None - self.trainer.result_collection.device = self.trainer.lightning_module.device + assert self.trainer.results is not None + self.trainer.results.device = self.trainer.lightning_module.device if self.trainer.testing: self.trainer.call_hook('on_test_start', *args, **kwargs) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index efa6f6d8d8d41..4bae6ad05b3dd 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -483,6 +483,13 @@ def sanity_checking(self, val: bool) -> None: Loop properties """ + @property + def active_loop(self) -> Optional[Union[TrainLoop, EvaluationLoop]]: + if self.training: + return self.train_loop + elif self.sanity_checking or self.evaluating: + return self.evaluation_loop + @property def global_step(self) -> int: return self.train_loop.global_step @@ -524,13 +531,10 @@ def progress_bar_metrics(self) -> dict: return self.logger_connector.progress_bar_metrics @property - def result_collection(self) -> Optional[ResultCollection]: - if self.training: - return self.train_loop.train_results - elif self.validating or self.sanity_checking: - return self.evaluation_loop.validation_results - elif self.testing: - return self.evaluation_loop.test_results + def results(self) -> Optional[ResultCollection]: + active_loop = self.active_loop + if active_loop is not None: + return active_loop.results """ Other diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 68184fafced6b..2b5e8e39e2e0c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -82,7 +82,7 @@ def __init__( else: self.trainer.num_sanity_val_steps = num_sanity_val_steps - self.train_results = ResultCollection(True) + self.results = ResultCollection(True) @property def num_active_optimizers(self) -> int: @@ -100,7 +100,7 @@ def should_skip_training(self) -> bool: return should_by_max_steps or should_by_epoch or self.trainer.num_training_batches == 0 def on_train_start(self): - self.trainer.result_collection.device = self.trainer.lightning_module.device + self.trainer.results.device = self.trainer.lightning_module.device self.trainer.call_hook("on_train_start") def on_train_end(self): @@ -307,11 +307,11 @@ def _process_training_step_output(self, training_step_output): if training_step_output is None: return None - result = self.trainer.result_collection + results = self.trainer.results loss = None hiddens = None - result.extra = {} + results.extra = {} # handle dict return if isinstance(training_step_output, dict): @@ -319,20 +319,20 @@ def _process_training_step_output(self, training_step_output): hiddens = training_step_output.pop("hiddens", None) if hiddens is not None: hiddens = hiddens.detach() - result.extra = training_step_output + results.extra = training_step_output # handle scalar return elif isinstance(training_step_output, torch.Tensor): loss = training_step_output # map to results under the hood - result.minimize = loss + results.minimize = loss self._hiddens = hiddens if self.trainer.move_metrics_to_cpu: - result = result.cpu() + results = results.cpu() - return result + return results @staticmethod def _prepare_outputs( From 90d71bf987fa3dc5f9aad31939fcd4dd43a57151 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 17:40:48 +0200 Subject: [PATCH 383/455] Move to device --- .../connectors/logger_connector/logger_connector.py | 10 +++++----- pytorch_lightning/trainer/evaluation_loop.py | 2 +- pytorch_lightning/trainer/training_loop.py | 3 ++- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 73ef1432e21a2..c1d85e4a84b27 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -233,11 +233,6 @@ def update_train_epoch_metrics(self) -> None: # reset result collection for next epoch self.trainer.results.reset(metrics=True) - def teardown(self): - self.trainer.train_loop.results.cpu() - self.trainer.evaluation_loop._val_results.cpu() - self.trainer.evaluation_loop._test_results.cpu() - """ Utilities and properties """ @@ -308,3 +303,8 @@ def progress_bar_metrics(self) -> Dict[str, float]: metrics = self.metrics[MetricSource.PBAR] self._progress_bar_metrics.update(metrics) return self._progress_bar_metrics + + def teardown(self): + self.trainer.train_loop.results.cpu() + self.trainer.evaluation_loop._val_results.cpu() + self.trainer.evaluation_loop._test_results.cpu() diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 2d303c54511ec..beaabd5eb8345 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -88,7 +88,7 @@ def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end() assert self.trainer.results is not None - self.trainer.results.device = self.trainer.lightning_module.device + self.trainer.results.to(device=self.trainer.lightning_module.device) if self.trainer.testing: self.trainer.call_hook('on_test_start', *args, **kwargs) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 2b5e8e39e2e0c..dc1ae39bdc78f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -100,7 +100,8 @@ def should_skip_training(self) -> bool: return should_by_max_steps or should_by_epoch or self.trainer.num_training_batches == 0 def on_train_start(self): - self.trainer.results.device = self.trainer.lightning_module.device + self.trainer.results.to(device=self.trainer.lightning_module.device) + self.trainer.call_hook("on_train_start") def on_train_end(self): From 6ce67628c042efb0218c8bfd0939378b25e50f33 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 8 Jun 2021 19:29:44 +0000 Subject: [PATCH 384/455] resolve test --- pytorch_lightning/core/lightning.py | 13 +++++++++++++ .../connectors/logger_connector/result.py | 16 ++++++++++++---- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 2ceef175917c1..a9a096c8db5a6 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -376,6 +376,12 @@ def log( # when restarting an new epoch, reset the tensors results.reset(metrics=False, fx=self._current_fx_name) + if isinstance(sync_dist_op, str): + sync_dist_op = sync_dist_op.lower() + if sync_dist_op == "avg": + sync_dist_op = 'mean' + reduce_fx = self.__check_sync_dist_op(sync_dist_op, reduce_fx) + results.log( self._current_fx_name, name, @@ -469,6 +475,13 @@ def __check_allowed(v: Any, name: str, value: Any) -> None: def __to_float(self, value: numbers.Number) -> torch.Tensor: return torch.tensor(value, device=self.device, dtype=torch.float) + @staticmethod + def __check_sync_dist_op(sync_dist_op: str, fx: Callable) -> Callable: + torch_fx = getattr(torch, sync_dist_op) + if getattr(torch, sync_dist_op) != fx: + return torch_fx + return fx + def log_grad_norm(self, grad_norm_dict: Dict[str, torch.Tensor]) -> None: """Override this method to change the default behaviour of ``log_grad_norm``. diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 03a7c78e11175..daf4d0b8b6344 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -81,17 +81,21 @@ def forked_name(self, on_step: bool) -> str: def is_mean_reduction(self) -> bool: return self.reduce_fx == torch.mean + @property + def is_sum_reduction(self) -> bool: + return self.reduce_fx in (torch.sum, sum, "sum") + @property def is_max_reduction(self) -> bool: - return self.reduce_fx in (torch.max, max) + return self.reduce_fx in (torch.max, max, 'max') @property def is_min_reduction(self) -> bool: - return self.reduce_fx in (torch.min, min) + return self.reduce_fx in (torch.min, min, 'min') @property def is_custom_reduction(self) -> bool: - return not (self.is_mean_reduction or self.is_max_reduction or self.is_min_reduction) + return not (self.is_mean_reduction or self.is_max_reduction or self.is_min_reduction or self.is_sum_reduction) class ResultMetric(Metric, DeviceDtypeModuleMixin): @@ -121,6 +125,8 @@ def update(self, value: _METRIC, batch_size: torch.Tensor) -> None: self.cumulated_batch_size += batch_size elif self.meta.is_max_reduction or self.meta.is_min_reduction: self.value = self.meta.reduce_fx(self.value, value.mean()) + elif self.meta.is_sum_reduction: + self.value += value.mean() * batch_size else: self.value = value # noqa: attribute-defined-outside-init self._forward_cache = value._forward_cache @@ -131,6 +137,8 @@ def compute(self) -> torch.Tensor: if self.meta.is_mean_reduction: cumulated_batch_size = self.meta.sync(self.cumulated_batch_size) return value / cumulated_batch_size + elif self.meta.is_sum_reduction: + return value elif self.meta.is_max_reduction or self.meta.is_min_reduction: return value raise MisconfigurationException( @@ -323,7 +331,7 @@ def log( if key not in self: if meta.is_custom_reduction: raise MisconfigurationException( - 'Only `self.log(..., reduce_fx={min,max,mean})` are currently supported.' + 'Only `self.log(..., reduce_fx={min,max,mean,sum})` are currently supported.' ' Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`' ) self.register_key(key, meta, value) From fba9a87f0d6974810d353b4de35f96eb667b4b95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 8 Jun 2021 22:59:09 +0200 Subject: [PATCH 385/455] properties first --- pytorch_lightning/loops/base.py | 8 ++++---- pytorch_lightning/loops/batch_loop.py | 8 ++++---- pytorch_lightning/loops/epoch_loop.py | 8 ++++---- pytorch_lightning/loops/training_loop.py | 10 +++++----- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 9ef2903dd5761..94c20ebfa67c0 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -11,15 +11,15 @@ def __init__(self): self.iteration_count: int = 0 self.trainer: Optional['pl.Trainer'] = None - def connect(self, trainer, *args, **kwargs) -> None: - """Connects Loop with all the necessary things like connectors and accelerators""" - self.trainer = proxy(trainer) - @property @abstractmethod def done(self) -> bool: """Property indicating when loop is finished""" + def connect(self, trainer, *args, **kwargs) -> None: + """Connects Loop with all the necessary things like connectors and accelerators""" + self.trainer = proxy(trainer) + @abstractmethod def reset(self) -> None: pass diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 1cdc574ee1fd5..1a1a73ebaea8a 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -35,6 +35,10 @@ def __init__(self): self._optimizer_freq_cumsum = None self._skip_backward = False + @property + def done(self): + return len(self._remaining_splits) == 0 + @property def skip_backward(self) -> bool: """ Determines whether the loop will skip backward during automatic optimization. """ @@ -48,10 +52,6 @@ def skip_backward(self, value: bool): def connect(self, trainer, *args, **kwargs): self.trainer = trainer - @property - def done(self): - return len(self._remaining_splits) == 0 - def run(self, batch, batch_idx, dataloader_idx): if batch is None: self.warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index a84225becb9d4..3bfae16fee247 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -93,10 +93,6 @@ def skip_backward(self, value: bool): """ Determines whether the loop will skip backward during automatic optimization. """ self.training_loop.batch_loop.skip_backward = value - def connect(self, trainer: 'pl.Trainer', *args, **kwargs): - self.trainer = trainer - self.training_loop.connect(trainer) - @property def done(self) -> bool: # TODO: Move track steps inside training loop and move part of these condition inside training loop @@ -122,6 +118,10 @@ def done(self) -> bool: return stop_steps or should_stop or stop_epochs + def connect(self, trainer: 'pl.Trainer', *args, **kwargs): + self.trainer = trainer + self.training_loop.connect(trainer) + def reset(self) -> None: self.iteration_count = 0 diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_loop.py index c7f34ffd6a702..9214c22af4594 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_loop.py @@ -43,16 +43,16 @@ def __init__(self, min_steps, max_steps): def batch_idx(self) -> int: return self.iteration_count - def connect(self, trainer: 'pl.Trainer', *args, **kwargs): - self.trainer = trainer - self.batch_loop = BatchLoop() - self.batch_loop.connect(trainer) - @property def done(self): max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch) + def connect(self, trainer: 'pl.Trainer', *args, **kwargs): + self.trainer = trainer + self.batch_loop = BatchLoop() + self.batch_loop.connect(trainer) + def run(self, *args, **kwargs): self.reset() self.on_run_start() From 79c73b99e39449472e2ff829fc9dd55195ebad3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 8 Jun 2021 23:24:05 +0200 Subject: [PATCH 386/455] define union --- .../trainer/connectors/logger_connector/logger_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 291072c287c96..f9dfbf7a83e22 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -13,7 +13,7 @@ # limitations under the License. import os from pprint import pprint -from typing import Any, Dict, Iterable, Mapping, Optional +from typing import Any, Dict, Iterable, Mapping, Optional, Union import torch From 37a0b9d744c1fc4290357dd456c946e1beb03e8b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 9 Jun 2021 00:10:53 +0200 Subject: [PATCH 387/455] Update logger connector --- .../logger_connector/logger_connector.py | 84 ++--- .../logger_connector/logger_connector_new.py | 311 ------------------ 2 files changed, 42 insertions(+), 353 deletions(-) delete mode 100644 pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 291072c287c96..83dfa3294b218 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -13,14 +13,14 @@ # limitations under the License. import os from pprint import pprint -from typing import Any, Dict, Iterable, Mapping, Optional +from typing import Any, Dict, Iterable, Mapping, Optional, Union import torch import pytorch_lightning as pl from pytorch_lightning.core import memory from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger -from pytorch_lightning.trainer.connectors.logger_connector.result import _METRIC, MetricSource +from pytorch_lightning.trainer.connectors.logger_connector.result_new import _METRIC, MetricSource from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import DeviceType from pytorch_lightning.utilities.metrics import metrics_to_scalars @@ -119,7 +119,44 @@ def log_metrics(self, metrics: Dict[str, _METRIC], step: Optional[int] = None) - Evaluation metric updates """ - def prepare_eval_loop_results(self, metrics: Mapping[str, _METRIC]) -> None: + @property + def _eval_log_step(self) -> Optional[int]: + if self.trainer.state.stage is RunningStage.VALIDATING: + return self._val_log_step + elif self.trainer.state.stage is RunningStage.TESTING: + return self._test_log_step + else: + return None + + def _increment_eval_log_step(self) -> None: + if self.trainer.state.stage is RunningStage.VALIDATING: + self._val_log_step += 1 + elif self.trainer.state.stage is RunningStage.TESTING: + self._test_log_step += 1 + + def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: + model = self.trainer.lightning_module + # set dataloader_idx only if multiple ones + model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None + + # track batch_size + self.trainer.result_collection.extract_batch_size(batch) + self._batch_idx = batch_idx + + def update_eval_step_metrics(self) -> None: + if self.trainer.sanity_checking: + return + + # logs user requested information to logger + assert not self._epoch_end_reached + metrics = self.metrics[MetricSource.LOG] + if metrics: + self.log_metrics(metrics, step=self._eval_log_step) + + # increment the step even if nothing was logged + self._increment_eval_log_step() + + def _prepare_eval_loop_results(self, metrics: Mapping[str, _METRIC]) -> None: if self.trainer.sanity_checking: return @@ -136,7 +173,7 @@ def prepare_eval_loop_results(self, metrics: Mapping[str, _METRIC]) -> None: else: self.eval_loop_results.append(callback_metrics) - def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: + def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT: assert self._epoch_end_reached metrics = self.metrics @@ -146,7 +183,7 @@ def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: if log_metrics: self.log_metrics(log_metrics) - self.prepare_eval_loop_results(metrics[MetricSource.CALLBACK]) + self._prepare_eval_loop_results(metrics[MetricSource.CALLBACK]) # log results of evaluation if ( @@ -167,43 +204,6 @@ def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: self.eval_loop_results = [] return results - @property - def evaluation_log_step(self) -> Optional[int]: - if self.trainer.state.stage is RunningStage.VALIDATING: - return self._val_log_step - elif self.trainer.state.stage is RunningStage.TESTING: - return self._test_log_step - else: - return None - - def increment_evaluation_log_step(self) -> None: - if self.trainer.state.stage is RunningStage.VALIDATING: - self._val_log_step += 1 - elif self.trainer.state.stage is RunningStage.TESTING: - self._test_log_step += 1 - - def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: - model = self.trainer.lightning_module - # set dataloader_idx only if multiple ones - model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None - - # track batch_size - self.trainer.results.extract_batch_size(batch) - self._batch_idx = batch_idx - - def update_evaluation_step_metrics(self) -> None: - if self.trainer.sanity_checking: - return - - # logs user requested information to logger - assert not self._epoch_end_reached - metrics = self.metrics[MetricSource.LOG] - if metrics: - self.log_metrics(metrics, step=self.evaluation_log_step) - - # increment the step even if nothing was logged - self.increment_evaluation_log_step() - """ Train metric updates """ diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py deleted file mode 100644 index 069a7f5183c70..0000000000000 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py +++ /dev/null @@ -1,311 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from pprint import pprint -from typing import Any, Dict, Iterable, Mapping, Optional, Union - -import torch - -import pytorch_lightning as pl -from pytorch_lightning.core import memory -from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger -from pytorch_lightning.trainer.connectors.logger_connector.result_new import _METRIC, MetricSource -from pytorch_lightning.trainer.states import RunningStage, TrainerFn -from pytorch_lightning.utilities import DeviceType -from pytorch_lightning.utilities.metrics import metrics_to_scalars -from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT - - -# TODO(@carmocca): Remove `New` suffix -class LoggerConnectorNew: - - def __init__(self, trainer: 'pl.Trainer', log_gpu_memory: Optional[str] = None) -> None: - self.trainer = trainer - self.log_gpu_memory = log_gpu_memory - self.eval_loop_results = [] - self._val_log_step: int = 0 - self._test_log_step: int = 0 - self._progress_bar_metrics: Dict[str, float] = {} - self._logged_metrics: Dict[str, _METRIC] = {} - self._callback_metrics: Dict[str, _METRIC] = {} - self._epoch_end_reached = False - self._current_fx: Optional[str] = None - self._batch_idx: Optional[int] = None - self._split_idx: Optional[int] = None - - def on_trainer_init( - self, - logger: LightningLoggerBase, - flush_logs_every_n_steps: int, - log_every_n_steps: int, - move_metrics_to_cpu: bool, - ) -> None: - self.configure_logger(logger) - self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps - self.trainer.log_every_n_steps = log_every_n_steps - self.trainer.move_metrics_to_cpu = move_metrics_to_cpu - - @property - def should_flush_logs(self) -> bool: - should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0 - return should_flush or self.trainer.should_stop - - @property - def should_update_logs(self) -> bool: - should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 - return should_log_every_n_steps or self.trainer.should_stop - - def configure_logger(self, logger: Union[bool, Iterable, LightningLoggerBase]) -> None: - if logger is True: - version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id) - - # default logger - self.trainer.logger = TensorBoardLogger( - save_dir=self.trainer.default_root_dir, version=version, name='lightning_logs' - ) - elif logger is False: - self.trainer.logger = None - else: - if isinstance(logger, Iterable): - self.trainer.logger = LoggerCollection(logger) - else: - self.trainer.logger = logger - - def log_metrics(self, metrics: Dict[str, _METRIC], step: Optional[int] = None) -> None: - """Logs the metric dict passed in. - If `step` parameter is None and `step` key is presented is metrics, - uses metrics["step"] as a step - - Args: - metrics: Metric values - step: Step for which metrics should be logged. Default value is `self.global_step` during training or - the total validation / test log step count during validation and testing. - """ - # add gpu memory - if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory: - mem_map = memory.get_memory_profile(self.log_gpu_memory) - metrics.update(mem_map) - - # turn all tensors to scalars - scalar_metrics = metrics_to_scalars(metrics) - - if "step" in scalar_metrics and step is None: - step = scalar_metrics.pop("step") - - elif step is None: - # added metrics by Lightning for convenience - scalar_metrics['epoch'] = self.trainer.current_epoch - step = self.trainer.global_step - - # log actual metrics - if self.trainer.logger is not None: - if self.trainer.is_global_zero: - self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step) - self.trainer.logger.save() - - self._logged_metrics.update(scalar_metrics) - - """ - Evaluation metric updates - """ - - @property - def _eval_log_step(self) -> Optional[int]: - if self.trainer.state.stage is RunningStage.VALIDATING: - return self._val_log_step - elif self.trainer.state.stage is RunningStage.TESTING: - return self._test_log_step - else: - return None - - def _increment_eval_log_step(self) -> None: - if self.trainer.state.stage is RunningStage.VALIDATING: - self._val_log_step += 1 - elif self.trainer.state.stage is RunningStage.TESTING: - self._test_log_step += 1 - - def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: - model = self.trainer.lightning_module - # set dataloader_idx only if multiple ones - model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None - - # track batch_size - self.trainer.result_collection.extract_batch_size(batch) - self._batch_idx = batch_idx - - def update_eval_step_metrics(self) -> None: - if self.trainer.sanity_checking: - return - - # logs user requested information to logger - assert not self._epoch_end_reached - metrics = self.metrics[MetricSource.LOG] - if metrics: - self.log_metrics(metrics, step=self._eval_log_step) - - # increment the step even if nothing was logged - self._increment_eval_log_step() - - def _prepare_eval_loop_results(self, metrics: Mapping[str, _METRIC]) -> None: - if self.trainer.sanity_checking: - return - - num_dataloaders = self.trainer.evaluation_loop.num_dataloaders - has_been_initialized = len(self.eval_loop_results) == num_dataloaders - for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders): - # remove callback metrics that don't belong to this dataloader - callback_metrics = { - k: v - for k, v in metrics.items() if "dataloader_idx" not in k or f"dataloader_idx_{dl_idx}" in k - } - if has_been_initialized: - self.eval_loop_results[dl_idx].update(callback_metrics) - else: - self.eval_loop_results.append(callback_metrics) - - def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT: - assert self._epoch_end_reached - metrics = self.metrics - - if not self.trainer.sanity_checking: - # log all the metrics as a single dict - log_metrics = metrics[MetricSource.LOG] - if log_metrics: - self.log_metrics(log_metrics) - - self._prepare_eval_loop_results(metrics[MetricSource.CALLBACK]) - - # log results of evaluation - if ( - self.trainer.state.fn != TrainerFn.FITTING and self.trainer.evaluating and self.trainer.is_global_zero - and self.trainer.verbose_evaluate - ): - print('-' * 80) - for result_idx, results in enumerate(self.eval_loop_results): - print(f'DATALOADER:{result_idx} {self.trainer.state.stage.upper()} RESULTS') - pprint({ - k: (v.item() if v.numel() == 1 else v.tolist()) if isinstance(v, torch.Tensor) else v - for k, v in results.items() - }) - print('-' * 80) - - results = self.eval_loop_results - # clear mem - self.eval_loop_results = [] - return results - - """ - Train metric updates - """ - - def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None: - self.trainer.results.extract_batch_size(split_batch) - self._batch_idx = batch_idx - self._split_idx = split_idx - - def update_train_step_metrics(self) -> None: - if self.trainer.train_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization: - return - - # when metrics should be logged - assert not self._epoch_end_reached - metrics = self.metrics[MetricSource.LOG] - if self.should_update_logs or self.trainer.fast_dev_run is True and metrics: - self.log_metrics(metrics) - - def update_train_epoch_metrics(self) -> None: - # add the metrics to the loggers - assert self._epoch_end_reached - metrics = self.metrics[MetricSource.LOG] - if metrics: - self.log_metrics(metrics) - - # reset result collection for next epoch - self.trainer.results.reset(metrics=True) - - """ - Utilities and properties - """ - - def on_epoch_start(self) -> None: - self._epoch_end_reached = False - - def on_batch_start(self) -> None: - self._epoch_end_reached = False - - def epoch_end_reached(self): - self.trainer.logger_connector._epoch_end_reached = True - self.trainer.logger_connector._batch_idx = None - self.trainer.logger_connector._split_idx = None - - def on_epoch_end(self) -> None: - assert self._epoch_end_reached - metrics = self.metrics - self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) - self._callback_metrics.update(metrics[MetricSource.CALLBACK]) - self._logged_metrics.update(metrics[MetricSource.LOG]) - self._current_fx = None - - def on_batch_end(self) -> None: - assert not self._epoch_end_reached - metrics = self.metrics - self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) - self._callback_metrics.update(metrics[MetricSource.CALLBACK]) - self._logged_metrics.update(metrics[MetricSource.LOG]) - - def should_reset_tensors(self, fx: str) -> bool: - is_different_fx = self._current_fx != fx - if self._split_idx is None: - is_first_batch = self._batch_idx in (None, 0) - else: - is_first_batch = self._batch_idx + self._split_idx == 0 - return is_different_fx and is_first_batch - - def reset(self, metrics: Optional[bool] = None) -> None: - self.trainer.results.reset(metrics=metrics) - self._batch_idx = None - self._split_idx = None - self._current_fx = None - - @property - def metrics(self) -> Dict[MetricSource, Dict[str, _METRIC]]: - """This function returns either batch or epoch metrics depending on ``_epoch_end_reached``.""" - on_step = not self._epoch_end_reached - return self.trainer.results.metrics(on_step) - - @property - def callback_metrics(self) -> Dict[str, _METRIC]: - if self.trainer.results: - metrics = self.metrics[MetricSource.CALLBACK] - self._callback_metrics.update(metrics) - return self._callback_metrics - - @property - def logged_metrics(self) -> Dict[str, _METRIC]: - if self.trainer.results: - metrics = self.metrics[MetricSource.LOG] - self._logged_metrics.update(metrics) - return self._logged_metrics - - @property - def progress_bar_metrics(self) -> Dict[str, float]: - if self.trainer.results: - metrics = self.metrics[MetricSource.PBAR] - self._progress_bar_metrics.update(metrics) - return self._progress_bar_metrics - - def teardown(self): - self.trainer.train_loop.results.cpu() - self.trainer.evaluation_loop._val_results.cpu() - self.trainer.evaluation_loop._test_results.cpu() From aaea3877bce6bef1939226d271227ef59c277129 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 9 Jun 2021 00:13:30 +0200 Subject: [PATCH 388/455] Update result --- .../connectors/logger_connector/result.py | 2 +- .../connectors/logger_connector/result_new.py | 499 ------------------ 2 files changed, 1 insertion(+), 500 deletions(-) delete mode 100644 pytorch_lightning/trainer/connectors/logger_connector/result_new.py diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index daf4d0b8b6344..6bb478b18789f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -62,7 +62,7 @@ class _Metadata: logger: bool = True on_step: bool = False on_epoch: bool = True - reduce_fx: Callable = torch.mean + reduce_fx: Union[str, Callable] = torch.mean enable_graph: bool = False dataloader_idx: Optional[int] = None metric_attribute: Optional[str] = None diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result_new.py b/pytorch_lightning/trainer/connectors/logger_connector/result_new.py deleted file mode 100644 index 03a7c78e11175..0000000000000 --- a/pytorch_lightning/trainer/connectors/logger_connector/result_new.py +++ /dev/null @@ -1,499 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from collections.abc import Generator -from dataclasses import dataclass, field -from functools import partial, wraps -from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union - -import torch -from torchmetrics import Metric - -from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator -from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections -from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin -from pytorch_lightning.utilities.enums import LightningEnum -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.metrics import metrics_to_scalars - -# re-define the ones from pytorch_lightning.utilities.types without the `Number` type -_METRIC = Union[Metric, torch.Tensor] -_METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]] - - -class MetricSource(LightningEnum): - CALLBACK = "callback" - PBAR = "pbar" - LOG = "log" - - -@dataclass -class _Sync: - fn: Callable - should: bool = False - op: Union[Any, str] = 'mean' - group: Optional[Any] = None - - @property - def __call__(self) -> Any: - return partial(self.fn, reduce_op=self.op, group=self.group) if self.should else self.no_op - - @staticmethod - def no_op(value: Any, *_, **__) -> Any: - return value - - -@dataclass -class _Metadata: - fx: str - name: str - prog_bar: bool = False - logger: bool = True - on_step: bool = False - on_epoch: bool = True - reduce_fx: Callable = torch.mean - enable_graph: bool = False - dataloader_idx: Optional[int] = None - metric_attribute: Optional[str] = None - sync: _Sync = field(default_factory=_Sync) - - @property - def forked(self) -> bool: - return self.on_step and self.on_epoch - - def forked_name(self, on_step: bool) -> str: - if self.forked: - return f'{self.name}_{"step" if on_step else "epoch"}' - return self.name - - @property - def is_mean_reduction(self) -> bool: - return self.reduce_fx == torch.mean - - @property - def is_max_reduction(self) -> bool: - return self.reduce_fx in (torch.max, max) - - @property - def is_min_reduction(self) -> bool: - return self.reduce_fx in (torch.min, min) - - @property - def is_custom_reduction(self) -> bool: - return not (self.is_mean_reduction or self.is_max_reduction or self.is_min_reduction) - - -class ResultMetric(Metric, DeviceDtypeModuleMixin): - """Wraps the value provided to `:meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" - - def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: - super().__init__() - self.is_tensor = is_tensor - self.meta = metadata - self.has_reset = False - if is_tensor: - self.add_state("value", torch.tensor(0, dtype=torch.float)) - if self.meta.is_mean_reduction: - self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float)) - - def update(self, value: _METRIC, batch_size: torch.Tensor) -> None: - if self.is_tensor: - value = value.float() - self._forward_cache = value - # performance: no need to accumulate on values only logged on_step - if self.meta.on_step and not self.meta.on_epoch: - self.value = self.meta.sync(value) - return - # perform accumulation with reduction - if self.meta.is_mean_reduction: - self.value += value.mean() * batch_size - self.cumulated_batch_size += batch_size - elif self.meta.is_max_reduction or self.meta.is_min_reduction: - self.value = self.meta.reduce_fx(self.value, value.mean()) - else: - self.value = value # noqa: attribute-defined-outside-init - self._forward_cache = value._forward_cache - - def compute(self) -> torch.Tensor: - if self.is_tensor: - value = self.meta.sync(self.value) - if self.meta.is_mean_reduction: - cumulated_batch_size = self.meta.sync(self.cumulated_batch_size) - return value / cumulated_batch_size - elif self.meta.is_max_reduction or self.meta.is_min_reduction: - return value - raise MisconfigurationException( - f"Only [min, max, mean] reductions are supported. Found {self.meta.reduce_fx}" - ) - return self.value.compute() - - def reset(self) -> None: - if self.is_tensor: - super().reset() - else: - self.value.reset() - self.has_reset = True - - def forward(self, value: _METRIC, batch_size: torch.Tensor) -> None: - if self.meta.enable_graph: - with torch.no_grad(): - self.update(value, batch_size) - else: - # performance: skip the `torch.no_grad` context manager by calling `update` directly - self.update(value, batch_size) - - def _wrap_compute(self, compute: Any) -> Any: - # Override to avoid syncing - we handle it ourselves. - @wraps(compute) - def wrapped_func(*args, **kwargs): - if not self._update_called: - rank_zero_warn( - f"The ``compute`` method of metric {self.__class__.__name__}" - " was called before the ``update`` method which may lead to errors," - " as metric states have not yet been updated.", UserWarning - ) - - # return cached value - if self._computed is not None: - return self._computed - self._computed = compute(*args, **kwargs) - return self._computed - - return wrapped_func - - def __setattr__(self, key: str, value: Any) -> None: - # performance: skip the `torch.nn.Module.__setattr__` checks - object.__setattr__(self, key, value) - - def __repr__(self) -> str: - state = f"value={self.value}" - if self.is_tensor and self.meta.is_mean_reduction: - state += f", cumulated_batch_size={self.cumulated_batch_size}" - return f"{self.__class__.__name__}({state})" - - -class ResultMetricCollection(dict): - """ - Dict wrapper for easy access to metadata. - - All of the leaf items should be instances of - :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` - with the same metadata. - """ - - def __init__(self, *args, metadata: Optional[_Metadata] = None) -> None: - super().__init__(*args) - self.meta = metadata - - -class ResultCollection(dict): - """ - Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` or - :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetricCollection` - - Example: - - # `device` needs to be provided before logging - result = ResultCollection(True, torch.device("cpu")) - - # you can log to a specific collection. - # arguments: fx, key, value, metadata - result.log('training_step', 'acc', torch.tensor(...), on_step=True, on_epoch=True) - result.log('validation_step', 'recall', torch.tensor(...), on_step=True, on_epoch=True) - """ - - DATALOADER_SUFFIX = "/dataloader_idx_{}" - - def __init__(self, training: bool, device: Optional[torch.device] = None) -> None: - super().__init__() - self.training = training - self._minimize = None - self._batch_size = torch.tensor(1, device=device) - self.device: Optional[Union[str, torch.device]] = device - self.fx_validator = FxValidator() - - @property - def batch_size(self) -> torch.Tensor: - # performance: cache the `batch_size` tensor instead of re-creating it - return self._batch_size - - @batch_size.setter - def batch_size(self, value: int) -> None: - self._batch_size = torch.tensor(value, device=self.device) - - @property - def minimize(self) -> Optional[torch.Tensor]: - """ - The :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` loss - will be saved as the ``minimize`` attribute. - """ - return self._minimize - - @minimize.setter - def minimize(self, loss: Optional[torch.Tensor]) -> None: - if loss is not None: - if not isinstance(loss, torch.Tensor): - raise ValueError(f"`Result.minimize` must be a `torch.Tensor`, found: {loss}") - if loss.grad_fn is None: - raise RuntimeError("`Result.minimize` must have a `grad_fn`") - self._minimize = loss - - @property - def extra(self) -> Dict[str, Any]: - """ - Extras are any keys other than the loss returned by - :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` - """ - return self.get('_extra', {}) - - @extra.setter - def extra(self, extra: Mapping[str, Any]) -> None: - - def check_fn(v): - if v.grad_fn is not None: - raise MisconfigurationException(f'You returned a tensor with `grad_fn`. The extra values are {extra}') - - apply_to_collection(extra, torch.Tensor, check_fn) - self['_extra'] = extra - - def log( - self, - fx: str, - name: str, - value: _METRIC_COLLECTION, - prog_bar: bool = False, - logger: bool = True, - on_step: bool = False, - on_epoch: bool = True, - reduce_fx: Callable = torch.mean, - enable_graph: bool = False, - sync_dist: bool = False, - sync_dist_fn: Callable = _Sync.no_op, - sync_dist_op: Union[Any, str] = 'mean', - sync_dist_group: Optional[Any] = None, - dataloader_idx: Optional[int] = None, - batch_size: Optional[int] = None, - metric_attribute: Optional[str] = None, - ) -> None: - """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" - # no metrics should be logged with graphs - if not enable_graph and isinstance(value, torch.Tensor): - value = value.detach() - - # move metrics to cpu on TPU. - if isinstance(value, torch.Tensor) and value.device.type == "xla": - value = value.cpu() - - # storage key - key = f"{fx}.{name}" - # add dataloader_suffix to both key and fx - if dataloader_idx is not None: - key += f'.{dataloader_idx}' - fx += f'.{dataloader_idx}' - - meta = _Metadata( - fx=fx, - name=name, - prog_bar=prog_bar, - logger=logger, - on_step=on_step, - on_epoch=on_epoch, - reduce_fx=reduce_fx, - enable_graph=enable_graph, - dataloader_idx=dataloader_idx, - metric_attribute=metric_attribute, - sync=_Sync( - should=sync_dist, - fn=sync_dist_fn, - op=sync_dist_op, - group=sync_dist_group, - ) - ) - if key not in self: - if meta.is_custom_reduction: - raise MisconfigurationException( - 'Only `self.log(..., reduce_fx={min,max,mean})` are currently supported.' - ' Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`' - ) - self.register_key(key, meta, value) - elif meta != self[key].meta: - raise MisconfigurationException( - f'You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed' - ) - - if batch_size is not None: - self.batch_size = batch_size - - self.update_metrics(key, value) - - def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None: - """Create one ResultMetric object per value. Value can be provided as a nested collection""" - - def fn(v: _METRIC) -> ResultMetric: - metric = ResultMetric(meta, isinstance(v, torch.Tensor)) - return metric.to(self.device) - - value = apply_to_collection(value, (torch.Tensor, Metric), fn) - if isinstance(value, dict): - value = ResultMetricCollection(value, metadata=meta) - self[key] = value - - def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None: - - def fn(result_metric, v): - # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl` - result_metric.forward(v.to(self.device), self.batch_size) - result_metric.has_reset = False - - apply_to_collections(self[key], value, ResultMetric, fn) - - @staticmethod - def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Tensor]: - cache = None - if on_step and result_metric.meta.on_step: - cache = result_metric._forward_cache - elif not on_step and result_metric.meta.on_epoch: - if not result_metric._computed: - result_metric.compute() - cache = result_metric._computed - if cache is not None and not result_metric.meta.enable_graph: - return cache.detach() - return cache - - def valid_items(self) -> Generator: - """This function is used to iterate over current valid metrics.""" - return ((k, v) for k, v in self.items() - if not k == "_extra" and not (isinstance(v, ResultMetric) and v.has_reset)) - - def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]: - name = result_metric.meta.name - forked_name = result_metric.meta.forked_name(on_step) - dl_idx = result_metric.meta.dataloader_idx - if dl_idx is not None: - dataloader_suffix = self.DATALOADER_SUFFIX.format(dl_idx) - name += dataloader_suffix - forked_name += dataloader_suffix - return name, forked_name - - def metrics(self, on_step: bool) -> Dict[MetricSource, Dict[str, _METRIC]]: - metrics = {k: {} for k in MetricSource} - - for key, result_metric in self.valid_items(): - - # extract forward_cache or computed from the ResultMetric. ignore when the output is None - value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False) - - # check if the collection is empty - has_tensor = False - - def any_tensor(_): - nonlocal has_tensor - has_tensor = True - - apply_to_collection(value, torch.Tensor, any_tensor) - if not has_tensor: - continue - - name, forked_name = self._forked_name(result_metric, on_step) - - # populate logging metrics - if result_metric.meta.logger: - metrics[MetricSource.LOG][forked_name] = value - - # populate callback metrics. callback metrics don't take `_step` forked metrics - if self.training or result_metric.meta.on_epoch and not on_step: - metrics[MetricSource.CALLBACK][name] = value - metrics[MetricSource.CALLBACK][forked_name] = value - - # populate progress_bar metrics. convert tensors to numbers - if result_metric.meta.prog_bar: - metrics[MetricSource.PBAR][forked_name] = metrics_to_scalars(value) - - return metrics - - def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> None: - """ - Reset the result collection - - Args: - metrics: If True, only ``torchmetrics.Metric`` results are reset, - if False, only ``torch.Tensors`` are reset, - if ``None``, both are. - fx: Function to reset - """ - - def fn(item: ResultMetric) -> None: - requested_type = metrics is None or metrics ^ item.is_tensor - same_fx = fx is None or fx == item.meta.fx - if requested_type and same_fx: - item.reset() - - apply_to_collection(self, ResultMetric, fn) - - def extract_batch_size(self, batch: Any) -> None: - try: - self.batch_size = self._extract_batch_size(batch) - except RecursionError: - self.batch_size = 1 - - def _extract_batch_size(self, batch: Any) -> int: - """ - Recursively unpack a batch to find a torch.Tensor. - - Returns: - ``len(tensor)`` when found, or ``1`` when it hits an empty or non iterable. - """ - if isinstance(batch, torch.Tensor): - size = batch.size(0) - elif isinstance(batch, str): - return len(batch) - elif isinstance(batch, dict): - sample = next(iter(batch.values()), 1) - size = self._extract_batch_size(sample) - elif isinstance(batch, Iterable): - sample = next(iter(batch), 1) - size = self._extract_batch_size(sample) - else: - size = 1 - return size - - def to(self, *args, **kwargs) -> 'ResultCollection': - """Move all data to the given device.""" - - def to_(item: Union[torch.Tensor, Metric], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Metric]: - return item.to(*args, **kwargs) - - apply_to_collection(self, (torch.Tensor, Metric), to_, *args, **kwargs) - - if self.minimize is not None: - self.minimize = self.minimize.to(*args, **kwargs) - self._batch_size = self._batch_size.to(*args, **kwargs) - if 'device' in kwargs: - self.device = kwargs['device'] - return self - - def cpu(self) -> 'ResultCollection': - """Move all data to CPU.""" - return self.to(device="cpu") - - def __str__(self) -> str: - return f'{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})' - - def __getstate__(self) -> dict: - d = self.__dict__.copy() - # can't deepcopy tensors with grad_fn - minimize = d.get('_minimize') - if minimize is not None: - d['_minimize'] = minimize.detach() - return d From e2f69cea340204db3b3f5966cd45c92da3ecb14d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 9 Jun 2021 00:20:34 +0200 Subject: [PATCH 389/455] Update imports --- .../logger_connector/logger_connector.py | 2 +- tests/core/test_metric_result_integration.py | 2 +- tests/core/test_results.py | 2 +- tests/models/test_tpu.py | 2 +- .../trainer/logging_/test_logger_connector.py | 29 ++----------------- 5 files changed, 6 insertions(+), 31 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 83dfa3294b218..0ebd388b26d2c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -20,7 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.core import memory from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger -from pytorch_lightning.trainer.connectors.logger_connector.result_new import _METRIC, MetricSource +from pytorch_lightning.trainer.connectors.logger_connector.result import _METRIC, MetricSource from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import DeviceType from pytorch_lightning.utilities.metrics import metrics_to_scalars diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 8a636a0b15dd1..270cfff30797f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -17,7 +17,7 @@ from torchmetrics import Metric import tests.helpers.utils as tutils -from pytorch_lightning.trainer.connectors.logger_connector.result_new import MetricSource, ResultCollection +from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource, ResultCollection from tests.helpers.runif import RunIf diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 59f01184ba33e..5fffb64331ae4 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -21,7 +21,7 @@ import tests.helpers.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.trainer.connectors.logger_connector.result_new import _Sync +from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync from pytorch_lightning.utilities.distributed import sync_ddp_if_available from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index b178126c4f81b..2e7db175801b9 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -25,7 +25,7 @@ from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.plugins import TPUSpawnPlugin -from pytorch_lightning.trainer.connectors.logger_connector.result_new import _Sync +from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index e4e2dfbc24713..06413838340b5 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -18,11 +18,11 @@ from torch.utils.data import DataLoader from torchmetrics import Accuracy, AveragePrecision -from pytorch_lightning import LightningModule, seed_everything +from pytorch_lightning import LightningModule from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator -from pytorch_lightning.trainer.connectors.logger_connector.result_new import MetricSource, ResultCollection +from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource, ResultCollection from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset @@ -203,31 +203,6 @@ def test_dataloader(self): trainer.test(model, ckpt_path=None) -def test_metric_holder_raises(tmpdir): - """Check that an error is raised when trying to convert non-scalar tensors""" - - class TestModel(BoringModel): - - def validation_step(self, batch, *args, **kwargs): - output = self(batch) - self.log('test', output) - - def test_step(self, *args, **kwargs): - return self.validation_step(*args, **kwargs) - - model = TestModel() - model.validation_epoch_end = None - model.test_epoch_end = None - - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - - match = "The metric `.*` does not contain a single element" - with pytest.raises(MisconfigurationException, match=match): - trainer.validate(model) - with pytest.raises(MisconfigurationException, match=match): - trainer.test(model) - - def test_can_return_tensor_with_more_than_one_element(tmpdir): """Ensure {validation,test}_step return values are not included as callback metrics. #6623""" From 6037833ae6511ed26d0eb3605fdc6f1ab856b9d5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 9 Jun 2021 00:21:03 +0200 Subject: [PATCH 390/455] Update after rename --- .../trainer/connectors/logger_connector/logger_connector.py | 2 +- pytorch_lightning/trainer/trainer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 0ebd388b26d2c..67b4e88fff3a1 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -140,7 +140,7 @@ def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None # track batch_size - self.trainer.result_collection.extract_batch_size(batch) + self.trainer.results.extract_batch_size(batch) self._batch_idx = batch_idx def update_eval_step_metrics(self) -> None: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b6088f045566a..5fa6d40251f2d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -986,7 +986,7 @@ def _run_evaluation(self) -> _EVALUATE_OUTPUT: self.evaluation_loop.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx) # log batch metrics - self.logger_connector.update_evaluation_step_metrics() + self.logger_connector.update_eval_step_metrics() # track epoch level outputs dl_outputs = self._track_output_for_epoch_end(dl_outputs, output) @@ -1011,7 +1011,7 @@ def _run_evaluation(self) -> _EVALUATE_OUTPUT: self.evaluation_loop.on_evaluation_epoch_end() # log epoch metrics - eval_loop_results = self.logger_connector.get_evaluate_epoch_results() + eval_loop_results = self.logger_connector.update_eval_epoch_metrics() # hook self.evaluation_loop.on_evaluation_end() From 499da7641f2ba6cd7d758cd414cfd88f1dffbc73 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 9 Jun 2021 00:56:09 +0200 Subject: [PATCH 391/455] Refactor reduce_fx and op --- pytorch_lightning/core/lightning.py | 19 ++------ .../connectors/logger_connector/result.py | 44 ++++++++++++------- .../logging_/test_train_loop_logging.py | 5 ++- 3 files changed, 35 insertions(+), 33 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a9a096c8db5a6..f07f7d2f7a08d 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -263,7 +263,7 @@ def log( logger: bool = True, on_step: Optional[bool] = None, on_epoch: Optional[bool] = None, - reduce_fx: Callable = torch.mean, + reduce_fx: Union[str, Callable] = 'mean', tbptt_reduce_fx: Optional = None, # noqa: Remove in 1.6 tbptt_pad_token: Optional = None, # noqa: Remove in 1.6 enable_graph: bool = False, @@ -376,12 +376,6 @@ def log( # when restarting an new epoch, reset the tensors results.reset(metrics=False, fx=self._current_fx_name) - if isinstance(sync_dist_op, str): - sync_dist_op = sync_dist_op.lower() - if sync_dist_op == "avg": - sync_dist_op = 'mean' - reduce_fx = self.__check_sync_dist_op(sync_dist_op, reduce_fx) - results.log( self._current_fx_name, name, @@ -410,7 +404,7 @@ def log_dict( logger: bool = True, on_step: Optional[bool] = None, on_epoch: Optional[bool] = None, - reduce_fx: Callable = torch.mean, + reduce_fx: Union[str, Callable] = 'mean', tbptt_reduce_fx: Optional = None, # noqa: Remove in 1.6 tbptt_pad_token: Optional = None, # noqa: Remove in 1.6 enable_graph: bool = False, @@ -462,7 +456,7 @@ def log_dict( ) @staticmethod - def __check_not_nested(value: dict, name: str) -> None: + def __check_not_nested(value: dict, name: str) -> dict: # self-imposed restriction. for simplicity if any(isinstance(v, dict) for v in value.values()): raise ValueError(f'`self.log({name}, {value})` was called, but nested dictionaries cannot be logged') @@ -475,13 +469,6 @@ def __check_allowed(v: Any, name: str, value: Any) -> None: def __to_float(self, value: numbers.Number) -> torch.Tensor: return torch.tensor(value, device=self.device, dtype=torch.float) - @staticmethod - def __check_sync_dist_op(sync_dist_op: str, fx: Callable) -> Callable: - torch_fx = getattr(torch, sync_dist_op) - if getattr(torch, sync_dist_op) != fx: - return torch_fx - return fx - def log_grad_norm(self, grad_norm_dict: Dict[str, torch.Tensor]) -> None: """Override this method to change the default behaviour of ``log_grad_norm``. diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 6bb478b18789f..a20b6d931c04d 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -45,6 +45,13 @@ class _Sync: op: Union[Any, str] = 'mean' group: Optional[Any] = None + def __post_init__(self) -> None: + if isinstance(self.op, str): + op = self.op.lower() + if op == 'avg': + op = 'mean' + self.op = op + @property def __call__(self) -> Any: return partial(self.fn, reduce_op=self.op, group=self.group) if self.should else self.no_op @@ -68,6 +75,22 @@ class _Metadata: metric_attribute: Optional[str] = None sync: _Sync = field(default_factory=_Sync) + def __post_init__(self) -> None: + error = ( + 'Only `self.log(..., reduce_fx={min,max,mean,sum})` are currently supported.' + ' Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`.' + f' Found: {self.reduce_fx}' + ) + if isinstance(self.reduce_fx, str): + reduce_fx = self.reduce_fx.lower() + if reduce_fx == 'avg': + reduce_fx = 'mean' + if reduce_fx not in ('min', 'max', 'mean', 'sum'): + raise MisconfigurationException(error) + self.reduce_fx = getattr(torch, reduce_fx) + elif self.is_custom_reduction: + raise MisconfigurationException(error) + @property def forked(self) -> bool: return self.on_step and self.on_epoch @@ -79,19 +102,19 @@ def forked_name(self, on_step: bool) -> str: @property def is_mean_reduction(self) -> bool: - return self.reduce_fx == torch.mean + return self.reduce_fx is torch.mean @property def is_sum_reduction(self) -> bool: - return self.reduce_fx in (torch.sum, sum, "sum") + return self.reduce_fx in (torch.sum, sum) @property def is_max_reduction(self) -> bool: - return self.reduce_fx in (torch.max, max, 'max') + return self.reduce_fx in (torch.max, max) @property def is_min_reduction(self) -> bool: - return self.reduce_fx in (torch.min, min, 'min') + return self.reduce_fx in (torch.min, min) @property def is_custom_reduction(self) -> bool: @@ -137,13 +160,8 @@ def compute(self) -> torch.Tensor: if self.meta.is_mean_reduction: cumulated_batch_size = self.meta.sync(self.cumulated_batch_size) return value / cumulated_batch_size - elif self.meta.is_sum_reduction: - return value - elif self.meta.is_max_reduction or self.meta.is_min_reduction: + elif self.meta.is_max_reduction or self.meta.is_min_reduction or self.meta.is_sum_reduction: return value - raise MisconfigurationException( - f"Only [min, max, mean] reductions are supported. Found {self.meta.reduce_fx}" - ) return self.value.compute() def reset(self) -> None: @@ -328,12 +346,8 @@ def log( group=sync_dist_group, ) ) + if key not in self: - if meta.is_custom_reduction: - raise MisconfigurationException( - 'Only `self.log(..., reduce_fx={min,max,mean,sum})` are currently supported.' - ' Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`' - ) self.register_key(key, meta, value) elif meta != self[key].meta: raise MisconfigurationException( diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index ef3ef492ab37c..95e0815eb91eb 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -193,7 +193,8 @@ def training_epoch_end(self, outputs): assert set(trainer.callback_metrics) == (logged_metrics | pbar_metrics | {'a', 'b'}) - {'epoch'} -@pytest.mark.parametrize(['batches', 'fx', 'result'], [(3, min, 0), (3, max, 2), (11, max, 10)]) +@pytest.mark.parametrize(['batches', 'fx', 'result'], [(3, min, 0), (3, torch.max, 2), (11, max, 10), (5, 'avg', 2), + (5, 'SUM', 10)]) def test__training_step__log_max_reduce_fx(tmpdir, batches, fx, result): """ Tests that log works correctly with different tensor types @@ -779,5 +780,5 @@ def training_step(self, *args): trainer = Trainer(default_root_dir=tmpdir) model = TestModel() - with pytest.raises(MisconfigurationException, match=r'reduce_fx={min,max,mean}\)` are currently supported'): + with pytest.raises(MisconfigurationException, match=r'reduce_fx={min,max,mean,sum}\)` are currently supported'): trainer.fit(model) From 6eb448aac3e4ab690ae104244339df221304d98d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 9 Jun 2021 00:58:07 +0200 Subject: [PATCH 392/455] Fix test after rename --- tests/trainer/loops/test_evaluation_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py index 6e601b577d648..769cf07366b82 100644 --- a/tests/trainer/loops/test_evaluation_loop.py +++ b/tests/trainer/loops/test_evaluation_loop.py @@ -47,12 +47,12 @@ def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir): @mock.patch( - "pytorch_lightning.trainer.connectors.logger_connector.logger_connector.LoggerConnector.get_evaluate_epoch_results" + "pytorch_lightning.trainer.connectors.logger_connector.logger_connector.LoggerConnector.update_eval_epoch_metrics" ) -def test_log_epoch_metrics_before_on_evaluation_end(get_evaluate_epoch_results_mock, tmpdir): +def test_log_epoch_metrics_before_on_evaluation_end(update_eval_epoch_metrics_mock, tmpdir): """Test that the epoch metrics are logged before the `on_evalutaion_end` hook is fired""" order = [] - get_evaluate_epoch_results_mock.side_effect = lambda: order.append("log_epoch_metrics") + update_eval_epoch_metrics_mock.side_effect = lambda: order.append("log_epoch_metrics") class LessBoringModel(BoringModel): From f871cbdb3dfef3641160edb6ebf66370a55e7cec Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 9 Jun 2021 01:25:37 +0200 Subject: [PATCH 393/455] mypy --- pytorch_lightning/trainer/evaluation_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index beaabd5eb8345..cd196d4629984 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -43,6 +43,7 @@ def results(self) -> Optional[ResultCollection]: return self._val_results elif self.trainer.testing: return self._test_results + return None def on_trainer_init(self) -> None: self.trainer.num_sanity_val_batches = [] From d10d5c7d53811a48994a9836f320059b2bb2e6dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 9 Jun 2021 02:06:37 +0200 Subject: [PATCH 394/455] integrate latest changes from logger connector poc --- pytorch_lightning/loops/batch_loop.py | 12 ++++++------ pytorch_lightning/loops/epoch_loop.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/batch_loop.py index 1a1a73ebaea8a..d7cf6692c5380 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/batch_loop.py @@ -247,11 +247,11 @@ def _process_training_step_output(self, training_step_output): if training_step_output is None: return None - result = self.trainer.result_collection + results = self.trainer.results loss = None hiddens = None - result.extra = {} + results.extra = {} # handle dict return if isinstance(training_step_output, dict): @@ -259,20 +259,20 @@ def _process_training_step_output(self, training_step_output): hiddens = training_step_output.pop("hiddens", None) if hiddens is not None: hiddens = hiddens.detach() - result.extra = training_step_output + results.extra = training_step_output # handle scalar return elif isinstance(training_step_output, torch.Tensor): loss = training_step_output # map to results under the hood - result.minimize = loss + results.minimize = loss self._hiddens = hiddens if self.trainer.move_metrics_to_cpu: - result = result.cpu() + results = results.cpu() - return result + return results def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): model_ref = self.trainer.lightning_module diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epoch_loop.py index 3bfae16fee247..644e8ac3e6b40 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epoch_loop.py @@ -36,7 +36,7 @@ def __init__(self, min_epochs, max_epochs, min_steps, max_steps): self.training_loop = TrainingLoop(min_steps, max_steps) - self.train_results = ResultCollection(True) + self.results = ResultCollection(True) @property def current_epoch(self) -> int: @@ -130,7 +130,7 @@ def run(self): return super().run() def on_run_start(self): - self.trainer.result_collection.device = self.trainer.lightning_module.device + self.trainer.results.to(device=self.trainer.lightning_module.device) self.trainer.call_hook("on_train_start") def on_advance_start(self): # equal to old on_train_epoch_start From 7b6803a6a8c5da2badbaeb1d2f103990f79f4a4d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 9 Jun 2021 02:18:12 +0200 Subject: [PATCH 395/455] Fix test --- tests/trainer/logging_/test_train_loop_logging.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 95e0815eb91eb..e621ee1fe5d94 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -493,14 +493,14 @@ class TestLoggingSyncDistModel(BoringModel): def training_step(self, batch, batch_idx): acc = self.step(batch[0]) - self.log('foo', 1, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='SUM') + self.log('foo', 1, on_step=False, on_epoch=True, sync_dist=True, reduce_fx='SUM') self.log('cho', acc, on_step=False, on_epoch=True) return acc def validation_step(self, batch, batch_idx): output = self.layer(batch) loss = self.loss(batch, output) - self.log('bar', 2, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='AVG') + self.log('bar', 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx='AVG') return {"x": loss} model = TestLoggingSyncDistModel() From 9bfedc906d03421ebbf1c93df66b9e07edd82fa0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 9 Jun 2021 02:40:10 +0200 Subject: [PATCH 396/455] Refactor test --- .../logging_/test_train_loop_logging.py | 54 ++++--------------- 1 file changed, 9 insertions(+), 45 deletions(-) diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index e621ee1fe5d94..5c959833dee1e 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -448,25 +448,23 @@ def get_expected(on_epoch, values): assert is_included if should_include else not is_included -def test_logging_sync_dist_true_cpu(tmpdir): +@pytest.mark.parametrize('gpus', [None, pytest.param(1, marks=RunIf(min_gpus=1))]) +def test_logging_sync_dist_true_cpu(tmpdir, gpus): """ - Tests to ensure that the sync_dist flag works with CPU (should just return the original value) + Tests to ensure that the sync_dist flag works (should just return the original value) """ fake_result = 1 class TestModel(BoringModel): def training_step(self, batch, batch_idx): - acc = self.step(batch[0]) - self.log('foo', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum') - self.log('foo_2', 2, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum') - return acc + self.log('foo', fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx='sum') + self.log('foo_2', 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx='sum') + return super().training_step(batch, batch_idx) def validation_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - self.log('bar', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum') - return {"x": loss} + self.log('bar', fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx='sum') + return super().validation_step(batch, batch_idx) model = TestModel() trainer = Trainer( @@ -475,6 +473,7 @@ def validation_step(self, batch, batch_idx): limit_val_batches=1, max_epochs=2, weights_summary=None, + gpus=gpus, ) trainer.fit(model) @@ -520,41 +519,6 @@ def validation_step(self, batch, batch_idx): assert trainer.logged_metrics['bar'] == 2 -@RunIf(min_gpus=1) -def test_logging_sync_dist_true_gpu(tmpdir): - """ - Tests to ensure that the sync_dist flag works with GPU (should just return the original value) - """ - fake_result = 1 - - class TestModel(BoringModel): - - def training_step(self, batch, batch_idx): - acc = self.step(batch[0]) - self.log('foo', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum') - return acc - - def validation_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - self.log('bar', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum') - return {"x": loss} - - model = TestModel() - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=1, - limit_val_batches=1, - max_epochs=2, - gpus=1, - weights_summary=None, - ) - trainer.fit(model) - - assert trainer.logged_metrics['foo'] == fake_result - assert trainer.logged_metrics['bar'] == fake_result - - def test_progress_bar_dict_contains_values_on_train_epoch_end(tmpdir): class TestModel(BoringModel): From c9c7829015f7df4e79315e83cff94419a5304a31 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 9 Jun 2021 03:02:33 +0200 Subject: [PATCH 397/455] Deprecate `self.log(sync_dist_op)` in favor of `self.log(reduce_fx)` --- CHANGELOG.md | 3 +++ pytorch_lightning/core/lightning.py | 20 ++++++++++++------- .../connectors/logger_connector/result.py | 12 ++--------- tests/deprecated_api/test_remove_1-6.py | 13 ++++++++++++ tests/models/test_horovod.py | 2 +- .../logging_/test_train_loop_logging.py | 2 +- 6 files changed, 33 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a3f58312726f..3e8aaf0b27511 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -167,6 +167,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `num_nodes` and `sync_batchnorm` arguments in `DDPPlugin` and `DDPSpawnPlugin` ([#7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026)) +- Deprecated `self.log(sync_dist_op)` in favor of `self.log(reduce_fx)`. ([#7631](https://github.com/PyTorchLightning/pytorch-lightning/pull/7631)) + + ### Removed - Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f07f7d2f7a08d..d7b8cde7dd196 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -263,12 +263,12 @@ def log( logger: bool = True, on_step: Optional[bool] = None, on_epoch: Optional[bool] = None, - reduce_fx: Union[str, Callable] = 'mean', + reduce_fx: Union[str, Callable] = 'default', # TODO: change to 'mean' when `sync_dist_op` is removed in 1.6 tbptt_reduce_fx: Optional = None, # noqa: Remove in 1.6 tbptt_pad_token: Optional = None, # noqa: Remove in 1.6 enable_graph: bool = False, sync_dist: bool = False, - sync_dist_op: Union[Any, str] = 'mean', + sync_dist_op: Optional = None, # noqa: Remove in 1.6 sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, batch_size: Optional[int] = None, @@ -304,7 +304,6 @@ def log( reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default. enable_graph: if True, will not auto detach the graph sync_dist: if True, reduces the metric across GPUs/TPUs - sync_dist_op: the op to sync across GPUs/TPUs sync_dist_group: the ddp group to sync across add_dataloader_idx: if True, appends the index of the current dataloader to the name (when using multiple). If False, user needs to give unique names for @@ -326,6 +325,15 @@ def log( ' Please, open a discussion explaining your use-case in' ' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`' ) + if sync_dist_op is not None: + rank_zero_deprecation( + f"`self.log(sync_dist_op='{sync_dist_op}')` is deprecated and will be removed in v.1.6." + f" Use `self.log(reduce_fx={sync_dist_op})` instead." + ) + if reduce_fx == 'default': + reduce_fx = sync_dist_op + elif reduce_fx == 'default': + reduce_fx = 'mean' # check for invalid values apply_to_collection(value, dict, self.__check_not_nested, name) @@ -391,7 +399,6 @@ def log( metric_attribute=metric_attribute, sync_dist=sync_dist, sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp_if_available, - sync_dist_op=sync_dist_op, sync_dist_group=sync_dist_group, ) @@ -404,12 +411,12 @@ def log_dict( logger: bool = True, on_step: Optional[bool] = None, on_epoch: Optional[bool] = None, - reduce_fx: Union[str, Callable] = 'mean', + reduce_fx: Union[str, Callable] = 'default', # TODO: change to 'mean' when `sync_dist_op` is removed in 1.6 tbptt_reduce_fx: Optional = None, # noqa: Remove in 1.6 tbptt_pad_token: Optional = None, # noqa: Remove in 1.6 enable_graph: bool = False, sync_dist: bool = False, - sync_dist_op: Union[Any, str] = 'mean', + sync_dist_op: Optional = None, # noqa: Remove in 1.6 sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, ) -> None: @@ -431,7 +438,6 @@ def log_dict( reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default. enable_graph: if True, will not auto detach the graph sync_dist: if True, reduces the metric across GPUs/TPUs - sync_dist_op: the op to sync across GPUs/TPUs sync_dist_group: the ddp group sync across add_dataloader_idx: if True, appends the index of the current dataloader to the name (when using multiple). If False, user needs to give unique names for diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index a20b6d931c04d..dd917ac4082eb 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -42,16 +42,9 @@ class MetricSource(LightningEnum): class _Sync: fn: Callable should: bool = False - op: Union[Any, str] = 'mean' + op: Optional[str] = field(init=False, default=None) group: Optional[Any] = None - def __post_init__(self) -> None: - if isinstance(self.op, str): - op = self.op.lower() - if op == 'avg': - op = 'mean' - self.op = op - @property def __call__(self) -> Any: return partial(self.fn, reduce_op=self.op, group=self.group) if self.should else self.no_op @@ -90,6 +83,7 @@ def __post_init__(self) -> None: self.reduce_fx = getattr(torch, reduce_fx) elif self.is_custom_reduction: raise MisconfigurationException(error) + self.sync.op = self.reduce_fx.__name__ @property def forked(self) -> bool: @@ -306,7 +300,6 @@ def log( enable_graph: bool = False, sync_dist: bool = False, sync_dist_fn: Callable = _Sync.no_op, - sync_dist_op: Union[Any, str] = 'mean', sync_dist_group: Optional[Any] = None, dataloader_idx: Optional[int] = None, batch_size: Optional[int] = None, @@ -342,7 +335,6 @@ def log( sync=_Sync( should=sync_dist, fn=sync_dist_fn, - op=sync_dist_op, group=sync_dist_group, ) ) diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 7ca0939fd60d2..7a92501caee4a 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -86,3 +86,16 @@ def training_step(self, *args): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) with pytest.deprecated_call(match=r"tbptt_pad_token=...\)` is no longer supported"): trainer.fit(TestModel()) + + +def test_v1_6_0_sync_dist_op(tmpdir): + + class TestModel(BoringModel): + + def training_step(self, *args): + self.log("foo", 1, sync_dist_op='sum') + return super().training_step(*args) + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + with pytest.deprecated_call(match=r"`self.log\(sync_dist_op='sum'\)` is deprecated"): + trainer.fit(TestModel()) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 10f96845a7a48..a37935cb1b5de 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -296,7 +296,7 @@ def training_step(self, batch, batch_idx): self.training_step_called = True tensor = torch.tensor([1.0]) - self.log("test_tensor", tensor, sync_dist=True, sync_dist_op='sum', on_step=True, on_epoch=True) + self.log("test_tensor", tensor, sync_dist=True, reduce_fx='sum', on_step=True, on_epoch=True) res = self._results diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 5c959833dee1e..323362c9e8c2e 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -534,7 +534,7 @@ def on_train_epoch_end(self, *_): prog_bar=True, on_epoch=True, sync_dist=True, - sync_dist_op='sum' + reduce_fx='sum' ) self.on_train_epoch_end_called = True From e3dde0ba798d9074bca262db503593d0dddaf114 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 9 Jun 2021 03:21:55 +0200 Subject: [PATCH 398/455] Undo field --- pytorch_lightning/trainer/connectors/logger_connector/result.py | 2 +- tests/core/test_results.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index dd917ac4082eb..0470a25a75f46 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -42,7 +42,7 @@ class MetricSource(LightningEnum): class _Sync: fn: Callable should: bool = False - op: Optional[str] = field(init=False, default=None) + op: Optional[str] = None group: Optional[Any] = None @property diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 5fffb64331ae4..e2e3c892cc124 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -39,7 +39,7 @@ def _setup_ddp(rank, worldsize): def _ddp_test_fn(rank, worldsize): _setup_ddp(rank, worldsize) tensor = torch.tensor([1.0]) - sync = _Sync(sync_ddp_if_available, should=True, op=torch.distributed.ReduceOp.SUM) + sync = _Sync(sync_ddp_if_available, should=True, op='SUM') actual = sync(tensor) assert actual.item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors" From 2c167cc9c2ae17452fd31f0676b68f41ad6b0913 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 9 Jun 2021 16:59:25 +0200 Subject: [PATCH 399/455] rename rename files and classes --- .../loops/{epoch_loop.py => epochs_loop.py} | 6 +++--- .../loops/{batch_loop.py => training_batch_loop.py} | 2 +- .../{training_loop.py => training_epoch_loop.py} | 6 +++--- pytorch_lightning/trainer/properties.py | 12 ++++++------ pytorch_lightning/trainer/trainer.py | 4 ++-- 5 files changed, 15 insertions(+), 15 deletions(-) rename pytorch_lightning/loops/{epoch_loop.py => epochs_loop.py} (98%) rename pytorch_lightning/loops/{batch_loop.py => training_batch_loop.py} (99%) rename pytorch_lightning/loops/{training_loop.py => training_epoch_loop.py} (99%) diff --git a/pytorch_lightning/loops/epoch_loop.py b/pytorch_lightning/loops/epochs_loop.py similarity index 98% rename from pytorch_lightning/loops/epoch_loop.py rename to pytorch_lightning/loops/epochs_loop.py index 644e8ac3e6b40..78a935446dfa1 100644 --- a/pytorch_lightning/loops/epoch_loop.py +++ b/pytorch_lightning/loops/epochs_loop.py @@ -9,7 +9,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.loops.base import Loop -from pytorch_lightning.loops.training_loop import TrainingLoop +from pytorch_lightning.loops.training_epoch_loop import TrainingEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_info @@ -22,7 +22,7 @@ # TODO: typing -class EpochLoop(Loop): +class EpochsLoop(Loop): def __init__(self, min_epochs, max_epochs, min_steps, max_steps): super().__init__() @@ -34,7 +34,7 @@ def __init__(self, min_epochs, max_epochs, min_steps, max_steps): # If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1 self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs - self.training_loop = TrainingLoop(min_steps, max_steps) + self.training_loop = TrainingEpochLoop(min_steps, max_steps) self.results = ResultCollection(True) diff --git a/pytorch_lightning/loops/batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py similarity index 99% rename from pytorch_lightning/loops/batch_loop.py rename to pytorch_lightning/loops/training_batch_loop.py index d7cf6692c5380..7952f26af1689 100644 --- a/pytorch_lightning/loops/batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -21,7 +21,7 @@ # TODO: typing -class BatchLoop(Loop): +class TrainingBatchLoop(Loop): """ Runs over a single batch of data. """ def __init__(self): diff --git a/pytorch_lightning/loops/training_loop.py b/pytorch_lightning/loops/training_epoch_loop.py similarity index 99% rename from pytorch_lightning/loops/training_loop.py rename to pytorch_lightning/loops/training_epoch_loop.py index 9214c22af4594..7075703200c25 100644 --- a/pytorch_lightning/loops/training_loop.py +++ b/pytorch_lightning/loops/training_epoch_loop.py @@ -2,7 +2,7 @@ import pytorch_lightning as pl from pytorch_lightning.loops.base import Loop -from pytorch_lightning.loops.batch_loop import BatchLoop +from pytorch_lightning.loops.training_batch_loop import TrainingBatchLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -11,7 +11,7 @@ # TODO: typing -class TrainingLoop(Loop): +class TrainingEpochLoop(Loop): """ Runs over all batches in a dataloader (one epoch). """ def __init__(self, min_steps, max_steps): @@ -50,7 +50,7 @@ def done(self): def connect(self, trainer: 'pl.Trainer', *args, **kwargs): self.trainer = trainer - self.batch_loop = BatchLoop() + self.batch_loop = TrainingBatchLoop() self.batch_loop.connect(trainer) def run(self, *args, **kwargs): diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 9c80c601ef81b..7d22f370dcfa0 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -29,9 +29,9 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger -from pytorch_lightning.loops.batch_loop import BatchLoop -from pytorch_lightning.loops.epoch_loop import EpochLoop -from pytorch_lightning.loops.training_loop import TrainingLoop +from pytorch_lightning.loops.training_batch_loop import TrainingBatchLoop +from pytorch_lightning.loops.epochs_loop import EpochsLoop +from pytorch_lightning.loops.training_epoch_loop import TrainingEpochLoop from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector @@ -494,16 +494,16 @@ def active_loop(self) -> Optional[Union[TrainLoop, EvaluationLoop]]: return self.evaluation_loop @property - def epoch_loop(self) -> EpochLoop: + def epoch_loop(self) -> EpochsLoop: # TODO: the current train_loop should be renamed to epoch_loop return self.train_loop @property - def training_loop(self) -> TrainingLoop: + def training_loop(self) -> TrainingEpochLoop: return self.epoch_loop.training_loop @property - def batch_loop(self) -> BatchLoop: + def batch_loop(self) -> TrainingBatchLoop: return self.epoch_loop.training_loop.batch_loop @property diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b5cd200c4b554..93d1414a36be9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -29,7 +29,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.loops.epoch_loop import EpochLoop +from pytorch_lightning.loops.epochs_loop import EpochsLoop from pytorch_lightning.plugins import Plugin from pytorch_lightning.plugins.environments import ClusterEnvironment from pytorch_lightning.profiler import ( @@ -339,7 +339,7 @@ def __init__( self.tuner = Tuner(self) if NEW_LOOP: - self.train_loop = EpochLoop(min_epochs, max_epochs, min_steps, max_steps) + self.train_loop = EpochsLoop(min_epochs, max_epochs, min_steps, max_steps) self.train_loop.connect(self) else: # old loops: From 99db49755390fbd559742144820c926436e121e7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Jun 2021 17:19:25 +0000 Subject: [PATCH 400/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/properties.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 7d22f370dcfa0..8e5634ccaa777 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -29,8 +29,8 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger -from pytorch_lightning.loops.training_batch_loop import TrainingBatchLoop from pytorch_lightning.loops.epochs_loop import EpochsLoop +from pytorch_lightning.loops.training_batch_loop import TrainingBatchLoop from pytorch_lightning.loops.training_epoch_loop import TrainingEpochLoop from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector From 832dfb999d40e7d7bea2544c3e4ba18f35891e1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 9 Jun 2021 23:43:13 +0200 Subject: [PATCH 401/455] rename --- pytorch_lightning/core/lightning.py | 4 ++-- .../loops/{epochs_loop.py => fit_loop.py} | 6 +++--- pytorch_lightning/trainer/properties.py | 14 +++++++------- pytorch_lightning/trainer/trainer.py | 4 ++-- tests/deprecated_api/test_remove_1-5.py | 2 +- tests/trainer/loops/test_evaluation_loop_flow.py | 8 ++++---- .../loops/test_training_loop_flow_scalar.py | 12 ++++++------ tests/trainer/test_trainer.py | 16 ++++++++-------- 8 files changed, 33 insertions(+), 33 deletions(-) rename pytorch_lightning/loops/{epochs_loop.py => fit_loop.py} (98%) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ef730b8182a92..3aae86dfe241f 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1371,7 +1371,7 @@ def training_step(...): # backward self._running_manual_backward = True - self.trainer.batch_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs) + self.trainer.training_batch_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs) self._running_manual_backward = False def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None: @@ -1470,7 +1470,7 @@ def optimizer_step( If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter to ``optimizer.step()`` function as shown in the examples. This ensures that ``training_step()``, ``optimizer.zero_grad()``, ``backward()`` are called within - :meth:`~pytorch_lightning.trainer.training_loop.TrainLoop.run_training_batch`. + :meth:`~pytorch_lightning.trainer.training_batch_loop.TrainingBatchLoop.advance`. Args: epoch: Current epoch diff --git a/pytorch_lightning/loops/epochs_loop.py b/pytorch_lightning/loops/fit_loop.py similarity index 98% rename from pytorch_lightning/loops/epochs_loop.py rename to pytorch_lightning/loops/fit_loop.py index 78a935446dfa1..4d96a8a1e41a6 100644 --- a/pytorch_lightning/loops/epochs_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -22,7 +22,7 @@ # TODO: typing -class EpochsLoop(Loop): +class FitLoop(Loop): def __init__(self, min_epochs, max_epochs, min_steps, max_steps): super().__init__() @@ -106,7 +106,7 @@ def done(self) -> bool: met_min_steps = self.global_step >= self.min_steps if self.min_steps else True if met_min_epochs and met_min_steps: # TODO: THIS is now in on_run_end, always run? - # self.training_loop.on_train_end() + # self.training_epoch_loop.on_train_end() should_stop = True else: log.info( @@ -181,7 +181,7 @@ def advance(self): def on_advance_end(self): # # handle epoch_output on epoch end - # self.on_train_epoch_end(outputs) # Handled in on_run_end of training_loop now + # self.on_train_epoch_end(outputs) # Handled in on_run_end of training_epoch_loop now if self.training_loop.batches_seen == 0: return diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 8e5634ccaa777..7053d6797f4ec 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -29,7 +29,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger -from pytorch_lightning.loops.epochs_loop import EpochsLoop +from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.loops.training_batch_loop import TrainingBatchLoop from pytorch_lightning.loops.training_epoch_loop import TrainingEpochLoop from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin @@ -494,17 +494,17 @@ def active_loop(self) -> Optional[Union[TrainLoop, EvaluationLoop]]: return self.evaluation_loop @property - def epoch_loop(self) -> EpochsLoop: - # TODO: the current train_loop should be renamed to epoch_loop + def fit_loop(self) -> FitLoop: + # TODO: the current train_loop should be renamed to fit_loop return self.train_loop @property - def training_loop(self) -> TrainingEpochLoop: - return self.epoch_loop.training_loop + def training_epoch_loop(self) -> TrainingEpochLoop: + return self.fit_loop.training_loop @property - def batch_loop(self) -> TrainingBatchLoop: - return self.epoch_loop.training_loop.batch_loop + def training_batch_loop(self) -> TrainingBatchLoop: + return self.fit_loop.training_loop.batch_loop @property def global_step(self) -> int: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 93d1414a36be9..53e87c6c02145 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -29,7 +29,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.loops.epochs_loop import EpochsLoop +from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.plugins import Plugin from pytorch_lightning.plugins.environments import ClusterEnvironment from pytorch_lightning.profiler import ( @@ -339,7 +339,7 @@ def __init__( self.tuner = Tuner(self) if NEW_LOOP: - self.train_loop = EpochsLoop(min_epochs, max_epochs, min_steps, max_steps) + self.train_loop = FitLoop(min_epochs, max_epochs, min_steps, max_steps) self.train_loop.connect(self) else: # old loops: diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 4edf42854fabb..dcf5138e295eb 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -242,7 +242,7 @@ def on_train_epoch_end(self, outputs): # noqa with pytest.deprecated_call(match="old signature will be removed in v1.5"): trainer.fit(model) - trainer.training_loop.warning_cache.clear() + trainer.training_epoch_loop.warning_cache.clear() class NewSignature(Callback): diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py index 0f9bf7ca9e68c..8896a4d186d34 100644 --- a/tests/trainer/loops/test_evaluation_loop_flow.py +++ b/tests/trainer/loops/test_evaluation_loop_flow.py @@ -69,7 +69,7 @@ def backward(self, loss, optimizer, optimizer_idx): # simulate training manually trainer.state.stage = RunningStage.TRAINING batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.batch_loop.run(batch, batch_idx, 0) + out = trainer.training_batch_loop.run(batch, batch_idx, 0) assert out.signal == 0 train_step_out = out.training_step_output @@ -79,7 +79,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.batch_loop.training_step_and_backward( + opt_closure_result = trainer.training_batch_loop.training_step_and_backward( batch, batch_idx, 0, @@ -140,7 +140,7 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.batch_loop.run(batch, batch_idx, 0) + out = trainer.training_batch_loop.run(batch, batch_idx, 0) assert out.signal == 0 train_step_out = out.training_step_output @@ -150,7 +150,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.batch_loop.training_step_and_backward( + opt_closure_result = trainer.training_batch_loop.training_step_and_backward( batch, batch_idx, 0, trainer.optimizers[0], hiddens=None ) assert opt_closure_result['loss'].item() == 171 diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index 63e0482468c39..8ba64fc920c3a 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -153,7 +153,7 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.batch_loop.run(batch, batch_idx, 0) + out = trainer.training_batch_loop.run(batch, batch_idx, 0) assert out.signal == 0 train_step_out = out.training_step_output @@ -163,7 +163,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.batch_loop.training_step_and_backward( + opt_closure_result = trainer.training_batch_loop.training_step_and_backward( batch, batch_idx, 0, @@ -231,7 +231,7 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.batch_loop.run(batch, batch_idx, 0) + out = trainer.training_batch_loop.run(batch, batch_idx, 0) assert out.signal == 0 train_step_out = out.training_step_output @@ -241,7 +241,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.batch_loop.training_step_and_backward( + opt_closure_result = trainer.training_batch_loop.training_step_and_backward( batch, batch_idx, 0, trainer.optimizers[0], hiddens=None ) assert opt_closure_result['loss'].item() == 171 @@ -317,7 +317,7 @@ def training_step(self, batch, batch_idx): # manually check a few batches for batch_idx, batch in enumerate(model.train_dataloader()): - out = trainer.batch_loop.run(batch, batch_idx, 0) + out = trainer.training_batch_loop.run(batch, batch_idx, 0) if not batch_idx % 2: assert out.training_step_output == [[]] assert out.signal == 0 @@ -362,7 +362,7 @@ def train_dataloader(self): # manually check a few batches for batch_idx, batch in enumerate(model.train_dataloader()): - out = trainer.batch_loop.run(batch, batch_idx, 0) + out = trainer.training_batch_loop.run(batch, batch_idx, 0) if not batch_idx % 2: assert out.training_step_output == [[]] assert out.signal == 0 diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index aa4a78bf67912..ed4325979bb7b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -894,7 +894,7 @@ def test_gradient_clipping(tmpdir): default_root_dir=tmpdir, ) - old_training_step_and_backward = trainer.batch_loop.training_step_and_backward + old_training_step_and_backward = trainer.training_batch_loop.training_step_and_backward def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): """ @@ -908,7 +908,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde return ret_val - trainer.batch_loop.training_step_and_backward = training_step_and_backward + trainer.training_batch_loop.training_step_and_backward = training_step_and_backward # for the test model.prev_called_batch_idx = 0 @@ -932,7 +932,7 @@ def test_gradient_clipping_by_value(tmpdir): default_root_dir=tmpdir ) - old_training_step_and_backward = trainer.batch_loop.training_step_and_backward + old_training_step_and_backward = trainer.training_batch_loop.training_step_and_backward def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): """ @@ -948,7 +948,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde return ret_val - trainer.batch_loop.training_step_and_backward = training_step_and_backward + trainer.training_batch_loop.training_step_and_backward = training_step_and_backward # for the test model.prev_called_batch_idx = 0 @@ -973,7 +973,7 @@ def test_gradient_clipping_fp16(tmpdir): default_root_dir=tmpdir, ) - old_training_step_and_backward = trainer.batch_loop.training_step_and_backward + old_training_step_and_backward = trainer.training_batch_loop.training_step_and_backward def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): """ @@ -987,7 +987,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde return ret_val - trainer.batch_loop.training_step_and_backward = training_step_and_backward + trainer.training_batch_loop.training_step_and_backward = training_step_and_backward model.prev_called_batch_idx = 0 trainer.fit(model) @@ -1012,7 +1012,7 @@ def test_gradient_clipping_by_value_fp16(tmpdir): default_root_dir=tmpdir, ) - old_training_step_and_backward = trainer.batch_loop.training_step_and_backward + old_training_step_and_backward = trainer.training_batch_loop.training_step_and_backward def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): """ @@ -1028,7 +1028,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde return ret_val - trainer.batch_loop.training_step_and_backward = training_step_and_backward + trainer.training_batch_loop.training_step_and_backward = training_step_and_backward model.prev_called_batch_idx = 0 trainer.fit(model) From f92e01d1e40026ad9548190ab1e41b1e17957ddf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 9 Jun 2021 23:50:13 +0200 Subject: [PATCH 402/455] imports --- pytorch_lightning/loops/__init__.py | 19 +++++++++++++++ pytorch_lightning/loops/base.py | 14 +++++++++++ pytorch_lightning/loops/fit_loop.py | 24 ++++++++++++------- .../loops/training_batch_loop.py | 16 +++++++++++-- .../loops/training_epoch_loop.py | 15 +++++++++++- 5 files changed, 76 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/loops/__init__.py b/pytorch_lightning/loops/__init__.py index e69de29bb2d1d..8d6ab8f2c9bf0 100644 --- a/pytorch_lightning/loops/__init__.py +++ b/pytorch_lightning/loops/__init__.py @@ -0,0 +1,19 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.loops.base import Loop +from pytorch_lightning.loops.fit_loop import FitLoop +from pytorch_lightning.loops.training_batch_loop import TrainingBatchLoop +from pytorch_lightning.loops.training_epoch_loop import TrainingEpochLoop + diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 94c20ebfa67c0..c0128a449210d 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -1,3 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from abc import ABC, abstractmethod from typing import Any, Optional from weakref import proxy diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 4d96a8a1e41a6..1b89d0baef147 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -1,27 +1,33 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging from contextlib import suppress -from copy import deepcopy -from typing import Any, List, Optional, Tuple +from typing import List, Optional, Tuple -import torch from torch.optim import Optimizer import pytorch_lightning as pl -from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.training_epoch_loop import TrainingEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_info -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.parsing import AttributeDict -from pytorch_lightning.utilities.warnings import WarningCache log = logging.getLogger(__name__) -# TODO: typing class FitLoop(Loop): def __init__(self, min_epochs, max_epochs, min_steps, max_steps): diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index 7952f26af1689..113931da54dcd 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -1,6 +1,19 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from collections import OrderedDict from contextlib import contextmanager -from copy import copy from functools import partial, update_wrapper from typing import Any, Callable, List, Mapping, Optional, Tuple @@ -20,7 +33,6 @@ from pytorch_lightning.utilities.warnings import WarningCache -# TODO: typing class TrainingBatchLoop(Loop): """ Runs over a single batch of data. """ diff --git a/pytorch_lightning/loops/training_epoch_loop.py b/pytorch_lightning/loops/training_epoch_loop.py index 7075703200c25..bdccf4988eb22 100644 --- a/pytorch_lightning/loops/training_epoch_loop.py +++ b/pytorch_lightning/loops/training_epoch_loop.py @@ -1,3 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Dict, Iterator, List, Union import pytorch_lightning as pl @@ -10,7 +24,6 @@ from pytorch_lightning.utilities.warnings import WarningCache -# TODO: typing class TrainingEpochLoop(Loop): """ Runs over all batches in a dataloader (one epoch). """ From b15fc345d040773cc151112a056416825faf87e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 10 Jun 2021 00:09:19 +0200 Subject: [PATCH 403/455] loop hygiene --- pytorch_lightning/loops/fit_loop.py | 20 ++++------- .../loops/training_batch_loop.py | 36 ++++++++----------- .../loops/training_epoch_loop.py | 23 +++--------- 3 files changed, 25 insertions(+), 54 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 1b89d0baef147..35ea94692c88c 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -30,18 +30,16 @@ class FitLoop(Loop): - def __init__(self, min_epochs, max_epochs, min_steps, max_steps): + def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None, min_steps: Optional[int] = None, max_steps: Optional[int] = None): super().__init__() self._teardown_already_run = False - # TODO: Move this to trainer (it's a trainer default, loops shouldn't have to care about this # If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000 self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs # If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1 self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs self.training_loop = TrainingEpochLoop(min_steps, max_steps) - self.results = ResultCollection(True) @property @@ -82,7 +80,7 @@ def max_steps(self): @max_steps.setter def max_steps(self, value): - # TODO: This setter is required by debugging connector (fast dev run) + # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided self.training_loop.max_steps = value @property @@ -101,7 +99,7 @@ def skip_backward(self, value: bool): @property def done(self) -> bool: - # TODO: Move track steps inside training loop and move part of these condition inside training loop + # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop stop_steps = self.max_steps is not None and self.global_step >= self.max_steps stop_epochs = self.max_epochs is not None and self.current_epoch >= self.max_epochs @@ -111,8 +109,6 @@ def done(self) -> bool: met_min_epochs = self.current_epoch >= self.min_epochs if self.min_epochs else True met_min_steps = self.global_step >= self.min_steps if self.min_steps else True if met_min_epochs and met_min_steps: - # TODO: THIS is now in on_run_end, always run? - # self.training_epoch_loop.on_train_end() should_stop = True else: log.info( @@ -139,14 +135,14 @@ def on_run_start(self): self.trainer.results.to(device=self.trainer.lightning_module.device) self.trainer.call_hook("on_train_start") - def on_advance_start(self): # equal to old on_train_epoch_start + def on_advance_start(self): model = self.trainer.lightning_module # reset train dataloader if self.current_epoch != 0 and self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_train_dataloader(model) - # todo: specify the possible exception + # TODO: specify the possible exception with suppress(Exception): # set seed for distributed sampler (enables shuffling for each epoch) self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch) @@ -186,9 +182,6 @@ def advance(self): self.global_step += 1 def on_advance_end(self): - # # handle epoch_output on epoch end - # self.on_train_epoch_end(outputs) # Handled in on_run_end of training_epoch_loop now - if self.training_loop.batches_seen == 0: return @@ -202,7 +195,6 @@ def on_advance_end(self): self.check_checkpoint_callback(True) self.global_step += 1 - # why is this not the same as the old on_train_epoch_end? def on_run_end(self): if self._teardown_already_run: return @@ -248,7 +240,7 @@ def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[i return self.training_loop.batch_loop.get_active_optimizers(batch_idx) def check_checkpoint_callback(self, should_update, is_last=False): - # TODO bake this logic into the ModelCheckpoint callback + # TODO: bake this logic into the ModelCheckpoint callback if should_update and self.trainer.checkpoint_connector.has_trained: callbacks = self.trainer.checkpoint_callbacks diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index 113931da54dcd..88ee9ef75c424 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -39,12 +39,14 @@ class TrainingBatchLoop(Loop): def __init__(self): super().__init__() self.accumulated_loss = None + self.batch_outputs = None self.running_loss = TensorRunningAccum(window_length=20) self.split_idx = None self.warning_cache = WarningCache() self._hiddens = None self._optimizer_freq_cumsum = None + self._remaining_splits = None self._skip_backward = False @property @@ -61,6 +63,12 @@ def skip_backward(self, value: bool): """ Determines whether the loop will skip backward during automatic optimization. """ self._skip_backward = value + @property + def optimizer_freq_cumsum(self): + if self._optimizer_freq_cumsum is None: + self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) + return self._optimizer_freq_cumsum + def connect(self, trainer, *args, **kwargs): self.trainer = trainer @@ -85,10 +93,8 @@ def run(self, batch, batch_idx, dataloader_idx): return AttributeDict(signal=0, training_step_output=self.batch_outputs) def reset(self) -> None: - # self.iteration_count = 0 - self._hiddens = None - # TODO: let loops track individual outputs + # TODO(@awaelchli): let loops track individual outputs self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] def on_run_start(self, batch, batch_idx, dataloader_idx): @@ -101,10 +107,6 @@ def advance(self, batch, batch_idx, dataloader_idx): # let logger connector extract current batch size self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch) - # TODO: this list needs to go outside this loop - # batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] - grad_norm_dict = {} - if self.trainer.lightning_module.automatic_optimization: for opt_idx, optimizer in self.get_active_optimizers(batch_idx): result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer) @@ -116,22 +118,11 @@ def advance(self, batch, batch_idx, dataloader_idx): if result: self.batch_outputs[0].append(result.training_step_output) - -# ------------------------------------------------------------------------------------------------------------ -# HELPER --- TO BE CLEANED UP -# ------------------------------------------------------------------------------------------------------------ - def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int: return len(self.get_active_optimizers(batch_idx)) - @property - def optimizer_freq_cumsum(self): - if self._optimizer_freq_cumsum is None: - self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) - return self._optimizer_freq_cumsum - def _run_optimization(self, batch_idx, split_batch, opt_idx=0, optimizer=None): - # TODO: In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change + # TODO(@awaelchli): In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change # opt_idx=0 to opt_idx=None in the signature here # toggle model params @@ -334,12 +325,13 @@ def track_and_norm_grad(self, optimizer) -> dict: return grad_norm_dict def _accumulated_batches_reached(self): - # TODO: use progress tracking of batches instead of iteration count, because iteration count may reset - # iteration count is required to be global here, not reset + # TODO(@awaelchli): use progress tracking of batches instead of iteration count, because iteration count may + # reset iteration count is required to be global here, not reset return self.iteration_count % self.trainer.accumulate_grad_batches == 0 def _num_training_batches_reached(self, is_last_batch=False): - # TODO: use progress tracking of batches instead of iteration count, because iteration count may reset + # TODO(@awaelchli): use progress tracking of batches instead of iteration count, because iteration + # count may reset return (self.iteration_count + 1) == self.trainer.num_training_batches or is_last_batch def should_accumulate(self): diff --git a/pytorch_lightning/loops/training_epoch_loop.py b/pytorch_lightning/loops/training_epoch_loop.py index bdccf4988eb22..a3d2e6f03f9fd 100644 --- a/pytorch_lightning/loops/training_epoch_loop.py +++ b/pytorch_lightning/loops/training_epoch_loop.py @@ -29,8 +29,6 @@ class TrainingEpochLoop(Loop): def __init__(self, min_steps, max_steps): super().__init__() - # cache of all outputs in a single training run / epoch - # self.epoch_output = [[]] self.min_steps = min_steps self.max_steps = max_steps @@ -49,6 +47,7 @@ def __init__(self, min_steps, max_steps): self.is_last_batch = None self.batches_seen = 0 self.warning_cache = WarningCache() + self.epoch_output = None self.batch_loop = None @@ -70,7 +69,7 @@ def run(self, *args, **kwargs): self.reset() self.on_run_start() - # TODO: while condition is different from super.run(), + # TODO(@awaelchli): while condition is different from super.run(), # redesign the done conditions and use the base class run() implementation while True: try: @@ -94,7 +93,6 @@ def reset(self) -> None: self.epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] def advance(self, dataloader_iter: Iterator, **kwargs): - # TODO: profiling is gone _, (batch, is_last) = next(dataloader_iter) self.is_last_batch = is_last @@ -150,7 +148,6 @@ def on_advance_end(self): if self.done: raise StopIteration - # this is the old on train_epoch_end? def on_run_end(self): if self.batches_seen == 0: # dataloader/iterator did not produce a batch @@ -185,11 +182,6 @@ def on_run_end(self): self.trainer.logger_connector.on_epoch_end() return self.epoch_output - -# ------------------------------------------------------------------------------------------------------------ -# HELPER --- TO BE CLEANED UP -# ------------------------------------------------------------------------------------------------------------ - def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: # We cannot rely on Trainer.call_hook because the signatures might be different across # lightning module and callback @@ -233,15 +225,10 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: def _num_training_batches_reached(self, is_last_batch=False): return self.batches_seen == self.trainer.num_training_batches or is_last_batch - # TODO move to on_advance_end() ?? + # TODO(@awaelchli): merge with on_advance_end() def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx): - - # epoch output : [[] ... ] - # batch_end_outputs[0][0] = Result obj - batch_end_outputs = [opt_idx_out for opt_idx_out in batch_end_outputs if len(opt_idx_out)] - - processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True) # dict with loss + processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True) # hook self.trainer.call_hook('on_train_batch_end', processed_batch_end_outputs, batch, batch_idx, dataloader_idx) @@ -377,7 +364,7 @@ def should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: if self.trainer.should_stop: return True - # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch + # TODO(awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch is_val_check_batch = is_last_batch if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset: is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 From 7175a50e067d04fe6e20daa3dfff3ff455e88774 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 10 Jun 2021 00:09:42 +0200 Subject: [PATCH 404/455] yapf on loops --- pytorch_lightning/loops/__init__.py | 1 - pytorch_lightning/loops/fit_loop.py | 8 +++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/__init__.py b/pytorch_lightning/loops/__init__.py index 8d6ab8f2c9bf0..c9e25f212af2f 100644 --- a/pytorch_lightning/loops/__init__.py +++ b/pytorch_lightning/loops/__init__.py @@ -16,4 +16,3 @@ from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.loops.training_batch_loop import TrainingBatchLoop from pytorch_lightning.loops.training_epoch_loop import TrainingEpochLoop - diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 35ea94692c88c..df6741bbf4fc1 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -30,7 +30,13 @@ class FitLoop(Loop): - def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None, min_steps: Optional[int] = None, max_steps: Optional[int] = None): + def __init__( + self, + min_epochs: Optional[int] = None, + max_epochs: Optional[int] = None, + min_steps: Optional[int] = None, + max_steps: Optional[int] = None + ): super().__init__() self._teardown_already_run = False From 59d6227c9369dc478df06a603b9d172d0ee03775 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 10 Jun 2021 00:13:35 +0200 Subject: [PATCH 405/455] protected new loop trigger --- pytorch_lightning/trainer/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 53e87c6c02145..cb8e6b21c1687 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -81,7 +81,7 @@ 'please use torch.distributed.ReduceOp instead' ) -NEW_LOOP = True +_NEW_LOOP = True class Trainer( @@ -338,7 +338,7 @@ def __init__( self.slurm_connector = SLURMConnector(self) self.tuner = Tuner(self) - if NEW_LOOP: + if _NEW_LOOP: self.train_loop = FitLoop(min_epochs, max_epochs, min_steps, max_steps) self.train_loop.connect(self) else: @@ -907,7 +907,7 @@ def reset_train_val_dataloaders(self, model) -> None: self.reset_val_dataloader(model) def _run_train(self) -> None: - if NEW_LOOP: + if _NEW_LOOP: self._run_train_new_loop() else: self._run_train_old_loop() From a7c3555f965928b39526cbce6ca77b2a6fe022b8 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 10 Jun 2021 00:35:10 +0200 Subject: [PATCH 406/455] Replace code --- .../logger_connector/logger_connector.py | 22 +- .../logger_connector/logger_connector_new.py | 312 ----------- .../connectors/logger_connector/result.py | 6 +- .../connectors/logger_connector/result_new.py | 509 ------------------ 4 files changed, 13 insertions(+), 836 deletions(-) delete mode 100644 pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py delete mode 100644 pytorch_lightning/trainer/connectors/logger_connector/result_new.py diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 67b4e88fff3a1..058c7575cb3fd 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -20,14 +20,15 @@ import pytorch_lightning as pl from pytorch_lightning.core import memory from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger -from pytorch_lightning.trainer.connectors.logger_connector.result import _METRIC, MetricSource +from pytorch_lightning.trainer.connectors.logger_connector.result_new import _METRIC, MetricSource from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import DeviceType from pytorch_lightning.utilities.metrics import metrics_to_scalars from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT -class LoggerConnector: +# TODO(@carmocca): Remove `New` suffix +class LoggerConnectorNew: def __init__(self, trainer: 'pl.Trainer', log_gpu_memory: Optional[str] = None) -> None: self.trainer = trainer @@ -140,7 +141,7 @@ def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None # track batch_size - self.trainer.results.extract_batch_size(batch) + self.trainer._results.extract_batch_size(batch) self._batch_idx = batch_idx def update_eval_step_metrics(self) -> None: @@ -209,7 +210,7 @@ def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT: """ def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None: - self.trainer.results.extract_batch_size(split_batch) + self.trainer._results.extract_batch_size(split_batch) self._batch_idx = batch_idx self._split_idx = split_idx @@ -231,7 +232,7 @@ def update_train_epoch_metrics(self) -> None: self.log_metrics(metrics) # reset result collection for next epoch - self.trainer.results.reset(metrics=True) + self.trainer._results.reset(metrics=True) """ Utilities and properties @@ -272,7 +273,7 @@ def should_reset_tensors(self, fx: str) -> bool: return is_different_fx and is_first_batch def reset(self, metrics: Optional[bool] = None) -> None: - self.trainer.results.reset(metrics=metrics) + self.trainer._results.reset(metrics=metrics) self._batch_idx = None self._split_idx = None self._current_fx = None @@ -281,30 +282,31 @@ def reset(self, metrics: Optional[bool] = None) -> None: def metrics(self) -> Dict[MetricSource, Dict[str, _METRIC]]: """This function returns either batch or epoch metrics depending on ``_epoch_end_reached``.""" on_step = not self._epoch_end_reached - return self.trainer.results.metrics(on_step) + return self.trainer._results.metrics(on_step) @property def callback_metrics(self) -> Dict[str, _METRIC]: - if self.trainer.results: + if self.trainer._results: metrics = self.metrics[MetricSource.CALLBACK] self._callback_metrics.update(metrics) return self._callback_metrics @property def logged_metrics(self) -> Dict[str, _METRIC]: - if self.trainer.results: + if self.trainer._results: metrics = self.metrics[MetricSource.LOG] self._logged_metrics.update(metrics) return self._logged_metrics @property def progress_bar_metrics(self) -> Dict[str, float]: - if self.trainer.results: + if self.trainer._results: metrics = self.metrics[MetricSource.PBAR] self._progress_bar_metrics.update(metrics) return self._progress_bar_metrics def teardown(self): + # TODO(@awaelchli): This should be handled by the loops themselves self.trainer.train_loop.results.cpu() self.trainer.evaluation_loop._val_results.cpu() self.trainer.evaluation_loop._test_results.cpu() diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py deleted file mode 100644 index 058c7575cb3fd..0000000000000 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py +++ /dev/null @@ -1,312 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from pprint import pprint -from typing import Any, Dict, Iterable, Mapping, Optional, Union - -import torch - -import pytorch_lightning as pl -from pytorch_lightning.core import memory -from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger -from pytorch_lightning.trainer.connectors.logger_connector.result_new import _METRIC, MetricSource -from pytorch_lightning.trainer.states import RunningStage, TrainerFn -from pytorch_lightning.utilities import DeviceType -from pytorch_lightning.utilities.metrics import metrics_to_scalars -from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT - - -# TODO(@carmocca): Remove `New` suffix -class LoggerConnectorNew: - - def __init__(self, trainer: 'pl.Trainer', log_gpu_memory: Optional[str] = None) -> None: - self.trainer = trainer - self.log_gpu_memory = log_gpu_memory - self.eval_loop_results = [] - self._val_log_step: int = 0 - self._test_log_step: int = 0 - self._progress_bar_metrics: Dict[str, float] = {} - self._logged_metrics: Dict[str, _METRIC] = {} - self._callback_metrics: Dict[str, _METRIC] = {} - self._epoch_end_reached = False - self._current_fx: Optional[str] = None - self._batch_idx: Optional[int] = None - self._split_idx: Optional[int] = None - - def on_trainer_init( - self, - logger: LightningLoggerBase, - flush_logs_every_n_steps: int, - log_every_n_steps: int, - move_metrics_to_cpu: bool, - ) -> None: - self.configure_logger(logger) - self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps - self.trainer.log_every_n_steps = log_every_n_steps - self.trainer.move_metrics_to_cpu = move_metrics_to_cpu - - @property - def should_flush_logs(self) -> bool: - should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0 - return should_flush or self.trainer.should_stop - - @property - def should_update_logs(self) -> bool: - should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 - return should_log_every_n_steps or self.trainer.should_stop - - def configure_logger(self, logger: Union[bool, Iterable, LightningLoggerBase]) -> None: - if logger is True: - version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id) - - # default logger - self.trainer.logger = TensorBoardLogger( - save_dir=self.trainer.default_root_dir, version=version, name='lightning_logs' - ) - elif logger is False: - self.trainer.logger = None - else: - if isinstance(logger, Iterable): - self.trainer.logger = LoggerCollection(logger) - else: - self.trainer.logger = logger - - def log_metrics(self, metrics: Dict[str, _METRIC], step: Optional[int] = None) -> None: - """Logs the metric dict passed in. - If `step` parameter is None and `step` key is presented is metrics, - uses metrics["step"] as a step - - Args: - metrics: Metric values - step: Step for which metrics should be logged. Default value is `self.global_step` during training or - the total validation / test log step count during validation and testing. - """ - # add gpu memory - if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory: - mem_map = memory.get_memory_profile(self.log_gpu_memory) - metrics.update(mem_map) - - # turn all tensors to scalars - scalar_metrics = metrics_to_scalars(metrics) - - if "step" in scalar_metrics and step is None: - step = scalar_metrics.pop("step") - - elif step is None: - # added metrics by Lightning for convenience - scalar_metrics['epoch'] = self.trainer.current_epoch - step = self.trainer.global_step - - # log actual metrics - if self.trainer.logger is not None: - if self.trainer.is_global_zero: - self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step) - self.trainer.logger.save() - - self._logged_metrics.update(scalar_metrics) - - """ - Evaluation metric updates - """ - - @property - def _eval_log_step(self) -> Optional[int]: - if self.trainer.state.stage is RunningStage.VALIDATING: - return self._val_log_step - elif self.trainer.state.stage is RunningStage.TESTING: - return self._test_log_step - else: - return None - - def _increment_eval_log_step(self) -> None: - if self.trainer.state.stage is RunningStage.VALIDATING: - self._val_log_step += 1 - elif self.trainer.state.stage is RunningStage.TESTING: - self._test_log_step += 1 - - def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: - model = self.trainer.lightning_module - # set dataloader_idx only if multiple ones - model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None - - # track batch_size - self.trainer._results.extract_batch_size(batch) - self._batch_idx = batch_idx - - def update_eval_step_metrics(self) -> None: - if self.trainer.sanity_checking: - return - - # logs user requested information to logger - assert not self._epoch_end_reached - metrics = self.metrics[MetricSource.LOG] - if metrics: - self.log_metrics(metrics, step=self._eval_log_step) - - # increment the step even if nothing was logged - self._increment_eval_log_step() - - def _prepare_eval_loop_results(self, metrics: Mapping[str, _METRIC]) -> None: - if self.trainer.sanity_checking: - return - - num_dataloaders = self.trainer.evaluation_loop.num_dataloaders - has_been_initialized = len(self.eval_loop_results) == num_dataloaders - for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders): - # remove callback metrics that don't belong to this dataloader - callback_metrics = { - k: v - for k, v in metrics.items() if "dataloader_idx" not in k or f"dataloader_idx_{dl_idx}" in k - } - if has_been_initialized: - self.eval_loop_results[dl_idx].update(callback_metrics) - else: - self.eval_loop_results.append(callback_metrics) - - def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT: - assert self._epoch_end_reached - metrics = self.metrics - - if not self.trainer.sanity_checking: - # log all the metrics as a single dict - log_metrics = metrics[MetricSource.LOG] - if log_metrics: - self.log_metrics(log_metrics) - - self._prepare_eval_loop_results(metrics[MetricSource.CALLBACK]) - - # log results of evaluation - if ( - self.trainer.state.fn != TrainerFn.FITTING and self.trainer.evaluating and self.trainer.is_global_zero - and self.trainer.verbose_evaluate - ): - print('-' * 80) - for result_idx, results in enumerate(self.eval_loop_results): - print(f'DATALOADER:{result_idx} {self.trainer.state.stage.upper()} RESULTS') - pprint({ - k: (v.item() if v.numel() == 1 else v.tolist()) if isinstance(v, torch.Tensor) else v - for k, v in results.items() - }) - print('-' * 80) - - results = self.eval_loop_results - # clear mem - self.eval_loop_results = [] - return results - - """ - Train metric updates - """ - - def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None: - self.trainer._results.extract_batch_size(split_batch) - self._batch_idx = batch_idx - self._split_idx = split_idx - - def update_train_step_metrics(self) -> None: - if self.trainer.train_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization: - return - - # when metrics should be logged - assert not self._epoch_end_reached - metrics = self.metrics[MetricSource.LOG] - if self.should_update_logs or self.trainer.fast_dev_run is True and metrics: - self.log_metrics(metrics) - - def update_train_epoch_metrics(self) -> None: - # add the metrics to the loggers - assert self._epoch_end_reached - metrics = self.metrics[MetricSource.LOG] - if metrics: - self.log_metrics(metrics) - - # reset result collection for next epoch - self.trainer._results.reset(metrics=True) - - """ - Utilities and properties - """ - - def on_epoch_start(self) -> None: - self._epoch_end_reached = False - - def on_batch_start(self) -> None: - self._epoch_end_reached = False - - def epoch_end_reached(self): - self.trainer.logger_connector._epoch_end_reached = True - self.trainer.logger_connector._batch_idx = None - self.trainer.logger_connector._split_idx = None - - def on_epoch_end(self) -> None: - assert self._epoch_end_reached - metrics = self.metrics - self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) - self._callback_metrics.update(metrics[MetricSource.CALLBACK]) - self._logged_metrics.update(metrics[MetricSource.LOG]) - self._current_fx = None - - def on_batch_end(self) -> None: - assert not self._epoch_end_reached - metrics = self.metrics - self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) - self._callback_metrics.update(metrics[MetricSource.CALLBACK]) - self._logged_metrics.update(metrics[MetricSource.LOG]) - - def should_reset_tensors(self, fx: str) -> bool: - is_different_fx = self._current_fx != fx - if self._split_idx is None: - is_first_batch = self._batch_idx in (None, 0) - else: - is_first_batch = self._batch_idx + self._split_idx == 0 - return is_different_fx and is_first_batch - - def reset(self, metrics: Optional[bool] = None) -> None: - self.trainer._results.reset(metrics=metrics) - self._batch_idx = None - self._split_idx = None - self._current_fx = None - - @property - def metrics(self) -> Dict[MetricSource, Dict[str, _METRIC]]: - """This function returns either batch or epoch metrics depending on ``_epoch_end_reached``.""" - on_step = not self._epoch_end_reached - return self.trainer._results.metrics(on_step) - - @property - def callback_metrics(self) -> Dict[str, _METRIC]: - if self.trainer._results: - metrics = self.metrics[MetricSource.CALLBACK] - self._callback_metrics.update(metrics) - return self._callback_metrics - - @property - def logged_metrics(self) -> Dict[str, _METRIC]: - if self.trainer._results: - metrics = self.metrics[MetricSource.LOG] - self._logged_metrics.update(metrics) - return self._logged_metrics - - @property - def progress_bar_metrics(self) -> Dict[str, float]: - if self.trainer._results: - metrics = self.metrics[MetricSource.PBAR] - self._progress_bar_metrics.update(metrics) - return self._progress_bar_metrics - - def teardown(self): - # TODO(@awaelchli): This should be handled by the loops themselves - self.trainer.train_loop.results.cpu() - self.trainer.evaluation_loop._val_results.cpu() - self.trainer.evaluation_loop._test_results.cpu() diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 0470a25a75f46..bf155ad4210e5 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -65,7 +65,6 @@ class _Metadata: reduce_fx: Union[str, Callable] = torch.mean enable_graph: bool = False dataloader_idx: Optional[int] = None - metric_attribute: Optional[str] = None sync: _Sync = field(default_factory=_Sync) def __post_init__(self) -> None: @@ -225,7 +224,7 @@ class ResultCollection(dict): Example: # `device` needs to be provided before logging - result = ResultCollection(True, torch.device("cpu")) + result = ResultCollection(training=True, torch.device("cpu")) # you can log to a specific collection. # arguments: fx, key, value, metadata @@ -303,7 +302,6 @@ def log( sync_dist_group: Optional[Any] = None, dataloader_idx: Optional[int] = None, batch_size: Optional[int] = None, - metric_attribute: Optional[str] = None, ) -> None: """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" # no metrics should be logged with graphs @@ -331,14 +329,12 @@ def log( reduce_fx=reduce_fx, enable_graph=enable_graph, dataloader_idx=dataloader_idx, - metric_attribute=metric_attribute, sync=_Sync( should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, ) ) - if key not in self: self.register_key(key, meta, value) elif meta != self[key].meta: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result_new.py b/pytorch_lightning/trainer/connectors/logger_connector/result_new.py deleted file mode 100644 index bf155ad4210e5..0000000000000 --- a/pytorch_lightning/trainer/connectors/logger_connector/result_new.py +++ /dev/null @@ -1,509 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from collections.abc import Generator -from dataclasses import dataclass, field -from functools import partial, wraps -from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union - -import torch -from torchmetrics import Metric - -from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator -from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections -from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin -from pytorch_lightning.utilities.enums import LightningEnum -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.metrics import metrics_to_scalars - -# re-define the ones from pytorch_lightning.utilities.types without the `Number` type -_METRIC = Union[Metric, torch.Tensor] -_METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]] - - -class MetricSource(LightningEnum): - CALLBACK = "callback" - PBAR = "pbar" - LOG = "log" - - -@dataclass -class _Sync: - fn: Callable - should: bool = False - op: Optional[str] = None - group: Optional[Any] = None - - @property - def __call__(self) -> Any: - return partial(self.fn, reduce_op=self.op, group=self.group) if self.should else self.no_op - - @staticmethod - def no_op(value: Any, *_, **__) -> Any: - return value - - -@dataclass -class _Metadata: - fx: str - name: str - prog_bar: bool = False - logger: bool = True - on_step: bool = False - on_epoch: bool = True - reduce_fx: Union[str, Callable] = torch.mean - enable_graph: bool = False - dataloader_idx: Optional[int] = None - sync: _Sync = field(default_factory=_Sync) - - def __post_init__(self) -> None: - error = ( - 'Only `self.log(..., reduce_fx={min,max,mean,sum})` are currently supported.' - ' Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`.' - f' Found: {self.reduce_fx}' - ) - if isinstance(self.reduce_fx, str): - reduce_fx = self.reduce_fx.lower() - if reduce_fx == 'avg': - reduce_fx = 'mean' - if reduce_fx not in ('min', 'max', 'mean', 'sum'): - raise MisconfigurationException(error) - self.reduce_fx = getattr(torch, reduce_fx) - elif self.is_custom_reduction: - raise MisconfigurationException(error) - self.sync.op = self.reduce_fx.__name__ - - @property - def forked(self) -> bool: - return self.on_step and self.on_epoch - - def forked_name(self, on_step: bool) -> str: - if self.forked: - return f'{self.name}_{"step" if on_step else "epoch"}' - return self.name - - @property - def is_mean_reduction(self) -> bool: - return self.reduce_fx is torch.mean - - @property - def is_sum_reduction(self) -> bool: - return self.reduce_fx in (torch.sum, sum) - - @property - def is_max_reduction(self) -> bool: - return self.reduce_fx in (torch.max, max) - - @property - def is_min_reduction(self) -> bool: - return self.reduce_fx in (torch.min, min) - - @property - def is_custom_reduction(self) -> bool: - return not (self.is_mean_reduction or self.is_max_reduction or self.is_min_reduction or self.is_sum_reduction) - - -class ResultMetric(Metric, DeviceDtypeModuleMixin): - """Wraps the value provided to `:meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" - - def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: - super().__init__() - self.is_tensor = is_tensor - self.meta = metadata - self.has_reset = False - if is_tensor: - self.add_state("value", torch.tensor(0, dtype=torch.float)) - if self.meta.is_mean_reduction: - self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float)) - - def update(self, value: _METRIC, batch_size: torch.Tensor) -> None: - if self.is_tensor: - value = value.float() - self._forward_cache = value - # performance: no need to accumulate on values only logged on_step - if self.meta.on_step and not self.meta.on_epoch: - self.value = self.meta.sync(value) - return - # perform accumulation with reduction - if self.meta.is_mean_reduction: - self.value += value.mean() * batch_size - self.cumulated_batch_size += batch_size - elif self.meta.is_max_reduction or self.meta.is_min_reduction: - self.value = self.meta.reduce_fx(self.value, value.mean()) - elif self.meta.is_sum_reduction: - self.value += value.mean() * batch_size - else: - self.value = value # noqa: attribute-defined-outside-init - self._forward_cache = value._forward_cache - - def compute(self) -> torch.Tensor: - if self.is_tensor: - value = self.meta.sync(self.value) - if self.meta.is_mean_reduction: - cumulated_batch_size = self.meta.sync(self.cumulated_batch_size) - return value / cumulated_batch_size - elif self.meta.is_max_reduction or self.meta.is_min_reduction or self.meta.is_sum_reduction: - return value - return self.value.compute() - - def reset(self) -> None: - if self.is_tensor: - super().reset() - else: - self.value.reset() - self.has_reset = True - - def forward(self, value: _METRIC, batch_size: torch.Tensor) -> None: - if self.meta.enable_graph: - with torch.no_grad(): - self.update(value, batch_size) - else: - # performance: skip the `torch.no_grad` context manager by calling `update` directly - self.update(value, batch_size) - - def _wrap_compute(self, compute: Any) -> Any: - # Override to avoid syncing - we handle it ourselves. - @wraps(compute) - def wrapped_func(*args, **kwargs): - if not self._update_called: - rank_zero_warn( - f"The ``compute`` method of metric {self.__class__.__name__}" - " was called before the ``update`` method which may lead to errors," - " as metric states have not yet been updated.", UserWarning - ) - - # return cached value - if self._computed is not None: - return self._computed - self._computed = compute(*args, **kwargs) - return self._computed - - return wrapped_func - - def __setattr__(self, key: str, value: Any) -> None: - # performance: skip the `torch.nn.Module.__setattr__` checks - object.__setattr__(self, key, value) - - def __repr__(self) -> str: - state = f"value={self.value}" - if self.is_tensor and self.meta.is_mean_reduction: - state += f", cumulated_batch_size={self.cumulated_batch_size}" - return f"{self.__class__.__name__}({state})" - - -class ResultMetricCollection(dict): - """ - Dict wrapper for easy access to metadata. - - All of the leaf items should be instances of - :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` - with the same metadata. - """ - - def __init__(self, *args, metadata: Optional[_Metadata] = None) -> None: - super().__init__(*args) - self.meta = metadata - - -class ResultCollection(dict): - """ - Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` or - :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetricCollection` - - Example: - - # `device` needs to be provided before logging - result = ResultCollection(training=True, torch.device("cpu")) - - # you can log to a specific collection. - # arguments: fx, key, value, metadata - result.log('training_step', 'acc', torch.tensor(...), on_step=True, on_epoch=True) - result.log('validation_step', 'recall', torch.tensor(...), on_step=True, on_epoch=True) - """ - - DATALOADER_SUFFIX = "/dataloader_idx_{}" - - def __init__(self, training: bool, device: Optional[torch.device] = None) -> None: - super().__init__() - self.training = training - self._minimize = None - self._batch_size = torch.tensor(1, device=device) - self.device: Optional[Union[str, torch.device]] = device - self.fx_validator = FxValidator() - - @property - def batch_size(self) -> torch.Tensor: - # performance: cache the `batch_size` tensor instead of re-creating it - return self._batch_size - - @batch_size.setter - def batch_size(self, value: int) -> None: - self._batch_size = torch.tensor(value, device=self.device) - - @property - def minimize(self) -> Optional[torch.Tensor]: - """ - The :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` loss - will be saved as the ``minimize`` attribute. - """ - return self._minimize - - @minimize.setter - def minimize(self, loss: Optional[torch.Tensor]) -> None: - if loss is not None: - if not isinstance(loss, torch.Tensor): - raise ValueError(f"`Result.minimize` must be a `torch.Tensor`, found: {loss}") - if loss.grad_fn is None: - raise RuntimeError("`Result.minimize` must have a `grad_fn`") - self._minimize = loss - - @property - def extra(self) -> Dict[str, Any]: - """ - Extras are any keys other than the loss returned by - :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` - """ - return self.get('_extra', {}) - - @extra.setter - def extra(self, extra: Mapping[str, Any]) -> None: - - def check_fn(v): - if v.grad_fn is not None: - raise MisconfigurationException(f'You returned a tensor with `grad_fn`. The extra values are {extra}') - - apply_to_collection(extra, torch.Tensor, check_fn) - self['_extra'] = extra - - def log( - self, - fx: str, - name: str, - value: _METRIC_COLLECTION, - prog_bar: bool = False, - logger: bool = True, - on_step: bool = False, - on_epoch: bool = True, - reduce_fx: Callable = torch.mean, - enable_graph: bool = False, - sync_dist: bool = False, - sync_dist_fn: Callable = _Sync.no_op, - sync_dist_group: Optional[Any] = None, - dataloader_idx: Optional[int] = None, - batch_size: Optional[int] = None, - ) -> None: - """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" - # no metrics should be logged with graphs - if not enable_graph and isinstance(value, torch.Tensor): - value = value.detach() - - # move metrics to cpu on TPU. - if isinstance(value, torch.Tensor) and value.device.type == "xla": - value = value.cpu() - - # storage key - key = f"{fx}.{name}" - # add dataloader_suffix to both key and fx - if dataloader_idx is not None: - key += f'.{dataloader_idx}' - fx += f'.{dataloader_idx}' - - meta = _Metadata( - fx=fx, - name=name, - prog_bar=prog_bar, - logger=logger, - on_step=on_step, - on_epoch=on_epoch, - reduce_fx=reduce_fx, - enable_graph=enable_graph, - dataloader_idx=dataloader_idx, - sync=_Sync( - should=sync_dist, - fn=sync_dist_fn, - group=sync_dist_group, - ) - ) - if key not in self: - self.register_key(key, meta, value) - elif meta != self[key].meta: - raise MisconfigurationException( - f'You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed' - ) - - if batch_size is not None: - self.batch_size = batch_size - - self.update_metrics(key, value) - - def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None: - """Create one ResultMetric object per value. Value can be provided as a nested collection""" - - def fn(v: _METRIC) -> ResultMetric: - metric = ResultMetric(meta, isinstance(v, torch.Tensor)) - return metric.to(self.device) - - value = apply_to_collection(value, (torch.Tensor, Metric), fn) - if isinstance(value, dict): - value = ResultMetricCollection(value, metadata=meta) - self[key] = value - - def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None: - - def fn(result_metric, v): - # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl` - result_metric.forward(v.to(self.device), self.batch_size) - result_metric.has_reset = False - - apply_to_collections(self[key], value, ResultMetric, fn) - - @staticmethod - def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Tensor]: - cache = None - if on_step and result_metric.meta.on_step: - cache = result_metric._forward_cache - elif not on_step and result_metric.meta.on_epoch: - if not result_metric._computed: - result_metric.compute() - cache = result_metric._computed - if cache is not None and not result_metric.meta.enable_graph: - return cache.detach() - return cache - - def valid_items(self) -> Generator: - """This function is used to iterate over current valid metrics.""" - return ((k, v) for k, v in self.items() - if not k == "_extra" and not (isinstance(v, ResultMetric) and v.has_reset)) - - def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]: - name = result_metric.meta.name - forked_name = result_metric.meta.forked_name(on_step) - dl_idx = result_metric.meta.dataloader_idx - if dl_idx is not None: - dataloader_suffix = self.DATALOADER_SUFFIX.format(dl_idx) - name += dataloader_suffix - forked_name += dataloader_suffix - return name, forked_name - - def metrics(self, on_step: bool) -> Dict[MetricSource, Dict[str, _METRIC]]: - metrics = {k: {} for k in MetricSource} - - for key, result_metric in self.valid_items(): - - # extract forward_cache or computed from the ResultMetric. ignore when the output is None - value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False) - - # check if the collection is empty - has_tensor = False - - def any_tensor(_): - nonlocal has_tensor - has_tensor = True - - apply_to_collection(value, torch.Tensor, any_tensor) - if not has_tensor: - continue - - name, forked_name = self._forked_name(result_metric, on_step) - - # populate logging metrics - if result_metric.meta.logger: - metrics[MetricSource.LOG][forked_name] = value - - # populate callback metrics. callback metrics don't take `_step` forked metrics - if self.training or result_metric.meta.on_epoch and not on_step: - metrics[MetricSource.CALLBACK][name] = value - metrics[MetricSource.CALLBACK][forked_name] = value - - # populate progress_bar metrics. convert tensors to numbers - if result_metric.meta.prog_bar: - metrics[MetricSource.PBAR][forked_name] = metrics_to_scalars(value) - - return metrics - - def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> None: - """ - Reset the result collection - - Args: - metrics: If True, only ``torchmetrics.Metric`` results are reset, - if False, only ``torch.Tensors`` are reset, - if ``None``, both are. - fx: Function to reset - """ - - def fn(item: ResultMetric) -> None: - requested_type = metrics is None or metrics ^ item.is_tensor - same_fx = fx is None or fx == item.meta.fx - if requested_type and same_fx: - item.reset() - - apply_to_collection(self, ResultMetric, fn) - - def extract_batch_size(self, batch: Any) -> None: - try: - self.batch_size = self._extract_batch_size(batch) - except RecursionError: - self.batch_size = 1 - - def _extract_batch_size(self, batch: Any) -> int: - """ - Recursively unpack a batch to find a torch.Tensor. - - Returns: - ``len(tensor)`` when found, or ``1`` when it hits an empty or non iterable. - """ - if isinstance(batch, torch.Tensor): - size = batch.size(0) - elif isinstance(batch, str): - return len(batch) - elif isinstance(batch, dict): - sample = next(iter(batch.values()), 1) - size = self._extract_batch_size(sample) - elif isinstance(batch, Iterable): - sample = next(iter(batch), 1) - size = self._extract_batch_size(sample) - else: - size = 1 - return size - - def to(self, *args, **kwargs) -> 'ResultCollection': - """Move all data to the given device.""" - - def to_(item: Union[torch.Tensor, Metric], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Metric]: - return item.to(*args, **kwargs) - - apply_to_collection(self, (torch.Tensor, Metric), to_, *args, **kwargs) - - if self.minimize is not None: - self.minimize = self.minimize.to(*args, **kwargs) - self._batch_size = self._batch_size.to(*args, **kwargs) - if 'device' in kwargs: - self.device = kwargs['device'] - return self - - def cpu(self) -> 'ResultCollection': - """Move all data to CPU.""" - return self.to(device="cpu") - - def __str__(self) -> str: - return f'{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})' - - def __getstate__(self) -> dict: - d = self.__dict__.copy() - # can't deepcopy tensors with grad_fn - minimize = d.get('_minimize') - if minimize is not None: - d['_minimize'] = minimize.detach() - return d From 501224d51b58c0bd0fd1537ff99dab58b2d575d5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 10 Jun 2021 00:36:35 +0200 Subject: [PATCH 407/455] Fix names and imports --- .../trainer/connectors/logger_connector/logger_connector.py | 5 ++--- pytorch_lightning/trainer/properties.py | 2 +- pytorch_lightning/trainer/trainer.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 058c7575cb3fd..8223c56147a4a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -20,15 +20,14 @@ import pytorch_lightning as pl from pytorch_lightning.core import memory from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger -from pytorch_lightning.trainer.connectors.logger_connector.result_new import _METRIC, MetricSource +from pytorch_lightning.trainer.connectors.logger_connector.result import _METRIC, MetricSource from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import DeviceType from pytorch_lightning.utilities.metrics import metrics_to_scalars from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT -# TODO(@carmocca): Remove `New` suffix -class LoggerConnectorNew: +class LoggerConnector: def __init__(self, trainer: 'pl.Trainer', log_gpu_memory: Optional[str] = None) -> None: self.trainer = trainer diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 452eb8f481ac0..0f25acf11d821 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -60,7 +60,7 @@ class TrainerProperties(ABC): checkpoint_connector: CheckpointConnector limit_val_batches: int logger: LightningLoggerBase - logger_connector: LoggerConnectorNew + logger_connector: LoggerConnector state: TrainerState train_loop: TrainLoop evaluation_loop: EvaluationLoop diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0f9e0be737916..c042a6c6d6ade 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -326,7 +326,7 @@ def __init__( num_processes, tpu_cores, distributed_backend, auto_select_gpus, gpus, num_nodes, sync_batchnorm, benchmark, replace_sampler_ddp, deterministic, precision, amp_backend, amp_level, plugins ) - self.logger_connector = LoggerConnectorNew(self, log_gpu_memory) + self.logger_connector = LoggerConnector(self, log_gpu_memory) self.model_connector = ModelConnector(self) self.callback_connector = CallbackConnector(self) self.debugging_connector = DebuggingConnector(self) From dee7e5f9b7d03c28bf7cd842dfec20e308c1974a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 10 Jun 2021 00:37:53 +0200 Subject: [PATCH 408/455] Remove metric_attribute --- pytorch_lightning/core/lightning.py | 1 - tests/trainer/logging_/test_logger_connector.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 41a4e870be6ba..02633d3df16fa 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -111,7 +111,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._automatic_optimization: bool = True self._truncated_bptt_steps: int = 0 self._param_requires_grad_state = dict() - self._metric_attributes: Optional[Dict[int, str]] = None def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: if use_pl_optimizer: diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 06413838340b5..d93054439082b 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -322,8 +322,8 @@ def _step(self, stage, batch): acc.reset.reset_mock() ap.reset.reset_mock() - self.log(f"{stage}/accuracy", acc, metric_attribute=f"acc_{stage}") - self.log(f"{stage}/ap", ap, metric_attribute=f"ap_{stage}") + self.log(f"{stage}/accuracy", acc) + self.log(f"{stage}/ap", ap) return loss From d4bb357d1d29ed79ad74d2f84c72826a039b2b08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 10 Jun 2021 01:26:28 +0200 Subject: [PATCH 409/455] integrate latest logger connector changes --- pytorch_lightning/loops/fit_loop.py | 4 ++-- pytorch_lightning/loops/training_batch_loop.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index df6741bbf4fc1..7ed2efbbc7993 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -46,7 +46,7 @@ def __init__( self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs self.training_loop = TrainingEpochLoop(min_steps, max_steps) - self.results = ResultCollection(True) + self.results = ResultCollection(training=True) @property def current_epoch(self) -> int: @@ -138,7 +138,7 @@ def run(self): return super().run() def on_run_start(self): - self.trainer.results.to(device=self.trainer.lightning_module.device) + self.results.to(device=self.trainer.lightning_module.device) self.trainer.call_hook("on_train_start") def on_advance_start(self): diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index 88ee9ef75c424..25a7462dd9b49 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -250,7 +250,7 @@ def _process_training_step_output(self, training_step_output): if training_step_output is None: return None - results = self.trainer.results + results = self.trainer._results loss = None hiddens = None @@ -273,8 +273,7 @@ def _process_training_step_output(self, training_step_output): self._hiddens = hiddens if self.trainer.move_metrics_to_cpu: - results = results.cpu() - + results.cpu() return results def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): From c9b4e9e6d4002fbd89843bc0afbd9b7d24f89f70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 10 Jun 2021 02:18:48 +0200 Subject: [PATCH 410/455] resolve todo dataloading reset --- pytorch_lightning/trainer/data_loading.py | 13 +++++++++++++ pytorch_lightning/trainer/properties.py | 2 +- pytorch_lightning/trainer/trainer.py | 16 +--------------- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 77835a19769b4..c9b8a6f29652b 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -444,6 +444,19 @@ def reset_predict_dataloader(self, model) -> None: if has_loader: self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(model, 'predict') + def reset_train_val_dataloaders(self, model) -> None: + """ + Resets train and val dataloaders if none are attached to the trainer. + + The val dataloader must be initialized before training loop starts, as the training loop + inspects the val dataloader to determine whether to run the evaluation loop. + """ + if self.train_dataloader is None: + self.reset_train_dataloader(model) + + if self.val_dataloaders is None: + self.reset_val_dataloader(model) + def request_dataloader(self, model: LightningModule, stage: str) -> DataLoader: """Handles downloading data in the GPU or TPU case. diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index c4bf173084aaf..1f1977badb2e2 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -488,7 +488,7 @@ def sanity_checking(self, val: bool) -> None: @property def fit_loop(self) -> FitLoop: - # TODO: the current train_loop should be renamed to fit_loop + # TODO(@awaelchli): the current train_loop should be renamed to fit_loop return self.train_loop @property diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2d7046201a264..2bbfee820f770 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -893,26 +893,13 @@ def _pre_training_routine(self): self.on_pretrain_routine_end() ref_model.on_pretrain_routine_end() - def reset_train_val_dataloaders(self, model) -> None: - """ - Resets train and val dataloaders if none are attached to the trainer. - - The val dataloader must be initialized before training loop starts, as the training loop - inspects the val dataloader to determine whether to run the evaluation loop. - """ - if self.train_dataloader is None: - self.reset_train_dataloader(model) - - if self.val_dataloaders is None: - self.reset_val_dataloader(model) - def _run_train(self) -> None: if _NEW_LOOP: self._run_train_new_loop() else: self._run_train_old_loop() - # TODO: remove together with old loop + # TODO(@awaelchli): remove together with old loop def _should_skip_training(self) -> bool: should_by_max_steps = self.max_steps is not None and self.global_step >= self.max_steps should_by_epoch = self.max_epochs is not None and self.current_epoch >= self.max_epochs @@ -935,7 +922,6 @@ def _run_train_new_loop(self) -> None: # reload data when needed model = self.lightning_module - # TODO: This might move somewhere else self.reset_train_val_dataloaders(model) try: From a3ef0aaf27637af58f2227b40db494ada041cc07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 10 Jun 2021 02:21:19 +0200 Subject: [PATCH 411/455] re-add notebooks --- notebooks | 1 + 1 file changed, 1 insertion(+) create mode 160000 notebooks diff --git a/notebooks b/notebooks new file mode 160000 index 0000000000000..aeae8085b4833 --- /dev/null +++ b/notebooks @@ -0,0 +1 @@ +Subproject commit aeae8085b48339e9bd9ab61d81cc0dc8b0d48f9c From 53deef8076a5488d4d7e73d1ec1745dcf6e0877d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 10 Jun 2021 10:52:52 +0200 Subject: [PATCH 412/455] add missing init --- .../trainer/connectors/logger_connector/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/__init__.py b/pytorch_lightning/trainer/connectors/logger_connector/__init__.py index e69de29bb2d1d..4034840a09b97 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/__init__.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/__init__.py @@ -0,0 +1 @@ +from pytorch_lightning.trainer.connectors.logger_connector.logger_connector import LoggerConnector From 93fd682a1192238e6f10e534ad218308315017ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 10 Jun 2021 11:17:47 +0200 Subject: [PATCH 413/455] bad merge --- .../trainer/connectors/logger_connector/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/__init__.py b/pytorch_lightning/trainer/connectors/logger_connector/__init__.py index 4034840a09b97..f14e20f232533 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/__init__.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/__init__.py @@ -1 +1 @@ -from pytorch_lightning.trainer.connectors.logger_connector.logger_connector import LoggerConnector +from pytorch_lightning.trainer.connectors.logger_connector.logger_connector import LoggerConnector # noqa: F401 From a041b6fafee4cb2b45c81e711bef185c1dc461dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 10 Jun 2021 11:42:46 +0200 Subject: [PATCH 414/455] remove iteration count method --- pytorch_lightning/loops/base.py | 5 +---- pytorch_lightning/loops/training_epoch_loop.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index c0128a449210d..9df72bfc22d43 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -47,7 +47,7 @@ def run(self, *args: Any, **kwargs: Any) -> Any: self.on_advance_start(*args, **kwargs) self.advance(*args, **kwargs) self.on_advance_end() - self.iteration_count = self.increment_iteration(self.iteration_count) + self.iteration_count += 1 except StopIteration: break @@ -68,6 +68,3 @@ def on_advance_end(self) -> None: def on_run_end(self) -> Any: pass - - def increment_iteration(self, iteration: int) -> int: - return iteration + 1 diff --git a/pytorch_lightning/loops/training_epoch_loop.py b/pytorch_lightning/loops/training_epoch_loop.py index a3d2e6f03f9fd..68ce480a177ed 100644 --- a/pytorch_lightning/loops/training_epoch_loop.py +++ b/pytorch_lightning/loops/training_epoch_loop.py @@ -76,7 +76,7 @@ def run(self, *args, **kwargs): self.on_advance_start(*args, **kwargs) self.advance(*args, **kwargs) self.on_advance_end() - self.iteration_count = self.increment_iteration(self.iteration_count) + self.iteration_count += 1 except StopIteration: break From e080be83250567b2d6dc20e2a9810271ae2775e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 10 Jun 2021 11:46:21 +0200 Subject: [PATCH 415/455] todo for a fix in #5007 --- pytorch_lightning/loops/fit_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 7ed2efbbc7993..f3e36c25f3d3a 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -209,6 +209,7 @@ def on_run_end(self): # NOTE: the iteration_count/current_epoch is already incremented # Lightning today does not increment the current epoch at the last epoch run in Trainer.fit # To simulate that current behavior, we decrement here. + # TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007 self.current_epoch -= 1 # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates From c56adc17a77d52c7049c424dabeeb883786aa49e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 10 Jun 2021 13:54:23 +0200 Subject: [PATCH 416/455] remove NEW_LOOP guard --- pytorch_lightning/trainer/properties.py | 5 +- pytorch_lightning/trainer/trainer.py | 100 +----------------------- 2 files changed, 4 insertions(+), 101 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 1f1977badb2e2..cb0fa350e07db 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -39,7 +39,6 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop from pytorch_lightning.trainer.states import RunningStage, TrainerState, TrainerStatus -from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn from pytorch_lightning.utilities.argparse import ( add_argparse_args, @@ -65,7 +64,7 @@ class TrainerProperties(ABC): logger: LightningLoggerBase logger_connector: LoggerConnector state: TrainerState - train_loop: TrainLoop + train_loop: FitLoop evaluation_loop: EvaluationLoop """ Accelerator properties @@ -524,7 +523,7 @@ def min_steps(self) -> Optional[int]: return self.train_loop.min_steps @property - def _active_loop(self) -> Optional[Union[TrainLoop, EvaluationLoop]]: + def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop]]: if self.training: return self.train_loop elif self.sanity_checking or self.evaluating: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2bbfee820f770..3b87ec5d0c21a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -62,7 +62,6 @@ from pytorch_lightning.trainer.predict_loop import PredictLoop from pytorch_lightning.trainer.properties import TrainerProperties from pytorch_lightning.trainer.states import TrainerFn, TrainerState, TrainerStatus -from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.lr_finder import _LRFinder from pytorch_lightning.tuner.tuning import Tuner @@ -81,8 +80,6 @@ 'please use torch.distributed.ReduceOp instead' ) -_NEW_LOOP = True - class Trainer( TrainerProperties, @@ -338,13 +335,8 @@ def __init__( self.slurm_connector = SLURMConnector(self) self.tuner = Tuner(self) - if _NEW_LOOP: - self.train_loop = FitLoop(min_epochs, max_epochs, min_steps, max_steps) - self.train_loop.connect(self) - else: - # old loops: - self.train_loop = TrainLoop(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps) - + self.train_loop = FitLoop(min_epochs, max_epochs, min_steps, max_steps) + self.train_loop.connect(self) self.evaluation_loop = EvaluationLoop(self) self.predict_loop = PredictLoop(self) @@ -894,18 +886,6 @@ def _pre_training_routine(self): ref_model.on_pretrain_routine_end() def _run_train(self) -> None: - if _NEW_LOOP: - self._run_train_new_loop() - else: - self._run_train_old_loop() - - # TODO(@awaelchli): remove together with old loop - def _should_skip_training(self) -> bool: - should_by_max_steps = self.max_steps is not None and self.global_step >= self.max_steps - should_by_epoch = self.max_epochs is not None and self.current_epoch >= self.max_epochs - return should_by_max_steps or should_by_epoch or self.num_training_batches == 0 - - def _run_train_new_loop(self) -> None: self._pre_training_routine() if not self.is_global_zero and self.progress_bar_callback is not None: @@ -943,82 +923,6 @@ def _run_train_new_loop(self) -> None: self.state.stage = None raise - def _run_train_old_loop(self) -> None: - - self._pre_training_routine() - - if not self.is_global_zero and self.progress_bar_callback is not None: - self.progress_bar_callback.disable() - - self._run_sanity_check(self.lightning_module) - - self.checkpoint_connector.has_trained = False - - # enable train mode - self.model.train() - torch.set_grad_enabled(True) - - # reload data when needed - model = self.lightning_module - self.train_loop.reset_train_val_dataloaders(model) - - # hook - self.call_hook("on_train_start") - - try: - if self._should_skip_training(): - return - # run all epochs - epochs = range(self.current_epoch, self.max_epochs) if self.max_epochs else count(self.current_epoch) - for epoch in epochs: - - # hook - self.train_loop.on_train_epoch_start(epoch) - - with self.profiler.profile("run_training_epoch"): - # run train epoch - self.train_loop.run_training_epoch() - - if self.max_steps and self.max_steps <= self.global_step: - self.train_loop.on_train_end() - return - - # early stopping - met_min_epochs = (epoch >= self.min_epochs - 1) if self.min_epochs else True - met_min_steps = self.global_step >= self.min_steps if self.min_steps else True - - if self.should_stop: - if met_min_epochs and met_min_steps: - self.train_loop.on_train_end() - return - else: - log.info( - 'Trainer was signaled to stop but required minimum epochs' - f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' - ' not been met. Training will continue...' - ) - self.should_stop = False - - # hook - self.train_loop.on_train_end() - - except KeyboardInterrupt: - rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...') - # user could press Ctrl+c many times... only shutdown once - if not self.interrupted: - self.state.status = TrainerStatus.INTERRUPTED - self.on_keyboard_interrupt() - # same treatment as below - self.accelerator.on_train_end() - self.state.stage = None - except BaseException: - self.state.status = TrainerStatus.INTERRUPTED - # give accelerators a chance to finish - self.accelerator.on_train_end() - # reset bookkeeping - self.state.stage = None - raise - def _run_evaluation(self) -> _EVALUATE_OUTPUT: if not (self.evaluating or self.sanity_checking): rank_zero_warn( From bace4a2aab67028b2bd94c5778b1aebdb8dd1af5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 10 Jun 2021 14:10:10 +0200 Subject: [PATCH 417/455] flake8 --- pytorch_lightning/loops/__init__.py | 8 ++++---- pytorch_lightning/loops/training_batch_loop.py | 3 +-- pytorch_lightning/trainer/trainer.py | 1 - 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loops/__init__.py b/pytorch_lightning/loops/__init__.py index c9e25f212af2f..64de11151dc32 100644 --- a/pytorch_lightning/loops/__init__.py +++ b/pytorch_lightning/loops/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.loops.base import Loop -from pytorch_lightning.loops.fit_loop import FitLoop -from pytorch_lightning.loops.training_batch_loop import TrainingBatchLoop -from pytorch_lightning.loops.training_epoch_loop import TrainingEpochLoop +from pytorch_lightning.loops.base import Loop # noqa: F401 +from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401 +from pytorch_lightning.loops.training_batch_loop import TrainingBatchLoop # noqa: F401 +from pytorch_lightning.loops.training_epoch_loop import TrainingEpochLoop # noqa: F401 diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index 25a7462dd9b49..4f9142b1ac84e 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -313,8 +313,7 @@ def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): def track_and_norm_grad(self, optimizer) -> dict: # track gradient norms grad_norm_dict = {} - if (self.trainer.global_step - + 1) % self.trainer.log_every_n_steps == 0 and float(self.trainer.track_grad_norm) > 0: + if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 and float(self.trainer.track_grad_norm) > 0: grad_norm_dict = grad_norm(self.trainer.lightning_module, self.trainer.track_grad_norm) # clip gradients diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3b87ec5d0c21a..4a3d0ec11eced 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -15,7 +15,6 @@ import logging import warnings from datetime import timedelta -from itertools import count from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Union from weakref import proxy From 71bfb6fab37c6d35024e9349f66b93d988d50eaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 10 Jun 2021 14:17:43 +0200 Subject: [PATCH 418/455] exclude coverage --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index f12f3e2a03fe0..accd6a7f79abb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,6 +45,7 @@ exclude_lines = omit = pytorch_lightning/cluster_environments/*.py pytorch_lightning/utilities/distributed.py + pytorch_lightning/trainer/training_loop.py pytorch_lightning/tuner/auto_gpu_select.py From 41e0e6408bc8f6e73837bf8ab52adb98f015aeb0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Jun 2021 12:29:08 +0000 Subject: [PATCH 419/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/training_batch_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index 4f9142b1ac84e..25a7462dd9b49 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -313,7 +313,8 @@ def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): def track_and_norm_grad(self, optimizer) -> dict: # track gradient norms grad_norm_dict = {} - if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 and float(self.trainer.track_grad_norm) > 0: + if (self.trainer.global_step + + 1) % self.trainer.log_every_n_steps == 0 and float(self.trainer.track_grad_norm) > 0: grad_norm_dict = grad_norm(self.trainer.lightning_module, self.trainer.track_grad_norm) # clip gradients From 536574a3bf14a21f48e334737c78823f33e0883e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 10 Jun 2021 20:38:33 +0200 Subject: [PATCH 420/455] integrate #7917, remove teardown from training loop --- pytorch_lightning/loops/fit_loop.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index f3e36c25f3d3a..0d74d1a3ad628 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -38,8 +38,6 @@ def __init__( max_steps: Optional[int] = None ): super().__init__() - self._teardown_already_run = False - # If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000 self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs # If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1 @@ -202,10 +200,6 @@ def on_advance_end(self): self.global_step += 1 def on_run_end(self): - if self._teardown_already_run: - return - self._teardown_already_run = True - # NOTE: the iteration_count/current_epoch is already incremented # Lightning today does not increment the current epoch at the last epoch run in Trainer.fit # To simulate that current behavior, we decrement here. From b28fb09b201110dc5f92686838273e07ba394cb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 11 Jun 2021 10:05:42 +0200 Subject: [PATCH 421/455] update "accumulated_batches_reached" condition based on if iter count was updated or not --- pytorch_lightning/loops/training_batch_loop.py | 17 +++++++++++++++-- pytorch_lightning/loops/training_epoch_loop.py | 4 ++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index ac8306a88bb53..e2529ce864c33 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -323,9 +323,22 @@ def track_and_norm_grad(self, optimizer) -> dict: ) return grad_norm_dict - def _accumulated_batches_reached(self): + def _accumulated_batches_reached_before_iter_count_update(self): + """ + Determine if accumulation will be finished by the end of the current batch. + This returns the correct answer only if the loop iteration count has not yet been updated for the current batch + in progress. Use `_accumulated_batches_reached_after_iter_count_update` otherwise. + """ # TODO(@awaelchli): use progress tracking of batches instead of iteration count, because iteration count may # reset iteration count is required to be global here, not reset + return (self.iteration_count + 1) % self.trainer.accumulate_grad_batches == 0 + + def _accumulated_batches_reached_after_iter_count_update(self): + """ + Determine if accumulation has finished. + This returns the correct answer only right after a batch has ended, i.e., the iteration count was just updated. + Use `_accumulated_batches_reached_after_iter_count_update` otherwise. + """ return self.iteration_count % self.trainer.accumulate_grad_batches == 0 def _num_training_batches_reached(self, is_last_batch=False): @@ -335,7 +348,7 @@ def _num_training_batches_reached(self, is_last_batch=False): def should_accumulate(self): # checks if backward or backward + optimizer step (via closure) - accumulation_done = self._accumulated_batches_reached() + accumulation_done = self._accumulated_batches_reached_before_iter_count_update() is_final_batch = self._num_training_batches_reached() return not (accumulation_done or is_final_batch) diff --git a/pytorch_lightning/loops/training_epoch_loop.py b/pytorch_lightning/loops/training_epoch_loop.py index 68ce480a177ed..450e3ae323511 100644 --- a/pytorch_lightning/loops/training_epoch_loop.py +++ b/pytorch_lightning/loops/training_epoch_loop.py @@ -328,7 +328,7 @@ def _prepare_outputs( def update_lr_schedulers(self, interval: str) -> None: if interval == "step": - finished_accumulation = self.batch_loop._accumulated_batches_reached() + finished_accumulation = self.batch_loop._accumulated_batches_reached_before_iter_count_update() finished_epoch = self._num_training_batches_reached() if not finished_accumulation and not finished_epoch: return @@ -338,7 +338,7 @@ def update_lr_schedulers(self, interval: str) -> None: ) def increment_accumulated_grad_global_step(self): - num_accumulated_batches_reached = self.batch_loop._accumulated_batches_reached() + num_accumulated_batches_reached = self.batch_loop._accumulated_batches_reached_after_iter_count_update() num_training_batches_reached = self._num_training_batches_reached() # progress global step according to grads progress From 6f17688385923d3ea9d72818db03826f0c272965 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 11 Jun 2021 11:18:27 +0200 Subject: [PATCH 422/455] remove public loop properties --- pytorch_lightning/core/lightning.py | 4 ++-- pytorch_lightning/trainer/properties.py | 16 +++------------- pytorch_lightning/trainer/trainer.py | 6 +++--- tests/deprecated_api/test_remove_1-5.py | 2 +- tests/trainer/loops/test_evaluation_loop_flow.py | 8 ++++---- .../loops/test_training_loop_flow_scalar.py | 12 ++++++------ tests/trainer/test_trainer.py | 16 ++++++++-------- 7 files changed, 27 insertions(+), 37 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 1d91c1f28f479..bc070b25e7b4e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1346,7 +1346,7 @@ def training_step(...): # backward self._running_manual_backward = True - self.trainer.training_batch_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs) + self.trainer.fit_loop.training_loop.batch_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs) self._running_manual_backward = False def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None: @@ -1445,7 +1445,7 @@ def optimizer_step( If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter to ``optimizer.step()`` function as shown in the examples. This ensures that ``training_step()``, ``optimizer.zero_grad()``, ``backward()`` are called within - :meth:`~pytorch_lightning.trainer.training_batch_loop.TrainingBatchLoop.advance`. + :meth:`~pytorch_lightning.trainer.fit_loop.training_loop.batch_loop.TrainingBatchLoop.advance`. Args: epoch: Current epoch diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index cb0fa350e07db..2f95a50219bfa 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -30,8 +30,6 @@ from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger from pytorch_lightning.loops.fit_loop import FitLoop -from pytorch_lightning.loops.training_batch_loop import TrainingBatchLoop -from pytorch_lightning.loops.training_epoch_loop import TrainingEpochLoop from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector @@ -64,7 +62,7 @@ class TrainerProperties(ABC): logger: LightningLoggerBase logger_connector: LoggerConnector state: TrainerState - train_loop: FitLoop + fit_loop: FitLoop evaluation_loop: EvaluationLoop """ Accelerator properties @@ -486,17 +484,9 @@ def sanity_checking(self, val: bool) -> None: """ @property - def fit_loop(self) -> FitLoop: + def train_loop(self) -> FitLoop: # TODO(@awaelchli): the current train_loop should be renamed to fit_loop - return self.train_loop - - @property - def training_epoch_loop(self) -> TrainingEpochLoop: - return self.fit_loop.training_loop - - @property - def training_batch_loop(self) -> TrainingBatchLoop: - return self.fit_loop.training_loop.batch_loop + return self.fit_loop @property def global_step(self) -> int: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4a3d0ec11eced..4e30f8b353c00 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -334,8 +334,8 @@ def __init__( self.slurm_connector = SLURMConnector(self) self.tuner = Tuner(self) - self.train_loop = FitLoop(min_epochs, max_epochs, min_steps, max_steps) - self.train_loop.connect(self) + self.fit_loop = FitLoop(min_epochs, max_epochs, min_steps, max_steps) + self.fit_loop.connect(self) self.evaluation_loop = EvaluationLoop(self) self.predict_loop = PredictLoop(self) @@ -904,7 +904,7 @@ def _run_train(self) -> None: self.reset_train_val_dataloaders(model) try: - self.train_loop.run() + self.fit_loop.run() except KeyboardInterrupt: rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...') # user could press Ctrl+c many times... only shutdown once diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index dcf5138e295eb..f7e87059b62f6 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -242,7 +242,7 @@ def on_train_epoch_end(self, outputs): # noqa with pytest.deprecated_call(match="old signature will be removed in v1.5"): trainer.fit(model) - trainer.training_epoch_loop.warning_cache.clear() + trainer.fit_loop.training_loop.warning_cache.clear() class NewSignature(Callback): diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py index 8896a4d186d34..c9eb997c98dd6 100644 --- a/tests/trainer/loops/test_evaluation_loop_flow.py +++ b/tests/trainer/loops/test_evaluation_loop_flow.py @@ -69,7 +69,7 @@ def backward(self, loss, optimizer, optimizer_idx): # simulate training manually trainer.state.stage = RunningStage.TRAINING batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.training_batch_loop.run(batch, batch_idx, 0) + out = trainer.fit_loop.training_loop.batch_loop.run(batch, batch_idx, 0) assert out.signal == 0 train_step_out = out.training_step_output @@ -79,7 +79,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.training_batch_loop.training_step_and_backward( + opt_closure_result = trainer.fit_loop.training_loop.batch_loop.training_step_and_backward( batch, batch_idx, 0, @@ -140,7 +140,7 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.training_batch_loop.run(batch, batch_idx, 0) + out = trainer.fit_loop.training_loop.batch_loop.run(batch, batch_idx, 0) assert out.signal == 0 train_step_out = out.training_step_output @@ -150,7 +150,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.training_batch_loop.training_step_and_backward( + opt_closure_result = trainer.fit_loop.training_loop.batch_loop.training_step_and_backward( batch, batch_idx, 0, trainer.optimizers[0], hiddens=None ) assert opt_closure_result['loss'].item() == 171 diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index 0f6ea5eb8c7ba..0e57797a80890 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -149,7 +149,7 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.training_batch_loop.run(batch, batch_idx, 0) + out = trainer.fit_loop.training_loop.batch_loop.run(batch, batch_idx, 0) assert out.signal == 0 train_step_out = out.training_step_output @@ -159,7 +159,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.training_batch_loop.training_step_and_backward( + opt_closure_result = trainer.fit_loop.training_loop.batch_loop.training_step_and_backward( batch, batch_idx, 0, @@ -227,7 +227,7 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.training_batch_loop.run(batch, batch_idx, 0) + out = trainer.fit_loop.training_loop.batch_loop.run(batch, batch_idx, 0) assert out.signal == 0 train_step_out = out.training_step_output @@ -237,7 +237,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.training_batch_loop.training_step_and_backward( + opt_closure_result = trainer.fit_loop.training_loop.batch_loop.training_step_and_backward( batch, batch_idx, 0, trainer.optimizers[0], hiddens=None ) assert opt_closure_result['loss'].item() == 171 @@ -313,7 +313,7 @@ def training_step(self, batch, batch_idx): # manually check a few batches for batch_idx, batch in enumerate(model.train_dataloader()): - out = trainer.training_batch_loop.run(batch, batch_idx, 0) + out = trainer.fit_loop.training_loop.batch_loop.run(batch, batch_idx, 0) if not batch_idx % 2: assert out.training_step_output == [[]] assert out.signal == 0 @@ -358,7 +358,7 @@ def train_dataloader(self): # manually check a few batches for batch_idx, batch in enumerate(model.train_dataloader()): - out = trainer.training_batch_loop.run(batch, batch_idx, 0) + out = trainer.fit_loop.training_loop.batch_loop.run(batch, batch_idx, 0) if not batch_idx % 2: assert out.training_step_output == [[]] assert out.signal == 0 diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index ed4325979bb7b..fa25224527e67 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -894,7 +894,7 @@ def test_gradient_clipping(tmpdir): default_root_dir=tmpdir, ) - old_training_step_and_backward = trainer.training_batch_loop.training_step_and_backward + old_training_step_and_backward = trainer.fit_loop.training_loop.batch_loop.training_step_and_backward def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): """ @@ -908,7 +908,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde return ret_val - trainer.training_batch_loop.training_step_and_backward = training_step_and_backward + trainer.fit_loop.training_loop.batch_loop.training_step_and_backward = training_step_and_backward # for the test model.prev_called_batch_idx = 0 @@ -932,7 +932,7 @@ def test_gradient_clipping_by_value(tmpdir): default_root_dir=tmpdir ) - old_training_step_and_backward = trainer.training_batch_loop.training_step_and_backward + old_training_step_and_backward = trainer.fit_loop.training_loop.batch_loop.training_step_and_backward def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): """ @@ -948,7 +948,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde return ret_val - trainer.training_batch_loop.training_step_and_backward = training_step_and_backward + trainer.fit_loop.training_loop.batch_loop.training_step_and_backward = training_step_and_backward # for the test model.prev_called_batch_idx = 0 @@ -973,7 +973,7 @@ def test_gradient_clipping_fp16(tmpdir): default_root_dir=tmpdir, ) - old_training_step_and_backward = trainer.training_batch_loop.training_step_and_backward + old_training_step_and_backward = trainer.fit_loop.training_loop.batch_loop.training_step_and_backward def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): """ @@ -987,7 +987,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde return ret_val - trainer.training_batch_loop.training_step_and_backward = training_step_and_backward + trainer.fit_loop.training_loop.batch_loop.training_step_and_backward = training_step_and_backward model.prev_called_batch_idx = 0 trainer.fit(model) @@ -1012,7 +1012,7 @@ def test_gradient_clipping_by_value_fp16(tmpdir): default_root_dir=tmpdir, ) - old_training_step_and_backward = trainer.training_batch_loop.training_step_and_backward + old_training_step_and_backward = trainer.fit_loop.training_loop.batch_loop.training_step_and_backward def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): """ @@ -1028,7 +1028,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde return ret_val - trainer.training_batch_loop.training_step_and_backward = training_step_and_backward + trainer.fit_loop.training_loop.batch_loop.training_step_and_backward = training_step_and_backward model.prev_called_batch_idx = 0 trainer.fit(model) From 6dd4e1d35249a97ae9bdb722dd7f9031a858e5e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 11 Jun 2021 11:24:16 +0200 Subject: [PATCH 423/455] make skip backward protected again --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 4 ++-- pytorch_lightning/loops/fit_loop.py | 10 +++++----- pytorch_lightning/loops/training_batch_loop.py | 12 +----------- tests/callbacks/test_stochastic_weight_avg.py | 4 ++-- 4 files changed, 10 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 236145d00f4a8..3ec7774d5f8b6 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -220,12 +220,12 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo # performing only one pass over the train data-loader to compute activation statistics # Therefore, we will virtually increase `num_training_batches` by 1 and skip backward. trainer.num_training_batches += 1 - trainer.train_loop.skip_backward = True + trainer.train_loop._skip_backward = True self._accumulate_grad_batches = trainer.accumulate_grad_batches trainer.accumulate_grad_batches = len(trainer.train_dataloader) def on_train_epoch_end(self, trainer: 'pl.Trainer', *args): - trainer.train_loop.skip_backward = False + trainer.train_loop._skip_backward = False def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): if self._model_contains_batch_norm and trainer.current_epoch == self.swa_end + 1: diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 0d74d1a3ad628..c9e3117a4f8ef 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -92,14 +92,14 @@ def running_loss(self): return self.training_loop.batch_loop.running_loss @property - def skip_backward(self) -> bool: + def _skip_backward(self) -> bool: """ Determines whether the loop will skip backward during automatic optimization. """ - return self.training_loop.batch_loop.skip_backward + return self.training_loop.batch_loop._skip_backward - @skip_backward.setter - def skip_backward(self, value: bool): + @_skip_backward.setter + def _skip_backward(self, value: bool): """ Determines whether the loop will skip backward during automatic optimization. """ - self.training_loop.batch_loop.skip_backward = value + self.training_loop.batch_loop._skip_backward = value @property def done(self) -> bool: diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index e2529ce864c33..6ca33b8936acf 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -53,16 +53,6 @@ def __init__(self): def done(self): return len(self._remaining_splits) == 0 - @property - def skip_backward(self) -> bool: - """ Determines whether the loop will skip backward during automatic optimization. """ - return self._skip_backward - - @skip_backward.setter - def skip_backward(self, value: bool): - """ Determines whether the loop will skip backward during automatic optimization. """ - self._skip_backward = value - @property def optimizer_freq_cumsum(self): if self._optimizer_freq_cumsum is None: @@ -424,7 +414,7 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, # lightning module hook result = self.training_step(split_batch, batch_idx, opt_idx, hiddens) - if not self.skip_backward and self.trainer.lightning_module.automatic_optimization: + if not self._skip_backward and self.trainer.lightning_module.automatic_optimization: is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0 if is_first_batch_to_accumulate: diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index b8bb5e220eda9..81efc12b34662 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -74,7 +74,7 @@ def transfer_weights(self, *args, **kwargs): def on_train_epoch_start(self, trainer, *args): super().on_train_epoch_start(trainer, *args) - assert trainer.train_loop.skip_backward == (trainer.current_epoch > self.swa_end) + assert trainer.train_loop._skip_backward == (trainer.current_epoch > self.swa_end) if self.swa_start <= trainer.current_epoch: assert isinstance(trainer.lr_schedulers[0]["scheduler"], SWALR) assert trainer.lr_schedulers[0]["interval"] == "epoch" @@ -92,7 +92,7 @@ def on_train_end(self, trainer, pl_module): super().on_train_end(trainer, pl_module) # make sure these are correctly set again - assert not trainer.train_loop.skip_backward + assert not trainer.train_loop._skip_backward assert trainer.accumulate_grad_batches == 2 assert trainer.num_training_batches == 5 From c3942675a6025e1589d410fbfc165fd3febdb82f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 11 Jun 2021 14:49:44 +0200 Subject: [PATCH 424/455] typing base loop Co-authored-by: Justus Schock --- pytorch_lightning/loops/base.py | 63 ++++++++++++++++++++++++++++----- 1 file changed, 54 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 9df72bfc22d43..be74a9f8d7d6d 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -16,12 +16,35 @@ from typing import Any, Optional from weakref import proxy +from deprecate import void + import pytorch_lightning as pl class Loop(ABC): + """ + Basic Loops interface. All classes derived from this must implement the following properties and methods: + + * :attr`done` (property): Condition to break the loop + * :attr`reset` (method): Resets the internal state between multiple calls of :attr`run` + * :attr`advance` (method): Implements one step of the loop + + This class implements the following loop structure: + + .. codeblock:: python + + on_run_start() + + while not done: + on_advance_start() + advance() + on_advance_end() + + on_run_end() - def __init__(self): + """ + + def __init__(self) -> None: self.iteration_count: int = 0 self.trainer: Optional['pl.Trainer'] = None @@ -30,15 +53,24 @@ def __init__(self): def done(self) -> bool: """Property indicating when loop is finished""" - def connect(self, trainer, *args, **kwargs) -> None: - """Connects Loop with all the necessary things like connectors and accelerators""" + def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: + """Connects Loop with all the necessary things like connectors and accelerators.""" self.trainer = proxy(trainer) @abstractmethod def reset(self) -> None: - pass + """Resets the internal state of the loop at the beginning of each call to :attr:`run`.""" def run(self, *args: Any, **kwargs: Any) -> Any: + """ + The main entry point to the loop. + + Will frequently check the :attr:`done` condition and calls :attr:`advance` + until :attr`done` evaluates to ``True``. + + Returns: + the output of :attr`on_run_end` (often outputs collected from each step of the loop) + """ self.reset() self.on_run_start(*args, **kwargs) @@ -54,17 +86,30 @@ def run(self, *args: Any, **kwargs: Any) -> Any: return self.on_run_end() def on_run_start(self, *args: Any, **kwargs: Any) -> None: - pass + """ + Hook to be called as the first thing after entering :attr:`run` (except the state reset). + + Accepts all arguments passed to :attr:`run`. + + """ + void(*args, **kwargs) def on_advance_start(self, *args: Any, **kwargs: Any) -> None: - pass + """ + Hook to be called each time before :attr:`advance` is called. Accepts all arguments passed to :attr`run`. + + """ + void(*args, **kwargs) @abstractmethod def advance(self, *args: Any, **kwargs: Any) -> None: - """What to do within a single step""" + """ + Performs a single step. Accepts all arguments passed to :attr:`run`. + + """ def on_advance_end(self) -> None: - pass + """Hook to be called each time after :attr:`advance` is called.""" def on_run_end(self) -> Any: - pass + """Hook to be called at the end of the run. Its return argument is returned from :attr:`run`. From 4adae06a0ed97a8354adb48a237034d1e7f25d54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 11 Jun 2021 14:58:17 +0200 Subject: [PATCH 425/455] typing fit loop Co-authored-by: Justus Schock --- pytorch_lightning/loops/fit_loop.py | 89 ++++++++++++++++++++++------- 1 file changed, 68 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index c9e3117a4f8ef..88cc809889dad 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -14,8 +14,9 @@ import logging from contextlib import suppress -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Tuple +from deprecate import void from torch.optim import Optimizer import pytorch_lightning as pl @@ -29,6 +30,18 @@ class FitLoop(Loop): + """This Loop iterates over the epochs to run the training + + Args: + min_epochs: The minimum number of epochs + max_epochs: The maximum number of epochs + min_steps: The minimum number of steps + max_steps: The maximum number of epoch + + .. note:: + If neither the minimum epochs nor steps are specified the minimum number of epochs is set to 1 + and if neither the maximum steps nor epochs are specified, the maximum epochs are set to 1000. + """ def __init__( self, @@ -48,47 +61,61 @@ def __init__( @property def current_epoch(self) -> int: + """Return the current epoch""" return self.iteration_count @current_epoch.setter - def current_epoch(self, value: int): + def current_epoch(self, value: int) -> None: + """Setter for the current epoch + """ self.iteration_count = value @property - def global_step(self): + def global_step(self) -> int: + """Returns the global step""" return self.training_loop.global_step @global_step.setter - def global_step(self, value): + def global_step(self, value: int) -> None: + """Sets the global step (forwards to training_loop) + """ self.training_loop.global_step = value @property - def total_batch_idx(self): + def total_batch_idx(self) -> int: + """Returns the total number of batches already run (across all epochs)""" return self.training_loop.total_batch_idx @property - def batch_idx(self): + def batch_idx(self) -> int: + """Returns the number of batches already run within this epoch""" return self.training_loop.iteration_count @property - def split_idx(self): + def split_idx(self) -> int: + """Returns the index of the current batch split (within the current batch) for bptt""" return self.training_loop.split_idx @property - def min_steps(self): + def min_steps(self) -> int: + # TODO(@justusschock): Why aren't we using the attribute in this class? + """Returns the minimum numnber of steps to run""" return self.training_loop.min_steps @property - def max_steps(self): + def max_steps(self) -> int: + """Returns the maximum number of steps to run""" return self.training_loop.max_steps @max_steps.setter - def max_steps(self, value): + def max_steps(self, value: int) -> None: + """Sets the maximum number of steps (forwards to training_loop)""" # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided self.training_loop.max_steps = value @property - def running_loss(self): + def running_loss(self) -> TensorRunningAccum: + """Returns the running loss""" return self.training_loop.batch_loop.running_loss @property @@ -97,12 +124,17 @@ def _skip_backward(self) -> bool: return self.training_loop.batch_loop._skip_backward @_skip_backward.setter - def _skip_backward(self, value: bool): + def _skip_backward(self, value: bool) -> None: """ Determines whether the loop will skip backward during automatic optimization. """ self.training_loop.batch_loop._skip_backward = value @property def done(self) -> bool: + """Evaluates when to leave the loop. + + Returns True if trainer.should_stop was set (e.g. by early stopping) + or if the maximum number of steps or epochs is reached. + """ # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop stop_steps = self.max_steps is not None and self.global_step >= self.max_steps stop_epochs = self.max_epochs is not None and self.current_epoch >= self.max_epochs @@ -124,22 +156,30 @@ def done(self) -> bool: return stop_steps or should_stop or stop_epochs - def connect(self, trainer: 'pl.Trainer', *args, **kwargs): + def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: + """Connects the loop with necessary arguments like the trainer""" + # TODO(@justusschock): Do we want to forward *args and **kwargs to the inner loop here? + # TODO(@justusschock): Can we make the trainer a weakref/proxy? + void(*args, **kwargs) self.trainer = trainer self.training_loop.connect(trainer) def reset(self) -> None: + """Resets the trainer's internal state""" self.iteration_count = 0 - def run(self): + def run(self) -> None: + """Loops over epochs if the training should not be skipped""" if not self._should_skip_training(): return super().run() - def on_run_start(self): + def on_run_start(self) -> None: + """Calls the ``on_train_start`` hook.""" self.results.to(device=self.trainer.lightning_module.device) self.trainer.call_hook("on_train_start") - def on_advance_start(self): + def on_advance_start(self) -> None: + """Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and ``on_train_epoch_start``""" model = self.trainer.lightning_module # reset train dataloader @@ -164,7 +204,8 @@ def on_advance_start(self): self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") - def advance(self): + def advance(self) -> None: + """Runs one whole epoch.""" train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) @@ -185,7 +226,8 @@ def advance(self): self.trainer.logger_connector.update_train_epoch_metrics() self.global_step += 1 - def on_advance_end(self): + def on_advance_end(self) -> None: + """Updates the LR schedulers and does some internal bookkeeping""" if self.training_loop.batches_seen == 0: return @@ -199,7 +241,8 @@ def on_advance_end(self): self.check_checkpoint_callback(True) self.global_step += 1 - def on_run_end(self): + def on_run_end(self) -> None: + """Runs teardown logic and calls the ``on_train_end`` hook""" # NOTE: the iteration_count/current_epoch is already incremented # Lightning today does not increment the current epoch at the last epoch run in Trainer.fit # To simulate that current behavior, we decrement here. @@ -232,15 +275,19 @@ def on_run_end(self): self.trainer._running_stage = None def _should_skip_training(self) -> bool: + """Whether we should skip the training""" return self.done or self.trainer.num_training_batches == 0 - def should_accumulate(self): + def should_accumulate(self) -> bool: + """Whether the gradients should be accumulated""" return self.training_loop.batch_loop.should_accumulate() def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, Optimizer]]: + """Generates a list of active optimizers""" return self.training_loop.batch_loop.get_active_optimizers(batch_idx) - def check_checkpoint_callback(self, should_update, is_last=False): + def check_checkpoint_callback(self, should_update: bool, is_last: bool = False): + """Checks if checkpointing needs to be done""" # TODO: bake this logic into the ModelCheckpoint callback if should_update and self.trainer.checkpoint_connector.has_trained: callbacks = self.trainer.checkpoint_callbacks From c49875d7b51476f329b109b3eba847f2102e04de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 11 Jun 2021 15:00:10 +0200 Subject: [PATCH 426/455] typing training_batch_loop Co-authored-by: Justus Schock --- .../loops/training_batch_loop.py | 272 +++++++++++++++--- 1 file changed, 233 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index 6ca33b8936acf..cacfc216519a8 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -15,54 +15,70 @@ from collections import OrderedDict from contextlib import contextmanager from functools import partial, update_wrapper -from typing import Any, Callable, List, Mapping, Optional, Tuple +from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple import numpy as np import torch +from deprecate import void from torch.optim import Optimizer +import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loops.base import Loop from pytorch_lightning.plugins import ParallelPlugin +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.imports import _TPU_AVAILABLE from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature +from pytorch_lightning.utilities.types import STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache class TrainingBatchLoop(Loop): """ Runs over a single batch of data. """ - def __init__(self): + def __init__(self) -> None: super().__init__() - self.accumulated_loss = None - self.batch_outputs = None - self.running_loss = TensorRunningAccum(window_length=20) - self.split_idx = None - self.warning_cache = WarningCache() + self.accumulated_loss: torch.Tensor = None + self.batch_outputs: Optional[List[List[STEP_OUTPUT]]] = None + self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=20) + self.split_idx: Optional[int] = None + self.warning_cache: WarningCache = WarningCache() - self._hiddens = None - self._optimizer_freq_cumsum = None - self._remaining_splits = None - self._skip_backward = False + self._hiddens: Optional[torch.Tensor] = None + self._optimizer_freq_cumsum: Optional[int] = None + self._remaining_splits: Optional[List[Any]] = None + self._skip_backward: bool = False @property - def done(self): + def done(self) -> bool: + """Returns if all batch splits have been processed already""" return len(self._remaining_splits) == 0 @property - def optimizer_freq_cumsum(self): + def optimizer_freq_cumsum(self) -> int: + """Returns the cumulated sum of optimizer frequencies""" if self._optimizer_freq_cumsum is None: self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) return self._optimizer_freq_cumsum - def connect(self, trainer, *args, **kwargs): + def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: + # TODO(@justusschock): can we make this a weakref/proxy? + void(*args, **kwargs) self.trainer = trainer - def run(self, batch, batch_idx, dataloader_idx): + def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict: + """Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks + + Args: + batch: the current batch to run the train step on + batch_idx: the index of the current batch + dataloader_idx: the index of the dataloader producing the current batch + + """ if batch is None: self.warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") return AttributeDict(signal=0, training_step_output=[[]]) @@ -83,14 +99,32 @@ def run(self, batch, batch_idx, dataloader_idx): return AttributeDict(signal=0, training_step_output=self.batch_outputs) def reset(self) -> None: + """Resets the loop state""" self._hiddens = None # TODO(@awaelchli): let loops track individual outputs self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] - def on_run_start(self, batch, batch_idx, dataloader_idx): + def on_run_start(self, batch: Any, batch_idx: int, dataloader_idx: int): + """Splits the data into tbptt splits + + Args: + batch: the current batch to run the trainstep on + batch_idx: the index of the current batch + dataloader_idx: the index of the dataloader producing the current batch + + """ + void(batch_idx, dataloader_idx) self._remaining_splits = list(enumerate(self.tbptt_split_batch(batch))) def advance(self, batch, batch_idx, dataloader_idx): + """Runs the train step together with optimization (if necessary) on the current batch split + + Args: + batch: the current batch to run the training on (this is not the split!) + batch_idx: the index of the current batch + dataloader_idx: the index of the dataloader producing the current batch + """ + void(batch, dataloader_idx) split_idx, split_batch = self._remaining_splits.pop(0) self.split_idx = split_idx @@ -109,9 +143,21 @@ def advance(self, batch, batch_idx, dataloader_idx): self.batch_outputs[0].append(result.training_step_output) def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int: + """Gets the number of active optimizers based on their frequency""" return len(self.get_active_optimizers(batch_idx)) - def _run_optimization(self, batch_idx, split_batch, opt_idx=0, optimizer=None): + def _run_optimization( + self, batch_idx: int, split_batch: Any, opt_idx: int = 0, optimizer: Optional[torch.optim.Optimizer] = None + ): + """Runs closure (train step + backward) together with optimization if necessary. + + Args: + batch_idx: the index of the current batch + split_batch: the current tbptt split of the whole batch + opt_idx: the index of the current optimizer + optimizer: the current optimizer + + """ # TODO(@awaelchli): In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change # opt_idx=0 to opt_idx=None in the signature here @@ -161,21 +207,37 @@ def training_step_and_backward_closure( batch_idx: int, opt_idx: int, optimizer: Optimizer, - hiddens, + hiddens: torch.Tensor, return_result: AttributeDict, ) -> Optional[torch.Tensor]: + """Closure for training step and backward + + Args: + split_batch: the current tbptt split of the batch + batch_idx: the index of the current batch + opt_idx: the index of the current optimizer + optimizer: the current optimizer + hiddens: the hidden state of the recurrent net + return_result: the storage of the trainstep results + + """ result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) if result is not None: return_result.update(result) return return_result.loss - def make_closure(self, *closure_args, **closure_kwargs: Any) -> Callable: + def make_closure(self, *closure_args: Any, **closure_kwargs: Any) -> Callable: """ Wraps the training step closure into a partial object which will be called within ``optimizer.step``. """ partial_func = partial(self.training_step_and_backward_closure, *closure_args, **closure_kwargs) return update_wrapper(partial_func, self.training_step_and_backward_closure) def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) -> None: + """Checks if the closure results is finite and optionally breaks if it is not + + Args: + opt_closure_result: the result of the train step wrapped in an attribute dict + """ if not opt_closure_result: return @@ -183,14 +245,32 @@ def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) - if self.trainer.terminate_on_nan: self._check_finite(opt_closure_result.loss) - def on_after_backward(self, training_step_output, batch_idx, untouched_loss): + def on_after_backward( + self, training_step_output: STEP_OUTPUT, batch_idx: int, untouched_loss: torch.Tensor + ) -> None: + """Calls ``on_after_backward`` hook and tracks loss history + + Args: + training_step_output: the result from the training step (either a dict with key 'loss' or the loss tensor) + batch_idx: the index of the current batch + untouched_loss: the original loss value + """ + + void(training_step_output) # insert after step hook self.trainer.call_hook("on_after_backward") # when in dev debugging track the losses self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach()) - def _check_training_step_output(self, training_step_output): + def _check_training_step_output(self, training_step_output: STEP_OUTPUT) -> None: + """Sanity checks that training produced a valid output and optimizer step has already been called in manual + optimization. + + Args: + training_step_output: the output of the training step (before wrapping in an AttributeDict) + + """ if isinstance(training_step_output, torch.Tensor) and not self.trainer.lightning_module.automatic_optimization: if training_step_output.grad_fn is None: # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... @@ -206,7 +286,19 @@ def _check_training_step_output(self, training_step_output): "a dict with key 'loss' or None (where the step will be skipped)." ) - def training_step(self, split_batch, batch_idx, opt_idx, hiddens): + def training_step(self, split_batch: Any, batch_idx: int, opt_idx: int, + hiddens: torch.Tensor) -> Optional[AttributeDict]: + """Performs the actual train step with the tied hooks. + + Args: + split_batch: the current tbptt split of the current batch + batch_idx: the index of the current batch + opt_idx: the index of the current optimizer + hiddens: the model's hidden state of the previous iteration + + Returns: + an AttributeDict containing the loss value and the training step output. + """ # give the PL module a result for logging model_ref = self.trainer.lightning_module @@ -236,7 +328,15 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): loss = closure_loss.detach().clone() return AttributeDict(closure_loss=closure_loss, loss=loss, training_step_output=training_step_output) - def _process_training_step_output(self, training_step_output): + def _process_training_step_output(self, training_step_output: STEP_OUTPUT) -> Optional[ResultCollection]: + """Adds the :param:`training_step_output` to the trainer's results + + Args: + training_step_output: the output of the training step (before wrapping into an AttributeDict) + + Returns: + the updated results if the training_step's output was not None else None + """ if training_step_output is None: return None @@ -266,7 +366,19 @@ def _process_training_step_output(self, training_step_output): results.cpu() return results - def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): + def optimizer_step( + self, optimizer: torch.optim.Optimizer, opt_idx: int, batch_idx: int, train_step_and_backward_closure: Callable + ) -> None: + """Performs the optimizer step and some sanity checking. + + Args: + optimizer: the optimizer to perform the step with + opt_idx: the index of the current :param:`optimizer` + batch_idx: the index of the current batch + train_step_and_backward_closure: the closure function performing the train step and computing the + gradients. By default called by the optimizer (if possible) + + """ model_ref = self.trainer.lightning_module is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) @@ -294,17 +406,36 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_ using_lbfgs=is_lbfgs, ) - def on_before_zero_grad(self, optimizer): + def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None: + """Calls the ``on_before_zero_grad`` hook. + + Args: + optimizer: the current optimizer + + """ self.trainer.call_hook('on_before_zero_grad', optimizer) - def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): + def optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None: + """Zeroes out all gradients of parameters optimized by the current optimizer. + + Args: + batch_idx: the index of the current batch + optimizer: the current optimizer + opt_idx: the index of the current optimizer + + """ self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) - def track_and_norm_grad(self, optimizer) -> dict: + def track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, torch.Tensor]: + """Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer. + + Args: + optimizer: the current optimizer + """ # track gradient norms grad_norm_dict = {} - if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 and \ - float(self.trainer.track_grad_norm) > 0: + if ((self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 + and float(self.trainer.track_grad_norm) > 0): grad_norm_dict = grad_norm(self.trainer.lightning_module, self.trainer.track_grad_norm) # clip gradients @@ -313,7 +444,7 @@ def track_and_norm_grad(self, optimizer) -> dict: ) return grad_norm_dict - def _accumulated_batches_reached_before_iter_count_update(self): + def _accumulated_batches_reached_before_iter_count_update(self) -> bool: """ Determine if accumulation will be finished by the end of the current batch. This returns the correct answer only if the loop iteration count has not yet been updated for the current batch @@ -331,18 +462,31 @@ def _accumulated_batches_reached_after_iter_count_update(self): """ return self.iteration_count % self.trainer.accumulate_grad_batches == 0 - def _num_training_batches_reached(self, is_last_batch=False): + def _num_training_batches_reached(self, is_last_batch: bool = False) -> bool: + """Checks whether sufficient training batches have been processed. + + Args: + is_last_batch: Whether the current batch is the last one + + """ # TODO(@awaelchli): use progress tracking of batches instead of iteration count, because iteration # count may reset return (self.iteration_count + 1) == self.trainer.num_training_batches or is_last_batch - def should_accumulate(self): + def should_accumulate(self) -> bool: + """Checks if the optimizer step should be performed or gradients should be accumulated for the current step.""" # checks if backward or backward + optimizer step (via closure) accumulation_done = self._accumulated_batches_reached_before_iter_count_update() is_final_batch = self._num_training_batches_reached() return not (accumulation_done or is_final_batch) - def tbptt_split_batch(self, batch): + def tbptt_split_batch(self, batch: Any) -> List[Any]: + """Splits a single batch into a list of sequence steps for tbptt. + + Args: + batch: the current batch to split + + """ splits = [batch] if self.trainer.truncated_bptt_steps is not None: model_ref = self.trainer.lightning_module @@ -350,7 +494,18 @@ def tbptt_split_batch(self, batch): splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps) return splits - def build_train_args(self, batch, batch_idx, opt_idx, hiddens): + def build_train_args(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: torch.Tensor) -> List[Any]: + """Builds arguments for train step + + Args: + batch: the current batch to train on + batch_idx: the index of the current batch + opt_idx: the index of the current optimizer + hiddens: the hidden state of the previous RNN iteration + + Returns: + the positional arguments for training + """ # enable not needing to add opt_idx to training_step args = [batch, batch_idx] @@ -377,7 +532,14 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens): return args - def run_optimization_start(self, opt_idx, optimizer): + def run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None: + """Toggles the optimizer to ensure the correct one is used and prevend dangling grads. + + Args: + opt_idx: the index of the optimizer to use + optimizer: the optimizer to use + + """ # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1: @@ -385,7 +547,7 @@ def run_optimization_start(self, opt_idx, optimizer): model.toggle_optimizer(optimizer, opt_idx) @contextmanager - def block_ddp_sync_behaviour(self, should_block_sync: bool = False): + def block_ddp_sync_behaviour(self, should_block_sync: bool = False) -> Generator[None, None, None]: """ automatic_optimization = True Blocks ddp sync gradients behaviour on backwards pass. @@ -408,7 +570,10 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False): else: yield None - def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): + def training_step_and_backward( + self, split_batch: Any, batch_idx: int, opt_idx: int, optimizer: torch.optim.Optimizer, + hiddens: Optional[torch.Tensor] + ) -> STEP_OUTPUT: """Wrap forward, zero_grad and backward in a closure so second order methods work""" with self.trainer.profiler.profile("training_step_and_backward"): # lightning module hook @@ -443,12 +608,27 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, return result def _check_finite(self, loss: torch.Tensor) -> None: + """Checks fotr finite parameters and loss values. + + Args: + loss: the loss value to check to be finite + """ if not torch.isfinite(loss).all(): raise ValueError(f'The loss returned in `training_step` is {loss}.') model = self.trainer.lightning_module detect_nan_parameters(model) - def backward(self, result, optimizer, opt_idx, *args, **kwargs): + def backward( + self, result: STEP_OUTPUT, optimizer: torch.optim.Optimizer, opt_idx: int, *args: Any, **kwargs: Any + ) -> None: + """Performs the backward step. + + Args: + result: The output of the trainstep (including the loss value) + optimizer: The optimizer optimizing the gradients to call backward for + opt_idx: the index of the current optimizer + + """ self.trainer.dev_debugger.track_event("backward_call") should_accumulate = self.should_accumulate() @@ -469,6 +649,7 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs): self.trainer.lightning_module.log_grad_norm(grad_norm_dict) def update_running_loss(self, current_loss: torch.Tensor) -> None: + """Updates the running loss value with the current value""" if self.trainer.lightning_module.automatic_optimization: # track total loss for logging (avoid mem leaks) self.accumulated_loss.append(current_loss) @@ -501,7 +682,19 @@ def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[i opt_idx = int(np.argmax(self.optimizer_freq_cumsum > current_place_in_loop)) return [(opt_idx, self.trainer.optimizers[opt_idx])] - def _build_kwargs(self, batch, batch_idx, opt_idx, hiddens): + def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, + hiddens: Optional[torch.Tensor]) -> Dict[str, Any]: + """Builds the keyword arguments for training_step + + Args: + batch: the batch to train on + batch_idx: the index of the current batch + opt_idx: the index of the current optimizer + hiddens: the hidden state of the previous RNN iteration + + Returns: + the keyword arguments for the training step + """ # enable not needing to add opt_idx to training_step step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)]) @@ -535,6 +728,7 @@ def _truncated_bptt_enabled(self) -> bool: return self._truncated_bptt_steps() > 0 def _truncated_bptt_steps(self) -> int: + """Returns the number of tbptt steps""" lightning_module = self.trainer.lightning_module # Give precedence to the LightningModule as the Trainer flag will be removed in v1.5 if lightning_module.truncated_bptt_steps > 0: From 80edb7512aab65f55498f3c2d894ae7a2196d5a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 11 Jun 2021 15:01:00 +0200 Subject: [PATCH 427/455] typing training epoch loop Co-authored-by: Justus Schock --- .../loops/training_epoch_loop.py | 123 +++++++++++++----- 1 file changed, 94 insertions(+), 29 deletions(-) diff --git a/pytorch_lightning/loops/training_epoch_loop.py b/pytorch_lightning/loops/training_epoch_loop.py index 450e3ae323511..7b0f02eb3c1f0 100644 --- a/pytorch_lightning/loops/training_epoch_loop.py +++ b/pytorch_lightning/loops/training_epoch_loop.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Iterator, List, Union +from typing import Any, Dict, Iterator, List, Optional, Union import pytorch_lightning as pl from pytorch_lightning.loops.base import Loop @@ -21,51 +21,67 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature +from pytorch_lightning.utilities.types import STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache class TrainingEpochLoop(Loop): """ Runs over all batches in a dataloader (one epoch). """ - def __init__(self, min_steps, max_steps): + def __init__(self, min_steps: int, max_steps: int): super().__init__() - self.min_steps = min_steps - self.max_steps = max_steps + self.min_steps: int = min_steps + self.max_steps: int = max_steps - self.global_step = 0 + self.global_step: int = 0 # the total batch index across all epochs - self.total_batch_idx = 0 + self.total_batch_idx: int = 0 # the current batch index in the loop that runs over the dataloader(s) - self.iteration_count = 0 + self.iteration_count: int = 0 # the current split index when the batch gets split into chunks in truncated backprop through time - self.split_idx = None + self.split_idx: Optional[int] = None - self._dataloader_idx = None - self._should_stop = False + self._dataloader_idx: Optional[int] = None + self._should_stop: bool = False - self.is_last_batch = None - self.batches_seen = 0 - self.warning_cache = WarningCache() - self.epoch_output = None + self.is_last_batch: Optional[bool] = None + self.batches_seen: int = 0 + self.warning_cache: WarningCache = WarningCache() + self.epoch_output: Optional[List[List[STEP_OUTPUT]]] = None - self.batch_loop = None + self.batch_loop: Optional[TrainingBatchLoop] = None @property def batch_idx(self) -> int: + """Returns the current batch index (within this epoch)""" return self.iteration_count @property - def done(self): + def done(self) -> bool: + """Returns whether the training should be stopped. + The criteria are that the number of steps reached the max steps, + the last batch is reached or the trainer signals to stop (e.g. by early stopping). + + """ max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch) - def connect(self, trainer: 'pl.Trainer', *args, **kwargs): + def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: + """Connects the loop with all necessary parts like trainer and accelerators""" + + # TODO(@justusschock): should we forward *args and **kwargs to lower loops? + # TODO(@justusschock): can we make the trainer a proxy here? self.trainer = trainer self.batch_loop = TrainingBatchLoop() self.batch_loop.connect(trainer) - def run(self, *args, **kwargs): + def run(self, *args: Any, **kwargs: Any) -> List[List[STEP_OUTPUT]]: + """Runs over the dataloader until a StopIteration occurs. + + Returns: + the outputs of each step for each optimizer + """ self.reset() self.on_run_start() @@ -83,6 +99,7 @@ def run(self, *args, **kwargs): return self.on_run_end() def reset(self) -> None: + """Resets the internal state of the loop for a new run""" self.iteration_count = 0 self.batches_seen = 0 self.is_last_batch = False @@ -92,7 +109,16 @@ def reset(self) -> None: # track epoch output self.epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] - def advance(self, dataloader_iter: Iterator, **kwargs): + def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: + """Runs a single training batch. + + Args: + dataloader_iter: the iterator over the dataloader producing the new batch + + Raises: + StopIteration: When the epoch is canceled by the user returning -1 + + """ _, (batch, is_last) = next(dataloader_iter) self.is_last_batch = is_last @@ -122,6 +148,12 @@ def advance(self, dataloader_iter: Iterator, **kwargs): self.trainer.logger_connector.update_train_step_metrics() def on_advance_end(self): + """Runs validation and Checkpointing if necessary. + + Raises: + StopIteration: if :attr:`done` evaluates to ``True`` to finish this epoch + + """ # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- @@ -148,7 +180,15 @@ def on_advance_end(self): if self.done: raise StopIteration - def on_run_end(self): + def on_run_end(self) -> List[List[STEP_OUTPUT]]: + """Calls the on_epoch_end hook. + + Returns: + The output of each training step for each optimizer + + Raises: + MisconfigurationException: ``train_epoch_end`` does not return ``None`` + """ if self.batches_seen == 0: # dataloader/iterator did not produce a batch return @@ -182,7 +222,8 @@ def on_run_end(self): self.trainer.logger_connector.on_epoch_end() return self.epoch_output - def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: + def _on_train_epoch_end_hook(self, processed_epoch_output: List[List[STEP_OUTPUT]]) -> None: + """Runs ``on_train_epoch_end hook``.""" # We cannot rely on Trainer.call_hook because the signatures might be different across # lightning module and callback # As a result, we need to inspect if the module accepts `outputs` in `on_train_epoch_end` @@ -222,11 +263,27 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: # restore current_fx when nested context self.trainer.lightning_module._current_fx_name = prev_fx_name - def _num_training_batches_reached(self, is_last_batch=False): + def _num_training_batches_reached(self, is_last_batch: bool = False) -> bool: + """Checks if we are in the last batch or if there are more batches to follow.""" + + # TODO: Can we combine this with training_batch_loop's arg that does a similar check? return self.batches_seen == self.trainer.num_training_batches or is_last_batch # TODO(@awaelchli): merge with on_advance_end() - def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end( + self, epoch_output: List[List[STEP_OUTPUT]], batch_end_outputs: STEP_OUTPUT, batch: Any, batch_idx: int, + dataloader_idx: int + ) -> None: + """Runs ``on_train_batch_end`` hook. + + Args: + epoch_output: the store to add the batch outputs to + batch_end_outputs: the outputs of the batch step + batch: the batch this outputs were produced with + batch_idx: the index of the current batch + dataloader_idx: the index of the dataloader producing the current batch + + """ batch_end_outputs = [opt_idx_out for opt_idx_out in batch_end_outputs if len(opt_idx_out)] processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True) @@ -238,7 +295,10 @@ def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, # figure out what to track for epoch end self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs) - def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): + def track_epoch_end_reduce_metrics( + self, epoch_output: List[List[STEP_OUTPUT]], batch_end_outputs: STEP_OUTPUT + ) -> None: + """Adds the batch outputs to the epoch outputs and prepares reduction""" hook_overridden = self._should_add_batch_output_to_epoch_output() if not hook_overridden: return @@ -255,9 +315,11 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): epoch_output[opt_idx].append(opt_outputs) def _should_add_batch_output_to_epoch_output(self) -> bool: - # We add to the epoch outputs if - # 1. The model defines training_epoch_end OR - # 2. The model overrides on_train_epoch_end which has `outputs` in the signature + """ + We add to the epoch outputs if + 1. The model defines training_epoch_end OR + 2. The model overrides on_train_epoch_end which has `outputs` in the signature + """ # TODO: in v1.5 this only needs to check if training_epoch_end is overridden lightning_module = self.trainer.lightning_module if is_overridden("training_epoch_end", model=lightning_module): @@ -327,6 +389,7 @@ def _prepare_outputs( return processed_outputs def update_lr_schedulers(self, interval: str) -> None: + """updates the lr schedulers based on the given interval""" if interval == "step": finished_accumulation = self.batch_loop._accumulated_batches_reached_before_iter_count_update() finished_epoch = self._num_training_batches_reached() @@ -337,7 +400,8 @@ def update_lr_schedulers(self, interval: str) -> None: opt_indices=[opt_idx for opt_idx, _ in self.batch_loop.get_active_optimizers(self.total_batch_idx)], ) - def increment_accumulated_grad_global_step(self): + def increment_accumulated_grad_global_step(self) -> None: + """increments global step""" num_accumulated_batches_reached = self.batch_loop._accumulated_batches_reached_after_iter_count_update() num_training_batches_reached = self._num_training_batches_reached() @@ -372,7 +436,8 @@ def should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 return is_val_check_batch - def save_loggers_on_train_batch_end(self): + def save_loggers_on_train_batch_end(self) -> None: + """Flushes loggers to disk""" # when loggers should save to disk should_flush_logs = self.trainer.logger_connector.should_flush_logs if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: From 8b545055cc5abe077f3c11482ca2c3108cd8112c Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Fri, 11 Jun 2021 12:23:56 +0200 Subject: [PATCH 428/455] fix merge error --- pytorch_lightning/loops/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index be74a9f8d7d6d..87433e94eb6a3 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -112,4 +112,4 @@ def on_advance_end(self) -> None: """Hook to be called each time after :attr:`advance` is called.""" def on_run_end(self) -> Any: - """Hook to be called at the end of the run. Its return argument is returned from :attr:`run`. + """Hook to be called at the end of the run. Its return argument is returned from :attr:`run`.""" From e4ffa6cd9e436dcc49201599f10ac2935d62ac4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 11 Jun 2021 23:58:03 +0200 Subject: [PATCH 429/455] integrate train loop changes from master --- pytorch_lightning/loops/training_batch_loop.py | 6 ++---- pytorch_lightning/loops/training_epoch_loop.py | 6 +++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index cacfc216519a8..6a54b2e60f2a8 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -246,17 +246,15 @@ def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) - self._check_finite(opt_closure_result.loss) def on_after_backward( - self, training_step_output: STEP_OUTPUT, batch_idx: int, untouched_loss: torch.Tensor + self, batch_idx: int, untouched_loss: torch.Tensor ) -> None: """Calls ``on_after_backward`` hook and tracks loss history Args: - training_step_output: the result from the training step (either a dict with key 'loss' or the loss tensor) batch_idx: the index of the current batch untouched_loss: the original loss value """ - void(training_step_output) # insert after step hook self.trainer.call_hook("on_after_backward") @@ -594,7 +592,7 @@ def training_step_and_backward( # hook - call this hook only # when gradients have finished to accumulate if not self.should_accumulate(): - self.on_after_backward(result.training_step_output, batch_idx, result.loss) + self.on_after_backward(batch_idx, result.loss) # check if loss or model weights are nan if self.trainer.terminate_on_nan: diff --git a/pytorch_lightning/loops/training_epoch_loop.py b/pytorch_lightning/loops/training_epoch_loop.py index 7b0f02eb3c1f0..87c231d01c6f4 100644 --- a/pytorch_lightning/loops/training_epoch_loop.py +++ b/pytorch_lightning/loops/training_epoch_loop.py @@ -202,7 +202,7 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]: # get the model and call model.training_epoch_end model = self.trainer.lightning_module - if is_overridden('training_epoch_end', model=model): + if is_overridden('training_epoch_end', model): # run training_epoch_end # refresh the result for custom logging at the epoch level model._current_fx_name = 'training_epoch_end' @@ -322,10 +322,10 @@ def _should_add_batch_output_to_epoch_output(self) -> bool: """ # TODO: in v1.5 this only needs to check if training_epoch_end is overridden lightning_module = self.trainer.lightning_module - if is_overridden("training_epoch_end", model=lightning_module): + if is_overridden("training_epoch_end", lightning_module): return True - if is_overridden("on_train_epoch_end", model=lightning_module): + if is_overridden("on_train_epoch_end", lightning_module): model_hook_fx = getattr(lightning_module, "on_train_epoch_end") if is_param_in_hook_signature(model_hook_fx, "outputs"): return True From 69ed0e7228b8582cfe720f737b017e508acbc40c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 11 Jun 2021 22:08:00 +0000 Subject: [PATCH 430/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/training_batch_loop.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index 6a54b2e60f2a8..fad4ab22ab6d0 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -245,9 +245,7 @@ def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) - if self.trainer.terminate_on_nan: self._check_finite(opt_closure_result.loss) - def on_after_backward( - self, batch_idx: int, untouched_loss: torch.Tensor - ) -> None: + def on_after_backward(self, batch_idx: int, untouched_loss: torch.Tensor) -> None: """Calls ``on_after_backward`` hook and tracks loss history Args: From eeebc9aff462d049998c18fbb975ad43dc7510ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 20:10:04 +0200 Subject: [PATCH 431/455] fix tpipes moving model to cpu and leaving it there. --- tests/helpers/pipelines.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/helpers/pipelines.py b/tests/helpers/pipelines.py index f7a6484f6b27e..96b64bdf1788d 100644 --- a/tests/helpers/pipelines.py +++ b/tests/helpers/pipelines.py @@ -96,6 +96,7 @@ def run_model_test( @torch.no_grad() def run_prediction_eval_model_template(trained_model, dataloader, min_acc=0.50): + orig_device = trained_model.device # run prediction on 1 batch trained_model.cpu() trained_model.eval() @@ -108,3 +109,4 @@ def run_prediction_eval_model_template(trained_model, dataloader, min_acc=0.50): acc = accuracy(y_hat.cpu(), y.cpu(), top_k=2).item() assert acc >= min_acc, f"This model is expected to get > {min_acc} in test set (it got {acc})" + trained_model.to(orig_device) From ce9dd2a34d63ea8f83eed9c876973bbb477e5e3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 20:34:51 +0200 Subject: [PATCH 432/455] don't reset fit loop don't reset fit loop --- pytorch_lightning/loops/fit_loop.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 88cc809889dad..d080e3821b319 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -165,8 +165,7 @@ def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: self.training_loop.connect(trainer) def reset(self) -> None: - """Resets the trainer's internal state""" - self.iteration_count = 0 + """Resets the internal state of this loop""" def run(self) -> None: """Loops over epochs if the training should not be skipped""" From 80e225ae82484cf8fdc70c0339cb87e90850760f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 14 Jun 2021 11:52:04 +0200 Subject: [PATCH 433/455] fix test iteration count <-> batch_idx reset --- .../loops/training_batch_loop.py | 40 +++++++------------ .../loops/training_epoch_loop.py | 4 +- 2 files changed, 17 insertions(+), 27 deletions(-) diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index fad4ab22ab6d0..54c2e8dc16969 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -45,6 +45,7 @@ def __init__(self) -> None: self.accumulated_loss: torch.Tensor = None self.batch_outputs: Optional[List[List[STEP_OUTPUT]]] = None self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=20) + self.batch_idx: int = 0 self.split_idx: Optional[int] = None self.warning_cache: WarningCache = WarningCache() @@ -70,6 +71,13 @@ def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: void(*args, **kwargs) self.trainer = trainer + def reset(self) -> None: + """Resets the loop state""" + self._hiddens = None + self.batch_idx = 0 + # TODO(@awaelchli): let loops track individual outputs + self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] + def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict: """Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks @@ -98,12 +106,6 @@ def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict: return AttributeDict(signal=0, training_step_output=self.batch_outputs) - def reset(self) -> None: - """Resets the loop state""" - self._hiddens = None - # TODO(@awaelchli): let loops track individual outputs - self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] - def on_run_start(self, batch: Any, batch_idx: int, dataloader_idx: int): """Splits the data into tbptt splits @@ -126,6 +128,7 @@ def advance(self, batch, batch_idx, dataloader_idx): """ void(batch, dataloader_idx) split_idx, split_batch = self._remaining_splits.pop(0) + self.batch_idx = batch_idx self.split_idx = split_idx # let logger connector extract current batch size @@ -440,39 +443,26 @@ def track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, tor ) return grad_norm_dict - def _accumulated_batches_reached_before_iter_count_update(self) -> bool: + def _accumulated_batches_reached(self) -> bool: """ Determine if accumulation will be finished by the end of the current batch. - This returns the correct answer only if the loop iteration count has not yet been updated for the current batch - in progress. Use `_accumulated_batches_reached_after_iter_count_update` otherwise. """ - # TODO(@awaelchli): use progress tracking of batches instead of iteration count, because iteration count may - # reset iteration count is required to be global here, not reset - return (self.iteration_count + 1) % self.trainer.accumulate_grad_batches == 0 - - def _accumulated_batches_reached_after_iter_count_update(self): - """ - Determine if accumulation has finished. - This returns the correct answer only right after a batch has ended, i.e., the iteration count was just updated. - Use `_accumulated_batches_reached_after_iter_count_update` otherwise. - """ - return self.iteration_count % self.trainer.accumulate_grad_batches == 0 + # TODO(@awaelchli): use progress tracking of batches instead of manual batch_idx + return (self.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 def _num_training_batches_reached(self, is_last_batch: bool = False) -> bool: """Checks whether sufficient training batches have been processed. Args: is_last_batch: Whether the current batch is the last one - """ - # TODO(@awaelchli): use progress tracking of batches instead of iteration count, because iteration - # count may reset - return (self.iteration_count + 1) == self.trainer.num_training_batches or is_last_batch + # TODO(@awaelchli): use progress tracking of batches instead of manual batch_idx + return (self.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch def should_accumulate(self) -> bool: """Checks if the optimizer step should be performed or gradients should be accumulated for the current step.""" # checks if backward or backward + optimizer step (via closure) - accumulation_done = self._accumulated_batches_reached_before_iter_count_update() + accumulation_done = self._accumulated_batches_reached() is_final_batch = self._num_training_batches_reached() return not (accumulation_done or is_final_batch) diff --git a/pytorch_lightning/loops/training_epoch_loop.py b/pytorch_lightning/loops/training_epoch_loop.py index 87c231d01c6f4..2ba472c291430 100644 --- a/pytorch_lightning/loops/training_epoch_loop.py +++ b/pytorch_lightning/loops/training_epoch_loop.py @@ -391,7 +391,7 @@ def _prepare_outputs( def update_lr_schedulers(self, interval: str) -> None: """updates the lr schedulers based on the given interval""" if interval == "step": - finished_accumulation = self.batch_loop._accumulated_batches_reached_before_iter_count_update() + finished_accumulation = self.batch_loop._accumulated_batches_reached() finished_epoch = self._num_training_batches_reached() if not finished_accumulation and not finished_epoch: return @@ -402,7 +402,7 @@ def update_lr_schedulers(self, interval: str) -> None: def increment_accumulated_grad_global_step(self) -> None: """increments global step""" - num_accumulated_batches_reached = self.batch_loop._accumulated_batches_reached_after_iter_count_update() + num_accumulated_batches_reached = self.batch_loop._accumulated_batches_reached() num_training_batches_reached = self._num_training_batches_reached() # progress global step according to grads progress From 4880b26943ab22abc9d8e6ad2248fe986cfe5871 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 14 Jun 2021 11:52:47 +0200 Subject: [PATCH 434/455] replace torch.Tensor -> Tensor --- .../loops/training_batch_loop.py | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index 54c2e8dc16969..dbd22f08c3985 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -20,6 +20,7 @@ import numpy as np import torch from deprecate import void +from torch import Tensor from torch.optim import Optimizer import pytorch_lightning as pl @@ -42,14 +43,14 @@ class TrainingBatchLoop(Loop): def __init__(self) -> None: super().__init__() - self.accumulated_loss: torch.Tensor = None + self.accumulated_loss: Optional[Tensor] = None self.batch_outputs: Optional[List[List[STEP_OUTPUT]]] = None self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=20) self.batch_idx: int = 0 self.split_idx: Optional[int] = None self.warning_cache: WarningCache = WarningCache() - self._hiddens: Optional[torch.Tensor] = None + self._hiddens: Optional[Tensor] = None self._optimizer_freq_cumsum: Optional[int] = None self._remaining_splits: Optional[List[Any]] = None self._skip_backward: bool = False @@ -210,9 +211,9 @@ def training_step_and_backward_closure( batch_idx: int, opt_idx: int, optimizer: Optimizer, - hiddens: torch.Tensor, + hiddens: Tensor, return_result: AttributeDict, - ) -> Optional[torch.Tensor]: + ) -> Optional[Tensor]: """Closure for training step and backward Args: @@ -248,7 +249,7 @@ def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) - if self.trainer.terminate_on_nan: self._check_finite(opt_closure_result.loss) - def on_after_backward(self, batch_idx: int, untouched_loss: torch.Tensor) -> None: + def on_after_backward(self, batch_idx: int, untouched_loss: Tensor) -> None: """Calls ``on_after_backward`` hook and tracks loss history Args: @@ -270,13 +271,13 @@ def _check_training_step_output(self, training_step_output: STEP_OUTPUT) -> None training_step_output: the output of the training step (before wrapping in an AttributeDict) """ - if isinstance(training_step_output, torch.Tensor) and not self.trainer.lightning_module.automatic_optimization: + if isinstance(training_step_output, Tensor) and not self.trainer.lightning_module.automatic_optimization: if training_step_output.grad_fn is None: # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor") elif self.trainer.lightning_module.automatic_optimization: if not any(( - isinstance(training_step_output, torch.Tensor), + isinstance(training_step_output, Tensor), (isinstance(training_step_output, Mapping) and 'loss' in training_step_output), training_step_output is None )): @@ -285,8 +286,7 @@ def _check_training_step_output(self, training_step_output: STEP_OUTPUT) -> None "a dict with key 'loss' or None (where the step will be skipped)." ) - def training_step(self, split_batch: Any, batch_idx: int, opt_idx: int, - hiddens: torch.Tensor) -> Optional[AttributeDict]: + def training_step(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor) -> Optional[AttributeDict]: """Performs the actual train step with the tied hooks. Args: @@ -354,7 +354,7 @@ def _process_training_step_output(self, training_step_output: STEP_OUTPUT) -> Op results.extra = training_step_output # handle scalar return - elif isinstance(training_step_output, torch.Tensor): + elif isinstance(training_step_output, Tensor): loss = training_step_output # map to results under the hood @@ -425,7 +425,7 @@ def optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, """ self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) - def track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, torch.Tensor]: + def track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Tensor]: """Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer. Args: @@ -480,7 +480,7 @@ def tbptt_split_batch(self, batch: Any) -> List[Any]: splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps) return splits - def build_train_args(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: torch.Tensor) -> List[Any]: + def build_train_args(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor) -> List[Any]: """Builds arguments for train step Args: @@ -558,7 +558,7 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False) -> Generator def training_step_and_backward( self, split_batch: Any, batch_idx: int, opt_idx: int, optimizer: torch.optim.Optimizer, - hiddens: Optional[torch.Tensor] + hiddens: Optional[Tensor] ) -> STEP_OUTPUT: """Wrap forward, zero_grad and backward in a closure so second order methods work""" with self.trainer.profiler.profile("training_step_and_backward"): @@ -593,7 +593,7 @@ def training_step_and_backward( return result - def _check_finite(self, loss: torch.Tensor) -> None: + def _check_finite(self, loss: Tensor) -> None: """Checks fotr finite parameters and loss values. Args: @@ -620,7 +620,7 @@ def backward( should_accumulate = self.should_accumulate() # backward can be called manually in the training loop - if isinstance(result, torch.Tensor): + if isinstance(result, Tensor): self.trainer.accelerator.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs) else: result.closure_loss = self.trainer.accelerator.backward( @@ -634,7 +634,7 @@ def backward( self.trainer.lightning_module._current_fx_name = "on_after_backward" self.trainer.lightning_module.log_grad_norm(grad_norm_dict) - def update_running_loss(self, current_loss: torch.Tensor) -> None: + def update_running_loss(self, current_loss: Tensor) -> None: """Updates the running loss value with the current value""" if self.trainer.lightning_module.automatic_optimization: # track total loss for logging (avoid mem leaks) @@ -668,8 +668,7 @@ def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[i opt_idx = int(np.argmax(self.optimizer_freq_cumsum > current_place_in_loop)) return [(opt_idx, self.trainer.optimizers[opt_idx])] - def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, - hiddens: Optional[torch.Tensor]) -> Dict[str, Any]: + def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optional[Tensor]) -> Dict[str, Any]: """Builds the keyword arguments for training_step Args: From 5461f736c07dfa879791592fa89dcfd7bc407757 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 14 Jun 2021 12:57:00 +0200 Subject: [PATCH 435/455] fix attribute error to block_ddp_sync_behaviour --- pytorch_lightning/core/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 174631ae73e8b..b34e6662a7ab2 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -120,7 +120,7 @@ def toggle_model(self, sync_grad: bool = True): during the accumulation phase. Setting `sync_grad` to False will block this synchronization and improve performance. """ - with self._trainer.train_loop.block_ddp_sync_behaviour(not sync_grad): + with self._trainer.train_loop.training_loop.batch_loop.block_ddp_sync_behaviour(not sync_grad): self._toggle_model() yield self._untoggle_model() From 0fe6d9f5e44c78a898345c33c6df738a8473bf4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 14 Jun 2021 13:38:10 +0200 Subject: [PATCH 436/455] ignore mypy errors --- setup.cfg | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/setup.cfg b/setup.cfg index accd6a7f79abb..b90c0663c5cef 100644 --- a/setup.cfg +++ b/setup.cfg @@ -130,6 +130,10 @@ ignore_errors = True [mypy-pytorch_lightning.loggers.*] ignore_errors = True +# todo: add proper typing to this module... +[mypy-pytorch_lightning.loops.*] +ignore_errors = True + # todo: add proper typing to this module... [mypy-pytorch_lightning.metrics.*] ignore_errors = True From 5497fc0897bff0a84a32e1080262c99eea9e5c29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 14 Jun 2021 14:01:56 +0200 Subject: [PATCH 437/455] fix flake8 and yapf conflict --- pytorch_lightning/loops/training_batch_loop.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index dbd22f08c3985..25c7c00845dbc 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -433,8 +433,9 @@ def track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Ten """ # track gradient norms grad_norm_dict = {} - if ((self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 - and float(self.trainer.track_grad_norm) > 0): + can_log = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 + should_track = float(self.trainer.track_grad_norm) > 0 + if should_track and can_log: grad_norm_dict = grad_norm(self.trainer.lightning_module, self.trainer.track_grad_norm) # clip gradients From 4c51c453d7e41302da012fcd8dd8d173a14d0ff2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 14 Jun 2021 16:23:32 +0200 Subject: [PATCH 438/455] remove redundant override --- .../loops/training_epoch_loop.py | 22 ------------------- 1 file changed, 22 deletions(-) diff --git a/pytorch_lightning/loops/training_epoch_loop.py b/pytorch_lightning/loops/training_epoch_loop.py index 2ba472c291430..79defa207afca 100644 --- a/pytorch_lightning/loops/training_epoch_loop.py +++ b/pytorch_lightning/loops/training_epoch_loop.py @@ -76,28 +76,6 @@ def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: self.batch_loop = TrainingBatchLoop() self.batch_loop.connect(trainer) - def run(self, *args: Any, **kwargs: Any) -> List[List[STEP_OUTPUT]]: - """Runs over the dataloader until a StopIteration occurs. - - Returns: - the outputs of each step for each optimizer - """ - self.reset() - self.on_run_start() - - # TODO(@awaelchli): while condition is different from super.run(), - # redesign the done conditions and use the base class run() implementation - while True: - try: - self.on_advance_start(*args, **kwargs) - self.advance(*args, **kwargs) - self.on_advance_end() - self.iteration_count += 1 - except StopIteration: - break - - return self.on_run_end() - def reset(self) -> None: """Resets the internal state of the loop for a new run""" self.iteration_count = 0 From 8f68b613efe3ee32d16966ee1f2315415b928620 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 14 Jun 2021 17:29:17 +0200 Subject: [PATCH 439/455] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Ethan Harris Co-authored-by: Carlos Mocholí --- pytorch_lightning/loops/base.py | 4 ++-- pytorch_lightning/loops/fit_loop.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 87433e94eb6a3..76217dd3a8ee6 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -26,7 +26,7 @@ class Loop(ABC): Basic Loops interface. All classes derived from this must implement the following properties and methods: * :attr`done` (property): Condition to break the loop - * :attr`reset` (method): Resets the internal state between multiple calls of :attr`run` + * :attr:`reset` (method): Resets the internal state between multiple calls of :attr:`run` * :attr`advance` (method): Implements one step of the loop This class implements the following loop structure: @@ -66,7 +66,7 @@ def run(self, *args: Any, **kwargs: Any) -> Any: The main entry point to the loop. Will frequently check the :attr:`done` condition and calls :attr:`advance` - until :attr`done` evaluates to ``True``. + until :attr:`done` evaluates to ``True``. Returns: the output of :attr`on_run_end` (often outputs collected from each step of the loop) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index d080e3821b319..a691b2e8056b5 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -77,8 +77,7 @@ def global_step(self) -> int: @global_step.setter def global_step(self, value: int) -> None: - """Sets the global step (forwards to training_loop) - """ + """Sets the global step (forwards to training_loop)""" self.training_loop.global_step = value @property @@ -211,7 +210,6 @@ def advance(self) -> None: with self.trainer.profiler.profile("run_training_epoch"): # run train epoch epoch_output = self.training_loop.run(train_dataloader) - # log epoch metrics if epoch_output is None: return From 0150f6cfd78b2c613a8f225fab5cf30f65c804e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 14 Jun 2021 17:30:23 +0200 Subject: [PATCH 440/455] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí Co-authored-by: Ethan Harris Co-authored-by: Jirka Borovec --- pytorch_lightning/loops/training_batch_loop.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index 25c7c00845dbc..e57db0f02cd40 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -86,7 +86,6 @@ def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict: batch: the current batch to run the train step on batch_idx: the index of the current batch dataloader_idx: the index of the dataloader producing the current batch - """ if batch is None: self.warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") @@ -114,7 +113,6 @@ def on_run_start(self, batch: Any, batch_idx: int, dataloader_idx: int): batch: the current batch to run the trainstep on batch_idx: the index of the current batch dataloader_idx: the index of the dataloader producing the current batch - """ void(batch_idx, dataloader_idx) self._remaining_splits = list(enumerate(self.tbptt_split_batch(batch))) @@ -160,7 +158,6 @@ def _run_optimization( split_batch: the current tbptt split of the whole batch opt_idx: the index of the current optimizer optimizer: the current optimizer - """ # TODO(@awaelchli): In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change # opt_idx=0 to opt_idx=None in the signature here @@ -223,7 +220,6 @@ def training_step_and_backward_closure( optimizer: the current optimizer hiddens: the hidden state of the recurrent net return_result: the storage of the trainstep results - """ result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) @@ -376,7 +372,6 @@ def optimizer_step( batch_idx: the index of the current batch train_step_and_backward_closure: the closure function performing the train step and computing the gradients. By default called by the optimizer (if possible) - """ model_ref = self.trainer.lightning_module @@ -410,7 +405,6 @@ def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None: Args: optimizer: the current optimizer - """ self.trainer.call_hook('on_before_zero_grad', optimizer) @@ -421,7 +415,6 @@ def optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, batch_idx: the index of the current batch optimizer: the current optimizer opt_idx: the index of the current optimizer - """ self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) @@ -472,7 +465,6 @@ def tbptt_split_batch(self, batch: Any) -> List[Any]: Args: batch: the current batch to split - """ splits = [batch] if self.trainer.truncated_bptt_steps is not None: @@ -559,7 +551,7 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False) -> Generator def training_step_and_backward( self, split_batch: Any, batch_idx: int, opt_idx: int, optimizer: torch.optim.Optimizer, - hiddens: Optional[Tensor] + hiddens: Optional[Tensor], ) -> STEP_OUTPUT: """Wrap forward, zero_grad and backward in a closure so second order methods work""" with self.trainer.profiler.profile("training_step_and_backward"): From fd90c10974c73c52f5736ccff4e3d8d78026e888 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 14 Jun 2021 17:31:02 +0200 Subject: [PATCH 441/455] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- pytorch_lightning/loops/training_epoch_loop.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/loops/training_epoch_loop.py b/pytorch_lightning/loops/training_epoch_loop.py index 79defa207afca..b2fe4be508013 100644 --- a/pytorch_lightning/loops/training_epoch_loop.py +++ b/pytorch_lightning/loops/training_epoch_loop.py @@ -62,7 +62,6 @@ def done(self) -> bool: """Returns whether the training should be stopped. The criteria are that the number of steps reached the max steps, the last batch is reached or the trainer signals to stop (e.g. by early stopping). - """ max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch) @@ -260,7 +259,6 @@ def on_train_batch_end( batch: the batch this outputs were produced with batch_idx: the index of the current batch dataloader_idx: the index of the dataloader producing the current batch - """ batch_end_outputs = [opt_idx_out for opt_idx_out in batch_end_outputs if len(opt_idx_out)] processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True) From 153d2649d8573d3c54d65fdb591a74b3b2791ef6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 14 Jun 2021 15:31:24 +0000 Subject: [PATCH 442/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/training_batch_loop.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index e57db0f02cd40..14296508c865c 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -550,7 +550,11 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False) -> Generator yield None def training_step_and_backward( - self, split_batch: Any, batch_idx: int, opt_idx: int, optimizer: torch.optim.Optimizer, + self, + split_batch: Any, + batch_idx: int, + opt_idx: int, + optimizer: torch.optim.Optimizer, hiddens: Optional[Tensor], ) -> STEP_OUTPUT: """Wrap forward, zero_grad and backward in a closure so second order methods work""" From 4eb0eb1d7f3e2052528af26eaa0f0c7e42b23fd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 14 Jun 2021 17:36:14 +0200 Subject: [PATCH 443/455] remove all empty space between atoms --- pytorch_lightning/loops/base.py | 8 +------- pytorch_lightning/loops/training_batch_loop.py | 2 -- pytorch_lightning/loops/training_epoch_loop.py | 2 -- 3 files changed, 1 insertion(+), 11 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 76217dd3a8ee6..d1da76ba8b6ba 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -41,7 +41,6 @@ class Loop(ABC): on_advance_end() on_run_end() - """ def __init__(self) -> None: @@ -90,23 +89,18 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None: Hook to be called as the first thing after entering :attr:`run` (except the state reset). Accepts all arguments passed to :attr:`run`. - """ void(*args, **kwargs) def on_advance_start(self, *args: Any, **kwargs: Any) -> None: """ Hook to be called each time before :attr:`advance` is called. Accepts all arguments passed to :attr`run`. - """ void(*args, **kwargs) @abstractmethod def advance(self, *args: Any, **kwargs: Any) -> None: - """ - Performs a single step. Accepts all arguments passed to :attr:`run`. - - """ + """Performs a single step. Accepts all arguments passed to :attr:`run`.""" def on_advance_end(self) -> None: """Hook to be called each time after :attr:`advance` is called.""" diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index 14296508c865c..6f3b11f050a3c 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -538,7 +538,6 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False) -> Generator Returns: context manager with sync behaviour off - """ if ( isinstance(self.trainer.training_type_plugin, ParallelPlugin) @@ -610,7 +609,6 @@ def backward( result: The output of the trainstep (including the loss value) optimizer: The optimizer optimizing the gradients to call backward for opt_idx: the index of the current optimizer - """ self.trainer.dev_debugger.track_event("backward_call") diff --git a/pytorch_lightning/loops/training_epoch_loop.py b/pytorch_lightning/loops/training_epoch_loop.py index b2fe4be508013..cfa008779b6c6 100644 --- a/pytorch_lightning/loops/training_epoch_loop.py +++ b/pytorch_lightning/loops/training_epoch_loop.py @@ -94,7 +94,6 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: Raises: StopIteration: When the epoch is canceled by the user returning -1 - """ _, (batch, is_last) = next(dataloader_iter) self.is_last_batch = is_last @@ -129,7 +128,6 @@ def on_advance_end(self): Raises: StopIteration: if :attr:`done` evaluates to ``True`` to finish this epoch - """ # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK From 70cdb144238276315412fc5b242fd1c8bfd48699 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 14 Jun 2021 17:38:58 +0200 Subject: [PATCH 444/455] carlos --- pytorch_lightning/loops/fit_loop.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index a691b2e8056b5..cf458e5d9152c 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -51,11 +51,8 @@ def __init__( max_steps: Optional[int] = None ): super().__init__() - # If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000 self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs - # If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1 self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs - self.training_loop = TrainingEpochLoop(min_steps, max_steps) self.results = ResultCollection(training=True) From bf26aa3e842e1746c03f89d9d2ff5273b0467666 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Mon, 14 Jun 2021 17:45:04 +0200 Subject: [PATCH 445/455] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí Co-authored-by: Jirka Borovec --- pytorch_lightning/loops/fit_loop.py | 5 ++--- pytorch_lightning/loops/training_batch_loop.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index a691b2e8056b5..635fa85d62c41 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -66,8 +66,7 @@ def current_epoch(self) -> int: @current_epoch.setter def current_epoch(self, value: int) -> None: - """Setter for the current epoch - """ + """Setter for the current epoch""" self.iteration_count = value @property @@ -151,7 +150,7 @@ def done(self) -> bool: f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' ' not been met. Training will continue...' ) - self.trainer.should_stop = False + self.trainer.should_stop = should_stop return stop_steps or should_stop or stop_epochs diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index 6f3b11f050a3c..65b761f81bee1 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -691,7 +691,7 @@ def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optio " the old signature will be removed in v1.5", DeprecationWarning ) step_kwargs['optimizer_idx'] = opt_idx - elif not has_opt_idx_in_train_step and self.trainer.lightning_module.automatic_optimization: + elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization: raise ValueError( f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but" ' `training_step` is missing the `optimizer_idx` argument.' From ffc4f458d4ccb7cdd9f1cabed7b13a9eacc89861 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Mon, 14 Jun 2021 17:48:10 +0200 Subject: [PATCH 446/455] Apply suggestions from code review Co-authored-by: Jirka Borovec --- pytorch_lightning/core/optimizer.py | 2 +- pytorch_lightning/trainer/trainer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index b34e6662a7ab2..1da8a7af36221 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -120,7 +120,7 @@ def toggle_model(self, sync_grad: bool = True): during the accumulation phase. Setting `sync_grad` to False will block this synchronization and improve performance. """ - with self._trainer.train_loop.training_loop.batch_loop.block_ddp_sync_behaviour(not sync_grad): + with self._trainer.fit_loop.training_loop.batch_loop.block_ddp_sync_behaviour(not sync_grad): self._toggle_model() yield self._untoggle_model() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 12471d58ba976..c0e6ad0658481 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -419,7 +419,7 @@ def __init__( def _setup_on_init( self, num_sanity_val_steps: int, - ): + ) -> None: self.should_stop = False self.state = TrainerState() self.num_training_batches = 0 @@ -430,7 +430,7 @@ def _setup_on_init( else: self.num_sanity_val_steps = num_sanity_val_steps - def _setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): + def _setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None) -> None: # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) From 3373cc89b4f398f887a9771450b6a49bdd8b69ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 14 Jun 2021 18:28:01 +0200 Subject: [PATCH 447/455] resolve a todo integrating on_train_batch_end with on_advance_end --- .../loops/training_epoch_loop.py | 41 ++++--------------- 1 file changed, 9 insertions(+), 32 deletions(-) diff --git a/pytorch_lightning/loops/training_epoch_loop.py b/pytorch_lightning/loops/training_epoch_loop.py index cfa008779b6c6..500cf3bfe8c99 100644 --- a/pytorch_lightning/loops/training_epoch_loop.py +++ b/pytorch_lightning/loops/training_epoch_loop.py @@ -109,14 +109,16 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: if batch_output.signal == -1: raise StopIteration + batch_end_outputs = [opt_idx_out for opt_idx_out in batch_output.training_step_output if len(opt_idx_out)] + processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True) + # hook - self.on_train_batch_end( - self.epoch_output, - batch_output.training_step_output, - batch, - self.iteration_count, - self._dataloader_idx, - ) + self.trainer.call_hook('on_train_batch_end', processed_batch_end_outputs, batch, self.iteration_count, self._dataloader_idx) + self.trainer.call_hook('on_batch_end') + self.trainer.logger_connector.on_batch_end() + + # figure out what to track for epoch end + self.track_epoch_end_reduce_metrics(self.epoch_output, batch_end_outputs) # ----------------------------------------- # SAVE METRICS TO LOGGERS AND PROGRESS_BAR @@ -244,31 +246,6 @@ def _num_training_batches_reached(self, is_last_batch: bool = False) -> bool: # TODO: Can we combine this with training_batch_loop's arg that does a similar check? return self.batches_seen == self.trainer.num_training_batches or is_last_batch - # TODO(@awaelchli): merge with on_advance_end() - def on_train_batch_end( - self, epoch_output: List[List[STEP_OUTPUT]], batch_end_outputs: STEP_OUTPUT, batch: Any, batch_idx: int, - dataloader_idx: int - ) -> None: - """Runs ``on_train_batch_end`` hook. - - Args: - epoch_output: the store to add the batch outputs to - batch_end_outputs: the outputs of the batch step - batch: the batch this outputs were produced with - batch_idx: the index of the current batch - dataloader_idx: the index of the dataloader producing the current batch - """ - batch_end_outputs = [opt_idx_out for opt_idx_out in batch_end_outputs if len(opt_idx_out)] - processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True) - - # hook - self.trainer.call_hook('on_train_batch_end', processed_batch_end_outputs, batch, batch_idx, dataloader_idx) - self.trainer.call_hook('on_batch_end') - self.trainer.logger_connector.on_batch_end() - - # figure out what to track for epoch end - self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs) - def track_epoch_end_reduce_metrics( self, epoch_output: List[List[STEP_OUTPUT]], batch_end_outputs: STEP_OUTPUT ) -> None: From e1a40c080f43e09f1029006bdebb3228ab572c73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 14 Jun 2021 18:40:42 +0200 Subject: [PATCH 448/455] clarify what is todo and what is fixme --- pytorch_lightning/loops/training_batch_loop.py | 6 +++--- pytorch_lightning/loops/training_epoch_loop.py | 2 +- .../trainer/connectors/logger_connector/logger_connector.py | 1 - pytorch_lightning/trainer/properties.py | 2 +- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index 65b761f81bee1..c0a47ff660345 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -76,7 +76,7 @@ def reset(self) -> None: """Resets the loop state""" self._hiddens = None self.batch_idx = 0 - # TODO(@awaelchli): let loops track individual outputs + # TODO: let loops track individual outputs self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict: @@ -441,7 +441,7 @@ def _accumulated_batches_reached(self) -> bool: """ Determine if accumulation will be finished by the end of the current batch. """ - # TODO(@awaelchli): use progress tracking of batches instead of manual batch_idx + # FIXME(@awaelchli): use progress tracking of batches instead of manual batch_idx return (self.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 def _num_training_batches_reached(self, is_last_batch: bool = False) -> bool: @@ -450,7 +450,7 @@ def _num_training_batches_reached(self, is_last_batch: bool = False) -> bool: Args: is_last_batch: Whether the current batch is the last one """ - # TODO(@awaelchli): use progress tracking of batches instead of manual batch_idx + # FIXME(@awaelchli): use progress tracking of batches instead of manual batch_idx return (self.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch def should_accumulate(self) -> bool: diff --git a/pytorch_lightning/loops/training_epoch_loop.py b/pytorch_lightning/loops/training_epoch_loop.py index 500cf3bfe8c99..bd892e0eebca0 100644 --- a/pytorch_lightning/loops/training_epoch_loop.py +++ b/pytorch_lightning/loops/training_epoch_loop.py @@ -379,7 +379,7 @@ def should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: if self.trainer.should_stop: return True - # TODO(awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch + # TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch is_val_check_batch = is_last_batch if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset: is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 8223c56147a4a..9d838d6bf91ef 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -305,7 +305,6 @@ def progress_bar_metrics(self) -> Dict[str, float]: return self._progress_bar_metrics def teardown(self): - # TODO(@awaelchli): This should be handled by the loops themselves self.trainer.train_loop.results.cpu() self.trainer.evaluation_loop._val_results.cpu() self.trainer.evaluation_loop._test_results.cpu() diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 2f95a50219bfa..811d7eaa80291 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -485,7 +485,7 @@ def sanity_checking(self, val: bool) -> None: @property def train_loop(self) -> FitLoop: - # TODO(@awaelchli): the current train_loop should be renamed to fit_loop + # FIXME(@awaelchli): the current train_loop should be renamed to fit_loop return self.fit_loop @property From b5bb08a7c7c49dbaa7a3472c947c84d026433008 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 14 Jun 2021 16:42:26 +0000 Subject: [PATCH 449/455] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/training_epoch_loop.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/training_epoch_loop.py b/pytorch_lightning/loops/training_epoch_loop.py index bd892e0eebca0..75de0f17c6f4b 100644 --- a/pytorch_lightning/loops/training_epoch_loop.py +++ b/pytorch_lightning/loops/training_epoch_loop.py @@ -113,7 +113,9 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True) # hook - self.trainer.call_hook('on_train_batch_end', processed_batch_end_outputs, batch, self.iteration_count, self._dataloader_idx) + self.trainer.call_hook( + 'on_train_batch_end', processed_batch_end_outputs, batch, self.iteration_count, self._dataloader_idx + ) self.trainer.call_hook('on_batch_end') self.trainer.logger_connector.on_batch_end() From 5d98009704191d34515a38116e634a67ea42f930 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 14 Jun 2021 19:16:41 +0200 Subject: [PATCH 450/455] shorten a docstring --- pytorch_lightning/loops/training_batch_loop.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index c0a47ff660345..b581c6c8c1384 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -76,7 +76,6 @@ def reset(self) -> None: """Resets the loop state""" self._hiddens = None self.batch_idx = 0 - # TODO: let loops track individual outputs self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict: @@ -438,9 +437,7 @@ def track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Ten return grad_norm_dict def _accumulated_batches_reached(self) -> bool: - """ - Determine if accumulation will be finished by the end of the current batch. - """ + """Determine if accumulation will be finished by the end of the current batch.""" # FIXME(@awaelchli): use progress tracking of batches instead of manual batch_idx return (self.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 From f001f812421e01b669c8684e3415601157796148 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 14 Jun 2021 21:28:09 +0200 Subject: [PATCH 451/455] move on_epoch_start to on_run_start of training loop --- pytorch_lightning/loops/fit_loop.py | 5 ----- pytorch_lightning/loops/training_epoch_loop.py | 6 ++++++ 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 676597878ab50..4839ad9f7d460 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -193,11 +193,6 @@ def on_advance_start(self) -> None: window_length=self.trainer.accumulate_grad_batches ) - # hook - self.trainer.logger_connector.on_epoch_start() - self.trainer.call_hook("on_epoch_start") - self.trainer.call_hook("on_train_epoch_start") - def advance(self) -> None: """Runs one whole epoch.""" train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) diff --git a/pytorch_lightning/loops/training_epoch_loop.py b/pytorch_lightning/loops/training_epoch_loop.py index 75de0f17c6f4b..d029c525d71ac 100644 --- a/pytorch_lightning/loops/training_epoch_loop.py +++ b/pytorch_lightning/loops/training_epoch_loop.py @@ -86,6 +86,12 @@ def reset(self) -> None: # track epoch output self.epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] + def on_run_start(self, *args: Any, **kwargs: Any) -> None: + # hook + self.trainer.logger_connector.on_epoch_start() + self.trainer.call_hook("on_epoch_start") + self.trainer.call_hook("on_train_epoch_start") + def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: """Runs a single training batch. From d191fe1e1468b88cb769924455efbe04261ab1be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 15 Jun 2021 03:00:07 +0200 Subject: [PATCH 452/455] Update pytorch_lightning/loops/base.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- pytorch_lightning/loops/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index d1da76ba8b6ba..c44965faa9a8b 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -25,9 +25,9 @@ class Loop(ABC): """ Basic Loops interface. All classes derived from this must implement the following properties and methods: - * :attr`done` (property): Condition to break the loop + * :attr:`done` (property): Condition to break the loop * :attr:`reset` (method): Resets the internal state between multiple calls of :attr:`run` - * :attr`advance` (method): Implements one step of the loop + * :attr:`advance` (method): Implements one step of the loop This class implements the following loop structure: From 1d21065bc3d5243868956c6e0a4687ce7d83c718 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 15 Jun 2021 03:03:48 +0200 Subject: [PATCH 453/455] update class names in changelog --- CHANGELOG.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c4ab209c26a08..700f723252201 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -118,8 +118,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Simplified "should run validation" logic ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682)) * Simplified logic for updating the learning rate for schedulers ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682)) * Removed the `on_epoch` guard from the "should stop" validation check ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701)) - ... - * Refactored internal loop interface; added new classes `EpochLoop`, `TrainingLoop`, `BatchLoop` ([#7871](https://github.com/PyTorchLightning/pytorch-lightning/pull/7871)) + * Refactored internal loop interface; added new classes `FitLoop`, `TrainingEpochLoop`, `TrainingBatchLoop` ([#7871](https://github.com/PyTorchLightning/pytorch-lightning/pull/7871)) - Refactored logging * Renamed and moved `core/step_result.py` to `trainer/connectors/logger_connector/result.py` ([#7736](https://github.com/PyTorchLightning/pytorch-lightning/pull/7736)) From 2d8c441294786109f847f928407dead00fa82a5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 15 Jun 2021 13:19:12 +0200 Subject: [PATCH 454/455] add empty teardown method --- pytorch_lightning/loops/base.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index c44965faa9a8b..25b3795740840 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -82,7 +82,9 @@ def run(self, *args: Any, **kwargs: Any) -> Any: except StopIteration: break - return self.on_run_end() + output = self.on_run_end() + self.teardown() + return output def on_run_start(self, *args: Any, **kwargs: Any) -> None: """ @@ -107,3 +109,6 @@ def on_advance_end(self) -> None: def on_run_end(self) -> Any: """Hook to be called at the end of the run. Its return argument is returned from :attr:`run`.""" + + def teardown(self) -> None: + """The very last method called inside :meth:`run`. Use to release memory etc.""" From f87418222136b638b75ee16ec3412c7268b4fa44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 15 Jun 2021 13:32:29 +0200 Subject: [PATCH 455/455] added skip property --- pytorch_lightning/loops/base.py | 10 +++++++++- pytorch_lightning/loops/fit_loop.py | 14 +++++--------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 25b3795740840..b5a684832401d 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -52,6 +52,11 @@ def __init__(self) -> None: def done(self) -> bool: """Property indicating when loop is finished""" + @property + def skip(self) -> bool: + """Determine whether to return immediately from the call to :meth:`run`.""" + return False + def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: """Connects Loop with all the necessary things like connectors and accelerators.""" self.trainer = proxy(trainer) @@ -60,7 +65,7 @@ def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: def reset(self) -> None: """Resets the internal state of the loop at the beginning of each call to :attr:`run`.""" - def run(self, *args: Any, **kwargs: Any) -> Any: + def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: """ The main entry point to the loop. @@ -70,6 +75,9 @@ def run(self, *args: Any, **kwargs: Any) -> Any: Returns: the output of :attr`on_run_end` (often outputs collected from each step of the loop) """ + if self.skip: + return + self.reset() self.on_run_start(*args, **kwargs) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 4839ad9f7d460..2837de5ff6537 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -151,6 +151,11 @@ def done(self) -> bool: return stop_steps or should_stop or stop_epochs + @property + def skip(self) -> bool: + """Whether we should skip the training and immediately return from the call to :meth:`run`.""" + return self.done or self.trainer.num_training_batches == 0 + def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: """Connects the loop with necessary arguments like the trainer""" # TODO(@justusschock): Do we want to forward *args and **kwargs to the inner loop here? @@ -162,11 +167,6 @@ def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: def reset(self) -> None: """Resets the internal state of this loop""" - def run(self) -> None: - """Loops over epochs if the training should not be skipped""" - if not self._should_skip_training(): - return super().run() - def on_run_start(self) -> None: """Calls the ``on_train_start`` hook.""" self.results.to(device=self.trainer.lightning_module.device) @@ -262,10 +262,6 @@ def on_run_end(self) -> None: # reset bookkeeping self.trainer._running_stage = None - def _should_skip_training(self) -> bool: - """Whether we should skip the training""" - return self.done or self.trainer.num_training_batches == 0 - def should_accumulate(self) -> bool: """Whether the gradients should be accumulated""" return self.training_loop.batch_loop.should_accumulate()