Skip to content

Commit

Permalink
Add support for using run_model_on_task simply (#888)
Browse files Browse the repository at this point in the history
* Add support for using run_model_on_task simply

* Add unit test

* fix mypy error
  • Loading branch information
m7142yosuke authored and mfeurer committed Nov 22, 2019
1 parent 2b7e740 commit d5e46fe
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 16 deletions.
19 changes: 14 additions & 5 deletions openml/runs/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
OpenMLRegressionTask, OpenMLSupervisedTask, OpenMLLearningCurveTask
from .run import OpenMLRun
from .trace import OpenMLRunTrace
from ..tasks import TaskTypeEnum
from ..tasks import TaskTypeEnum, get_task

# Avoid import cycles: https://mypy.readthedocs.io/en/latest/common_issues.html#import-cycles
if TYPE_CHECKING:
Expand All @@ -38,7 +38,7 @@

def run_model_on_task(
model: Any,
task: OpenMLTask,
task: Union[int, str, OpenMLTask],
avoid_duplicate_runs: bool = True,
flow_tags: List[str] = None,
seed: int = None,
Expand All @@ -54,8 +54,9 @@ def run_model_on_task(
A model which has a function fit(X,Y) and predict(X),
all supervised estimators of scikit learn follow this definition of a model [1]
[1](http://scikit-learn.org/stable/tutorial/statistical_inference/supervised_learning.html)
task : OpenMLTask
Task to perform. This may be a model instead if the first argument is an OpenMLTask.
task : OpenMLTask or int or str
Task to perform or Task id.
This may be a model instead if the first argument is an OpenMLTask.
avoid_duplicate_runs : bool, optional (default=True)
If True, the run will throw an error if the setup/task combination is already present on
the server. This feature requires an internet connection.
Expand Down Expand Up @@ -84,7 +85,7 @@ def run_model_on_task(
# Flexibility currently still allowed due to code-snippet in OpenML100 paper (3-2019).
# When removing this please also remove the method `is_estimator` from the extension
# interface as it is only used here (MF, 3-2019)
if isinstance(model, OpenMLTask):
if isinstance(model, (int, str, OpenMLTask)):
warnings.warn("The old argument order (task, model) is deprecated and "
"will not be supported in the future. Please use the "
"order (model, task).", DeprecationWarning)
Expand All @@ -98,6 +99,14 @@ def run_model_on_task(

flow = extension.model_to_flow(model)

def get_task_and_type_conversion(task: Union[int, str, OpenMLTask]) -> OpenMLTask:
if isinstance(task, (int, str)):
return get_task(int(task))
else:
return task

task = get_task_and_type_conversion(task)

run = run_flow_on_task(
task=task,
flow=flow,
Expand Down
38 changes: 27 additions & 11 deletions tests/test_runs/test_run_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ def _compare_predictions(self, predictions, predictions_prime):

return True

def _rerun_model_and_compare_predictions(self, run_id, model_prime, seed):
def _rerun_model_and_compare_predictions(self, run_id, model_prime, seed,
create_task_obj):
run = openml.runs.get_run(run_id)
task = openml.tasks.get_task(run.task_id)

# TODO: assert holdout task

Expand All @@ -121,12 +121,24 @@ def _rerun_model_and_compare_predictions(self, run_id, model_prime, seed):
predictions_url = openml._api_calls._file_id_to_url(file_id)
response = openml._api_calls._download_text_file(predictions_url)
predictions = arff.loads(response)
run_prime = openml.runs.run_model_on_task(
model=model_prime,
task=task,
avoid_duplicate_runs=False,
seed=seed,
)

# if create_task_obj=False, task argument in run_model_on_task is specified task_id
if create_task_obj:
task = openml.tasks.get_task(run.task_id)
run_prime = openml.runs.run_model_on_task(
model=model_prime,
task=task,
avoid_duplicate_runs=False,
seed=seed,
)
else:
run_prime = openml.runs.run_model_on_task(
model=model_prime,
task=run.task_id,
avoid_duplicate_runs=False,
seed=seed,
)

predictions_prime = run_prime._generate_arff_dict()

self._compare_predictions(predictions, predictions_prime)
Expand Down Expand Up @@ -425,13 +437,17 @@ def determine_grid_size(param_grid):
raise e

self._rerun_model_and_compare_predictions(run.run_id, model_prime,
seed)
seed, create_task_obj=True)
self._rerun_model_and_compare_predictions(run.run_id, model_prime,
seed, create_task_obj=False)
else:
run_downloaded = openml.runs.get_run(run.run_id)
sid = run_downloaded.setup_id
model_prime = openml.setups.initialize_model(sid)
self._rerun_model_and_compare_predictions(run.run_id,
model_prime, seed)
self._rerun_model_and_compare_predictions(run.run_id, model_prime,
seed, create_task_obj=True)
self._rerun_model_and_compare_predictions(run.run_id, model_prime,
seed, create_task_obj=False)

# todo: check if runtime is present
self._check_fold_timing_evaluations(run.fold_evaluations, 1, num_folds,
Expand Down

0 comments on commit d5e46fe

Please sign in to comment.