Skip to content

Commit

Permalink
ref: inner train loop (intermediate step) 8/n" (#3367)
Browse files Browse the repository at this point in the history
* ref: inner train loop (intermediate step) 7/n

* ref: inner train loop (intermediate step) 8/n
  • Loading branch information
williamFalcon committed Sep 6, 2020
1 parent dcbfd09 commit df7e064
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 16 deletions.
21 changes: 10 additions & 11 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,9 +829,17 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
# gradient update with accumulated gradients
if ((self.batch_idx + 1) % self.accumulate_grad_batches == 0
or (self.batch_idx + 1) == self.num_training_batches):
# hook
grad_norm_dic = self.train_loop.on_before_backward(batch_idx, optimizer)

# optimizer step
self.train_loop.optimizer_step(optimizer, opt_idx, batch_idx, split_batch)

# hook
self.train_loop.on_before_zero_grad(optimizer)

# backward
grad_norm_dic = self.run_batch_backward_pass(split_batch, batch_idx, opt_idx, optimizer)
# clear gradients
self.train_loop.optimizer_zero_grad(batch_idx, optimizer, opt_idx)

# calculate running loss for display
self.running_loss.append(self.batch_loss_value.mean() * self.accumulate_grad_batches)
Expand All @@ -854,15 +862,6 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
)
return result

def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer):
# hook
grad_norm_dic = self.train_loop.on_before_backward(batch_idx, optimizer)

# optimizer step (TODO: decouple zero grad)
self.train_loop.optimizer_step(optimizer, opt_idx, batch_idx, split_batch)

return grad_norm_dic

def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens):
"""
wrap the forward step in a closure so second order methods work
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/training_loop_temp.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
def optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch):
# calls .step(), .zero_grad()
# override function to modify this behavior
model = self.trainer.get_model()

with self.trainer.profiler.profile('optimizer_step'):
lambda_closure = lambda: self.trainer.optimizer_closure(
Expand All @@ -231,11 +230,12 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch):
# optimizer step lightningModule hook
self.trainer.accelerator_backend.optimizer_step(optimizer, batch_idx, opt_idx, lambda_closure)

# hook
model.on_before_zero_grad(optimizer)
def on_before_zero_grad(self, optimizer):
model = self.trainer.get_model()
model.on_before_zero_grad(optimizer)

# clear gradients
self.trainer.accelerator_backend.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):
self.trainer.accelerator_backend.optimizer_zero_grad(batch_idx, optimizer, opt_idx)

def on_before_backward(self, batch_idx, optimizer):
# track gradient norms
Expand Down

0 comments on commit df7e064

Please sign in to comment.