diff --git a/mqboost/dataset.py b/mqboost/dataset.py index 5dea1e7..e43a5e9 100644 --- a/mqboost/dataset.py +++ b/mqboost/dataset.py @@ -33,7 +33,7 @@ class MQDataset: data (pd.DataFrame | pd.Series | np.ndarray): The input features. label (pd.Series | np.ndarray): The target labels (if provided). model (str): The model type (LightGBM or XGBoost). - reference (MQBoost | None): Reference dataset for label encoding. + reference (MQBoost | None): Reference dataset for label encoding and label mean. Property: train_dtype: Returns the data type function for training data. @@ -81,7 +81,7 @@ def __init__( self._data = prepare_x(x=_data, alphas=self._alphas) self._columns = self._data.columns if label is not None: - self._label_mean = label.mean() + self._label_mean = reference.label_mean if reference else label.mean() self._label = prepare_y(y=label - self._label_mean, alphas=self._alphas) self._is_none_label = False