Skip to content

Commit

Permalink
Predictions (openml#1128)
Browse files Browse the repository at this point in the history
* Add easy way to retrieve run predictions

* Log addition of ``predictions`` (openml#1103)
  • Loading branch information
PGijsbers committed Feb 23, 2023
1 parent f9acefe commit b287fb9
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 1 deletion.
2 changes: 1 addition & 1 deletion doc/progress.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Changelog
* FIX#1030: ``pre-commit`` hooks now no longer should issue a warning.
* FIX#1110: Make arguments to ``create_study`` and ``create_suite`` that are defined as optional by the OpenML XSD actually optional.
* MAIN#1088: Do CI for Windows on Github Actions instead of Appveyor.

* ADD#1103: Add a ``predictions`` property to OpenMLRun for easy accessibility of prediction data.


0.12.2
Expand Down
18 changes: 18 additions & 0 deletions openml/runs/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import arff
import numpy as np
import pandas as pd

import openml
import openml._api_calls
Expand Down Expand Up @@ -116,6 +117,23 @@ def __init__(
self.predictions_url = predictions_url
self.description_text = description_text
self.run_details = run_details
self._predictions = None

@property
def predictions(self) -> pd.DataFrame:
""" Return a DataFrame with predictions for this run """
if self._predictions is None:
if self.data_content:
arff_dict = self._generate_arff_dict()
elif self.predictions_url:
arff_text = openml._api_calls._download_text_file(self.predictions_url)
arff_dict = arff.loads(arff_text)
else:
raise RuntimeError("Run has no predictions.")
self._predictions = pd.DataFrame(
arff_dict["data"], columns=[name for name, _ in arff_dict["attributes"]]
)
return self._predictions

@property
def id(self) -> Optional[int]:
Expand Down
1 change: 1 addition & 0 deletions tests/test_runs/test_run_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def _rerun_model_and_compare_predictions(self, run_id, model_prime, seed, create
predictions_prime = run_prime._generate_arff_dict()

self._compare_predictions(predictions, predictions_prime)
pd.testing.assert_frame_equal(run.predictions, run_prime.predictions)

def _perform_run(
self,
Expand Down

0 comments on commit b287fb9

Please sign in to comment.