Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Make our estimators compatible with scikit-learn #116

Draft
wants to merge 15 commits into
base: master
Choose a base branch
from
46 changes: 46 additions & 0 deletions 0001-Relax-init-parameter-type-checks.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
From 48a6dd8f0e7ff52d116b0c01e36dac4504547781 Mon Sep 17 00:00:00 2001
From: Timo Kaufmann <timokau@zoho.com>
Date: Thu, 16 Jul 2020 16:57:14 +0200
Subject: [PATCH] Relax init parameter type checks

We now allow any "type" (uninitialized classes) and all numeric numpy
types. See https://github.com/scikit-learn/scikit-learn/issues/17756 for
a discussion.
---
sklearn/utils/estimator_checks.py | 20 ++++++++++++++------
1 file changed, 14 insertions(+), 6 deletions(-)

diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py
index 30c668237..ed3c40d67 100644
--- a/sklearn/utils/estimator_checks.py
+++ b/sklearn/utils/estimator_checks.py
@@ -2529,12 +2529,20 @@ def check_parameters_default_constructible(name, Estimator):
assert init_param.default != init_param.empty, (
"parameter %s for %s has no default value"
% (init_param.name, type(estimator).__name__))
- if type(init_param.default) is type:
- assert init_param.default in [np.float64, np.int64]
- else:
- assert (type(init_param.default) in
- [str, int, float, bool, tuple, type(None),
- np.float64, types.FunctionType, joblib.Memory])
+ allowed_types = {
+ str,
+ int,
+ float,
+ bool,
+ tuple,
+ type(None),
+ type,
+ types.FunctionType,
+ joblib.Memory,
+ }
+ # Any numpy numeric such as np.int32.
+ allowed_types.update(np.core.numerictypes.allTypes.values())
+ assert type(init_param.default) in allowed_types
if init_param.name not in params.keys():
# deprecated parameter, not in get_params
assert init_param.default is None
--
2.28.0

22 changes: 12 additions & 10 deletions csrank/choicefunction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
from .pairwise_choice import PairwiseSVMChoiceFunction
from .ranknet_choice import RankNetChoiceFunction

__all__ = [
"AllPositive",
"CmpNetChoiceFunction",
"FATEChoiceFunction",
"FATELinearChoiceFunction",
"FETAChoiceFunction",
"FETALinearChoiceFunction",
"GeneralizedLinearModel",
"PairwiseSVMChoiceFunction",
"RankNetChoiceFunction",
algorithms = [
AllPositive,
CmpNetChoiceFunction,
FATEChoiceFunction,
FATELinearChoiceFunction,
FETAChoiceFunction,
FETALinearChoiceFunction,
GeneralizedLinearModel,
PairwiseSVMChoiceFunction,
RankNetChoiceFunction,
]

__all__ = [algo.__name__ for algo in algorithms]
2 changes: 1 addition & 1 deletion csrank/dataset_reader/objectranking/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def generate_pairwise_instances(features):

def generate_complete_pairwise_dataset(X, Y):
"""
Generates the pairiwse preference data from the given rankings.The ranking amongst the objects in a query set
Generates the pairwise preference data from the given rankings.The ranking amongst the objects in a query set
:math:`Q = \\{x_1, x_2, x_3\\}` is represented by :math:`\\pi = (2,1,3)`, such that :math:`\\pi(2)=1` is the position of the :math:`x_2`.
One can extract the following *pairwise preferences* :math:`x_2 \\succ x_1, x_2 \\succ x_3 and x_1 \\succ x_3`.
This function generates pairwise preferences which can be used to learn different :class:`ObjectRanker` as:
Expand Down
32 changes: 17 additions & 15 deletions csrank/discretechoice/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,21 @@
from .pairwise_discrete_choice import PairwiseSVMDiscreteChoiceFunction
from .ranknet_discrete_choice import RankNetDiscreteChoiceFunction

__all__ = [
"RandomBaselineDC",
"CmpNetDiscreteChoiceFunction",
"FATEDiscreteChoiceFunction",
"FATELinearDiscreteChoiceFunction",
"FETADiscreteChoiceFunction",
"FETALinearDiscreteChoiceFunction",
"GeneralizedNestedLogitModel",
"MixedLogitModel",
"ModelSelector",
"MultinomialLogitModel",
"NestedLogitModel",
"PairedCombinatorialLogit",
"PairwiseSVMDiscreteChoiceFunction",
"RankNetDiscreteChoiceFunction",
algorithms = [
RandomBaselineDC,
CmpNetDiscreteChoiceFunction,
FATEDiscreteChoiceFunction,
FATELinearDiscreteChoiceFunction,
FETADiscreteChoiceFunction,
FETALinearDiscreteChoiceFunction,
GeneralizedNestedLogitModel,
MixedLogitModel,
ModelSelector,
MultinomialLogitModel,
NestedLogitModel,
PairedCombinatorialLogit,
PairwiseSVMDiscreteChoiceFunction,
RankNetDiscreteChoiceFunction,
]

__all__ = [algo.__name__ for algo in algorithms]
24 changes: 13 additions & 11 deletions csrank/objectranking/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@
from .rank_net import RankNet
from .rank_svm import RankSVM

__all__ = [
"CmpNet",
"ExpectedRankRegression",
"FATEObjectRanker",
"FATELinearObjectRanker",
"FETAObjectRanker",
"FETALinearObjectRanker",
"ListNet",
"RankNet",
"RankSVM",
"RandomBaselineRanker",
algorithms = [
CmpNet,
ExpectedRankRegression,
FATEObjectRanker,
FATELinearObjectRanker,
FETAObjectRanker,
FETALinearObjectRanker,
ListNet,
RankNet,
RankSVM,
RandomBaselineRanker,
]

__all__ = [algo.__name__ for algo in algorithms]
70 changes: 70 additions & 0 deletions csrank/tests/test_estimators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Check that our estimators adhere to the scikit-learn interface.

https://scikit-learn.org/stable/developers/develop.html#rolling-your-own-estimator
"""

from functools import partial

import pytest
from sklearn.utils.estimator_checks import check_estimator

from csrank import objectranking


def get_check_name(check):
if isinstance(check, partial):
return check.func.__name__
else:
return check.__name__


def _reshape_x(X):
n_instances, n_objects = X.shape
n_features = 1
return X.reshape((n_instances, n_objects, n_features))


@pytest.mark.parametrize(
"Estimator",
# TODO write wrappers for choice, discretechoice
objectranking.algorithms,
)
def test_all_estimators(Estimator):
class WrappedRanker(Estimator):
# scikit learn assumes that "X" is an array of one-dimensional
# feature vectors. Our learners however assume an array of objects
# as a "feature vector", hence they expect one more dimension.
# This is one scikit-learn API expectation that we do not fulfill.
# This thin wrapper is needed so that we can still use the other
# estimator checks. It just pretends every feature is itself a
# one-feature object.
def fit(self, X, Y, *args, **kwargs):
Xnew = _reshape_x(X)
Ynew = Xnew.argsort(axis=1).argsort(axis=1).squeeze(axis=-1)
return super().fit(Xnew, Ynew, *args, **kwargs)

def predict(self, X, *args, **kwargs):
super().predict(_reshape_x(X), *args, **kwargs)

for (estimator, check) in check_estimator(WrappedRanker, generate_only=True):
# checks that attempt to call "fit" do not work since our estimators
# expect a 3-dimensional data shape while scikit-learn assumes two
# dimensions (an array of 1d data).
if not get_check_name(check) in {
"check_estimators_fit_returns_self", # fails for all
"check_complex_data", # fails for CmpNet
"check_dtype_object", # fails for ExpectedRankRegression
"check_estimators_empty_data_messages", # fails for all
"check_estimators_nan_inf", # fails for CmpNet
"check_estimators_overwrite_params", # fails for FATELinearObjectRanker
"check_estimator_sparse_data", # fails for ExpectedRankRegression
"check_estimators_pickle", # fails for ExpectedRankRegression
"check_fit2d_predict1d", # fails for ExpectedRankRegression
"check_methods_subset_invariance", # fails for ExpectedRankRegression
"check_fit2d_1sample", # fails for FETAObjectRanker
"check_dict_unchanged", # fails for ListNet
"check_dont_overwrite_parameters", # fails for CmpNet
"check_fit_idempotent", # fails for ExpectedRankRegression
"check_n_features_in" # fails for RankSVM
}:
check(estimator)
5 changes: 5 additions & 0 deletions shell.nix
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ let
buildInputs = with pkgs; [ xorg.libX11 ] ++ old.buildInputs;
}
);
scikit-learn = super.scikit-learn.overridePythonAttrs (
old: {
patches = [./0001-Relax-init-parameter-type-checks.patch];
}
);
});
};
in pkgs.mkShell {
Expand Down