Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: inner train loop (intermediate step) 8/n" #3367

Merged
merged 2 commits into from
Sep 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,9 +829,17 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
# gradient update with accumulated gradients
if ((self.batch_idx + 1) % self.accumulate_grad_batches == 0
or (self.batch_idx + 1) == self.num_training_batches):
# hook
grad_norm_dic = self.train_loop.on_before_backward(batch_idx, optimizer)

# optimizer step
self.train_loop.optimizer_step(optimizer, opt_idx, batch_idx, split_batch)

# hook
self.train_loop.on_before_zero_grad(optimizer)

# backward
grad_norm_dic = self.run_batch_backward_pass(split_batch, batch_idx, opt_idx, optimizer)
# clear gradients
self.train_loop.optimizer_zero_grad(batch_idx, optimizer, opt_idx)

# calculate running loss for display
self.running_loss.append(self.batch_loss_value.mean() * self.accumulate_grad_batches)
Expand All @@ -854,15 +862,6 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
)
return result

def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer):
# hook
grad_norm_dic = self.train_loop.on_before_backward(batch_idx, optimizer)

# optimizer step (TODO: decouple zero grad)
self.train_loop.optimizer_step(optimizer, opt_idx, batch_idx, split_batch)

return grad_norm_dic

def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens):
"""
wrap the forward step in a closure so second order methods work
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/training_loop_temp.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
def optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch):
# calls .step(), .zero_grad()
# override function to modify this behavior
model = self.trainer.get_model()

with self.trainer.profiler.profile('optimizer_step'):
lambda_closure = lambda: self.trainer.optimizer_closure(
Expand All @@ -231,11 +230,12 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch):
# optimizer step lightningModule hook
self.trainer.accelerator_backend.optimizer_step(optimizer, batch_idx, opt_idx, lambda_closure)

# hook
model.on_before_zero_grad(optimizer)
def on_before_zero_grad(self, optimizer):
model = self.trainer.get_model()
model.on_before_zero_grad(optimizer)

# clear gradients
self.trainer.accelerator_backend.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):
self.trainer.accelerator_backend.optimizer_zero_grad(batch_idx, optimizer, opt_idx)

def on_before_backward(self, batch_idx, optimizer):
# track gradient norms
Expand Down