diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index dc84924f86769..5655b4a7f8939 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -829,9 +829,17 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): # gradient update with accumulated gradients if ((self.batch_idx + 1) % self.accumulate_grad_batches == 0 or (self.batch_idx + 1) == self.num_training_batches): + # hook + grad_norm_dic = self.train_loop.on_before_backward(batch_idx, optimizer) + + # optimizer step + self.train_loop.optimizer_step(optimizer, opt_idx, batch_idx, split_batch) + + # hook + self.train_loop.on_before_zero_grad(optimizer) - # backward - grad_norm_dic = self.run_batch_backward_pass(split_batch, batch_idx, opt_idx, optimizer) + # clear gradients + self.train_loop.optimizer_zero_grad(batch_idx, optimizer, opt_idx) # calculate running loss for display self.running_loss.append(self.batch_loss_value.mean() * self.accumulate_grad_batches) @@ -854,15 +862,6 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): ) return result - def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer): - # hook - grad_norm_dic = self.train_loop.on_before_backward(batch_idx, optimizer) - - # optimizer step (TODO: decouple zero grad) - self.train_loop.optimizer_step(optimizer, opt_idx, batch_idx, split_batch) - - return grad_norm_dic - def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): """ wrap the forward step in a closure so second order methods work diff --git a/pytorch_lightning/trainer/training_loop_temp.py b/pytorch_lightning/trainer/training_loop_temp.py index 8adda56a6b3c4..8f77aa9d9e46d 100644 --- a/pytorch_lightning/trainer/training_loop_temp.py +++ b/pytorch_lightning/trainer/training_loop_temp.py @@ -217,7 +217,6 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): def optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch): # calls .step(), .zero_grad() # override function to modify this behavior - model = self.trainer.get_model() with self.trainer.profiler.profile('optimizer_step'): lambda_closure = lambda: self.trainer.optimizer_closure( @@ -231,11 +230,12 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch): # optimizer step lightningModule hook self.trainer.accelerator_backend.optimizer_step(optimizer, batch_idx, opt_idx, lambda_closure) - # hook - model.on_before_zero_grad(optimizer) + def on_before_zero_grad(self, optimizer): + model = self.trainer.get_model() + model.on_before_zero_grad(optimizer) - # clear gradients - self.trainer.accelerator_backend.optimizer_zero_grad(batch_idx, optimizer, opt_idx) + def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): + self.trainer.accelerator_backend.optimizer_zero_grad(batch_idx, optimizer, opt_idx) def on_before_backward(self, batch_idx, optimizer): # track gradient norms