Skip to content

Commit

Permalink
feat(sklearn): add copy_loss, apply_objective and `apply_eval_met…
Browse files Browse the repository at this point in the history
…ric` parameter to `apply_custom_loss()` (#108)
  • Loading branch information
34j authored Nov 7, 2023
1 parent 0e53f08 commit e877604
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 9 deletions.
10 changes: 10 additions & 0 deletions src/boost_loss/regression/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def __init__(
random_state: int | None = None,
m_type: Literal["mean", "median"] = "mean",
var_type: Literal["var", "std", "range", "mae", "mse"] = "std",
apply_objective: bool = True,
apply_eval_metric: bool = True,
target_transformer: BaseEstimator | Any | None = None,
recursive: bool = True,
recursive_strict: bool = False,
Expand Down Expand Up @@ -104,6 +106,10 @@ def __init__(
M-statistics type to return from `predict` by default, by default "median"
var_type : Literal["var", "std", "range", "mae", "mse"], optional
Variance type to return from `predict` by default, by default "var"
apply_objective : bool, optional
Whether to apply the custom loss to the estimator's objective, by default True
apply_eval_metric : bool, optional
Whether to apply the custom loss to the estimator's eval_metric, by default True
target_transformer : BaseEstimator | Any | None, optional
The transformer to use for transforming the target, by default None
If `None`, no `TransformedTargetRegressor` is used.
Expand All @@ -130,6 +136,8 @@ def __init__(
self.random_state = random_state
self.m_type = m_type
self.var_type = var_type
self.apply_objective = apply_objective
self.apply_eval_metric = apply_eval_metric
self.target_transformer = target_transformer
self.recursive = recursive
self.recursive_strict = recursive_strict
Expand Down Expand Up @@ -163,6 +171,8 @@ def fit(self, X: Any, y: Any, **fit_params: Any) -> Self:
apply_custom_loss(
self.estimator,
AsymmetricLoss(self.loss, t=t),
apply_objective=self.apply_objective,
apply_eval_metric=self.apply_eval_metric,
target_transformer=self.target_transformer,
recursive=self.recursive,
recursive_strict=self.recursive_strict,
Expand Down
53 changes: 44 additions & 9 deletions src/boost_loss/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import functools
import importlib.util
from copy import deepcopy
from typing import Any, Literal, TypeVar, overload

import catboost as cb
Expand All @@ -23,6 +24,9 @@ def apply_custom_loss(
loss: LossBase,
*,
copy: bool = ...,
copy_loss: bool = ...,
apply_objective: bool = ...,
apply_eval_metric: bool = ...,
target_transformer: None = ...,
recursive: bool = ...,
recursive_strict: bool = ...,
Expand All @@ -36,6 +40,9 @@ def apply_custom_loss(
loss: LossBase,
*,
copy: bool = ...,
copy_loss: bool = ...,
apply_objective: bool = ...,
apply_eval_metric: bool = ...,
target_transformer: BaseEstimator = ...,
recursive: bool = ...,
recursive_strict: bool = ...,
Expand All @@ -48,6 +55,9 @@ def apply_custom_loss(
loss: LossBase,
*,
copy: bool = True,
copy_loss: bool = True,
apply_objective: bool = True,
apply_eval_metric: bool = True,
target_transformer: BaseEstimator | Any | None = StandardScaler(),
recursive: bool = True,
recursive_strict: bool = False,
Expand All @@ -62,6 +72,12 @@ def apply_custom_loss(
The custom loss to apply
copy : bool, optional
Whether to copy the estimator using `sklearn.base.clone`, by default True
copy_loss : bool, optional
Whether to copy the loss using `copy.deepcopy`, by default True
apply_objective : bool, optional
Whether to apply the custom loss to the estimator's objective, by default True
apply_eval_metric : bool, optional
Whether to apply the custom loss to the estimator's eval_metric, by default True
target_transformer : BaseEstimator | Any | None, optional
The target transformer to use, by default StandardScaler()
(This option exists because some loss functions require the target
Expand All @@ -81,20 +97,30 @@ def apply_custom_loss(
"""
if copy:
estimator = clone(estimator)
if copy_loss:
loss = deepcopy(loss)
if isinstance(estimator, cb.CatBoost):
estimator.set_params(loss_function=loss, eval_metric=loss)
if apply_objective:
estimator.set_params(loss_function=loss)
if apply_eval_metric:
estimator.set_params(eval_metric=loss)
if isinstance(estimator, lgb.LGBMModel):
estimator.set_params(objective=loss)
estimator_fit = estimator.fit
if apply_objective:
estimator.set_params(objective=loss)
if apply_eval_metric:
estimator_fit = estimator.fit

@functools.wraps(estimator_fit)
def fit(X: Any, y: Any, **fit_params: Any) -> Any:
fit_params["eval_metric"] = loss.eval_metric_lgb
return estimator_fit(X, y, **fit_params)
@functools.wraps(estimator_fit)
def fit(X: Any, y: Any, **fit_params: Any) -> Any:
fit_params["eval_metric"] = loss.eval_metric_lgb
return estimator_fit(X, y, **fit_params)

setattr(estimator, "fit", fit)
setattr(estimator, "fit", fit)
if isinstance(estimator, xgb.XGBModel):
estimator.set_params(objective=loss, eval_metric=loss.eval_metric_xgb_sklearn)
if apply_objective:
estimator.set_params(objective=loss)
if apply_eval_metric:
estimator.set_params(eval_metric=loss.eval_metric_xgb_sklearn)

if recursive:
if hasattr(estimator, "get_params") and hasattr(estimator, "set_params"):
Expand All @@ -106,6 +132,9 @@ def fit(X: Any, y: Any, **fit_params: Any) -> Any:
value,
loss,
copy=False,
copy_loss=copy_loss,
apply_objective=apply_objective,
apply_eval_metric=apply_eval_metric,
target_transformer=None,
recursive=False,
recursive_strict=False,
Expand All @@ -119,6 +148,9 @@ def fit(X: Any, y: Any, **fit_params: Any) -> Any:
value,
loss,
copy=False,
copy_loss=copy_loss,
apply_objective=apply_objective,
apply_eval_metric=apply_eval_metric,
target_transformer=None,
recursive=True,
recursive_strict=True,
Expand All @@ -130,6 +162,9 @@ def fit(X: Any, y: Any, **fit_params: Any) -> Any:
value,
loss,
copy=False,
copy_loss=copy_loss,
apply_objective=apply_objective,
apply_eval_metric=apply_eval_metric,
target_transformer=None,
recursive=True,
recursive_strict=True,
Expand Down

0 comments on commit e877604

Please sign in to comment.