Skip to content

Commit

Permalink
add __sklearn_tags__
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbkoch committed Dec 10, 2024
1 parent bd005ec commit e0369a7
Show file tree
Hide file tree
Showing 6 changed files with 389 additions and 1 deletion.
65 changes: 65 additions & 0 deletions python/interpret-core/interpret/glassbox/_aplr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Distributed under the MIT software license
from typing import Dict, List, Optional, Tuple
from warnings import warn
from dataclasses import dataclass, field
from typing import Optional

import numpy as np
import pandas as pd
Expand All @@ -23,6 +25,57 @@
IntMatrix = np.ndarray


@dataclass
class APLRInputTags:
one_d_array: bool = False
two_d_array: bool = True
three_d_array: bool = False
sparse: bool = False
categorical: bool = False
string: bool = False
dict: bool = False
positive_only: bool = False
allow_nan: bool = False
pairwise: bool = False


@dataclass
class APLRTargetTags:
required: bool = True
one_d_labels: bool = True
two_d_labels: bool = False
positive_only: bool = False
multi_output: bool = False
single_output: bool = True


@dataclass
class APLRClassifierTags:
poor_score: bool = False
multi_class: bool = True
multi_label: bool = False


@dataclass
class APLRRegressorTags:
poor_score: bool = False


@dataclass
class APLRTags:
estimator_type: Optional[str] = None
target_tags: APLRTargetTags = field(default_factory=APLRTargetTags)
transformer_tags: None = None
classifier_tags: Optional[APLRClassifierTags] = None
regressor_tags: Optional[APLRRegressorTags] = None
array_api_support: bool = False
no_validation: bool = False
non_deterministic: bool = False
requires_fit: bool = True
_skip_test: bool = False
input_tags: APLRInputTags = field(default_factory=APLRInputTags)


class APLRRegressor(RegressorMixin, ExplainerMixin):
available_explanations = ["local", "global"]
explainer_type = "model"
Expand Down Expand Up @@ -409,6 +462,12 @@ def explain_local(
selector=selector,
)

def __sklearn_tags__(self):
tags = APLRTags()
tags.estimator_type = "regressor"
tags.regressor_tags = APLRRegressorTags()
return tags


def calculate_densities(X: FloatMatrix) -> Tuple[List[List[int]], List[List[float]]]:
bin_counts: List[List[int]] = []
Expand Down Expand Up @@ -816,6 +875,12 @@ def explain_local(
selector=selector,
)

def __sklearn_tags__(self):
tags = APLRTags()
tags.estimator_type = "classifier"
tags.classifier_tags = APLRClassifierTags()
return tags


class APLRExplanation(FeatureValueExplanation):
"""Visualizes specifically for APLR."""
Expand Down
70 changes: 69 additions & 1 deletion python/interpret-core/interpret/glassbox/_decisiontree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging
from abc import abstractmethod
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Optional

import numpy as np
from sklearn.base import ClassifierMixin, RegressorMixin, is_classifier
Expand Down Expand Up @@ -219,6 +221,57 @@ def _weight_nodes_feature(self, nodes, feature_name):
return new_nodes


@dataclass
class TreeInputTags:
one_d_array: bool = False
two_d_array: bool = True
three_d_array: bool = False
sparse: bool = True
categorical: bool = False
string: bool = True
dict: bool = True
positive_only: bool = False
allow_nan: bool = True
pairwise: bool = False


@dataclass
class TreeTargetTags:
required: bool = True
one_d_labels: bool = True
two_d_labels: bool = False
positive_only: bool = False
multi_output: bool = False
single_output: bool = True


@dataclass
class TreeClassifierTags:
poor_score: bool = False
multi_class: bool = True
multi_label: bool = False


@dataclass
class TreeRegressorTags:
poor_score: bool = False


@dataclass
class TreeTags:
estimator_type: Optional[str] = None
target_tags: TreeTargetTags = field(default_factory=TreeTargetTags)
transformer_tags: None = None
classifier_tags: Optional[TreeClassifierTags] = None
regressor_tags: Optional[TreeRegressorTags] = None
array_api_support: bool = True
no_validation: bool = False
non_deterministic: bool = False
requires_fit: bool = True
_skip_test: bool = False
input_tags: TreeInputTags = field(default_factory=TreeInputTags)


class BaseShallowDecisionTree:
"""Shallow Decision Tree (low depth).
Expand Down Expand Up @@ -280,7 +333,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
X, n_samples = preclean_X(X, self.feature_names, self.feature_types, len(y))

X, self.feature_names_in_, self.feature_types_in_ = unify_data(
X, n_samples, self.feature_names, self.feature_types, False, 0
X, n_samples, self.feature_names, self.feature_types, True, 0
)

model = self._model()
Expand Down Expand Up @@ -540,6 +593,9 @@ def recur(i, depth=0):
recur(0)
return nodes, edges

def __sklearn_tags__(self):
return TreeTags()


class RegressionTree(BaseShallowDecisionTree, RegressorMixin, ExplainerMixin):
"""Regression tree with shallow depth."""
Expand Down Expand Up @@ -583,6 +639,12 @@ def fit(self, X, y, sample_weight=None, check_input=True):
check_input=check_input,
)

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.estimator_type = "regressor"
tags.regressor_tags = TreeRegressorTags()
return tags


class ClassificationTree(BaseShallowDecisionTree, ClassifierMixin, ExplainerMixin):
"""Classification tree with shallow depth."""
Expand Down Expand Up @@ -644,3 +706,9 @@ def predict_proba(self, X):
)

return self._model().predict_proba(X)

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.estimator_type = "classifier"
tags.classifier_tags = TreeClassifierTags()
return tags
79 changes: 79 additions & 0 deletions python/interpret-core/interpret/glassbox/_ebm/_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from math import ceil, isnan
from typing import Dict, List, Mapping, Optional, Sequence, Tuple, Union
from warnings import warn
from dataclasses import dataclass, field

import numpy as np
from sklearn.base import (
Expand Down Expand Up @@ -276,6 +277,57 @@ def _clean_exclude(exclude, feature_map):
return ret


@dataclass
class EbmInputTags:
one_d_array: bool = False
two_d_array: bool = True
three_d_array: bool = False
sparse: bool = True
categorical: bool = True
string: bool = True
dict: bool = True
positive_only: bool = False
allow_nan: bool = True
pairwise: bool = False


@dataclass
class EbmTargetTags:
required: bool = True
one_d_labels: bool = True
two_d_labels: bool = False
positive_only: bool = False
multi_output: bool = False
single_output: bool = True


@dataclass
class EbmClassifierTags:
poor_score: bool = False
multi_class: bool = True
multi_label: bool = False


@dataclass
class EbmRegressorTags:
poor_score: bool = False


@dataclass
class EbmTags:
estimator_type: Optional[str] = None
target_tags: EbmTargetTags = field(default_factory=EbmTargetTags)
transformer_tags: None = None
classifier_tags: Optional[EbmClassifierTags] = None
regressor_tags: Optional[EbmRegressorTags] = None
array_api_support: bool = True
no_validation: bool = False
non_deterministic: bool = False
requires_fit: bool = True
_skip_test: bool = False
input_tags: EbmInputTags = field(default_factory=EbmInputTags)


class EBMModel(BaseEstimator):
"""Base class for all EBMs."""

Expand Down Expand Up @@ -2627,6 +2679,9 @@ def _more_tags(self):
],
}

def __sklearn_tags__(self):
return EbmTags()


class ExplainableBoostingClassifier(EBMModel, ClassifierMixin, ExplainerMixin):
r"""An Explainable Boosting Classifier.
Expand Down Expand Up @@ -2977,6 +3032,12 @@ def predict(self, X, init_score=None):
# multiclass
return self.classes_[np.argmax(scores, axis=1)]

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.estimator_type = "classifier"
tags.classifier_tags = EbmClassifierTags()
return tags


class ExplainableBoostingRegressor(EBMModel, RegressorMixin, ExplainerMixin):
r"""An Explainable Boosting Regressor.
Expand Down Expand Up @@ -3293,6 +3354,12 @@ def predict(self, X, init_score=None):
scores = self._predict_score(X, init_score)
return inv_link(scores, self.link_, self.link_param_)

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.estimator_type = "regressor"
tags.regressor_tags = EbmRegressorTags()
return tags


class DPExplainableBoostingClassifier(EBMModel, ClassifierMixin, ExplainerMixin):
r"""Differentially Private Explainable Boosting Classifier.
Expand Down Expand Up @@ -3554,6 +3621,12 @@ def predict(self, X, init_score=None):
# multiclass
return self.classes_[np.argmax(scores, axis=1)]

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.estimator_type = "classifier"
tags.classifier_tags = EbmClassifierTags()
return tags


class DPExplainableBoostingRegressor(EBMModel, RegressorMixin, ExplainerMixin):
r"""Differentially Private Explainable Boosting Regressor.
Expand Down Expand Up @@ -3791,3 +3864,9 @@ def predict(self, X, init_score=None):
"""
scores = self._predict_score(X, init_score)
return inv_link(scores, self.link_, self.link_param_)

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.estimator_type = "regressor"
tags.regressor_tags = EbmRegressorTags()
return tags
Loading

0 comments on commit e0369a7

Please sign in to comment.