Skip to content

Commit

Permalink
[Chore] add type hint (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk authored Oct 6, 2024
1 parent d3c2098 commit 72d1525
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions mqboost/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
CHECK_LOSS: str = "check_loss"
GradFnLike = Callable[[Any], np.ndarray]
HessFnLike = Callable[[Any], np.ndarray]
CallGradHessLike = Callable[
ObjLike = Callable[
[np.ndarray, DtrainLike, list[float], Any], tuple[np.ndarray, np.ndarray]
]
EvalLike = Callable[
[np.ndarray, DtrainLike, list[float]], tuple[str, float, bool] | tuple[str, float]
]


# check loss
Expand Down Expand Up @@ -63,7 +66,7 @@ def _train_pred_reshape(


# Compute gradient hessian logic
def compute_grad_hess(grad_fn: GradFnLike, hess_fn: HessFnLike) -> CallGradHessLike:
def compute_grad_hess(grad_fn: GradFnLike, hess_fn: HessFnLike) -> ObjLike:
"""Return computing gradient hessian function."""

def _compute_grads_hess(
Expand Down Expand Up @@ -132,30 +135,32 @@ def _lgb_eval_loss(
return CHECK_LOSS, loss, False


def validate_parameters(objective: ObjectiveName, delta: float, epsilon: float):
def validate_parameters(objective: ObjectiveName, delta: float, epsilon: float) -> None:
if objective == ObjectiveName.huber:
delta_validate(delta=delta)
elif objective == ObjectiveName.approx:
epsilon_validate(epsilon=epsilon)


def get_fobj_function(objective, alphas, delta: float, epsilon: float):
objective_mapping = {
def get_fobj_function(
objective: ObjectiveName, alphas: list[float], delta: float, epsilon: float
) -> ObjLike:
objective_mapping: dict[ObjectiveName, ObjLike] = {
ObjectiveName.check: partial(check_loss_grad_hess, alphas=alphas),
ObjectiveName.huber: partial(huber_loss_grad_hess, alphas=alphas, delta=delta),
ObjectiveName.approx: partial(
approx_loss_grad_hess, alphas=alphas, epsilon=epsilon
),
}
return objective_mapping.get(objective)
return objective_mapping[objective]


def get_feval_function(model: ModelName, alphas):
model_mapping = {
def get_feval_function(model: ModelName, alphas: list[float]) -> EvalLike:
model_mapping: dict[ModelName, EvalLike] = {
ModelName.lightgbm: partial(_lgb_eval_loss, alphas=alphas),
ModelName.xgboost: partial(_xgb_eval_loss, alphas=alphas),
}
return model_mapping.get(model)
return model_mapping[model]


class MQObjective:
Expand Down

1 comment on commit 72d1525

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests Skipped Failures Errors Time
91 0 💤 0 ❌ 0 🔥 6.753s ⏱️

Please sign in to comment.