Skip to content

Commit

Permalink
fix contiguous targets gam_multitask
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Mar 10, 2024
1 parent ec12a02 commit f9664a6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
22 changes: 13 additions & 9 deletions imodels/algebraic/gam_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class MultiTaskGAM(BaseEstimator):

def __init__(
self,
ebm_kwargs=None,
ebm_kwargs={},
multitask=True,
random_state=42,

Expand All @@ -43,11 +43,13 @@ def __init__(
self.ebm_kwargs = ebm_kwargs
self.multitask = multitask
self.random_state = random_state
if not 'random_state' in ebm_kwargs:
ebm_kwargs['random_state'] = random_state
self.ebm_ = ExplainableBoostingRegressor(**(ebm_kwargs or {}))

# self.ebm_ = ExplainableBoostingClassifier(**(ebm_kwargs or {}))

def fit(self, X, y, task_weights=None, sample_weight=None):
def fit(self, X, y, sample_weight=None):
X, y = check_X_y(X, y, accept_sparse=False, multi_output=False)
if isinstance(self, ClassifierMixin):
check_classification_targets(y)
Expand All @@ -64,7 +66,7 @@ def fit(self, X, y, task_weights=None, sample_weight=None):
num_features = X.shape[1]
for task_num in tqdm(range(num_features)):
self.ebms_[task_num] = deepcopy(self.ebm_)
y_ = X[:, task_num]
y_ = np.ascontiguousarray(X[:, task_num])
X_ = deepcopy(X)
X_[:, task_num] = 0
self.ebms_[task_num].fit(X_, y_, sample_weight=sample_weight)
Expand All @@ -77,7 +79,7 @@ def fit(self, X, y, task_weights=None, sample_weight=None):
feats = self.extract_ebm_features(X)

# fit a linear model to the features
self.lin_model = RidgeCV()
self.lin_model = RidgeCV(alphas=np.logspace(-2, 3, 7))
self.lin_model.fit(feats, y)
return self

Expand Down Expand Up @@ -124,7 +126,8 @@ class MultiTaskGAMClassifier(MultiTaskGAM, ClassifierMixin):


if __name__ == "__main__":
X, y, feature_names = imodels.get_clean_dataset("heart")
# X, y, feature_names = imodels.get_clean_dataset("heart")
X, y, feature_names = imodels.get_clean_dataset("bike_sharing")
# X, y, feature_names = imodels.get_clean_dataset("diabetes")

# remove some features to speed things up
Expand All @@ -135,13 +138,14 @@ class MultiTaskGAMClassifier(MultiTaskGAM, ClassifierMixin):
random_state=42,
)
results = defaultdict(list)
for gam in [
MultiTaskGAMRegressor(multitask=True),
for gam in tqdm([
MultiTaskGAMRegressor(multitask=False),
]:
MultiTaskGAMRegressor(multitask=True),
]):
np.random.seed(42)
gam.fit(X, y_train)
results["model_name"].append(gam)
print('Fitting', results['model_name'][-1])
gam.fit(X, y_train)

# check roc auc score
# y_pred = gam.predict_proba(X_test)[:, 1]
Expand Down
8 changes: 4 additions & 4 deletions imodels/util/data_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

DSET_CLASSIFICATION_KWARGS = {
# classification
"pima_diabetes": {"dataset_name": "40715", "data_source": "openml"},
"pima_diabetes": {"dataset_name": 40715, "data_source": "openml"},
"sonar": {"dataset_name": "sonar", "data_source": "pmlb"},
"heart": {"dataset_name": "heart", "data_source": "imodels"},
"diabetes": {"dataset_name": "diabetes", "data_source": "pmlb"},
Expand All @@ -36,7 +36,7 @@
"data_source": "imodels",
}, # big, 100k points
# big, 1e6 points
"adult": {"dataset_name": "1182", "data_source": "openml"},
"adult": {"dataset_name": 1182, "data_source": "openml"},
# CDI classification
"csi_pecarn": {"dataset_name": "csi_pecarn_pred", "data_source": "imodels"},
"iai_pecarn": {"dataset_name": "iai_pecarn_pred", "data_source": "imodels"},
Expand All @@ -45,12 +45,12 @@

DSET_REGRESSION_KWARGS = {
# regression
"bike_sharing": {"dataset_name": "42712", "data_source": "openml"},
"bike_sharing": {"dataset_name": 42712, "data_source": "openml"},
"friedman1": {"dataset_name": "friedman1", "data_source": "synthetic"},
"friedman2": {"dataset_name": "friedman2", "data_source": "synthetic"},
"friedman3": {"dataset_name": "friedman3", "data_source": "synthetic"},
"diabetes_regr": {"dataset_name": "diabetes", "data_source": "sklearn"},
"abalone": {"dataset_name": "183", "data_source": "openml"},
"abalone": {"dataset_name": 183, "data_source": "openml"},
"echo_months": {"dataset_name": "1199_BNG_echoMonths", "data_source": "pmlb"},
"satellite_image": {"dataset_name": "294_satellite_image", "data_source": "pmlb"},
"california_housing": {
Expand Down

0 comments on commit f9664a6

Please sign in to comment.