diff --git a/mqboost/engine.py b/mqboost/engine.py index a9e69f4..dfd7c49 100644 --- a/mqboost/engine.py +++ b/mqboost/engine.py @@ -207,13 +207,13 @@ def _study_func(trial: optuna.Trial) -> float: get_params_func=get_params_func, ) - study = optuna.create_study( + self._study = optuna.create_study( study_name=f"MQBoost_{self._model}", direction="minimize", load_if_exists=True, ) - study.optimize(_study_func, n_trials=n_trials) - return study.best_params + self._study.optimize(_study_func, n_trials=n_trials) + return self._study.best_params def __optuna_objective( self, @@ -282,3 +282,7 @@ def __is_xgb(self) -> bool: def __is_fitted(self) -> None: if not getattr(self, "_fitted", False): raise FittingException("train must be executed before predict") + + @property + def study(self) -> optuna.Study: + return getattr(self, "_study", None)