Skip to content

Commit

Permalink
pass ebm_kwargs to gam_shap
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Nov 6, 2024
1 parent 35af526 commit d47ead5
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions imodels/algebraic/gam_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class ShapGAMClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, n_estimators=10, feature_fraction=0.7, random_state=None):
def __init__(self, n_estimators=10, feature_fraction=0.7, random_state=None, ebm_kwargs: dict = {}):
"""
Initialize the ensemble EBM classifier.
Expand All @@ -19,6 +19,7 @@ def __init__(self, n_estimators=10, feature_fraction=0.7, random_state=None):
self.random_state = random_state
self.models = []
self.feature_subsets = []
self.ebm_kwargs = ebm_kwargs

def fit(self, X, y):
"""
Expand All @@ -37,7 +38,8 @@ def fit(self, X, y):
self.feature_subsets.append(feature_subset)

# Create an EBM with the selected feature subset
ebm = ExplainableBoostingClassifier(random_state=self.random_state)
ebm = ExplainableBoostingClassifier(
random_state=self.random_state, **self.ebm_kwargs)
X_subset = X[:, feature_subset]
ebm.fit(X_subset, y)
self.models.append(ebm)
Expand Down

0 comments on commit d47ead5

Please sign in to comment.