Skip to content

Commit

Permalink
Add possibility to finetune ortools-cpsat solvers
Browse files Browse the repository at this point in the history
Update ortools.sat.python.cp_model.CPSolver.parameters according to the
new argument ortools_cpsat_solver_kwargs.

Add an Example of ortools-spsat finetuning with optuna.
  • Loading branch information
nhuet authored and g-poveda committed Jul 4, 2024
1 parent 1885977 commit 17b2177
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 1 deletion.
9 changes: 8 additions & 1 deletion discrete_optimization/generic_tools/ortools_cpsat_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# LICENSE file in the root directory of this source tree.
import logging
from abc import abstractmethod
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional

from ortools.sat.python.cp_model import (
FEASIBLE,
Expand Down Expand Up @@ -59,6 +59,7 @@ def solve(
self,
callbacks: Optional[List[Callback]] = None,
parameters_cp: Optional[ParametersCP] = None,
ortools_cpsat_solver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> ResultStorage:
"""Solve the problem with a CPSat solver drom ortools library.
Expand All @@ -67,6 +68,8 @@ def solve(
callbacks: list of callbacks used to hook into the various stage of the solve
parameters_cp: parameters specific to cp solvers.
We use here only `parameters_cp.time_limit` and `parameters_cp.nb_process`.
ortools_cpsat_solver_kwargs: used to customize the underlying ortools solver.
Each key/value will update the corresponding attribute from the ortools.sat.python.cp_model.CPSolver
**kwargs: keyword arguments passed to `self.init_model()`
Returns:
Expand All @@ -88,6 +91,10 @@ def solve(
solver = CpSolver()
solver.parameters.max_time_in_seconds = parameters_cp.time_limit
solver.parameters.num_workers = parameters_cp.nb_process
if ortools_cpsat_solver_kwargs is not None:
# customize solver
for k, v in ortools_cpsat_solver_kwargs.items():
setattr(solver.parameters, k, v)
ortools_callback = OrtoolsCallback(do_solver=self, callback=callbacks_list)
status = solver.Solve(self.cp_model, ortools_callback)
self.status_solver = cpstatus_to_dostatus(status_from_cpsat=status)
Expand Down
209 changes: 209 additions & 0 deletions examples/coloring/optuna_ortools_cpsat_finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# Copyright (c) 2024 AIRBUS and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Example using OPTUNA to choose a solving method and tune its hyperparameters for coloring.
This example show different features of optuna integration with discrete-optimization:
- use of `suggest_hyperparameters_with_optuna()` to get hyperparameters values
- use of a dedicated callback to report intermediate results with corresponding time to optuna
and potentially prune the trial
- time-based pruner
- how to fix some parameters/hyperparameters
Results can be viewed on optuna-dashboard with:
optuna-dashboard optuna-journal.log
"""
import logging
from collections import defaultdict
from typing import Any, Dict, List, Type

from ortools.sat.sat_parameters_pb2 import SatParameters

from discrete_optimization.coloring.coloring_parser import (
get_data_available,
parse_file,
)
from discrete_optimization.coloring.coloring_solvers import (
ColoringASPSolver,
ColoringLP,
ParametersMilp,
solvers_map,
toulbar2_available,
)
from discrete_optimization.coloring.solvers.coloring_cp_solvers import ColoringCP
from discrete_optimization.coloring.solvers.coloring_cpsat_solver import (
ColoringCPSatSolver,
)
from discrete_optimization.coloring.solvers.coloring_lp_solvers import ColoringLP_MIP
from discrete_optimization.coloring.solvers.coloring_toulbar_solver import (
ToulbarColoringSolver,
)
from discrete_optimization.coloring.solvers.greedy_coloring import (
NXGreedyColoringMethod,
)
from discrete_optimization.generic_tools.cp_tools import ParametersCP
from discrete_optimization.generic_tools.do_solver import SolverDO
from discrete_optimization.generic_tools.hyperparameters.hyperparameter import (
CategoricalHyperparameter,
IntegerHyperparameter,
SubBrickKwargsHyperparameter,
)
from discrete_optimization.generic_tools.hyperparameters.hyperparametrizable import (
Hyperparametrizable,
)
from discrete_optimization.generic_tools.lp_tools import gurobi_available
from discrete_optimization.generic_tools.optuna.utils import (
generic_optuna_experiment_monoproblem,
)

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s:%(levelname)s:%(message)s")


seed = 42 # set this to an integer to get reproducible results, else to None
n_trials = 100 # number of trials to launch
create_another_study = True # True: generate a study name with timestamp to avoid overwriting previous study, False: keep same study name
max_time_per_solver = 20 # max duration per solver (seconds)
min_time_per_solver = 5 # min duration before pruning a solver (seconds)

modelfilename = "gc_70_9" # filename of the model used

study_basename = f"coloring-ortools-cpsat-finetune-{modelfilename}"

# solvers to test
solvers_to_test: List[Type[SolverDO]] = [ColoringCPSatSolver]
# fixed kwargs per solver: either hyperparameters we do not want to search, or other parameters like time limits
p = ParametersCP.default_cpsat()
p.nb_process = 6
p.time_limit = max_time_per_solver
p_m = ParametersMilp.default()
p_m.time_limit = max_time_per_solver
kwargs_fixed_by_solver: Dict[Type[SolverDO], Dict[str, Any]] = defaultdict(
dict, # default kwargs for unspecified solvers
{
ColoringCPSatSolver: dict(parameters_cp=p, warmstart=True),
ColoringCP: dict(parameters_cp=p),
ColoringLP: dict(parameters_milp=p_m),
ColoringASPSolver: dict(timeout_seconds=max_time_per_solver),
ToulbarColoringSolver: dict(time_limit=max_time_per_solver),
},
)

# restrict some hyperparameters choices, for some solvers (making use of `kwargs_by_name` of `suggest_with_optuna`)
suggest_optuna_kwargs_by_name_by_solver: Dict[
Type[SolverDO], Dict[str, Dict[str, Any]]
] = defaultdict(
dict, # default kwargs_by_name for unspecified solvers
{
ToulbarColoringSolver: { # options for ToulbarColoringSolver hyperparameters
"tolerance_delta_max": dict(low=1, high=2), # we restrict to [1, 2]
"greedy_method": dict( # we restrict the available choices for greedy_method
choices=[
NXGreedyColoringMethod.best,
NXGreedyColoringMethod.largest_first,
NXGreedyColoringMethod.random_sequential,
]
),
}
},
)

# finetuning hyperparameters for ortools-cpsat


class OrtoolsCpsatSolverKwargs(Hyperparametrizable):
hyperparameters = [
CategoricalHyperparameter(name="optimize_with_core", choices=[True, False]),
CategoricalHyperparameter(
name="search_branching",
choices={
"AUTOMATIC_SEARCH": SatParameters.AUTOMATIC_SEARCH,
"FIXED_SEARCH": SatParameters.FIXED_SEARCH,
"PORTFOLIO_SEARCH": SatParameters.PORTFOLIO_SEARCH,
"LP_SEARCH": SatParameters.LP_SEARCH,
"PSEUDO_COST_SEARCH": SatParameters.PSEUDO_COST_SEARCH,
"PORTFOLIO_WITH_QUICK_RESTART_SEARCH": SatParameters.PORTFOLIO_WITH_QUICK_RESTART_SEARCH,
"HINT_SEARCH": SatParameters.HINT_SEARCH,
"PARTIAL_FIXED_SEARCH": SatParameters.PARTIAL_FIXED_SEARCH,
"RANDOMIZED_SEARCH": SatParameters.RANDOMIZED_SEARCH,
},
),
IntegerHyperparameter(
name="boolean_encoding_level",
low=0,
high=3,
),
IntegerHyperparameter(
name="linearization_level",
low=0,
high=2,
),
IntegerHyperparameter(
name="cp_model_probing_level",
low=0,
high=3,
),
CategoricalHyperparameter(name="cp_model_presolve", choices=[True, False]),
CategoricalHyperparameter(
name="clause_cleanup_ordering",
choices={
"CLAUSE_ACTIVITY": SatParameters.CLAUSE_ACTIVITY,
"CLAUSE_LBD": SatParameters.CLAUSE_LBD,
},
),
CategoricalHyperparameter(
name="binary_minimization_algorithm",
choices={
"NO_BINARY_MINIMIZATION": SatParameters.NO_BINARY_MINIMIZATION,
"BINARY_MINIMIZATION_FIRST": SatParameters.BINARY_MINIMIZATION_FIRST,
"BINARY_MINIMIZATION_FIRST_WITH_TRANSITIVE_REDUCTION": SatParameters.BINARY_MINIMIZATION_FIRST_WITH_TRANSITIVE_REDUCTION,
"BINARY_MINIMIZATION_WITH_REACHABILITY": SatParameters.BINARY_MINIMIZATION_WITH_REACHABILITY,
"EXPERIMENTAL_BINARY_MINIMIZATION": SatParameters.EXPERIMENTAL_BINARY_MINIMIZATION,
},
),
CategoricalHyperparameter(
name="minimization_algorithm",
choices={
"NONE": SatParameters.NONE,
"SIMPLE": SatParameters.SIMPLE,
"RECURSIVE": SatParameters.RECURSIVE,
"EXPERIMENTAL": SatParameters.EXPERIMENTAL,
},
),
CategoricalHyperparameter(name="use_phase_saving", choices=[True, False]),
]


additional_hyperparameters_by_solver = defaultdict(
list,
{
ColoringCPSatSolver: [
SubBrickKwargsHyperparameter(
name="ortools_cpsat_solver_kwargs",
subbrick_cls=OrtoolsCpsatSolverKwargs,
)
]
},
)


# problem definition
file = [f for f in get_data_available() if "gc_70_9" in f][0]
problem = parse_file(file)

# generate and launch the optuna study
generic_optuna_experiment_monoproblem(
problem=problem,
solvers_to_test=solvers_to_test,
kwargs_fixed_by_solver=kwargs_fixed_by_solver,
suggest_optuna_kwargs_by_name_by_solver=suggest_optuna_kwargs_by_name_by_solver,
additional_hyperparameters_by_solver=additional_hyperparameters_by_solver,
n_trials=n_trials,
computation_time_in_study=True,
study_basename=study_basename,
create_another_study=create_another_study,
seed=seed,
min_time_per_solver=min_time_per_solver,
)
31 changes: 31 additions & 0 deletions tests/coloring/test_coloring.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,37 @@ def test_cpsat_solver(modeling):
assert color_problem.satisfy(solution)


def test_cpsat_solver_finetuned():
small_example = [f for f in get_data_available() if "gc_20_1" in f][0]
color_problem = parse_file(small_example)
solver = ColoringCPSatSolver(color_problem)
solver.init_model(nb_colors=20)
p = ParametersCP.default()

# must use existing attribute name for ortools CPSolver
with pytest.raises(AttributeError):
result_store = solver.solve(
parameters_cp=p, ortools_cpsat_solver_kwargs=dict(toto=4)
)
# must use correct value
with pytest.raises(ValueError):
result_store = solver.solve(
parameters_cp=p, ortools_cpsat_solver_kwargs=dict(search_branching=-4)
)
# works
from ortools.sat.sat_parameters_pb2 import SatParameters

result_store = solver.solve(
parameters_cp=p,
ortools_cpsat_solver_kwargs=dict(
search_branching=SatParameters.PSEUDO_COST_SEARCH
),
)

solution, fit = result_store.get_best_solution_fit()
assert color_problem.satisfy(solution)


def test_asp_solver():
small_example = [f for f in get_data_available() if "gc_20_1" in f][0]
color_problem = parse_file(small_example)
Expand Down

0 comments on commit 17b2177

Please sign in to comment.