Skip to content

Commit

Permalink
[Feature] simplify code for readability (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk authored Sep 12, 2024
1 parent 843bdeb commit 6d0acfe
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 164 deletions.
6 changes: 0 additions & 6 deletions mqboost/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,6 @@ class TypeName(BaseEnum):
constraints_type: str = "constraints_type"


class MQStr(BaseEnum):
mono: str = "monotone_constraints"
obj: str = "objective"
valid: str = "valid"


# Functions
def _lgb_predict_dtype(data: XdataLike):
return data
Expand Down
12 changes: 7 additions & 5 deletions mqboost/constraints.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pandas as pd

from mqboost.base import FUNC_TYPE, ModelName, MQStr, ParamsLike, TypeName
from mqboost.base import FUNC_TYPE, ModelName, ParamsLike, TypeName


def set_monotone_constraints(
Expand All @@ -19,16 +19,18 @@ def set_monotone_constraints(
Returns:
ParamsLike
"""
MONOTONE_CONSTRAINTS: str = "monotone_constraints"

constraints_fucs = FUNC_TYPE.get(model_name).get(TypeName.constraints_type)
_params = params.copy()
if MQStr.mono.value in _params:
_monotone_constraints = list(_params[MQStr.mono.value])
if MONOTONE_CONSTRAINTS in _params:
_monotone_constraints = list(_params[MONOTONE_CONSTRAINTS])
_monotone_constraints.append(1)
_params.update({MQStr.mono.value: constraints_fucs(_monotone_constraints)})
_params.update({MONOTONE_CONSTRAINTS: constraints_fucs(_monotone_constraints)})
else:
_params.update(
{
MQStr.mono.value: constraints_fucs(
MONOTONE_CONSTRAINTS: constraints_fucs(
[1 if "_tau" == col else 0 for col in columns]
)
}
Expand Down
75 changes: 39 additions & 36 deletions mqboost/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from mqboost.utils import delta_validate, epsilon_validate

CHECK_LOSS: str = "check_loss"
GradFnLike = Callable[[Any], np.ndarray]
HessFnLike = Callable[[Any], np.ndarray]
CallGradHessLike = Callable[
[np.ndarray, DtrainLike, list[float], Any], tuple[np.ndarray, np.ndarray]
]


# check loss
Expand Down Expand Up @@ -42,13 +47,13 @@ def _hess_huber(error: np.ndarray, **kwargs) -> np.ndarray:


# Approx loss (MM loss)
def _grad_approx(error: np.ndarray, alpha: float, epsilon: float):
def _grad_approx(error: np.ndarray, alpha: float, epsilon: float) -> np.ndarray:
"""Compute the gradient of the approx of the smooth approximated check loss function."""
_grad = 0.5 * (1 - 2 * alpha - error / (epsilon + np.abs(error)))
return _grad


def _hess_approx(error: np.ndarray, epsilon: float, **kwargs):
def _hess_approx(error: np.ndarray, epsilon: float, **kwargs) -> np.ndarray:
"""Compute the Hessian of the approx of the smooth approximated check loss function."""
_hess = 1 / (2 * (epsilon + np.abs(error)))
return _hess
Expand All @@ -64,40 +69,38 @@ def _train_pred_reshape(
return _y_train.reshape(len_alpha, -1), y_pred.reshape(len_alpha, -1)


def _compute_grads_hess(
y_pred: np.ndarray,
dtrain: DtrainLike,
alphas: list[float],
grad_fn: Callable[[np.ndarray, float, Any], np.ndarray],
hess_fn: Callable[[np.ndarray, float, Any], np.ndarray],
**kwargs: dict[str, float],
) -> tuple[np.ndarray, np.ndarray]:
"""Compute gradients and hessians for the given loss function."""
_len_alpha = len(alphas)
_y_train, _y_pred = _train_pred_reshape(
y_pred=y_pred, dtrain=dtrain, len_alpha=_len_alpha
)
grads = []
hess = []
for alpha_inx in range(len(alphas)):
_err_for_alpha = _y_train[alpha_inx] - _y_pred[alpha_inx]
_grad = grad_fn(error=_err_for_alpha, alpha=alphas[alpha_inx], **kwargs)
_hess = hess_fn(error=_err_for_alpha, alpha=alphas[alpha_inx], **kwargs)
grads.append(_grad)
hess.append(_hess)

return np.concatenate(grads), np.concatenate(hess)


check_loss_grad_hess: Callable = partial(
_compute_grads_hess, grad_fn=_grad_rho, hess_fn=_hess_rho
)
huber_loss_grad_hess: Callable = partial(
_compute_grads_hess, grad_fn=_grad_huber, hess_fn=_hess_huber
)
approx_loss_grad_hess: Callable = partial(
_compute_grads_hess, grad_fn=_grad_approx, hess_fn=_hess_approx
)
# Compute gradient hessian logic
def compute_grad_hess(grad_fn: GradFnLike, hess_fn: HessFnLike) -> CallGradHessLike:
"""Return computing gradient hessian function."""

def _compute_grads_hess(
y_pred: np.ndarray,
dtrain: DtrainLike,
alphas: list[float],
**kwargs: Any,
) -> tuple[np.ndarray, np.ndarray]:
"""Compute gradients and hessians for the given gradient and hessian function."""
_len_alpha = len(alphas)
_y_train, _y_pred = _train_pred_reshape(
y_pred=y_pred, dtrain=dtrain, len_alpha=_len_alpha
)
grads = []
hess = []
for alpha_inx in range(len(alphas)):
_err_for_alpha = _y_train[alpha_inx] - _y_pred[alpha_inx]
_grad = grad_fn(error=_err_for_alpha, alpha=alphas[alpha_inx], **kwargs)
_hess = hess_fn(error=_err_for_alpha, alpha=alphas[alpha_inx], **kwargs)
grads.append(_grad)
hess.append(_hess)

return np.concatenate(grads), np.concatenate(hess)

return _compute_grads_hess


check_loss_grad_hess = compute_grad_hess(grad_fn=_grad_rho, hess_fn=_hess_rho)
huber_loss_grad_hess = compute_grad_hess(grad_fn=_grad_huber, hess_fn=_hess_huber)
approx_loss_grad_hess = compute_grad_hess(grad_fn=_grad_approx, hess_fn=_hess_approx)


def _eval_check_loss(
Expand Down
3 changes: 1 addition & 2 deletions mqboost/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
DtrainLike,
FittingException,
ModelName,
MQStr,
ObjectiveName,
ParamsLike,
)
Expand Down Expand Up @@ -200,7 +199,7 @@ def __optuna_objective(
model_params = dict(
params=params,
dtrain=dtrain,
evals=[(dvalid, MQStr.valid.value)],
evals=[(dvalid, "valid")],
)
_gbm = xgb.train(**model_params)
_preds = _gbm.predict(data=deval)
Expand Down
4 changes: 2 additions & 2 deletions mqboost/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import xgboost as xgb

from mqboost.base import FittingException, ModelName, MQStr, ObjectiveName, ParamsLike
from mqboost.base import FittingException, ModelName, ObjectiveName, ParamsLike
from mqboost.constraints import set_monotone_constraints
from mqboost.dataset import MQDataset
from mqboost.objective import MQObjective
Expand Down Expand Up @@ -84,7 +84,7 @@ def fit(
epsilon=self._epsilon,
)
if self.__is_lgb:
params.update({MQStr.obj.value: self._MQObj.fobj})
params.update({"objective": self._MQObj.fobj})
self.model = lgb.train(
train_set=dataset.dtrain,
params=params,
Expand Down
3 changes: 1 addition & 2 deletions mqboost/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from mqboost.base import (
AlphaLike,
MQStr,
ParamsLike,
ValidationException,
XdataLike,
Expand Down Expand Up @@ -98,7 +97,7 @@ def epsilon_validate(epsilon: float) -> None:

def params_validate(params: ParamsLike) -> None:
"""Validates the model parameter ensuring its key dosen't contain 'objective'."""
if MQStr.obj.value in params:
if "objective" in params:
raise ValidationException(
"The parameter named 'objective' must be excluded in params"
)
98 changes: 1 addition & 97 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ scikit-learn = "^1.5.1"

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.7.1"
isort = "^5.13.2"
black = "^24.4.2"
pytest = "^8.3.2"
pytest-cov = "^5.0.0"

Expand Down
Loading

1 comment on commit 6d0acfe

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests Skipped Failures Errors Time
86 0 💤 0 ❌ 0 🔥 5.712s ⏱️

Please sign in to comment.