Skip to content

Commit

Permalink
stratified group kfold splitter (microsoft#899)
Browse files Browse the repository at this point in the history
* stratified group kfold splitter

* exclude catboost

---------

Co-authored-by: Shaokun <shaokunzhang529@gmail.com>
Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu>
  • Loading branch information
3 people authored Feb 5, 2023
1 parent 44a5bb5 commit 7b3a8f2
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 6 deletions.
5 changes: 3 additions & 2 deletions flaml/automl/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
GroupKFold,
TimeSeriesSplit,
GroupShuffleSplit,
StratifiedGroupKFold,
)
from sklearn.utils import shuffle
from sklearn.base import BaseEstimator
Expand Down Expand Up @@ -1575,8 +1576,8 @@ def _prepare_data(self, eval_method, split_ratio, n_splits):
else:
# logger.info("Using splitter object")
self._state.kf = self._split_type
if isinstance(self._state.kf, GroupKFold):
# self._split_type is either "group" or a GroupKFold object
if isinstance(self._state.kf, (GroupKFold, StratifiedGroupKFold)):
# self._split_type is either "group", a GroupKFold object, or a StratifiedGroupKFold object
self._state.kf.groups = self._state.groups_all

def add_learner(self, learner_name, learner_class):
Expand Down
21 changes: 17 additions & 4 deletions flaml/automl/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
mean_absolute_percentage_error,
ndcg_score,
)
from sklearn.model_selection import RepeatedStratifiedKFold, GroupKFold, TimeSeriesSplit
from sklearn.model_selection import (
RepeatedStratifiedKFold,
GroupKFold,
TimeSeriesSplit,
StratifiedGroupKFold,
)
from flaml.automl.model import (
XGBoostSklearnEstimator,
XGBoost_TS,
Expand Down Expand Up @@ -517,7 +522,7 @@ def evaluate_model_CV(
shuffle = getattr(kf, "shuffle", task not in TS_FORECAST)
if isinstance(kf, RepeatedStratifiedKFold):
kf = kf.split(X_train_split, y_train_split)
elif isinstance(kf, GroupKFold):
elif isinstance(kf, (GroupKFold, StratifiedGroupKFold)):
groups = kf.groups
kf = kf.split(X_train_split, y_train_split, groups)
shuffle = False
Expand Down Expand Up @@ -548,8 +553,16 @@ def evaluate_model_CV(
weight[val_index],
)
if groups is not None:
fit_kwargs["groups"] = groups[train_index]
groups_val = groups[val_index]
fit_kwargs["groups"] = (
groups[train_index]
if isinstance(groups, np.ndarray)
else groups.iloc[train_index]
)
groups_val = (
groups[val_index]
if isinstance(groups, np.ndarray)
else groups.iloc[val_index]
)
else:
groups_val = None
val_loss_i, metric_i, train_time_i, pred_time_i = get_val_loss(
Expand Down
27 changes: 27 additions & 0 deletions test/automl/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,33 @@ def test_groups():
automl.fit(X, y, **automl_settings)


def test_stratified_groupkfold():
from sklearn.model_selection import StratifiedGroupKFold
from flaml.data import load_openml_dataset

X_train, _, y_train, _ = load_openml_dataset(dataset_id=1169, data_dir="test/")
splitter = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=0)

automl = AutoML()
settings = {
"time_budget": 6,
"metric": "ap",
"eval_method": "cv",
"split_type": splitter,
"groups": X_train["Airline"],
"estimator_list": [
"lgbm",
"rf",
"xgboost",
"extra_tree",
"xgb_limitdepth",
"lrl1",
],
}

automl.fit(X_train=X_train, y_train=y_train, **settings)


def test_rank():
from sklearn.externals._arff import ArffException

Expand Down

0 comments on commit 7b3a8f2

Please sign in to comment.