Skip to content

Commit

Permalink
Adding support for scikit-learn > 0.22 (#936)
Browse files Browse the repository at this point in the history
* Preliminary changes

* Updating unit tests for sklearn 0.22 and above

* Triggering sklearn tests + fixes

* Refactoring to inspect.signature in extensions
  • Loading branch information
Neeratyoy authored Aug 3, 2020
1 parent 9c93f5b commit 666ca68
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 91 deletions.
6 changes: 5 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@ env:
- TEST_DIR=/tmp/test_dir/
- MODULE=openml
matrix:
- DISTRIB="conda" PYTHON_VERSION="3.6" SKLEARN_VERSION="0.21.2" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.21.2" RUN_FLAKE8="true" SKIP_TESTS="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.21.2" COVERAGE="true" DOCPUSH="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.23.1" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.6" SKLEARN_VERSION="0.23.1" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.22.2" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.6" SKLEARN_VERSION="0.22.2" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.6" SKLEARN_VERSION="0.21.2" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.20.2"
# Checks for older scikit-learn versions (which also don't nicely work with
# Python3.7)
Expand Down
18 changes: 11 additions & 7 deletions openml/extensions/sklearn/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,12 +994,16 @@ def _get_fn_arguments_with_defaults(self, fn_name: Callable) -> Tuple[Dict, Set]
a set with all parameters that do not have a default value
"""
# parameters with defaults are optional, all others are required.
signature = inspect.getfullargspec(fn_name)
if signature.defaults:
optional_params = dict(zip(reversed(signature.args), reversed(signature.defaults)))
else:
optional_params = dict()
required_params = {arg for arg in signature.args if arg not in optional_params}
parameters = inspect.signature(fn_name).parameters
required_params = set()
optional_params = dict()
for param in parameters.keys():
parameter = parameters.get(param)
default_val = parameter.default # type: ignore
if default_val is inspect.Signature.empty:
required_params.add(param)
else:
optional_params[param] = default_val
return optional_params, required_params

def _deserialize_model(
Expand Down Expand Up @@ -1346,7 +1350,7 @@ def _can_measure_cputime(self, model: Any) -> bool:
# check the parameters for n_jobs
n_jobs_vals = SklearnExtension._get_parameter_values_recursive(model.get_params(), "n_jobs")
for val in n_jobs_vals:
if val is not None and val != 1:
if val is not None and val != 1 and val != "deprecated":
return False
return True

Expand Down
Loading

0 comments on commit 666ca68

Please sign in to comment.