Skip to content

Commit

Permalink
Fixes train outputs (#2428)
Browse files Browse the repository at this point in the history
* fix outputs

* fix outputs
  • Loading branch information
williamFalcon committed Jun 30, 2020
1 parent a753985 commit a42a0e1
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 20 deletions.
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,7 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
# PROCESS THE RESULT
# ----------------------------
# format and reduce outputs accordingly
training_step_output_for_epoch_end = training_step_output
training_step_output = self.process_output(training_step_output, train=True)

# TODO: temporary part of structured results PR
Expand All @@ -788,7 +789,7 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
)

# if the user decides to finally reduce things in epoch_end, save raw output without graphs
training_step_output_for_epoch_end = recursive_detach(training_step_output)
training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end)

# accumulate loss
# (if accumulate_grad_batches = 1 no effect)
Expand Down
16 changes: 9 additions & 7 deletions tests/base/deterministic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def training_step_dict_return(self, batch, batch_idx):
pbar = {'pbar_acc1': torch.tensor(17).type_as(acc), 'pbar_acc2': torch.tensor(19).type_as(acc)}

self.training_step_called = True
return {'loss': acc, 'log': logs, 'progress_bar': pbar}
return {'loss': acc, 'log': logs, 'progress_bar': pbar, 'train_step_test': torch.tensor(549).type_as(acc)}

def training_step_for_step_end_dict(self, batch, batch_idx):
"""sends outputs to training_batch_end"""
Expand Down Expand Up @@ -89,11 +89,11 @@ def training_step_end_dict(self, output):
assert 'pbar_acc1' in output
assert 'pbar_acc2' in output

logs = {'log_acc1': output['log_acc1'], 'log_acc2': output['log_acc2']}
pbar = {'pbar_acc1': output['pbar_acc1'], 'pbar_acc2': output['pbar_acc2']}
logs = {'log_acc1': output['log_acc1'] + 2, 'log_acc2': output['log_acc2'] + 2}
pbar = {'pbar_acc1': output['pbar_acc1'] + 2, 'pbar_acc2': output['pbar_acc2'] + 2}

acc = output['loss']
return {'loss': acc, 'log': logs, 'progress_bar': pbar}
return {'loss': acc, 'log': logs, 'progress_bar': pbar, 'train_step_end': acc}

def training_epoch_end_dict(self, outputs):
self.training_epoch_end_called = True
Expand All @@ -104,12 +104,14 @@ def training_epoch_end_dict(self, outputs):
# only saw 4 batches
assert len(outputs) == 4
for batch_out in outputs:
assert len(batch_out.keys()) == 5
keys = ['batch_loss', 'pbar_on_batch_end', 'log_metrics', 'callback_metrics']
assert len(batch_out.keys()) == 4
assert self.count_num_graphs(batch_out) == 0
last_key = 'train_step_end' if self.training_step_end_called else 'train_step_test'
keys = ['loss', 'log', 'progress_bar', last_key]
for key in keys:
assert key in batch_out

prototype_loss = outputs[0]['batch_loss']
prototype_loss = outputs[0]['loss']
logs = {'epoch_end_log_1': torch.tensor(178).type_as(prototype_loss)}
pbar = {'epoch_end_pbar_1': torch.tensor(234).type_as(prototype_loss)}

Expand Down
33 changes: 21 additions & 12 deletions tests/trainer/test_trainer_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ def test_training_step_dict(tmpdir):
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0

pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
train_step_out = out.training_step_output_for_epoch_end
pbar_metrics = train_step_out['progress_bar']
assert 'loss' in train_step_out
assert 'log' in train_step_out
assert 'progress_bar' in train_step_out
assert train_step_out['train_step_test'] == 549
assert pbar_metrics['pbar_acc1'] == 17.0
assert pbar_metrics['pbar_acc2'] == 19.0

Expand Down Expand Up @@ -59,12 +64,14 @@ def training_step_with_step_end(tmpdir):

out = trainer.run_training_batch(batch, batch_idx)
assert out.signal == 0
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0
assert out.batch_log_metrics['log_acc1'] == 14.0
assert out.batch_log_metrics['log_acc2'] == 9.0

pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
assert pbar_metrics['pbar_acc1'] == 17.0
assert pbar_metrics['pbar_acc2'] == 19.0
train_step_end_out = out.training_step_output_for_epoch_end
pbar_metrics = train_step_end_out['progress_bar']
assert 'train_step_end' in train_step_end_out
assert pbar_metrics['pbar_acc1'] == 19.0
assert pbar_metrics['pbar_acc2'] == 21.0


def test_full_training_loop_dict(tmpdir):
Expand Down Expand Up @@ -99,12 +106,13 @@ def test_full_training_loop_dict(tmpdir):

out = trainer.run_training_batch(batch, batch_idx)
assert out.signal == 0
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0
assert out.batch_log_metrics['log_acc1'] == 14.0
assert out.batch_log_metrics['log_acc2'] == 9.0

pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
assert pbar_metrics['pbar_acc1'] == 17.0
assert pbar_metrics['pbar_acc2'] == 19.0
train_step_end_out = out.training_step_output_for_epoch_end
pbar_metrics = train_step_end_out['progress_bar']
assert pbar_metrics['pbar_acc1'] == 19.0
assert pbar_metrics['pbar_acc2'] == 21.0


def test_train_step_epoch_end(tmpdir):
Expand Down Expand Up @@ -138,6 +146,7 @@ def test_train_step_epoch_end(tmpdir):
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0

pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
train_step_end_out = out.training_step_output_for_epoch_end
pbar_metrics = train_step_end_out['progress_bar']
assert pbar_metrics['pbar_acc1'] == 17.0
assert pbar_metrics['pbar_acc2'] == 19.0

0 comments on commit a42a0e1

Please sign in to comment.