Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write predictions in LightningModule instead of EvalResult #3882

Merged
merged 2 commits into from
Oct 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,13 @@ def log_dict(
tbptt_reduce_fx=tbptt_reduce_fx,
)

def write_prediction(self, name, value, filename='predictions.pt'):
self.trainer.evaluation_loop.predictions._add_prediction(name, value, filename)

def write_prediction_dict(self, predictions_dict, filename='predictions.pt'):
for k, v in predictions_dict.items():
self.write_prediction(k, v, filename)

Comment on lines +307 to +313
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this write out as a CSV file? do you think this makes more sense as a callback?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

writes to list of dictionaries via torch.save. making predictions is different for every use case, so this way lets user define how they want to do it.

Could be achieved with callback (in fact, that's how I originally was doing this), but that would only work for 1 use case.

def __auto_choose_log_on_step(self, on_step):
if on_step is None:
if self._current_fx_name in {'training_step', 'training_step_end'}:
Expand Down
33 changes: 18 additions & 15 deletions tests/base/model_test_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,38 +151,41 @@ def test_step_result_preds(self, batch, batch_idx, optimizer_idx=None):

# Base
if option == 0:
result.write('idxs', lazy_ids, prediction_file)
result.write('preds', labels_hat, prediction_file)
self.write_prediction('idxs', lazy_ids, prediction_file)
self.write_prediction('preds', labels_hat, prediction_file)

# Check mismatching tensor len
elif option == 1:
result.write('idxs', torch.cat((lazy_ids, lazy_ids)), prediction_file)
result.write('preds', labels_hat, prediction_file)
self.write_prediction('idxs', torch.cat((lazy_ids, lazy_ids)), prediction_file)
self.write_prediction('preds', labels_hat, prediction_file)

# write multi-dimension
elif option == 2:
result.write('idxs', lazy_ids, prediction_file)
result.write('preds', labels_hat, prediction_file)
result.write('x', x, prediction_file)
self.write_prediction('idxs', lazy_ids, prediction_file)
self.write_prediction('preds', labels_hat, prediction_file)
self.write_prediction('x', x, prediction_file)

# write str list
elif option == 3:
result.write('idxs', lazy_ids, prediction_file)
result.write('vals', lst_of_str, prediction_file)
self.write_prediction('idxs', lazy_ids, prediction_file)
self.write_prediction('vals', lst_of_str, prediction_file)

# write int list
elif option == 4:
result.write('idxs', lazy_ids, prediction_file)
result.write('vals', lst_of_int, prediction_file)
self.write_prediction('idxs', lazy_ids, prediction_file)
self.write_prediction('vals', lst_of_int, prediction_file)

# write nested list
elif option == 5:
result.write('idxs', lazy_ids, prediction_file)
result.write('vals', lst_of_lst, prediction_file)
self.write_prediction('idxs', lazy_ids, prediction_file)
self.write_prediction('vals', lst_of_lst, prediction_file)

# write dict list
elif option == 6:
result.write('idxs', lazy_ids, prediction_file)
result.write('vals', lst_of_dict, prediction_file)
self.write_prediction('idxs', lazy_ids, prediction_file)
self.write_prediction('vals', lst_of_dict, prediction_file)

elif option == 7:
self.write_prediction_dict({'idxs': lazy_ids, 'preds': labels_hat}, prediction_file)

return result
3 changes: 3 additions & 0 deletions tests/core/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def test_result_reduce_ddp(result_cls):
pytest.param(
6, False, 0, id='dict_list_predictions'
),
pytest.param(
7, True, 0, id='write_dict_predictions'
),
pytest.param(
0, True, 1, id='full_loop_single_gpu', marks=pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires single-GPU machine")
)
Expand Down