Skip to content

Commit

Permalink
fixup! make test pass, simplify overloads
Browse files Browse the repository at this point in the history
  • Loading branch information
EricCousineau-TRI committed May 14, 2021
1 parent 7869d6e commit dbe7506
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 31 deletions.
7 changes: 2 additions & 5 deletions bindings/pydrake/symbolic_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -408,14 +408,11 @@ PYBIND11_MODULE(symbolic, m) {
.def("Differentiate", &Expression::Differentiate,
doc.Expression.Differentiate.doc)
.def("Jacobian", &Expression::Jacobian, doc.Expression.Jacobian.doc);
// TODO(eric.cousineau): Clean this overload stuff up (#15041).
pydrake::internal::BindSymbolicMathOverloads<Expression>(&expr_cls);
pydrake::internal::BindSymbolicMathOverloads<Expression>(&m);
DefCopyAndDeepCopy(&expr_cls);

// TODO(eric.cousineau): These should actually exist on the class, and should
// be should be consolidated with the above repeated definitions. This would
// yield the same parity with AutoDiff.
pydrake::internal::BindSymbolicMathModuleOverloads(m);

m.def("if_then_else", &symbolic::if_then_else);

m.def(
Expand Down
12 changes: 3 additions & 9 deletions bindings/pydrake/symbolic_types_pybind.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,9 @@ void BindSymbolicMathOverloads(PyObject* obj) {
.def("ceil", &symbolic::ceil, doc.ceil.doc)
.def("__ceil__", &symbolic::ceil, doc.ceil.doc)
.def("floor", &symbolic::floor, doc.floor.doc)
.def("__floor__", &symbolic::floor, doc.floor.doc);
}

inline void BindSymbolicMathModuleOverloads(py::module m) {
using symbolic::Expression;
BindSymbolicMathOverloads<Expression>(&m);
m // BR
// TODO(eric.cousineau): This is not a NumPy-overridable method using
// dtype=object. Deprecate and move solely into `pydrake.math`.
.def("__floor__", &symbolic::floor, doc.floor.doc)
// TODO(eric.cousineau): This is not a NumPy-overridable method using
// dtype=object. Deprecate and move solely into `pydrake.math`.
.def(
"inv",
[](const MatrixX<Expression>& X) -> MatrixX<Expression> {
Expand Down
47 changes: 30 additions & 17 deletions bindings/pydrake/test/math_overloads_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def supports(self, func):
supported = backwards_compat
if func.__name__ in backwards_compat:
# Check backwards compatibility.
assert hasattr(self.m, func.__name__)
assert hasattr(self.m, func.__name__), self.m.__name__
return func.__name__ in supported

def to_float(self, y_T):
Expand Down Expand Up @@ -158,22 +158,35 @@ def check_eval(functions, nargs):
args_T = list(map(overload.to_type, args_float))
# Check each supported function.
for f_drake, f_builtin in functions:
if not overload.supports(f_drake):
continue
debug_print(
"- Functions: ", qualname(f_drake), qualname(f_builtin))
y_builtin = f_builtin(*args_float)
y_float = f_drake(*args_float)
debug_print(" - - Float Eval:", repr(y_builtin), repr(y_float))
self.assertEqual(y_float, y_builtin)
self.assertIsInstance(y_float, float)
# Test method current overload, and ensure value is accurate.
y_T = f_drake(*args_T)
y_T_float = overload.to_float(y_T)
debug_print(" - - Overload Eval:", repr(y_T), repr(y_T_float))
self.assertIsInstance(y_T, overload.T)
# - Ensure the translated value is accurate.
self.assertEqual(y_T_float, y_float)
with self.subTest(function=f_drake.__name__, nargs=nargs):
if not overload.supports(f_drake):
continue
debug_print(
"- Functions: ",
qualname(f_drake),
qualname(f_builtin),
)
y_builtin = f_builtin(*args_float)
y_float = f_drake(*args_float)
debug_print(
" - - Float Eval:",
repr(y_builtin),
repr(y_float),
)
self.assertEqual(y_float, y_builtin)
self.assertIsInstance(y_float, float)
# Test method current overload, and ensure value is
# accurate.
y_T = f_drake(*args_T)
y_T_float = overload.to_float(y_T)
debug_print(
" - - Overload Eval:",
repr(y_T),
repr(y_T_float),
)
self.assertIsInstance(y_T, overload.T)
# - Ensure the translated value is accurate.
self.assertEqual(y_T_float, y_float)

debug_print("\n\nOverload: ", qualname(type(overload)))
float_overload = FloatOverloads()
Expand Down

0 comments on commit dbe7506

Please sign in to comment.