Skip to content

Commit

Permalink
Include the necessary scikit-learn patch
Browse files Browse the repository at this point in the history
  • Loading branch information
timokau committed Oct 1, 2020
1 parent bd5c46f commit 0c7ca20
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 0 deletions.
46 changes: 46 additions & 0 deletions 0001-Relax-init-parameter-type-checks.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
From 48a6dd8f0e7ff52d116b0c01e36dac4504547781 Mon Sep 17 00:00:00 2001
From: Timo Kaufmann <timokau@zoho.com>
Date: Thu, 16 Jul 2020 16:57:14 +0200
Subject: [PATCH] Relax init parameter type checks

We now allow any "type" (uninitialized classes) and all numeric numpy
types. See https://github.com/scikit-learn/scikit-learn/issues/17756 for
a discussion.
---
sklearn/utils/estimator_checks.py | 20 ++++++++++++++------
1 file changed, 14 insertions(+), 6 deletions(-)

diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py
index 30c668237..ed3c40d67 100644
--- a/sklearn/utils/estimator_checks.py
+++ b/sklearn/utils/estimator_checks.py
@@ -2529,12 +2529,20 @@ def check_parameters_default_constructible(name, Estimator):
assert init_param.default != init_param.empty, (
"parameter %s for %s has no default value"
% (init_param.name, type(estimator).__name__))
- if type(init_param.default) is type:
- assert init_param.default in [np.float64, np.int64]
- else:
- assert (type(init_param.default) in
- [str, int, float, bool, tuple, type(None),
- np.float64, types.FunctionType, joblib.Memory])
+ allowed_types = {
+ str,
+ int,
+ float,
+ bool,
+ tuple,
+ type(None),
+ type,
+ types.FunctionType,
+ joblib.Memory,
+ }
+ # Any numpy numeric such as np.int32.
+ allowed_types.update(np.core.numerictypes.allTypes.values())
+ assert type(init_param.default) in allowed_types
if init_param.name not in params.keys():
# deprecated parameter, not in get_params
assert init_param.default is None
--
2.28.0

1 change: 1 addition & 0 deletions csrank/tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,6 @@ def predict(self, X, *args, **kwargs):
"check_dict_unchanged", # fails for ListNet
"check_dont_overwrite_parameters", # fails for CmpNet
"check_fit_idempotent", # fails for ExpectedRankRegression
"check_n_features_in" # fails for RankSVM
}:
check(estimator)
5 changes: 5 additions & 0 deletions shell.nix
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ let
buildInputs = with pkgs; [ xorg.libX11 ] ++ old.buildInputs;
}
);
scikit-learn = super.scikit-learn.overridePythonAttrs (
old: {
patches = [./0001-Relax-init-parameter-type-checks.patch];
}
);
});
};
in pkgs.mkShell {
Expand Down

0 comments on commit 0c7ca20

Please sign in to comment.