Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for scikit-learn > 0.22 #936

Merged
merged 5 commits into from
Aug 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@ env:
- TEST_DIR=/tmp/test_dir/
- MODULE=openml
matrix:
- DISTRIB="conda" PYTHON_VERSION="3.6" SKLEARN_VERSION="0.21.2" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.21.2" RUN_FLAKE8="true" SKIP_TESTS="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.21.2" COVERAGE="true" DOCPUSH="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.23.1" TEST_DIST="true"
mfeurer marked this conversation as resolved.
Show resolved Hide resolved
- DISTRIB="conda" PYTHON_VERSION="3.6" SKLEARN_VERSION="0.23.1" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.22.2" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.6" SKLEARN_VERSION="0.22.2" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.6" SKLEARN_VERSION="0.21.2" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.20.2"
# Checks for older scikit-learn versions (which also don't nicely work with
# Python3.7)
Expand Down
18 changes: 11 additions & 7 deletions openml/extensions/sklearn/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,12 +994,16 @@ def _get_fn_arguments_with_defaults(self, fn_name: Callable) -> Tuple[Dict, Set]
a set with all parameters that do not have a default value
"""
# parameters with defaults are optional, all others are required.
signature = inspect.getfullargspec(fn_name)
if signature.defaults:
optional_params = dict(zip(reversed(signature.args), reversed(signature.defaults)))
else:
optional_params = dict()
required_params = {arg for arg in signature.args if arg not in optional_params}
parameters = inspect.signature(fn_name).parameters
required_params = set()
optional_params = dict()
for param in parameters.keys():
parameter = parameters.get(param)
default_val = parameter.default # type: ignore
if default_val is inspect.Signature.empty:
required_params.add(param)
else:
optional_params[param] = default_val
return optional_params, required_params

def _deserialize_model(
Expand Down Expand Up @@ -1346,7 +1350,7 @@ def _can_measure_cputime(self, model: Any) -> bool:
# check the parameters for n_jobs
n_jobs_vals = SklearnExtension._get_parameter_values_recursive(model.get_params(), "n_jobs")
for val in n_jobs_vals:
if val is not None and val != 1:
if val is not None and val != 1 and val != "deprecated":
mfeurer marked this conversation as resolved.
Show resolved Hide resolved
return False
return True

Expand Down
Loading