Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes train outputs #2428

Merged
merged 2 commits into from
Jun 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,7 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
# PROCESS THE RESULT
# ----------------------------
# format and reduce outputs accordingly
training_step_output_for_epoch_end = training_step_output
training_step_output = self.process_output(training_step_output, train=True)

# TODO: temporary part of structured results PR
Expand All @@ -788,7 +789,7 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
)

# if the user decides to finally reduce things in epoch_end, save raw output without graphs
training_step_output_for_epoch_end = recursive_detach(training_step_output)
training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end)

# accumulate loss
# (if accumulate_grad_batches = 1 no effect)
Expand Down
16 changes: 9 additions & 7 deletions tests/base/deterministic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def training_step_dict_return(self, batch, batch_idx):
pbar = {'pbar_acc1': torch.tensor(17).type_as(acc), 'pbar_acc2': torch.tensor(19).type_as(acc)}

self.training_step_called = True
return {'loss': acc, 'log': logs, 'progress_bar': pbar}
return {'loss': acc, 'log': logs, 'progress_bar': pbar, 'train_step_test': torch.tensor(549).type_as(acc)}

def training_step_for_step_end_dict(self, batch, batch_idx):
"""sends outputs to training_batch_end"""
Expand Down Expand Up @@ -89,11 +89,11 @@ def training_step_end_dict(self, output):
assert 'pbar_acc1' in output
assert 'pbar_acc2' in output

logs = {'log_acc1': output['log_acc1'], 'log_acc2': output['log_acc2']}
pbar = {'pbar_acc1': output['pbar_acc1'], 'pbar_acc2': output['pbar_acc2']}
logs = {'log_acc1': output['log_acc1'] + 2, 'log_acc2': output['log_acc2'] + 2}
pbar = {'pbar_acc1': output['pbar_acc1'] + 2, 'pbar_acc2': output['pbar_acc2'] + 2}

acc = output['loss']
return {'loss': acc, 'log': logs, 'progress_bar': pbar}
return {'loss': acc, 'log': logs, 'progress_bar': pbar, 'train_step_end': acc}

def training_epoch_end_dict(self, outputs):
self.training_epoch_end_called = True
Expand All @@ -104,12 +104,14 @@ def training_epoch_end_dict(self, outputs):
# only saw 4 batches
assert len(outputs) == 4
for batch_out in outputs:
assert len(batch_out.keys()) == 5
keys = ['batch_loss', 'pbar_on_batch_end', 'log_metrics', 'callback_metrics']
assert len(batch_out.keys()) == 4
assert self.count_num_graphs(batch_out) == 0
last_key = 'train_step_end' if self.training_step_end_called else 'train_step_test'
keys = ['loss', 'log', 'progress_bar', last_key]
for key in keys:
assert key in batch_out

prototype_loss = outputs[0]['batch_loss']
prototype_loss = outputs[0]['loss']
logs = {'epoch_end_log_1': torch.tensor(178).type_as(prototype_loss)}
pbar = {'epoch_end_pbar_1': torch.tensor(234).type_as(prototype_loss)}

Expand Down
33 changes: 21 additions & 12 deletions tests/trainer/test_trainer_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ def test_training_step_dict(tmpdir):
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0

pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
train_step_out = out.training_step_output_for_epoch_end
pbar_metrics = train_step_out['progress_bar']
assert 'loss' in train_step_out
assert 'log' in train_step_out
assert 'progress_bar' in train_step_out
assert train_step_out['train_step_test'] == 549
assert pbar_metrics['pbar_acc1'] == 17.0
assert pbar_metrics['pbar_acc2'] == 19.0

Expand Down Expand Up @@ -59,12 +64,14 @@ def training_step_with_step_end(tmpdir):

out = trainer.run_training_batch(batch, batch_idx)
assert out.signal == 0
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0
assert out.batch_log_metrics['log_acc1'] == 14.0
assert out.batch_log_metrics['log_acc2'] == 9.0

pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
assert pbar_metrics['pbar_acc1'] == 17.0
assert pbar_metrics['pbar_acc2'] == 19.0
train_step_end_out = out.training_step_output_for_epoch_end
pbar_metrics = train_step_end_out['progress_bar']
assert 'train_step_end' in train_step_end_out
assert pbar_metrics['pbar_acc1'] == 19.0
assert pbar_metrics['pbar_acc2'] == 21.0


def test_full_training_loop_dict(tmpdir):
Expand Down Expand Up @@ -99,12 +106,13 @@ def test_full_training_loop_dict(tmpdir):

out = trainer.run_training_batch(batch, batch_idx)
assert out.signal == 0
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0
assert out.batch_log_metrics['log_acc1'] == 14.0
assert out.batch_log_metrics['log_acc2'] == 9.0

pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
assert pbar_metrics['pbar_acc1'] == 17.0
assert pbar_metrics['pbar_acc2'] == 19.0
train_step_end_out = out.training_step_output_for_epoch_end
pbar_metrics = train_step_end_out['progress_bar']
assert pbar_metrics['pbar_acc1'] == 19.0
assert pbar_metrics['pbar_acc2'] == 21.0


def test_train_step_epoch_end(tmpdir):
Expand Down Expand Up @@ -138,6 +146,7 @@ def test_train_step_epoch_end(tmpdir):
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0

pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
train_step_end_out = out.training_step_output_for_epoch_end
pbar_metrics = train_step_end_out['progress_bar']
assert pbar_metrics['pbar_acc1'] == 17.0
assert pbar_metrics['pbar_acc2'] == 19.0