Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] simplify code for readability #34

Merged
merged 3 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions mqboost/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,6 @@ class TypeName(BaseEnum):
constraints_type: str = "constraints_type"


class MQStr(BaseEnum):
mono: str = "monotone_constraints"
obj: str = "objective"
valid: str = "valid"


# Functions
def _lgb_predict_dtype(data: XdataLike):
return data
Expand Down
12 changes: 7 additions & 5 deletions mqboost/constraints.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pandas as pd

from mqboost.base import FUNC_TYPE, ModelName, MQStr, ParamsLike, TypeName
from mqboost.base import FUNC_TYPE, ModelName, ParamsLike, TypeName


def set_monotone_constraints(
Expand All @@ -19,16 +19,18 @@ def set_monotone_constraints(
Returns:
ParamsLike
"""
MONOTONE_CONSTRAINTS: str = "monotone_constraints"

constraints_fucs = FUNC_TYPE.get(model_name).get(TypeName.constraints_type)
_params = params.copy()
if MQStr.mono.value in _params:
_monotone_constraints = list(_params[MQStr.mono.value])
if MONOTONE_CONSTRAINTS in _params:
_monotone_constraints = list(_params[MONOTONE_CONSTRAINTS])
_monotone_constraints.append(1)
_params.update({MQStr.mono.value: constraints_fucs(_monotone_constraints)})
_params.update({MONOTONE_CONSTRAINTS: constraints_fucs(_monotone_constraints)})
else:
_params.update(
{
MQStr.mono.value: constraints_fucs(
MONOTONE_CONSTRAINTS: constraints_fucs(
[1 if "_tau" == col else 0 for col in columns]
)
}
Expand Down
75 changes: 39 additions & 36 deletions mqboost/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from mqboost.utils import delta_validate, epsilon_validate

CHECK_LOSS: str = "check_loss"
GradFnLike = Callable[[Any], np.ndarray]
HessFnLike = Callable[[Any], np.ndarray]
CallGradHessLike = Callable[
[np.ndarray, DtrainLike, list[float], Any], tuple[np.ndarray, np.ndarray]
]


# check loss
Expand Down Expand Up @@ -42,13 +47,13 @@ def _hess_huber(error: np.ndarray, **kwargs) -> np.ndarray:


# Approx loss (MM loss)
def _grad_approx(error: np.ndarray, alpha: float, epsilon: float):
def _grad_approx(error: np.ndarray, alpha: float, epsilon: float) -> np.ndarray:
"""Compute the gradient of the approx of the smooth approximated check loss function."""
_grad = 0.5 * (1 - 2 * alpha - error / (epsilon + np.abs(error)))
return _grad


def _hess_approx(error: np.ndarray, epsilon: float, **kwargs):
def _hess_approx(error: np.ndarray, epsilon: float, **kwargs) -> np.ndarray:
"""Compute the Hessian of the approx of the smooth approximated check loss function."""
_hess = 1 / (2 * (epsilon + np.abs(error)))
return _hess
Expand All @@ -64,40 +69,38 @@ def _train_pred_reshape(
return _y_train.reshape(len_alpha, -1), y_pred.reshape(len_alpha, -1)


def _compute_grads_hess(
y_pred: np.ndarray,
dtrain: DtrainLike,
alphas: list[float],
grad_fn: Callable[[np.ndarray, float, Any], np.ndarray],
hess_fn: Callable[[np.ndarray, float, Any], np.ndarray],
**kwargs: dict[str, float],
) -> tuple[np.ndarray, np.ndarray]:
"""Compute gradients and hessians for the given loss function."""
_len_alpha = len(alphas)
_y_train, _y_pred = _train_pred_reshape(
y_pred=y_pred, dtrain=dtrain, len_alpha=_len_alpha
)
grads = []
hess = []
for alpha_inx in range(len(alphas)):
_err_for_alpha = _y_train[alpha_inx] - _y_pred[alpha_inx]
_grad = grad_fn(error=_err_for_alpha, alpha=alphas[alpha_inx], **kwargs)
_hess = hess_fn(error=_err_for_alpha, alpha=alphas[alpha_inx], **kwargs)
grads.append(_grad)
hess.append(_hess)

return np.concatenate(grads), np.concatenate(hess)


check_loss_grad_hess: Callable = partial(
_compute_grads_hess, grad_fn=_grad_rho, hess_fn=_hess_rho
)
huber_loss_grad_hess: Callable = partial(
_compute_grads_hess, grad_fn=_grad_huber, hess_fn=_hess_huber
)
approx_loss_grad_hess: Callable = partial(
_compute_grads_hess, grad_fn=_grad_approx, hess_fn=_hess_approx
)
# Compute gradient hessian logic
def compute_grad_hess(grad_fn: GradFnLike, hess_fn: HessFnLike) -> CallGradHessLike:
"""Return computing gradient hessian function."""

def _compute_grads_hess(
y_pred: np.ndarray,
dtrain: DtrainLike,
alphas: list[float],
**kwargs: Any,
) -> tuple[np.ndarray, np.ndarray]:
"""Compute gradients and hessians for the given gradient and hessian function."""
_len_alpha = len(alphas)
_y_train, _y_pred = _train_pred_reshape(
y_pred=y_pred, dtrain=dtrain, len_alpha=_len_alpha
)
grads = []
hess = []
for alpha_inx in range(len(alphas)):
_err_for_alpha = _y_train[alpha_inx] - _y_pred[alpha_inx]
_grad = grad_fn(error=_err_for_alpha, alpha=alphas[alpha_inx], **kwargs)
_hess = hess_fn(error=_err_for_alpha, alpha=alphas[alpha_inx], **kwargs)
grads.append(_grad)
hess.append(_hess)

return np.concatenate(grads), np.concatenate(hess)

return _compute_grads_hess


check_loss_grad_hess = compute_grad_hess(grad_fn=_grad_rho, hess_fn=_hess_rho)
huber_loss_grad_hess = compute_grad_hess(grad_fn=_grad_huber, hess_fn=_hess_huber)
approx_loss_grad_hess = compute_grad_hess(grad_fn=_grad_approx, hess_fn=_hess_approx)


def _eval_check_loss(
Expand Down
3 changes: 1 addition & 2 deletions mqboost/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
DtrainLike,
FittingException,
ModelName,
MQStr,
ObjectiveName,
ParamsLike,
)
Expand Down Expand Up @@ -200,7 +199,7 @@ def __optuna_objective(
model_params = dict(
params=params,
dtrain=dtrain,
evals=[(dvalid, MQStr.valid.value)],
evals=[(dvalid, "valid")],
)
_gbm = xgb.train(**model_params)
_preds = _gbm.predict(data=deval)
Expand Down
4 changes: 2 additions & 2 deletions mqboost/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import xgboost as xgb

from mqboost.base import FittingException, ModelName, MQStr, ObjectiveName, ParamsLike
from mqboost.base import FittingException, ModelName, ObjectiveName, ParamsLike
from mqboost.constraints import set_monotone_constraints
from mqboost.dataset import MQDataset
from mqboost.objective import MQObjective
Expand Down Expand Up @@ -84,7 +84,7 @@ def fit(
epsilon=self._epsilon,
)
if self.__is_lgb:
params.update({MQStr.obj.value: self._MQObj.fobj})
params.update({"objective": self._MQObj.fobj})
self.model = lgb.train(
train_set=dataset.dtrain,
params=params,
Expand Down
3 changes: 1 addition & 2 deletions mqboost/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from mqboost.base import (
AlphaLike,
MQStr,
ParamsLike,
ValidationException,
XdataLike,
Expand Down Expand Up @@ -98,7 +97,7 @@ def epsilon_validate(epsilon: float) -> None:

def params_validate(params: ParamsLike) -> None:
"""Validates the model parameter ensuring its key dosen't contain 'objective'."""
if MQStr.obj.value in params:
if "objective" in params:
raise ValidationException(
"The parameter named 'objective' must be excluded in params"
)
98 changes: 1 addition & 97 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ scikit-learn = "^1.5.1"

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.7.1"
isort = "^5.13.2"
black = "^24.4.2"
pytest = "^8.3.2"
pytest-cov = "^5.0.0"

Expand Down
Loading
Loading