Skip to content

Commit

Permalink
label mean if reference
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk committed Oct 5, 2024
1 parent 9181925 commit 237dec9
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mqboost/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MQDataset:
data (pd.DataFrame | pd.Series | np.ndarray): The input features.
label (pd.Series | np.ndarray): The target labels (if provided).
model (str): The model type (LightGBM or XGBoost).
reference (MQBoost | None): Reference dataset for label encoding.
reference (MQBoost | None): Reference dataset for label encoding and label mean.
Property:
train_dtype: Returns the data type function for training data.
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(
self._data = prepare_x(x=_data, alphas=self._alphas)
self._columns = self._data.columns
if label is not None:
self._label_mean = label.mean()
self._label_mean = reference.label_mean if reference else label.mean()
self._label = prepare_y(y=label - self._label_mean, alphas=self._alphas)
self._is_none_label = False

Expand Down

0 comments on commit 237dec9

Please sign in to comment.