diff --git a/eli5/sklearn/utils.py b/eli5/sklearn/utils.py index 3f3c74de..802b7e90 100644 --- a/eli5/sklearn/utils.py +++ b/eli5/sklearn/utils.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import numpy as np +from sklearn.base import MetaEstimatorMixin def is_multiclass_classifier(clf): @@ -19,7 +20,11 @@ def is_multitarget_regressor(clf): def is_probabilistic_classifier(clf): """ Return True if a classifier can return probabilities """ - return hasattr(clf, 'predict_proba') + if not hasattr(clf, 'predict_proba'): + return False + if isinstance(clf, MetaEstimatorMixin) and hasattr(clf, 'estimator'): + return hasattr(clf.estimator, 'predict_proba') + return True def has_intercept(estimator):