diff --git a/mqboost/objective.py b/mqboost/objective.py index 3652a77..aa42bf2 100644 --- a/mqboost/objective.py +++ b/mqboost/objective.py @@ -9,9 +9,12 @@ CHECK_LOSS: str = "check_loss" GradFnLike = Callable[[Any], np.ndarray] HessFnLike = Callable[[Any], np.ndarray] -CallGradHessLike = Callable[ +ObjLike = Callable[ [np.ndarray, DtrainLike, list[float], Any], tuple[np.ndarray, np.ndarray] ] +EvalLike = Callable[ + [np.ndarray, DtrainLike, list[float]], tuple[str, float, bool] | tuple[str, float] +] # check loss @@ -63,7 +66,7 @@ def _train_pred_reshape( # Compute gradient hessian logic -def compute_grad_hess(grad_fn: GradFnLike, hess_fn: HessFnLike) -> CallGradHessLike: +def compute_grad_hess(grad_fn: GradFnLike, hess_fn: HessFnLike) -> ObjLike: """Return computing gradient hessian function.""" def _compute_grads_hess( @@ -132,30 +135,32 @@ def _lgb_eval_loss( return CHECK_LOSS, loss, False -def validate_parameters(objective: ObjectiveName, delta: float, epsilon: float): +def validate_parameters(objective: ObjectiveName, delta: float, epsilon: float) -> None: if objective == ObjectiveName.huber: delta_validate(delta=delta) elif objective == ObjectiveName.approx: epsilon_validate(epsilon=epsilon) -def get_fobj_function(objective, alphas, delta: float, epsilon: float): - objective_mapping = { +def get_fobj_function( + objective: ObjectiveName, alphas: list[float], delta: float, epsilon: float +) -> ObjLike: + objective_mapping: dict[ObjectiveName, ObjLike] = { ObjectiveName.check: partial(check_loss_grad_hess, alphas=alphas), ObjectiveName.huber: partial(huber_loss_grad_hess, alphas=alphas, delta=delta), ObjectiveName.approx: partial( approx_loss_grad_hess, alphas=alphas, epsilon=epsilon ), } - return objective_mapping.get(objective) + return objective_mapping[objective] -def get_feval_function(model: ModelName, alphas): - model_mapping = { +def get_feval_function(model: ModelName, alphas: list[float]) -> EvalLike: + model_mapping: dict[ModelName, EvalLike] = { ModelName.lightgbm: partial(_lgb_eval_loss, alphas=alphas), ModelName.xgboost: partial(_xgb_eval_loss, alphas=alphas), } - return model_mapping.get(model) + return model_mapping[model] class MQObjective: