Skip to content

Commit

Permalink
Fix CategoricalHyperparameter.suggest_with_optuna()
Browse files Browse the repository at this point in the history
Was not taking into account choices arg

Also add kwargs to all suggest_with_optuna signature
  • Loading branch information
nhuet authored and g-poveda committed Apr 11, 2024
1 parent 7804d36 commit dd3f577
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,10 @@ def suggest_with_optuna(
Returns:
"""
if self.choices is not None and "choices" not in kwargs:
kwargs["choices"] = self.choices
return trial.suggest_categorical(name=prefix + self.name, **kwargs)
if choices is None:
choices = self.choices

return trial.suggest_categorical(name=prefix + self.name, choices=choices, **kwargs) # type: ignore


class EnumHyperparameter(CategoricalHyperparameter):
Expand Down Expand Up @@ -280,7 +281,7 @@ def suggest_with_optuna(
choices = self.choices
choices_str = [self.choices_cls2str[c] for c in choices]
choice_str = trial.suggest_categorical(
name=prefix + self.name, choices=choices_str
name=prefix + self.name, choices=choices_str, **kwargs # type: ignore
)
return self.choices_str2cls[choice_str]

Expand Down Expand Up @@ -312,6 +313,7 @@ def suggest_with_optuna(
kwargs_by_name: Optional[Dict[str, Dict[str, Any]]] = None,
fixed_hyperparameters: Optional[Dict[str, Any]] = None,
prefix: str = "",
**kwargs,
) -> Dict[str, Any]:
"""Suggest hyperparameter value for an Optuna trial.
Expand All @@ -331,6 +333,7 @@ def suggest_with_optuna(
if the subsolver class is not suggested by this method, but already fixed.
prefix: prefix to add to optuna corresponding parameter name
(useful for disambiguating hyperparameters from subsolvers in case of meta-solvers)
**kwargs: passed to `trial.suggest_categorical()`
Returns:
Expand All @@ -344,4 +347,5 @@ def suggest_with_optuna(
kwargs_by_name=kwargs_by_name,
fixed_hyperparameters=fixed_hyperparameters,
prefix=prefix,
**kwargs, # type: ignore
)
10 changes: 7 additions & 3 deletions tests/generic_tools/hyperparameters/test_hyperparameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,11 @@ def objective(trial: optuna.Trial) -> float:
suggested_hyperparameters_kwargs = (
DummySolver.suggest_hyperparameters_with_optuna(
trial=trial,
kwargs_by_name={"coeff": dict(step=0.5), "nb": dict(high=1)},
kwargs_by_name={
"coeff": dict(step=0.5),
"nb": dict(high=1),
"use_it": dict(choices=[True]),
},
)
)
assert len(suggested_hyperparameters_kwargs) == 4
Expand All @@ -142,7 +146,7 @@ def objective(trial: optuna.Trial) -> float:
assert 1 >= suggested_hyperparameters_kwargs["nb"]
assert -1.0 <= suggested_hyperparameters_kwargs["coeff"]
assert 1.0 >= suggested_hyperparameters_kwargs["coeff"]
assert suggested_hyperparameters_kwargs["use_it"] in (True, False)
assert suggested_hyperparameters_kwargs["use_it"] is True

return 0.0

Expand All @@ -151,7 +155,7 @@ def objective(trial: optuna.Trial) -> float:
)
study.optimize(objective)

assert len(study.trials) == 2 * 2 * 5 * 2
assert len(study.trials) == 2 * 2 * 5 * 1


def test_suggest_with_optuna_meta_solver():
Expand Down

0 comments on commit dd3f577

Please sign in to comment.