Skip to content

Commit

Permalink
fix(sklearn): check if estimator has get_params() and `set_params()…
Browse files Browse the repository at this point in the history
…` in `apply_custom_loss()` (#106)
  • Loading branch information
34j authored Nov 5, 2023
1 parent c59f4f9 commit 3b77017
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions src/boost_loss/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,21 @@ def fit(X: Any, y: Any, **fit_params: Any) -> Any:
estimator.set_params(objective=loss, eval_metric=loss.eval_metric_xgb_sklearn)

if recursive:
for key, value in estimator.get_params(deep=True).items():
if hasattr(value, "fit"):
estimator.set_params(
**{
key: apply_custom_loss(
value,
loss,
copy=False,
target_transformer=None,
recursive=False,
recursive_strict=False,
)
}
)
if hasattr(estimator, "get_params") and hasattr(estimator, "set_params"):
for key, value in estimator.get_params(deep=True).items():
if hasattr(value, "fit"):
estimator.set_params(
**{
key: apply_custom_loss(
value,
loss,
copy=False,
target_transformer=None,
recursive=False,
recursive_strict=False,
)
}
)
if recursive_strict:
if hasattr(estimator, "__dict__"):
for _, value in estimator.__dict__.items():
Expand Down Expand Up @@ -281,13 +282,15 @@ def predict(
return prediction[:, 0]
if return_std:
# see virtual_ensembles_predict() for details
return prediction, np.sqrt(
return (
prediction,
predict_var(
data,
ntree_end=ntree_end, # 0
thread_count=thread_count, # -1
verbose=verbose, # None
)
** 0.5,
)
return prediction

Expand Down

0 comments on commit 3b77017

Please sign in to comment.