Skip to content

Commit

Permalink
remove non usable input in ftns and add **kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk committed Sep 10, 2024
1 parent c0532eb commit c213fd9
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions mqboost/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _rho(error: np.ndarray, alpha: float) -> np.ndarray:
return -error * _grad_rho(error=error, alpha=alpha)


def _hess_rho(error: np.ndarray, alpha: float) -> np.ndarray:
def _hess_rho(error: np.ndarray, **kwargs) -> np.ndarray:
"""Compute the Hessian of the check."""
return np.ones_like(error)

Expand All @@ -36,7 +36,7 @@ def _grad_huber(error: np.ndarray, alpha: float, delta: float) -> np.ndarray:
return _r * _smaller_delta + _grad * _bigger_delta


def _hess_huber(error: np.ndarray, alpha: float, delta: float) -> np.ndarray:
def _hess_huber(error: np.ndarray, **kwargs) -> np.ndarray:
"""Compute the Hessian of the huber loss function."""
return np.ones_like(error)

Expand All @@ -48,7 +48,7 @@ def _grad_approx(error: np.ndarray, alpha: float, epsilon: float):
return _grad


def _hess_approx(error: np.ndarray, alpha: float, epsilon: float):
def _hess_approx(error: np.ndarray, epsilon: float, **kwargs):
"""Compute the Hessian of the approx of the smooth approximated check loss function."""
_hess = 1 / (2 * (epsilon + np.abs(error)))
return _hess
Expand All @@ -70,7 +70,7 @@ def _compute_grads_hess(
alphas: list[float],
grad_fn: Callable[[np.ndarray, float, Any], np.ndarray],
hess_fn: Callable[[np.ndarray, float, Any], np.ndarray],
**kwargs: Any,
**kwargs: dict[str, float],
) -> tuple[np.ndarray, np.ndarray]:
"""Compute gradients and hessians for the given loss function."""
_len_alpha = len(alphas)
Expand Down

0 comments on commit c213fd9

Please sign in to comment.