Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Chore] simplify annotations #27

Merged
merged 8 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions mqboost/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Callable, Dict, List, Union
from typing import Callable

import lightgbm as lgb
import numpy as np
Expand All @@ -23,11 +23,12 @@ def _isin(cls, text: str) -> None:


# Type
XdataLike = Union[pd.DataFrame, pd.Series, np.ndarray]
YdataLike = Union[pd.Series, np.ndarray]
AlphaLike = Union[List[float], float]
ModelLike = Union[lgb.basic.Booster, xgb.Booster]
DtrainLike = Union[lgb.basic.Dataset, xgb.DMatrix]
XdataLike = pd.DataFrame | pd.Series | np.ndarray
YdataLike = pd.Series | np.ndarray
AlphaLike = list[float] | float
ModelLike = lgb.basic.Booster | xgb.Booster
DtrainLike = lgb.basic.Dataset | xgb.DMatrix
ParamsLike = dict[str, float | int | str | bool]


# Name
Expand Down Expand Up @@ -59,7 +60,7 @@ def _lgb_predict_dtype(data: XdataLike):
return data


FUNC_TYPE: Dict[ModelName, Dict[TypeName, Callable]] = {
FUNC_TYPE: dict[ModelName, dict[TypeName, Callable]] = {
ModelName.lightgbm: {
TypeName.train_dtype: lgb.Dataset,
TypeName.predict_dtype: _lgb_predict_dtype,
Expand Down
19 changes: 12 additions & 7 deletions mqboost/constraints.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
from typing import Any, Dict

import pandas as pd

from mqboost.base import FUNC_TYPE, ModelName, MQStr, TypeName, ValidationException
from mqboost.base import (
FUNC_TYPE,
ModelName,
MQStr,
ParamsLike,
TypeName,
ValidationException,
)


def set_monotone_constraints(
params: Dict[str, Any],
params: ParamsLike,
columns: pd.Index,
model_name: ModelName,
) -> Dict[str, Any]:
) -> ParamsLike:
"""
Set monotone constraints in params
Args:
params (Dict[str, Any])
params (ParamsLike)
columns (pd.Index)
model_name (ModelName)
Raises:
ValidationException: when "objective" is in params.keys()
Returns:
Dict[str, Any]
ParamsLike
"""
constraints_fucs = FUNC_TYPE.get(model_name).get(TypeName.constraints_type)
if MQStr.obj.value in params:
Expand Down
70 changes: 15 additions & 55 deletions mqboost/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, List, Optional, Union
from typing import Callable

import pandas as pd

Expand All @@ -21,7 +21,7 @@ class MQDataset:
It supports both LightGBM and XGBoost models, handling data preparation, validation, and conversion for training and prediction.

Attributes:
alphas (List[float]):
alphas (list[float]):
List of quantile levels.
Must be in ascending order and contain no duplicates.
data (pd.DataFrame): The input features.
Expand All @@ -45,7 +45,7 @@ def __init__(
self,
alphas: AlphaLike,
data: XdataLike,
label: Optional[YdataLike] = None,
label: YdataLike | None = None,
model: str = ModelName.lightgbm.value,
) -> None:
"""Initialize the MQDataset."""
Expand All @@ -66,94 +66,54 @@ def __init__(

@property
def train_dtype(self) -> Callable:
"""
Get the data type function for training data.
Returns:
Callable: The function that converts data to the required training data type.
"""
"""Get the data type function for training data."""
return self._train_dtype

@property
def predict_dtype(self) -> Callable:
"""
Get the data type function for prediction data.
Returns:
Callable: The function that converts data to the required prediction data type.
"""
"""Get the data type function for prediction data."""
return self._predict_dtype

@property
def model(self) -> ModelName:
"""
Get the model type.
Returns:
ModelName: The model type (LightGBM or XGBoost).
"""
"""Get the model type."""
return self._model

@property
def columns(self) -> pd.Index:
"""
Get the column names of the input features.
Returns:
pd.Index: The column names.
"""
"""Get the column names of the input features."""
return self._columns

@property
def nrow(self) -> int:
"""
Get the number of rows in the dataset.
Returns:
int: The number of rows.
"""
"""Get the number of rows in the dataset."""
return self._nrow

@property
def data(self) -> pd.DataFrame:
"""
Get the raw input features.
Returns:
pd.DataFrame: The input features.
"""
"""Get the raw input features."""
return self._data

@property
def label(self) -> pd.DataFrame:
"""
Get the raw target labels.
Returns:
pd.DataFrame: The target labels.
"""
"""Get the raw target labels."""
self.__label_available()
return self._label

@property
def alphas(self) -> List[float]:
"""
Get the list of quantile levels.
Returns:
List[float]: The quantile levels.
"""
def alphas(self) -> list[float]:
"""Get the list of quantile levels."""
return self._alphas

@property
def dtrain(self) -> DtrainLike:
"""
Get the training data in the required format for the model.
Returns:
DtrainLike: The training data.
"""
"""Get the training data in the required format for the model."""
self.__label_available()
return self._train_dtype(data=self._data, label=self._label)

@property
def dpredict(self) -> Union[DtrainLike, Callable]:
"""
Get the prediction data in the required format for the model.
Returns:
Union[DtrainLike, Callable]: The prediction data.
"""
def dpredict(self) -> DtrainLike | Callable:
"""Get the prediction data in the required format for the model."""
return self._predict_dtype(data=self._data)

def __label_available(self) -> None:
Expand Down
Loading
Loading