From 6c643c14e9fd2410398aa6536cd39bb238acab46 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 5 Sep 2020 19:25:17 -0400 Subject: [PATCH 1/2] ref: inner train loop (intermediate step) 7/n --- pytorch_lightning/trainer/training_loop.py | 8 +++++++- pytorch_lightning/trainer/training_loop_temp.py | 10 +++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index dc84924f86769..9953a55f310ce 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -858,9 +858,15 @@ def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer): # hook grad_norm_dic = self.train_loop.on_before_backward(batch_idx, optimizer) - # optimizer step (TODO: decouple zero grad) + # optimizer step self.train_loop.optimizer_step(optimizer, opt_idx, batch_idx, split_batch) + # hook + self.train_loop.on_before_zero_grad(optimizer) + + # clear gradients + self.train_loop.optimizer_zero_grad(batch_idx, optimizer, opt_idx) + return grad_norm_dic def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): diff --git a/pytorch_lightning/trainer/training_loop_temp.py b/pytorch_lightning/trainer/training_loop_temp.py index 8adda56a6b3c4..8f77aa9d9e46d 100644 --- a/pytorch_lightning/trainer/training_loop_temp.py +++ b/pytorch_lightning/trainer/training_loop_temp.py @@ -217,7 +217,6 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): def optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch): # calls .step(), .zero_grad() # override function to modify this behavior - model = self.trainer.get_model() with self.trainer.profiler.profile('optimizer_step'): lambda_closure = lambda: self.trainer.optimizer_closure( @@ -231,11 +230,12 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch): # optimizer step lightningModule hook self.trainer.accelerator_backend.optimizer_step(optimizer, batch_idx, opt_idx, lambda_closure) - # hook - model.on_before_zero_grad(optimizer) + def on_before_zero_grad(self, optimizer): + model = self.trainer.get_model() + model.on_before_zero_grad(optimizer) - # clear gradients - self.trainer.accelerator_backend.optimizer_zero_grad(batch_idx, optimizer, opt_idx) + def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): + self.trainer.accelerator_backend.optimizer_zero_grad(batch_idx, optimizer, opt_idx) def on_before_backward(self, batch_idx, optimizer): # track gradient norms From aa8904a87a8b545d72a0752828631c67c65cd795 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 5 Sep 2020 19:26:27 -0400 Subject: [PATCH 2/2] ref: inner train loop (intermediate step) 8/n --- pytorch_lightning/trainer/training_loop.py | 27 ++++++++-------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 9953a55f310ce..5655b4a7f8939 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -829,9 +829,17 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): # gradient update with accumulated gradients if ((self.batch_idx + 1) % self.accumulate_grad_batches == 0 or (self.batch_idx + 1) == self.num_training_batches): + # hook + grad_norm_dic = self.train_loop.on_before_backward(batch_idx, optimizer) + + # optimizer step + self.train_loop.optimizer_step(optimizer, opt_idx, batch_idx, split_batch) + + # hook + self.train_loop.on_before_zero_grad(optimizer) - # backward - grad_norm_dic = self.run_batch_backward_pass(split_batch, batch_idx, opt_idx, optimizer) + # clear gradients + self.train_loop.optimizer_zero_grad(batch_idx, optimizer, opt_idx) # calculate running loss for display self.running_loss.append(self.batch_loss_value.mean() * self.accumulate_grad_batches) @@ -854,21 +862,6 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): ) return result - def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer): - # hook - grad_norm_dic = self.train_loop.on_before_backward(batch_idx, optimizer) - - # optimizer step - self.train_loop.optimizer_step(optimizer, opt_idx, batch_idx, split_batch) - - # hook - self.train_loop.on_before_zero_grad(optimizer) - - # clear gradients - self.train_loop.optimizer_zero_grad(batch_idx, optimizer, opt_idx) - - return grad_norm_dic - def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): """ wrap the forward step in a closure so second order methods work