From 237dec9edf1c1d6d4e661f3d9ac50529f226b526 Mon Sep 17 00:00:00 2001 From: RektPunk Date: Sat, 5 Oct 2024 22:15:41 +0900 Subject: [PATCH] label mean if reference --- mqboost/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mqboost/dataset.py b/mqboost/dataset.py index 5dea1e7..e43a5e9 100644 --- a/mqboost/dataset.py +++ b/mqboost/dataset.py @@ -33,7 +33,7 @@ class MQDataset: data (pd.DataFrame | pd.Series | np.ndarray): The input features. label (pd.Series | np.ndarray): The target labels (if provided). model (str): The model type (LightGBM or XGBoost). - reference (MQBoost | None): Reference dataset for label encoding. + reference (MQBoost | None): Reference dataset for label encoding and label mean. Property: train_dtype: Returns the data type function for training data. @@ -81,7 +81,7 @@ def __init__( self._data = prepare_x(x=_data, alphas=self._alphas) self._columns = self._data.columns if label is not None: - self._label_mean = label.mean() + self._label_mean = reference.label_mean if reference else label.mean() self._label = prepare_y(y=label - self._label_mean, alphas=self._alphas) self._is_none_label = False