Skip to content

Commit

Permalink
ref: inner train loop (intermediate step) 9/n (#3368)
Browse files Browse the repository at this point in the history
* ref: inner train loop (intermediate step) 9/n

* ref: inner train loop (intermediate step) 9/n

* ref: inner train loop (intermediate step) 9/n

* ref: inner train loop (intermediate step) 9/n
  • Loading branch information
williamFalcon committed Sep 6, 2020
1 parent df7e064 commit b375a26
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 22 deletions.
25 changes: 3 additions & 22 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,31 +787,12 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
optimizer,
self.hiddens
)
using_results_obj = isinstance(opt_closure_result.training_step_output, Result)

# ------------------------------
# POST forward bookkeeping
# ------------------------------
batch_callback_metrics.append(opt_closure_result.training_step_output.callback_metrics)

# add metrics to loggers
if using_results_obj:
metrics_to_log = opt_closure_result.training_step_output.batch_log_metrics
step_pbar_metrics = opt_closure_result.training_step_output.batch_pbar_metrics
else:
metrics_to_log = opt_closure_result.training_step_output.log_metrics
step_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end

# track metrics
batch_log_metrics.append(metrics_to_log)
if len(step_pbar_metrics) > 0:
self.add_progress_bar_metrics(step_pbar_metrics)
# log metrics
self.train_loop.log_training_step_metrics(opt_closure_result, batch_callback_metrics, batch_log_metrics)

# track hiddens
self.hiddens = opt_closure_result.hiddens

if using_results_obj:
opt_closure_result.training_step_output_for_epoch_end.drop_hiddens()
self.hiddens = self.train_loop.process_hiddens(opt_closure_result)

# check if loss or model weights are nan
if self.terminate_on_nan:
Expand Down
27 changes: 27 additions & 0 deletions pytorch_lightning/trainer/training_loop_temp.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,30 @@ def _track_gradient_norm(self, batch_idx):
grad_norm_dic = model.grad_norm(
self.trainer.track_grad_norm)
return grad_norm_dic

def log_training_step_metrics(self, opt_closure_result, batch_callback_metrics, batch_log_metrics):
# track callback metrics
callback_metrics = opt_closure_result.training_step_output.callback_metrics
batch_callback_metrics.append(callback_metrics)

# decide which metrics to log (results vs dict return)
using_results_obj = isinstance(opt_closure_result.training_step_output, Result)
if using_results_obj:
metrics_to_log = opt_closure_result.training_step_output.batch_log_metrics
step_pbar_metrics = opt_closure_result.training_step_output.batch_pbar_metrics
else:
metrics_to_log = opt_closure_result.training_step_output.log_metrics
step_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end

# track batch log metrics
batch_log_metrics.append(metrics_to_log)

# track progress bar metrics
if len(step_pbar_metrics) > 0:
self.trainer.add_progress_bar_metrics(step_pbar_metrics)

def process_hiddens(self, opt_closure_result):
hiddens = opt_closure_result.hiddens
if isinstance(opt_closure_result.training_step_output, Result):
opt_closure_result.training_step_output_for_epoch_end.drop_hiddens()
return hiddens

0 comments on commit b375a26

Please sign in to comment.