diff --git a/docs/source/autodocs/eli5.rst b/docs/source/autodocs/eli5.rst index b1030549..7b471543 100644 --- a/docs/source/autodocs/eli5.rst +++ b/docs/source/autodocs/eli5.rst @@ -13,3 +13,5 @@ The following functions are exposed to a top level, e.g. .. autofunction:: eli5.show_weights .. autofunction:: eli5.show_prediction + +.. autofunction:: eli5.transform_feature_names diff --git a/docs/source/libraries/sklearn.rst b/docs/source/libraries/sklearn.rst index 0abaaa89..a8d99854 100644 --- a/docs/source/libraries/sklearn.rst +++ b/docs/source/libraries/sklearn.rst @@ -195,6 +195,23 @@ is independent. .. _ExtraTreesClassifier: http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.ExtraTreesClassifier.html#sklearn.ensemble.ExtraTreesClassifier .. _ExtraTreesRegressor: http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.ExtraTreesRegressor.html#sklearn.ensemble.ExtraTreesRegressor +.. _sklearn-pipelines: + +Transformation pipelines +------------------------ + +:func:`eli5.explain_weights` can be applied to a scikit-learn Pipeline_ as +long as: + +* ``explain_weights`` is supported for the final step of the Pipeline +* :func:`eli5.transform_feature_names` is supported for all preceding steps + of the Pipeline. singledispatch_ can be used to register + ``transform_feature_names`` for transformer classes not handled (yet) by ELI5 + or to override the default implementation. + +.. _Pipeline: http://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html#sklearn.pipeline.Pipeline +.. _singledispatch: https://pypi.python.org/pypi/singledispatch + Reversing hashing trick ----------------------- diff --git a/eli5/__init__.py b/eli5/__init__.py index 5178da52..e2530c29 100644 --- a/eli5/__init__.py +++ b/eli5/__init__.py @@ -6,6 +6,7 @@ from .formatters import format_as_html, format_html_styles, format_as_text from .explain import explain_weights, explain_prediction from .sklearn import explain_weights_sklearn, explain_prediction_sklearn +from .transform import transform_feature_names try: diff --git a/eli5/sklearn/__init__.py b/eli5/sklearn/__init__.py index 9067b20b..5c96f1f6 100644 --- a/eli5/sklearn/__init__.py +++ b/eli5/sklearn/__init__.py @@ -17,3 +17,4 @@ FeatureUnhasher, invert_hashing_and_fit, ) +from . import transform as _ diff --git a/eli5/sklearn/explain_weights.py b/eli5/sklearn/explain_weights.py index 81190ca3..9f7b09e3 100644 --- a/eli5/sklearn/explain_weights.py +++ b/eli5/sklearn/explain_weights.py @@ -5,6 +5,7 @@ import numpy as np # type: ignore from sklearn.base import BaseEstimator, RegressorMixin # type: ignore +from sklearn.pipeline import Pipeline # type: ignore from sklearn.linear_model import ( # type: ignore ElasticNet, # includes Lasso, MultiTaskElasticNet, etc. ElasticNetCV, @@ -61,6 +62,7 @@ get_default_target_names, ) from eli5.explain import explain_weights +from eli5.transform import transform_feature_names from eli5._feature_importances import ( get_feature_importances_filtered, get_feature_importance_explanation, @@ -422,3 +424,17 @@ def _features(target_id): method='linear model', is_regression=True, ) + + +@register(Pipeline) +def explain_weights_pipeline(estimator, feature_names=None, **kwargs): + last_estimator = estimator.steps[-1][1] + transform_pipeline = Pipeline(estimator.steps[:-1]) + if 'vec' in kwargs: + feature_names = get_feature_names(feature_names, vec=kwargs.pop('vec')) + feature_names = transform_feature_names(transform_pipeline, feature_names) + out = explain_weights(last_estimator, + feature_names=feature_names, + **kwargs) + out.estimator = repr(estimator) + return out diff --git a/eli5/sklearn/transform.py b/eli5/sklearn/transform.py new file mode 100644 index 00000000..ce1364cc --- /dev/null +++ b/eli5/sklearn/transform.py @@ -0,0 +1,30 @@ +"""transform_feature_names implementations for scikit-learn transformers +""" + +import numpy as np # type: ignore +from sklearn.pipeline import Pipeline # type: ignore +from sklearn.feature_selection.base import SelectorMixin # type: ignore + +from eli5.transform import transform_feature_names +from eli5.sklearn.utils import get_feature_names as _get_feature_names + + +# Feature selection: + +@transform_feature_names.register(SelectorMixin) +def _select_names(est, in_names=None): + mask = est.get_support(indices=False) + in_names = _get_feature_names(est, feature_names=in_names, + num_features=len(mask)) + return [in_names[i] for i in np.flatnonzero(mask)] + + +# Pipelines + +@transform_feature_names.register(Pipeline) +def _pipeline_names(est, in_names=None): + names = in_names + for name, trans in est.steps: + if trans is not None: + names = transform_feature_names(trans, names) + return names diff --git a/eli5/transform.py b/eli5/transform.py new file mode 100644 index 00000000..6bdb9b4e --- /dev/null +++ b/eli5/transform.py @@ -0,0 +1,29 @@ +"""Handling transformation pipelines in explanations""" + +from singledispatch import singledispatch + + +@singledispatch +def transform_feature_names(transformer, in_names=None): + """Get feature names for transformer output as a function of input names + + Used by :func:`explain_weights` when applied to a scikit-learn Pipeline, + this ``singledispatch`` should be registered with custom name + transformations for each class of transformer. + + Parameters + ---------- + transform : scikit-learn-compatible transformer + in_names : list of str, optional + Names for features input to transformer.transform(). + If not provided, the implementation may generate default feature names + if the number of input features is known. + + Returns + ------- + feature_names : list of str + """ + if hasattr(transformer, 'get_feature_names'): + return transformer.get_feature_names() + raise NotImplementedError('transform_feature_names not available for ' + '{}'.format(transformer)) diff --git a/tests/test_sklearn_explain_weights.py b/tests/test_sklearn_explain_weights.py index f328f5f2..6f599234 100644 --- a/tests/test_sklearn_explain_weights.py +++ b/tests/test_sklearn_explain_weights.py @@ -50,7 +50,9 @@ AdaBoostRegressor, ) from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor -from sklearn.base import BaseEstimator +from sklearn.base import BaseEstimator, clone +from sklearn.pipeline import make_pipeline +from sklearn.feature_selection import SelectKBest from sklearn.multiclass import OneVsRestClassifier import pytest @@ -484,3 +486,30 @@ def test_feature_importances_no_remaining(clf): for expl in format_as_all(res, clf): assert 'more features' not in expl and 'more …' not in expl assert 'x1' not in expl # it has zero importance + + +@pytest.mark.parametrize(['transformer', 'X', 'feature_names', + 'explain_kwargs'], [ + [None, [[1, 0], [0, 1]], ['hello', 'world'], {}], + [None, [[1, 0], [0, 1]], None, + {'vec': CountVectorizer().fit(['hello', 'world'])}], + [CountVectorizer(), ['hello', 'world'], None, {'top': 1}], + [CountVectorizer(), ['hello', 'world'], None, {'top': 2}], + [make_pipeline(CountVectorizer(), + SelectKBest(lambda X, y: np.array([3, 2, 1]), k=2)), + ['hello', 'world zzzignored'], None, {}], +]) +@pytest.mark.parametrize(['predictor'], [ + [LogisticRegression()], + [LinearSVR()], +]) +def test_explain_pipeline(predictor, transformer, X, feature_names, + explain_kwargs): + y = [1, 0] + expected = explain_weights(clone(predictor).fit([[1, 0], [0, 1]], y), + feature_names=['hello', 'world'], + **explain_kwargs) + pipe = make_pipeline(transformer, clone(predictor)).fit(X, y) + actual = explain_weights(pipe, feature_names=feature_names, + **explain_kwargs) + assert expected._repr_html_() == actual._repr_html_()