-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Init estimation for regression. (#8272)
- Loading branch information
1 parent
1b58d81
commit badeff1
Showing
29 changed files
with
466 additions
and
132 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
"""Tests for dask shared by different test modules.""" | ||
import numpy as np | ||
from dask import array as da | ||
from distributed import Client | ||
from xgboost.testing.updater import get_basescore | ||
|
||
import xgboost as xgb | ||
|
||
|
||
def check_init_estimation_clf(tree_method: str, client: Client) -> None: | ||
"""Test init estimation for classsifier.""" | ||
from sklearn.datasets import make_classification | ||
|
||
X, y = make_classification(n_samples=4096 * 2, n_features=32, random_state=1994) | ||
clf = xgb.XGBClassifier(n_estimators=1, max_depth=1, tree_method=tree_method) | ||
clf.fit(X, y) | ||
base_score = get_basescore(clf) | ||
|
||
dx = da.from_array(X).rechunk(chunks=(32, None)) | ||
dy = da.from_array(y).rechunk(chunks=(32,)) | ||
dclf = xgb.dask.DaskXGBClassifier( | ||
n_estimators=1, max_depth=1, tree_method=tree_method | ||
) | ||
dclf.client = client | ||
dclf.fit(dx, dy) | ||
dbase_score = get_basescore(dclf) | ||
np.testing.assert_allclose(base_score, dbase_score) | ||
|
||
|
||
def check_init_estimation_reg(tree_method: str, client: Client) -> None: | ||
"""Test init estimation for regressor.""" | ||
from sklearn.datasets import make_regression | ||
|
||
# pylint: disable=unbalanced-tuple-unpacking | ||
X, y = make_regression(n_samples=4096 * 2, n_features=32, random_state=1994) | ||
reg = xgb.XGBRegressor(n_estimators=1, max_depth=1, tree_method=tree_method) | ||
reg.fit(X, y) | ||
base_score = get_basescore(reg) | ||
|
||
dx = da.from_array(X).rechunk(chunks=(32, None)) | ||
dy = da.from_array(y).rechunk(chunks=(32,)) | ||
dreg = xgb.dask.DaskXGBRegressor( | ||
n_estimators=1, max_depth=1, tree_method=tree_method | ||
) | ||
dreg.client = client | ||
dreg.fit(dx, dy) | ||
dbase_score = get_basescore(dreg) | ||
np.testing.assert_allclose(base_score, dbase_score) | ||
|
||
|
||
def check_init_estimation(tree_method: str, client: Client) -> None: | ||
"""Test init estimation.""" | ||
check_init_estimation_reg(tree_method, client) | ||
check_init_estimation_clf(tree_method, client) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
"""Tests for updaters.""" | ||
import json | ||
|
||
import numpy as np | ||
|
||
import xgboost as xgb | ||
|
||
|
||
def get_basescore(model: xgb.XGBModel) -> float: | ||
"""Get base score from an XGBoost sklearn estimator.""" | ||
base_score = float( | ||
json.loads(model.get_booster().save_config())["learner"]["learner_model_param"][ | ||
"base_score" | ||
] | ||
) | ||
return base_score | ||
|
||
|
||
def check_init_estimation(tree_method: str) -> None: | ||
"""Test for init estimation.""" | ||
from sklearn.datasets import ( | ||
make_classification, | ||
make_multilabel_classification, | ||
make_regression, | ||
) | ||
|
||
def run_reg(X: np.ndarray, y: np.ndarray) -> None: # pylint: disable=invalid-name | ||
reg = xgb.XGBRegressor(tree_method=tree_method, max_depth=1, n_estimators=1) | ||
reg.fit(X, y, eval_set=[(X, y)]) | ||
base_score_0 = get_basescore(reg) | ||
score_0 = reg.evals_result()["validation_0"]["rmse"][0] | ||
|
||
reg = xgb.XGBRegressor( | ||
tree_method=tree_method, max_depth=1, n_estimators=1, boost_from_average=0 | ||
) | ||
reg.fit(X, y, eval_set=[(X, y)]) | ||
base_score_1 = get_basescore(reg) | ||
score_1 = reg.evals_result()["validation_0"]["rmse"][0] | ||
assert not np.isclose(base_score_0, base_score_1) | ||
assert score_0 < score_1 # should be better | ||
|
||
# pylint: disable=unbalanced-tuple-unpacking | ||
X, y = make_regression(n_samples=4096, random_state=17) | ||
run_reg(X, y) | ||
# pylint: disable=unbalanced-tuple-unpacking | ||
X, y = make_regression(n_samples=4096, n_targets=3, random_state=17) | ||
run_reg(X, y) | ||
|
||
def run_clf(X: np.ndarray, y: np.ndarray) -> None: # pylint: disable=invalid-name | ||
clf = xgb.XGBClassifier(tree_method=tree_method, max_depth=1, n_estimators=1) | ||
clf.fit(X, y, eval_set=[(X, y)]) | ||
base_score_0 = get_basescore(clf) | ||
score_0 = clf.evals_result()["validation_0"]["logloss"][0] | ||
|
||
clf = xgb.XGBClassifier( | ||
tree_method=tree_method, max_depth=1, n_estimators=1, boost_from_average=0 | ||
) | ||
clf.fit(X, y, eval_set=[(X, y)]) | ||
base_score_1 = get_basescore(clf) | ||
score_1 = clf.evals_result()["validation_0"]["logloss"][0] | ||
assert not np.isclose(base_score_0, base_score_1) | ||
assert score_0 < score_1 # should be better | ||
|
||
# pylint: disable=unbalanced-tuple-unpacking | ||
X, y = make_classification(n_samples=4096, random_state=17) | ||
run_clf(X, y) | ||
X, y = make_multilabel_classification( | ||
n_samples=4096, n_labels=3, n_classes=5, random_state=17 | ||
) | ||
run_clf(X, y) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.