Skip to content

Commit

Permalink
Update training_loop.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ananthsub committed Apr 15, 2021
1 parent 5bd3cd5 commit 77faf1c
Showing 1 changed file with 35 additions and 34 deletions.
69 changes: 35 additions & 34 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@

from contextlib import contextmanager, suppress
from copy import copy, deepcopy
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch

from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.plugins import ParallelPlugin
Expand All @@ -28,8 +29,8 @@
from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, parsing
from pytorch_lightning.utilities.distributed import rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.warnings import WarningCache

Expand Down Expand Up @@ -80,16 +81,15 @@ def on_trainer_init(
self.trainer.num_sanity_val_steps = num_sanity_val_steps

@property
def num_optimizers(self):
num_optimizers = len(self.get_optimizers_iterable())
return num_optimizers
def num_optimizers(self) -> int:
return len(self.get_optimizers_iterable())

def should_skip_training(self):
def should_skip_training(self) -> bool:
should_by_max_steps = self.trainer.max_steps is not None and self.trainer.global_step >= self.trainer.max_steps
should_by_epoch = self.trainer.max_epochs is not None and self.trainer.current_epoch >= self.trainer.max_epochs
return should_by_max_steps or should_by_epoch or self.trainer.num_training_batches == 0

def on_train_start(self):
def on_train_start(self) -> None:
# hook
self.trainer.call_hook("on_train_start")

Expand All @@ -107,7 +107,7 @@ def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodu
# attach model log function to callback
self.trainer.callback_connector.attach_model_logging_functions(model)

def on_train_end(self):
def on_train_end(self) -> None:
if self._teardown_already_run:
return
self._teardown_already_run = True
Expand Down Expand Up @@ -136,7 +136,7 @@ def on_train_end(self):
# reset bookkeeping
self.trainer._running_stage = None

def check_checkpoint_callback(self, should_update, is_last=False):
def check_checkpoint_callback(self, should_update: bool, is_last: bool = False) -> None:
# TODO bake this logic into the ModelCheckpoint callback
if should_update and self.trainer.checkpoint_connector.has_trained:
callbacks = self.trainer.checkpoint_callbacks
Expand All @@ -149,7 +149,7 @@ def check_checkpoint_callback(self, should_update, is_last=False):
for cb in callbacks:
cb.on_validation_end(self.trainer, model)

def check_early_stopping_callback(self, should_update):
def check_early_stopping_callback(self, should_update: bool) -> None:
# TODO bake this logic into the EarlyStopping callback
if should_update and self.trainer.checkpoint_connector.has_trained:
callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)]
Expand All @@ -158,7 +158,7 @@ def check_early_stopping_callback(self, should_update):
for cb in callbacks:
cb.on_validation_end(self.trainer, model)

def on_train_epoch_start(self, epoch):
def on_train_epoch_start(self, epoch: int) -> None:

# update training progress in trainer
self.trainer.current_epoch = epoch
Expand All @@ -184,7 +184,9 @@ def on_train_epoch_start(self, epoch):
self.trainer.call_hook("on_epoch_start")
self.trainer.call_hook("on_train_epoch_start")

def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(
self, epoch_output, batch_end_outputs, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
batch_end_outputs = [opt_idx_out for opt_idx_out in batch_end_outputs if len(opt_idx_out)]

processed_batch_end_outputs = TrainLoop._prepare_outputs(batch_end_outputs, batch_mode=True)
Expand All @@ -199,14 +201,14 @@ def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx,
# reset batch logger internals
self.trainer.logger_connector.on_train_batch_end()

def reset_train_val_dataloaders(self, model):
def reset_train_val_dataloaders(self, model: LightningModule) -> None:
if self.trainer.train_dataloader is None or not self.trainer.reload_dataloaders_every_epoch:
self.trainer.reset_train_dataloader(model)

if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch:
self.trainer.reset_val_dataloader(model)

def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs):
def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs) -> None:

# track the outputs to reduce at the end of the epoch
for opt_idx, opt_outputs in enumerate(batch_end_outputs):
Expand All @@ -229,7 +231,7 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs):

epoch_output[opt_idx].append(opt_outputs)

def get_optimizers_iterable(self):
def get_optimizers_iterable(self) -> List[Tuple[int, Optimizer]]:
"""
Generates an iterable with (idx, optimizer) for each optimizer.
"""
Expand Down Expand Up @@ -348,8 +350,8 @@ def _process_training_step_output(self, training_step_output, split_batch):

@staticmethod
def _prepare_outputs(
outputs: List[List[List[Result]]],
batch_mode: bool,
outputs: List[List[List[Result]]],
batch_mode: bool,
) -> Union[List[List[List[Dict]]], List[List[Dict]], List[Dict], Dict]:
"""
Extract required information from batch or epoch end results.
Expand Down Expand Up @@ -438,8 +440,7 @@ def track_and_norm_grad(self, optimizer):

# clip gradients
self.trainer.accelerator.clip_gradients(
optimizer, self.trainer.gradient_clip_val,
gradient_clip_algorithm=self.trainer.gradient_clip_algorithm
optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm
)
self._cur_grad_norm_dict = grad_norm_dic

Expand All @@ -451,15 +452,15 @@ def _track_gradient_norm(self):
grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm)
return grad_norm_dict

def tbptt_split_batch(self, batch):
def tbptt_split_batch(self, batch) -> list:
splits = [batch]
if self.trainer.truncated_bptt_steps is not None:
model_ref = self.trainer.lightning_module
with self.trainer.profiler.profile("tbptt_split_batch"):
splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps)
return splits

def run_training_epoch(self):
def run_training_epoch(self) -> None:
# modify dataloader if needed (ddp, etc...)
train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader)

Expand Down Expand Up @@ -599,7 +600,7 @@ def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None:
self.trainer.call_hook('on_train_epoch_end', processed_epoch_output)
self.trainer.call_hook('on_epoch_end')

def run_training_batch(self, batch, batch_idx, dataloader_idx):
def run_training_batch(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
# track grad norms
grad_norm_dic = {}

Expand Down Expand Up @@ -748,7 +749,7 @@ def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list:

return batch_outputs

def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens):
def training_step_and_backward(self, split_batch, batch_idx: int, opt_idx: int, optimizer, hiddens):
"""
wrap the forward step in a closure so second order methods work
"""
Expand Down Expand Up @@ -793,7 +794,7 @@ def _check_finite(self, loss: torch.Tensor) -> None:
model = self.trainer.lightning_module
detect_nan_parameters(model)

def backward(self, result, optimizer, opt_idx, *args, **kwargs):
def backward(self, result, optimizer, opt_idx: int, *args, **kwargs) -> None:
self.trainer.dev_debugger.track_event("backward_call")

should_accumulate = self.should_accumulate()
Expand All @@ -810,15 +811,15 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs):
# track gradients
self.track_and_norm_grad(optimizer=optimizer)

def update_train_loop_lr_schedulers(self, monitor_metrics=None):
def update_train_loop_lr_schedulers(self, monitor_metrics: Optional[dict] = None) -> None:
num_accumulated_batches_reached = self._accumulated_batches_reached()
num_training_batches_reached = self._num_training_batches_reached()

if num_accumulated_batches_reached or num_training_batches_reached:
# update lr
self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics)

def increment_accumulated_grad_global_step(self):
def increment_accumulated_grad_global_step(self) -> None:
num_accumulated_batches_reached = self._accumulated_batches_reached()
num_training_batches_reached = self._num_training_batches_reached()

Expand All @@ -828,19 +829,19 @@ def increment_accumulated_grad_global_step(self):
self.trainer.total_batch_idx, self.trainer.global_step
)

def _accumulated_batches_reached(self):
def _accumulated_batches_reached(self) -> bool:
return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0

def _num_training_batches_reached(self, is_last_batch=False):
def _num_training_batches_reached(self, is_last_batch: bool = False) -> bool:
return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch

def should_accumulate(self):
def should_accumulate(self) -> bool:
# checks if backward or backward + optimizer step (via closure)
accumulation_done = self._accumulated_batches_reached()
is_final_batch = self._num_training_batches_reached()
return not (accumulation_done or is_final_batch)

def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False):
def should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool:
# decide if we should run validation
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
Expand All @@ -854,7 +855,7 @@ def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False):

return should_check_val and can_check_val

def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
def build_train_args(self, batch: Any, batch_idx: int, opt_idx: int, hiddens) -> list:
# enable not needing to add opt_idx to training_step
args = [batch, batch_idx]

Expand All @@ -879,7 +880,7 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens):

return args

def save_loggers_on_train_batch_end(self):
def save_loggers_on_train_batch_end(self) -> None:
# when loggers should save to disk
should_flush_logs = self.trainer.logger_connector.should_flush_logs
if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None:
Expand All @@ -892,7 +893,7 @@ def prepare_optimizers(self):
optimizers = [optimizers[0]]
return optimizers

def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer):
def run_train_split_start(self, split_idx: int, split_batch, opt_idx: int, optimizer) -> None:
# set split_idx to trainer for tracking
self.trainer.split_idx = split_idx

Expand All @@ -905,7 +906,7 @@ def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer):
# use to track metrics internally
self.trainer.logger_connector.on_train_split_start(split_idx, opt_idx, split_batch)

def update_running_loss(self):
def update_running_loss(self) -> None:
accumulated_loss = self.accumulated_loss.mean()

if accumulated_loss is not None:
Expand Down

0 comments on commit 77faf1c

Please sign in to comment.