From b7a99af93273915d6b60b697bb7e552e116c2169 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luis=20Fran=C3=A7a?= Date: Fri, 14 Jan 2022 17:04:15 -0300 Subject: [PATCH 1/2] Fixing sklearn error when using RandomizedSearchCV MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luis França --- .../interpret/glassbox/linear.py | 19 +++++- .../interpret/glassbox/test/test_linear.py | 59 +++++++++++++++++++ 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/python/interpret-core/interpret/glassbox/linear.py b/python/interpret-core/interpret/glassbox/linear.py index cb9b07abe..cdd2289e3 100644 --- a/python/interpret-core/interpret/glassbox/linear.py +++ b/python/interpret-core/interpret/glassbox/linear.py @@ -10,12 +10,12 @@ from abc import abstractmethod from sklearn.base import is_classifier import numpy as np -from sklearn.base import ClassifierMixin, RegressorMixin +from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin from sklearn.linear_model import LogisticRegression as SKLogistic from sklearn.linear_model import Lasso as SKLinear -class BaseLinear: +class BaseLinear(BaseEstimator): """ Base linear model. Currently wrapper around linear models in scikit-learn. @@ -43,11 +43,26 @@ def __init__( self.linear_class = linear_class self.kwargs = kwargs + for key, value in self.kwargs.items(): + setattr(self, key, value) + @abstractmethod def _model(self): # This method should be overridden. return None + # get_params and set_params are usually inherited from BaseEstimator, but they will + # fail here due to the **kwargs in the __init__. Therefore, we implement them. + def get_params(self, deep = True): + return {param: getattr(self, param) + for param in self.kwargs} + + def set_params(self, **parameters): + for parameter, value in parameters.items(): + setattr(self, parameter, value) + + return self + def fit(self, X, y): """ Fits model to provided instances. diff --git a/python/interpret-core/interpret/glassbox/test/test_linear.py b/python/interpret-core/interpret/glassbox/test/test_linear.py index f543a94f0..ecb70f546 100644 --- a/python/interpret-core/interpret/glassbox/test/test_linear.py +++ b/python/interpret-core/interpret/glassbox/test/test_linear.py @@ -5,6 +5,7 @@ from sklearn.datasets import load_breast_cancer, load_boston from sklearn.linear_model import LogisticRegression as SKLogistic from sklearn.linear_model import Lasso as SKLinear +from sklearn.model_selection import RandomizedSearchCV import numpy as np @@ -38,6 +39,35 @@ def test_linear_regression(): assert global_viz is not None +def test_linear_regression_sklearn_compatibility(): + boston = load_boston() + X, y = boston.data, boston.target + + distributions = { + 'max_iter': [250, 500], + 'alpha': [0.1 , 0.25, 0.5, 1] + } + + sk_lr = SKLinear() + our_lr = LinearRegression() + + search_sk = RandomizedSearchCV(estimator = sk_lr, + param_distributions = distributions, + random_state = 2022) + + search_our = RandomizedSearchCV(estimator = our_lr, + param_distributions = distributions, + random_state = 2022) + + search_sk.fit(X, y) + search_our.fit(X, y) + + sk_pred = search_sk.predict(X) + our_pred = search_our.predict(X) + + assert np.allclose(sk_pred, our_pred) + + def test_logistic_regression(): cancer = load_breast_cancer() X, y = cancer.data, cancer.target @@ -72,6 +102,35 @@ def test_logistic_regression(): assert global_viz is not None +def test_logistic_regression_sklearn_compatibility(): + cancer = load_breast_cancer() + X, y = cancer.data, cancer.target + + distributions = { + 'penalty': ['l1', 'l2'], + 'C': [1 , 0.5, 0.1, 0.05, 0.01] + } + + sk_lr = SKLogistic() + our_lr = LogisticRegression() + + search_sk = RandomizedSearchCV(estimator = sk_lr, + param_distributions = distributions, + random_state = 2022) + + search_our = RandomizedSearchCV(estimator = our_lr, + param_distributions = distributions, + random_state = 2022) + + search_sk.fit(X, y) + search_our.fit(X, y) + + sk_pred = search_sk.predict_proba(X) + our_pred = search_our.predict_proba(X) + + assert np.allclose(sk_pred, our_pred) + + def test_sorting(): cancer = load_breast_cancer() X, y = cancer.data, cancer.target From d8847f795d90c618104478cac905dec24b81b73c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luis=20Fran=C3=A7a?= Date: Fri, 21 Jan 2022 17:45:57 -0300 Subject: [PATCH 2/2] Adding tolerance to np.allclose --- python/interpret-core/interpret/glassbox/test/test_linear.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/interpret-core/interpret/glassbox/test/test_linear.py b/python/interpret-core/interpret/glassbox/test/test_linear.py index ecb70f546..a6f7e87dc 100644 --- a/python/interpret-core/interpret/glassbox/test/test_linear.py +++ b/python/interpret-core/interpret/glassbox/test/test_linear.py @@ -1,7 +1,7 @@ # Copyright (c) 2019 Microsoft Corporation # Distributed under the MIT software license -from ..linear import LogisticRegression, LinearRegression +from interpret.glassbox.linear import LogisticRegression, LinearRegression from sklearn.datasets import load_breast_cancer, load_boston from sklearn.linear_model import LogisticRegression as SKLogistic from sklearn.linear_model import Lasso as SKLinear @@ -128,7 +128,8 @@ def test_logistic_regression_sklearn_compatibility(): sk_pred = search_sk.predict_proba(X) our_pred = search_our.predict_proba(X) - assert np.allclose(sk_pred, our_pred) + assert np.allclose(sk_pred, our_pred, rtol=1e-03, atol=1e-06) + def test_sorting():