Skip to content

Commit

Permalink
[solvers] SolverOptions is serializable
Browse files Browse the repository at this point in the history
Newly deprecated:
- drake::solvers::MathematicalProgram::GetSolverOptionsDouble
- drake::solvers::MathematicalProgram::GetSolverOptionsInt
- drake::solvers::MathematicalProgram::GetSolverOptionsStr
- drake::solvers::SolverOptions::CheckOptionsKeysForSolver
- drake::solvers::SolverOptions::GetOptions
- drake::solvers::SolverOptions::GetSolverIds
- drake::solvers::SolverOptions::GetSolverOptionsDouble
- drake::solvers::SolverOptions::GetSolverOptionsInt
- drake::solvers::SolverOptions::GetSolverOptionsStr
- drake::solvers::SolverOptions::common_solver_options
- drake::solvers::SolverOptions::get_max_threads
- drake::solvers::SolverOptions::get_print_file_name
- drake::solvers::SolverOptions::get_print_to_console
- drake::solvers::SolverOptions::get_standalone_reproduction_file_name
- drake::solvers::SolverOptions::operator<<
  • Loading branch information
jwnimmer-tri committed Nov 15, 2024
1 parent b630cb1 commit 7b0dddf
Show file tree
Hide file tree
Showing 15 changed files with 668 additions and 502 deletions.
21 changes: 12 additions & 9 deletions bindings/pydrake/geometry/test/optimization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,10 +798,10 @@ def test_graph_of_convex_sets(self):
ClpSolver.id(), "log_level", 3)
options.parallelism = True
self.assertIn("scaling",
options.solver_options.GetOptions(ClpSolver.id()))
options.solver_options.options[ClpSolver().id().name()])
self.assertIn("log_level",
options.restriction_solver_options.GetOptions(
ClpSolver.id()))
options.restriction_solver_options.options[
ClpSolver().id().name()])
self.assertIn("convex_relaxation", repr(options))

spp = mut.GraphOfConvexSets()
Expand Down Expand Up @@ -1157,8 +1157,9 @@ def test_CspaceFreePolytope_Options(self):
self.assertEqual(find_separation_options.solver_id, ScsSolver.id())
self.assertFalse(find_separation_options.terminate_at_failure)
self.assertEqual(
find_separation_options.solver_options.common_solver_options()[
CommonSolverOption.kPrintToConsole], 1)
find_separation_options.solver_options.options[
"Drake"]["kPrintToConsole"],
1)

# FindSeparationCertificateGivenPolytopeOptions
lagrangian_options = \
Expand Down Expand Up @@ -1191,8 +1192,9 @@ def test_CspaceFreePolytope_Options(self):
self.assertFalse(
lagrangian_options.terminate_at_failure)
self.assertEqual(
lagrangian_options.solver_options.common_solver_options()[
CommonSolverOption.kPrintToConsole], 1)
lagrangian_options.solver_options.options[
"Drake"]["kPrintToConsole"],
1)
self.assertTrue(
lagrangian_options.ignore_redundant_C)

Expand Down Expand Up @@ -1230,8 +1232,9 @@ def test_CspaceFreePolytope_Options(self):
polytope_options.solver_id,
ScsSolver.id())
self.assertEqual(
polytope_options.solver_options.common_solver_options()[
CommonSolverOption.kPrintToConsole], 1)
polytope_options.solver_options.options[
"Drake"]["kPrintToConsole"],
1)
np.testing.assert_array_almost_equal(
polytope_options.s_inner_pts, np.zeros(
(2, 1)), 1e-5)
Expand Down
15 changes: 6 additions & 9 deletions bindings/pydrake/planning/test/graph_algorithms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ def test_max_clique_solver_via_mip_methods(self):
# Test the default constructor.
solver_default = mut.MaxCliqueSolverViaMip()
self.assertIsNone(solver_default.GetInitialGuess())
self.assertFalse(
solver_default.GetSolverOptions().get_print_to_console())
self.assertEqual(len(solver_default.GetSolverOptions().options), 0)

# Test the argument constructor.
solver_options = SolverOptions()
Expand All @@ -68,21 +67,19 @@ def test_max_clique_solver_via_mip_methods(self):
solver = mut.MaxCliqueSolverViaMip(solver_options=solver_options,
initial_guess=initial_guess)
# Test the getters.
numpy_compare.assert_equal(
solver.GetInitialGuess(), initial_guess
numpy_compare.assert_equal(solver.GetInitialGuess(), initial_guess)
self.assertTrue(
solver.GetSolverOptions().options["Drake"]["kPrintToConsole"],
)
self.assertTrue(solver.GetSolverOptions().get_print_to_console())

# Test the setters.
new_guess = np.zeros(graph.shape[0])
solver.SetInitialGuess(initial_guess=new_guess)
numpy_compare.assert_equal(
solver.GetInitialGuess(), new_guess
)
numpy_compare.assert_equal(solver.GetInitialGuess(), new_guess)

new_options = SolverOptions()
solver.SetSolverOptions(solver_options=new_options)
self.assertFalse(solver.GetSolverOptions().get_print_to_console())
self.assertEqual(len(solver.GetSolverOptions().options), 0)

# Test solve max clique.
if GurobiOrMosekSolverAvailable():
Expand Down
48 changes: 29 additions & 19 deletions bindings/pydrake/solvers/solvers_py_mathematicalprogram.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "drake/bindings/pydrake/autodiff_types_pybind.h"
#include "drake/bindings/pydrake/common/cpp_param_pybind.h"
#include "drake/bindings/pydrake/common/cpp_template_pybind.h"
#include "drake/bindings/pydrake/common/deprecation_pybind.h"
#include "drake/bindings/pydrake/common/eigen_pybind.h"
#include "drake/bindings/pydrake/documentation_pybind.h"
#include "drake/bindings/pydrake/pydrake_pybind.h"
Expand Down Expand Up @@ -1345,27 +1346,36 @@ void BindMathematicalProgram(py::module m) {
doc.MathematicalProgram.SetSolverOptions.doc)
.def("solver_options", &MathematicalProgram::solver_options,
py_rvp::reference_internal,
doc.MathematicalProgram.solver_options.doc)
// TODO(m-chaturvedi) Add Pybind11 documentation.
doc.MathematicalProgram.solver_options.doc);
// Deprecated 2025-05.
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
prog_cls // BR
.def("GetSolverOptions",
[](MathematicalProgram& prog, SolverId solver_id) {
py::dict out;
py::object update = out.attr("update");
update(prog.GetSolverOptionsDouble(solver_id));
update(prog.GetSolverOptionsInt(solver_id));
update(prog.GetSolverOptionsStr(solver_id));
return out;
})
WrapDeprecated(
doc.MathematicalProgram.GetSolverOptionsDouble.doc_deprecated,
[](MathematicalProgram& prog, SolverId solver_id) {
py::dict out;
py::object update = out.attr("update");
update(prog.GetSolverOptionsDouble(solver_id));
update(prog.GetSolverOptionsInt(solver_id));
update(prog.GetSolverOptionsStr(solver_id));
return out;
}))
.def("GetSolverOptions",
[](MathematicalProgram& prog, SolverType solver_type) {
py::dict out;
py::object update = out.attr("update");
const SolverId id = SolverTypeConverter::TypeToId(solver_type);
update(prog.GetSolverOptionsDouble(id));
update(prog.GetSolverOptionsInt(id));
update(prog.GetSolverOptionsStr(id));
return out;
})
WrapDeprecated(
doc.MathematicalProgram.GetSolverOptionsDouble.doc_deprecated,
[](MathematicalProgram& prog, SolverType solver_type) {
py::dict out;
py::object update = out.attr("update");
const SolverId id = SolverTypeConverter::TypeToId(solver_type);
update(prog.GetSolverOptionsDouble(id));
update(prog.GetSolverOptionsInt(id));
update(prog.GetSolverOptionsStr(id));
return out;
}));
#pragma GCC diagnostic pop
prog_cls // BR
.def("generic_costs", &MathematicalProgram::generic_costs,
doc.MathematicalProgram.generic_costs.doc)
.def("generic_constraints", &MathematicalProgram::generic_constraints,
Expand Down
96 changes: 50 additions & 46 deletions bindings/pydrake/solvers/solvers_py_options.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "drake/bindings/pydrake/common/deprecation_pybind.h"
#include "drake/bindings/pydrake/documentation_pybind.h"
#include "drake/bindings/pydrake/pydrake_pybind.h"
#include "drake/bindings/pydrake/solvers/solvers_py.h"
Expand Down Expand Up @@ -28,60 +29,63 @@ void DefineSolversOptions(py::module m) {
}

{
// TODO(jwnimmer-tri) Bind the accessors for SolverOptions.
py::class_<SolverOptions> cls(m, "SolverOptions", doc.SolverOptions.doc);
cls // BR
.def(py::init<>(), doc.SolverOptions.ctor.doc)
.def(py::init<>())
.def("SetOption",
py::overload_cast<const SolverId&, const std::string&, double>(
&SolverOptions::SetOption),
py::arg("solver_id"), py::arg("solver_option"),
py::arg("option_value"),
doc.SolverOptions.SetOption.doc_double_option)
.def("SetOption",
py::overload_cast<const SolverId&, const std::string&, int>(
&SolverOptions::SetOption),
py::arg("solver_id"), py::arg("solver_option"),
py::arg("option_value"), doc.SolverOptions.SetOption.doc_int_option)
.def("SetOption",
py::overload_cast<const SolverId&, const std::string&,
const std::string&>(&SolverOptions::SetOption),
py::arg("solver_id"), py::arg("solver_option"),
py::arg("option_value"), doc.SolverOptions.SetOption.doc_str_option)
py::overload_cast<const SolverId&, std::string,
SolverOptions::OptionValue>(&SolverOptions::SetOption),
py::arg("solver_id"), py::arg("key"), py::arg("value"),
doc.SolverOptions.SetOption.doc_3args)
.def("SetOption",
py::overload_cast<CommonSolverOption, SolverOptions::OptionValue>(
&SolverOptions::SetOption),
py::arg("key"), py::arg("value"),
doc.SolverOptions.SetOption.doc_common_option)
.def(
"GetOptions",
[](const SolverOptions& solver_options, SolverId solver_id) {
py::dict out;
py::object update = out.attr("update");
update(solver_options.GetOptionsDouble(solver_id));
update(solver_options.GetOptionsInt(solver_id));
update(solver_options.GetOptionsStr(solver_id));
return out;
},
py::arg("solver_id"), doc.SolverOptions.GetOptionsDouble.doc)
.def("common_solver_options", &SolverOptions::common_solver_options,
doc.SolverOptions.common_solver_options.doc)
.def("get_print_file_name", &SolverOptions::get_print_file_name,
doc.SolverOptions.get_print_file_name.doc)
.def("get_print_to_console", &SolverOptions::get_print_to_console,
doc.SolverOptions.get_print_to_console.doc)
.def("get_standalone_reproduction_file_name",
&SolverOptions::get_standalone_reproduction_file_name,
doc.SolverOptions.get_standalone_reproduction_file_name.doc)
.def("get_max_threads", &SolverOptions::get_max_threads,
doc.SolverOptions.get_max_threads.doc)
.def("__repr__", [](const SolverOptions&) -> std::string {
// This is a minimal implementation that serves to avoid displaying
// memory addresses in pydrake docs and help strings. In the future,
// we should enhance this to provide more details.
return "<SolverOptions>";
});
doc.SolverOptions.SetOption.doc_2args)
.def_readwrite("options", &SolverOptions::options)
.def("__repr__",
[](const SolverOptions& self) { return fmt::to_string(self); });
DefCopyAndDeepCopy(&cls);

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
// Deprecated 2025-05.
cls // BR
.def("GetOptions",
WrapDeprecated(doc.SolverOptions.GetOptionsDouble.doc_deprecated,
[](const SolverOptions& solver_options, SolverId solver_id) {
py::dict out;
py::object update = out.attr("update");
update(solver_options.GetOptionsDouble(solver_id));
update(solver_options.GetOptionsInt(solver_id));
update(solver_options.GetOptionsStr(solver_id));
return out;
}),
py::arg("solver_id"),
doc.SolverOptions.GetOptionsDouble.doc_deprecated)
.def("common_solver_options",
WrapDeprecated(
doc.SolverOptions.common_solver_options.doc_deprecated,
&SolverOptions::common_solver_options),
doc.SolverOptions.common_solver_options.doc_deprecated)
.def("get_print_file_name",
WrapDeprecated(doc.SolverOptions.get_print_file_name.doc_deprecated,
&SolverOptions::get_print_file_name),
doc.SolverOptions.get_print_file_name.doc_deprecated)
.def("get_print_to_console",
WrapDeprecated(
doc.SolverOptions.get_print_to_console.doc_deprecated,
&SolverOptions::get_print_to_console),
doc.SolverOptions.get_print_to_console.doc_deprecated)
.def("get_standalone_reproduction_file_name",
WrapDeprecated(
doc.SolverOptions.get_standalone_reproduction_file_name
.doc_deprecated,
&SolverOptions::get_standalone_reproduction_file_name))
.def("get_max_threads",
WrapDeprecated(doc.SolverOptions.get_max_threads.doc_deprecated,
&SolverOptions::get_max_threads),
doc.SolverOptions.get_max_threads.doc_deprecated);
}
}

Expand Down
37 changes: 28 additions & 9 deletions bindings/pydrake/solvers/test/mathematicalprogram_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydrake.autodiffutils import AutoDiffXd
from pydrake.common import kDrakeAssertIsArmed, Parallelism
from pydrake.common.test_utilities import numpy_compare
from pydrake.common.test_utilities.deprecation import catch_drake_warnings
from pydrake.forwarddiff import jacobian
from pydrake.math import ge
from pydrake.solvers import (
Expand Down Expand Up @@ -1305,13 +1306,18 @@ def test_mathematical_program_solver_options(self):
prog.SetSolverOption(solver, "india", 2)
prog.SetSolverOption(solver, "sierra", "3")
expected = {"foxtrot": 1.0, "india": 2, "sierra": "3"}
self.assertDictEqual(prog.GetSolverOptions(solver), expected)
with catch_drake_warnings(expected_count=1):
self.assertDictEqual(prog.GetSolverOptions(solver), expected)
old_options = prog.solver_options()
self.assertEqual(old_options.options, {
gurobi_id.name(): expected,
})
new_options = copy.deepcopy(old_options)
new_options.SetOption(gurobi_id, "india", 4)
prog.SetSolverOptions(new_options)
expected["india"] = 4
self.assertDictEqual(prog.GetSolverOptions(solver), expected)
with catch_drake_warnings(expected_count=1):
self.assertDictEqual(prog.GetSolverOptions(solver), expected)

def test_solver_options(self):
CSO = mp.CommonSolverOption
Expand All @@ -1331,13 +1337,26 @@ def test_solver_options(self):
CSO.kStandaloneReproductionFileName: "repro.txt",
CSO.kMaxThreads: 4,
}
self.assertDictEqual(dut.GetOptions(solver_id), expected_dummy)
self.assertEqual(dut.common_solver_options(), expected_common)
self.assertEqual(dut.get_print_to_console(), True)
self.assertEqual(dut.get_print_file_name(), "print.log")
self.assertEqual(dut.get_standalone_reproduction_file_name(),
"repro.txt")
self.assertEqual(dut.get_max_threads(), 4)
with catch_drake_warnings(expected_count=1):
self.assertDictEqual(dut.GetOptions(solver_id), expected_dummy)
with catch_drake_warnings(expected_count=1):
self.assertEqual(dut.common_solver_options(), expected_common)
self.assertEqual(dut.options, {
"dummy": expected_dummy,
"Drake": dict(
(key.name, value)
for key, value in expected_common.items()
)
})
with catch_drake_warnings(expected_count=1):
self.assertEqual(dut.get_print_to_console(), True)
with catch_drake_warnings(expected_count=1):
self.assertEqual(dut.get_print_file_name(), "print.log")
with catch_drake_warnings(expected_count=1):
self.assertEqual(dut.get_standalone_reproduction_file_name(),
"repro.txt")
with catch_drake_warnings(expected_count=1):
self.assertEqual(dut.get_max_threads(), 4)
copy.deepcopy(dut)

def test_infeasible_constraints(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ GTEST_TEST(MaxCliqueSolverViaMipTest, TestConstructorSettersAndGetters) {
solvers::SolverOptions options{};
options.SetOption(solvers::CommonSolverOption::kPrintToConsole, 1);
solver.SetSolverOptions(options);
EXPECT_TRUE(solver.GetSolverOptions().get_print_to_console());
EXPECT_EQ(solver.GetSolverOptions(), options);

// Test the constructor with the initial guess and solver options passed.
MaxCliqueSolverViaMip solver2{initial_guess, options};
EXPECT_TRUE(solver2.GetInitialGuess().has_value());
EXPECT_TRUE(
CompareMatrices(solver2.GetInitialGuess().value(), initial_guess));
EXPECT_TRUE(solver2.GetSolverOptions().get_print_to_console());
EXPECT_EQ(solver2.GetSolverOptions(), options);
}

GTEST_TEST(MaxCliqueSolverViaMipTest, TestClone) {
Expand All @@ -88,7 +88,7 @@ GTEST_TEST(MaxCliqueSolverViaMipTest, TestClone) {
ASSERT_FALSE(solver_clone_mip == nullptr);
EXPECT_TRUE(CompareMatrices(solver.GetInitialGuess().value(),
solver_clone_mip->GetInitialGuess().value()));
EXPECT_EQ(solver_clone_mip->GetSolverOptions().get_print_to_console(), 1);
EXPECT_EQ(solver_clone_mip->GetSolverOptions(), options);
}

GTEST_TEST(MaxCliqueSolverViaMipTest, CompleteGraph) {
Expand Down
4 changes: 4 additions & 0 deletions solvers/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,10 @@ drake_cc_library(
],
deps = [
":solver_id",
"//common:string_container",
],
implementation_deps = [
"//common:overloaded",
],
)

Expand Down
Loading

0 comments on commit 7b0dddf

Please sign in to comment.