Skip to content

Commit

Permalink
[Feature] introduce weight in MQDataset (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk authored Oct 24, 2024
1 parent 7e05273 commit 64e2681
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 3 deletions.
2 changes: 1 addition & 1 deletion mqboost/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from mqboost.optimize import MQOptimizer
from mqboost.regressor import MQRegressor

__version__ = "0.2.8"
__version__ = "0.2.9"
1 change: 1 addition & 0 deletions mqboost/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def _isin(cls, text: str) -> None:
ModelLike = lgb.basic.Booster | xgb.Booster
DtrainLike = lgb.basic.Dataset | xgb.DMatrix
ParamsLike = dict[str, float | int | str | bool]
WeightLike = list[float] | list[int] | np.ndarray | pd.Series


# Name
Expand Down
16 changes: 15 additions & 1 deletion mqboost/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable, Optional

import numpy as np
import pandas as pd

from mqboost.base import (
Expand All @@ -9,6 +10,7 @@
FittingException,
ModelName,
TypeName,
WeightLike,
XdataLike,
YdataLike,
)
Expand All @@ -32,6 +34,7 @@ class MQDataset:
Must be in ascending order and contain no duplicates.
data (pd.DataFrame | pd.Series | np.ndarray): The input features.
label (pd.Series | np.ndarray): The target labels (if provided).
weight (list[float] | list[int] | np.ndarray | pd.Series): Weight for each instance (if provided).
model (str): The model type (LightGBM or XGBoost).
reference (MQBoost | None): Reference dataset for label encoding and label mean.
Expand All @@ -52,6 +55,7 @@ def __init__(
alphas: AlphaLike,
data: XdataLike,
label: YdataLike | None = None,
weight: WeightLike | None = None,
model: str = ModelName.lightgbm.value,
reference: Optional["MQDataset"] = None,
) -> None:
Expand Down Expand Up @@ -85,6 +89,10 @@ def __init__(
self._label = prepare_y(y=label - self._label_mean, alphas=self._alphas)
self._is_none_label = False

if weight is not None:
_weight = np.array(weight) if not isinstance(weight, np.ndarray) else weight
self._weight = prepare_y(y=_weight, alphas=self._alphas)

@property
def train_dtype(self) -> Callable:
"""Get the data type function for training data."""
Expand Down Expand Up @@ -123,14 +131,20 @@ def label(self) -> pd.DataFrame:

@property
def label_mean(self) -> float:
"""Get the label mean."""
self.__label_available()
return self._label_mean

@property
def weight(self) -> WeightLike | None:
"""Get the weights."""
return getattr(self, "_weight", None)

@property
def dtrain(self) -> DtrainLike:
"""Get the training data in the required format for the model."""
self.__label_available()
return self._train_dtype(data=self._data, label=self._label)
return self._train_dtype(data=self._data, label=self._label, weight=self.weight)

@property
def dpredict(self) -> DtrainLike | Callable:
Expand Down
1 change: 1 addition & 0 deletions mqboost/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def __optuna_objective(
params=params,
dtrain=dtrain,
evals=[(dvalid, "valid")],
num_boost_round=100,
)
_gbm = xgb.train(**model_params)
_preds = _gbm.predict(data=deval) + self._label_mean
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mqboost"
version = "0.2.8"
version = "0.2.9"
description = "Monotonic composite quantile gradient boost regressor"
authors = ["RektPunk <rektpunk@gmail.com>"]
readme = "README.md"
Expand Down

1 comment on commit 64e2681

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests Skipped Failures Errors Time
91 0 💤 0 ❌ 0 🔥 7.046s ⏱️

Please sign in to comment.