-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Chain several subsolvers - Next subsolver warm started by the best solution of the previous one* - subsolvers must inherit from WarmstartMixin, except for first one - **kwargs needed in all solvers __init__() method as we apply - subsolver.__init__(problem=problem, **kwargs) - subsolver.init_model(**kwargs) - subsolver.solve(**kwargs)
- Loading branch information
Showing
16 changed files
with
209 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
91 changes: 91 additions & 0 deletions
91
discrete_optimization/generic_tools/sequential_metasolver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import logging | ||
from typing import Any, List, Optional | ||
|
||
from discrete_optimization.generic_tools.callbacks.callback import ( | ||
Callback, | ||
CallbackList, | ||
) | ||
from discrete_optimization.generic_tools.do_problem import ( | ||
ParamsObjectiveFunction, | ||
Problem, | ||
) | ||
from discrete_optimization.generic_tools.do_solver import SolverDO, WarmstartMixin | ||
from discrete_optimization.generic_tools.hyperparameters.hyperparameter import SubBrick | ||
from discrete_optimization.generic_tools.result_storage.result_storage import ( | ||
ResultStorage, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class SequentialMetasolver(SolverDO): | ||
"""Sequential metasolver. | ||
The problem will be solved sequentially, each subsolver being warm started by the previous one. | ||
Therefore each subsolver must inherit from WarmstartMixin, except the first one. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
problem: Problem, | ||
params_objective_function: Optional[ParamsObjectiveFunction] = None, | ||
list_subbricks: Optional[List[SubBrick]] = None, | ||
**kwargs, | ||
): | ||
""" | ||
Args: | ||
list_subbricks: list of subsolvers class and kwargs to be used sequentially | ||
""" | ||
super().__init__( | ||
problem=problem, params_objective_function=params_objective_function | ||
) | ||
self.list_subbricks = list_subbricks | ||
self.nb_solvers = len(list_subbricks) | ||
|
||
# checks | ||
if len(self.list_subbricks) == 0: | ||
raise ValueError("list_subbricks must contain at least one subbrick.") | ||
for i_subbrick, subbrick in enumerate(self.list_subbricks): | ||
if not issubclass(subbrick.cls, SolverDO): | ||
raise ValueError("Each subsolver must inherit SolverDO.") | ||
if i_subbrick > 0 and not issubclass(subbrick.cls, WarmstartMixin): | ||
raise ValueError( | ||
"Each subsolver except the first one must inherit WarmstartMixin." | ||
) | ||
|
||
def solve( | ||
self, callbacks: Optional[List[Callback]] = None, **kwargs: Any | ||
) -> ResultStorage: | ||
# wrap all callbacks in a single one | ||
callbacks_list = CallbackList(callbacks=callbacks) | ||
# start of solve callback | ||
callbacks_list.on_solve_start(solver=self) | ||
|
||
# iterate over next solvers | ||
res_tot = self.create_result_storage() | ||
for i_subbrick, subbrick in enumerate(self.list_subbricks): | ||
subsolver: SolverDO = subbrick.cls(problem=self.problem, **subbrick.kwargs) | ||
subsolver.init_model(**subbrick.kwargs) | ||
if i_subbrick > 0: | ||
subsolver.set_warm_start(res.get_best_solution()) | ||
res = subsolver.solve(**subbrick.kwargs) | ||
res_tot.extend(res) | ||
|
||
# end of step callback: stopping? | ||
stopping = callbacks_list.on_step_end( | ||
step=i_subbrick, res=res_tot, solver=self | ||
) | ||
if len(res) == 0: | ||
# no solution => warning + stopping if first subsolver | ||
logger.warning(f"Subsolver #{i_subbrick} did not find any solution.") | ||
if i_subbrick == 0: | ||
stopping = True | ||
if stopping: | ||
break | ||
|
||
# end of solve callback | ||
callbacks_list.on_solve_end(res=res_tot, solver=self) | ||
return res_tot |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright (c) 2024 AIRBUS and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import logging | ||
import random | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from discrete_optimization.generic_tools.callbacks.early_stoppers import TimerStopper | ||
from discrete_optimization.generic_tools.callbacks.loggers import ( | ||
NbIterationTracker, | ||
ObjectiveLogger, | ||
) | ||
from discrete_optimization.generic_tools.cp_tools import ParametersCP | ||
from discrete_optimization.generic_tools.hyperparameters.hyperparameter import SubBrick | ||
from discrete_optimization.generic_tools.ls.local_search import ( | ||
ModeMutation, | ||
RestartHandlerLimit, | ||
) | ||
from discrete_optimization.generic_tools.ls.simulated_annealing import ( | ||
SimulatedAnnealing, | ||
TemperatureSchedulingFactor, | ||
) | ||
from discrete_optimization.generic_tools.mutations.mixed_mutation import ( | ||
BasicPortfolioMutation, | ||
) | ||
from discrete_optimization.generic_tools.mutations.mutation_catalog import ( | ||
get_available_mutations, | ||
) | ||
from discrete_optimization.generic_tools.sequential_metasolver import ( | ||
SequentialMetasolver, | ||
) | ||
from discrete_optimization.rcpsp.rcpsp_parser import get_data_available, parse_file | ||
from discrete_optimization.rcpsp.solver import PileSolverRCPSP | ||
from discrete_optimization.rcpsp.solver.cpsat_solver import CPSatRCPSPSolver | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
@pytest.fixture | ||
def random_seed(): | ||
random.seed(0) | ||
np.random.seed(0) | ||
|
||
|
||
def test_sequential_metasolver_rcpsp(random_seed): | ||
logging.basicConfig(level=logging.INFO) | ||
|
||
files_available = get_data_available() | ||
file = [f for f in files_available if "j1201_1.sm" in f][0] | ||
rcpsp_problem = parse_file(file) | ||
|
||
# kwargs SA | ||
solution = rcpsp_problem.get_dummy_solution() | ||
_, list_mutation = get_available_mutations(rcpsp_problem, solution) | ||
list_mutation = [ | ||
mutate[0].build(rcpsp_problem, solution, **mutate[1]) | ||
for mutate in list_mutation | ||
] | ||
mixed_mutation = BasicPortfolioMutation( | ||
list_mutation, np.ones((len(list_mutation))) | ||
) | ||
restart_handler = RestartHandlerLimit(3000) | ||
temperature_handler = TemperatureSchedulingFactor(1000, restart_handler, 0.99) | ||
|
||
# kwargs cpsat | ||
parameters_cp = ParametersCP.default_cpsat() | ||
parameters_cp.time_limit = 20 | ||
parameters_cp.time_limit_iter0 = 20 | ||
|
||
list_subbricks = [ | ||
SubBrick(cls=PileSolverRCPSP, kwargs=dict()), | ||
SubBrick( | ||
cls=SimulatedAnnealing, | ||
kwargs=dict( | ||
mutator=mixed_mutation, | ||
restart_handler=restart_handler, | ||
temperature_handler=temperature_handler, | ||
mode_mutation=ModeMutation.MUTATE, | ||
nb_iteration_max=5000, | ||
), | ||
), | ||
SubBrick(cls=CPSatRCPSPSolver, kwargs=dict(parameters_cp=parameters_cp)), | ||
] | ||
|
||
solver = SequentialMetasolver(problem=rcpsp_problem, list_subbricks=list_subbricks) | ||
result_storage = solver.solve( | ||
callbacks=[ | ||
NbIterationTracker(step_verbosity_level=logging.INFO), | ||
ObjectiveLogger( | ||
step_verbosity_level=logging.INFO, end_verbosity_level=logging.INFO | ||
), | ||
TimerStopper(total_seconds=30), | ||
], | ||
) | ||
solution, fit = result_storage.get_best_solution_fit() | ||
print(solution, fit) | ||
assert rcpsp_problem.satisfy(solution) |