diff --git a/imodels/algebraic/gam_multitask.py b/imodels/algebraic/gam_multitask.py index 04c1d297..038006b9 100644 --- a/imodels/algebraic/gam_multitask.py +++ b/imodels/algebraic/gam_multitask.py @@ -31,7 +31,7 @@ class MultiTaskGAM(BaseEstimator): def __init__( self, - ebm_kwargs=None, + ebm_kwargs={}, multitask=True, random_state=42, @@ -43,11 +43,13 @@ def __init__( self.ebm_kwargs = ebm_kwargs self.multitask = multitask self.random_state = random_state + if not 'random_state' in ebm_kwargs: + ebm_kwargs['random_state'] = random_state self.ebm_ = ExplainableBoostingRegressor(**(ebm_kwargs or {})) # self.ebm_ = ExplainableBoostingClassifier(**(ebm_kwargs or {})) - def fit(self, X, y, task_weights=None, sample_weight=None): + def fit(self, X, y, sample_weight=None): X, y = check_X_y(X, y, accept_sparse=False, multi_output=False) if isinstance(self, ClassifierMixin): check_classification_targets(y) @@ -64,7 +66,7 @@ def fit(self, X, y, task_weights=None, sample_weight=None): num_features = X.shape[1] for task_num in tqdm(range(num_features)): self.ebms_[task_num] = deepcopy(self.ebm_) - y_ = X[:, task_num] + y_ = np.ascontiguousarray(X[:, task_num]) X_ = deepcopy(X) X_[:, task_num] = 0 self.ebms_[task_num].fit(X_, y_, sample_weight=sample_weight) @@ -77,7 +79,7 @@ def fit(self, X, y, task_weights=None, sample_weight=None): feats = self.extract_ebm_features(X) # fit a linear model to the features - self.lin_model = RidgeCV() + self.lin_model = RidgeCV(alphas=np.logspace(-2, 3, 7)) self.lin_model.fit(feats, y) return self @@ -124,7 +126,8 @@ class MultiTaskGAMClassifier(MultiTaskGAM, ClassifierMixin): if __name__ == "__main__": - X, y, feature_names = imodels.get_clean_dataset("heart") + # X, y, feature_names = imodels.get_clean_dataset("heart") + X, y, feature_names = imodels.get_clean_dataset("bike_sharing") # X, y, feature_names = imodels.get_clean_dataset("diabetes") # remove some features to speed things up @@ -135,13 +138,14 @@ class MultiTaskGAMClassifier(MultiTaskGAM, ClassifierMixin): random_state=42, ) results = defaultdict(list) - for gam in [ - MultiTaskGAMRegressor(multitask=True), + for gam in tqdm([ MultiTaskGAMRegressor(multitask=False), - ]: + MultiTaskGAMRegressor(multitask=True), + ]): np.random.seed(42) - gam.fit(X, y_train) results["model_name"].append(gam) + print('Fitting', results['model_name'][-1]) + gam.fit(X, y_train) # check roc auc score # y_pred = gam.predict_proba(X_test)[:, 1] diff --git a/imodels/util/data_util.py b/imodels/util/data_util.py index e6af6c83..83b6f8cf 100644 --- a/imodels/util/data_util.py +++ b/imodels/util/data_util.py @@ -15,7 +15,7 @@ DSET_CLASSIFICATION_KWARGS = { # classification - "pima_diabetes": {"dataset_name": "40715", "data_source": "openml"}, + "pima_diabetes": {"dataset_name": 40715, "data_source": "openml"}, "sonar": {"dataset_name": "sonar", "data_source": "pmlb"}, "heart": {"dataset_name": "heart", "data_source": "imodels"}, "diabetes": {"dataset_name": "diabetes", "data_source": "pmlb"}, @@ -36,7 +36,7 @@ "data_source": "imodels", }, # big, 100k points # big, 1e6 points - "adult": {"dataset_name": "1182", "data_source": "openml"}, + "adult": {"dataset_name": 1182, "data_source": "openml"}, # CDI classification "csi_pecarn": {"dataset_name": "csi_pecarn_pred", "data_source": "imodels"}, "iai_pecarn": {"dataset_name": "iai_pecarn_pred", "data_source": "imodels"}, @@ -45,12 +45,12 @@ DSET_REGRESSION_KWARGS = { # regression - "bike_sharing": {"dataset_name": "42712", "data_source": "openml"}, + "bike_sharing": {"dataset_name": 42712, "data_source": "openml"}, "friedman1": {"dataset_name": "friedman1", "data_source": "synthetic"}, "friedman2": {"dataset_name": "friedman2", "data_source": "synthetic"}, "friedman3": {"dataset_name": "friedman3", "data_source": "synthetic"}, "diabetes_regr": {"dataset_name": "diabetes", "data_source": "sklearn"}, - "abalone": {"dataset_name": "183", "data_source": "openml"}, + "abalone": {"dataset_name": 183, "data_source": "openml"}, "echo_months": {"dataset_name": "1199_BNG_echoMonths", "data_source": "pmlb"}, "satellite_image": {"dataset_name": "294_satellite_image", "data_source": "pmlb"}, "california_housing": {