Skip to content

Commit

Permalink
[Chore] simplify annotations (#27)
Browse files Browse the repository at this point in the history
- Simplify annotations and remove too much docstring
- Simplify objective
  • Loading branch information
RektPunk authored Sep 7, 2024
1 parent bedae56 commit eca3887
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 304 deletions.
15 changes: 8 additions & 7 deletions mqboost/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Callable, Dict, List, Union
from typing import Callable

import lightgbm as lgb
import numpy as np
Expand All @@ -23,11 +23,12 @@ def _isin(cls, text: str) -> None:


# Type
XdataLike = Union[pd.DataFrame, pd.Series, np.ndarray]
YdataLike = Union[pd.Series, np.ndarray]
AlphaLike = Union[List[float], float]
ModelLike = Union[lgb.basic.Booster, xgb.Booster]
DtrainLike = Union[lgb.basic.Dataset, xgb.DMatrix]
XdataLike = pd.DataFrame | pd.Series | np.ndarray
YdataLike = pd.Series | np.ndarray
AlphaLike = list[float] | float
ModelLike = lgb.basic.Booster | xgb.Booster
DtrainLike = lgb.basic.Dataset | xgb.DMatrix
ParamsLike = dict[str, float | int | str | bool]


# Name
Expand Down Expand Up @@ -59,7 +60,7 @@ def _lgb_predict_dtype(data: XdataLike):
return data


FUNC_TYPE: Dict[ModelName, Dict[TypeName, Callable]] = {
FUNC_TYPE: dict[ModelName, dict[TypeName, Callable]] = {
ModelName.lightgbm: {
TypeName.train_dtype: lgb.Dataset,
TypeName.predict_dtype: _lgb_predict_dtype,
Expand Down
19 changes: 12 additions & 7 deletions mqboost/constraints.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
from typing import Any, Dict

import pandas as pd

from mqboost.base import FUNC_TYPE, ModelName, MQStr, TypeName, ValidationException
from mqboost.base import (
FUNC_TYPE,
ModelName,
MQStr,
ParamsLike,
TypeName,
ValidationException,
)


def set_monotone_constraints(
params: Dict[str, Any],
params: ParamsLike,
columns: pd.Index,
model_name: ModelName,
) -> Dict[str, Any]:
) -> ParamsLike:
"""
Set monotone constraints in params
Args:
params (Dict[str, Any])
params (ParamsLike)
columns (pd.Index)
model_name (ModelName)
Raises:
ValidationException: when "objective" is in params.keys()
Returns:
Dict[str, Any]
ParamsLike
"""
constraints_fucs = FUNC_TYPE.get(model_name).get(TypeName.constraints_type)
if MQStr.obj.value in params:
Expand Down
70 changes: 15 additions & 55 deletions mqboost/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, List, Optional, Union
from typing import Callable

import pandas as pd

Expand All @@ -21,7 +21,7 @@ class MQDataset:
It supports both LightGBM and XGBoost models, handling data preparation, validation, and conversion for training and prediction.
Attributes:
alphas (List[float]):
alphas (list[float]):
List of quantile levels.
Must be in ascending order and contain no duplicates.
data (pd.DataFrame): The input features.
Expand All @@ -45,7 +45,7 @@ def __init__(
self,
alphas: AlphaLike,
data: XdataLike,
label: Optional[YdataLike] = None,
label: YdataLike | None = None,
model: str = ModelName.lightgbm.value,
) -> None:
"""Initialize the MQDataset."""
Expand All @@ -66,94 +66,54 @@ def __init__(

@property
def train_dtype(self) -> Callable:
"""
Get the data type function for training data.
Returns:
Callable: The function that converts data to the required training data type.
"""
"""Get the data type function for training data."""
return self._train_dtype

@property
def predict_dtype(self) -> Callable:
"""
Get the data type function for prediction data.
Returns:
Callable: The function that converts data to the required prediction data type.
"""
"""Get the data type function for prediction data."""
return self._predict_dtype

@property
def model(self) -> ModelName:
"""
Get the model type.
Returns:
ModelName: The model type (LightGBM or XGBoost).
"""
"""Get the model type."""
return self._model

@property
def columns(self) -> pd.Index:
"""
Get the column names of the input features.
Returns:
pd.Index: The column names.
"""
"""Get the column names of the input features."""
return self._columns

@property
def nrow(self) -> int:
"""
Get the number of rows in the dataset.
Returns:
int: The number of rows.
"""
"""Get the number of rows in the dataset."""
return self._nrow

@property
def data(self) -> pd.DataFrame:
"""
Get the raw input features.
Returns:
pd.DataFrame: The input features.
"""
"""Get the raw input features."""
return self._data

@property
def label(self) -> pd.DataFrame:
"""
Get the raw target labels.
Returns:
pd.DataFrame: The target labels.
"""
"""Get the raw target labels."""
self.__label_available()
return self._label

@property
def alphas(self) -> List[float]:
"""
Get the list of quantile levels.
Returns:
List[float]: The quantile levels.
"""
def alphas(self) -> list[float]:
"""Get the list of quantile levels."""
return self._alphas

@property
def dtrain(self) -> DtrainLike:
"""
Get the training data in the required format for the model.
Returns:
DtrainLike: The training data.
"""
"""Get the training data in the required format for the model."""
self.__label_available()
return self._train_dtype(data=self._data, label=self._label)

@property
def dpredict(self) -> Union[DtrainLike, Callable]:
"""
Get the prediction data in the required format for the model.
Returns:
Union[DtrainLike, Callable]: The prediction data.
"""
def dpredict(self) -> DtrainLike | Callable:
"""Get the prediction data in the required format for the model."""
return self._predict_dtype(data=self._data)

def __label_available(self) -> None:
Expand Down
Loading

0 comments on commit eca3887

Please sign in to comment.