Skip to content

Commit

Permalink
pydrake: Add UpdateGlobalsFromModule test utility
Browse files Browse the repository at this point in the history
  • Loading branch information
EricCousineau-TRI committed Dec 10, 2018
1 parent a2fc39e commit d774035
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 16 deletions.
8 changes: 8 additions & 0 deletions bindings/pydrake/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ drake_cc_library(
visibility = ["//visibility:public"],
)

drake_cc_library(
name = "test_util_pybind",
testonly = 1,
hdrs = ["test/test_util_pybind.h"],
declare_installed_headers = 0,
visibility = ["//visibility:public"],
)

drake_cc_library(
name = "autodiff_types_pybind",
hdrs = ["autodiff_types_pybind.h"],
Expand Down
8 changes: 7 additions & 1 deletion bindings/pydrake/common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,10 @@ drake_py_unittest(

drake_pybind_cc_googletest(
name = "cpp_param_pybind_test",
cc_deps = [":cpp_param_pybind"],
cc_deps = [
":cpp_param_pybind",
"//bindings/pydrake:test_util_pybind",
],
py_deps = [":cpp_param_py"],
)

Expand All @@ -355,6 +358,7 @@ drake_pybind_cc_googletest(
name = "cpp_template_pybind_test",
cc_deps = [
":cpp_template_pybind",
"//bindings/pydrake:test_util_pybind",
"//common:nice_type_name",
"//common/test_utilities:expect_throws_message",
],
Expand All @@ -365,6 +369,7 @@ drake_pybind_cc_googletest(
name = "drake_variant_pybind_test",
cc_deps = [
":drake_variant_pybind",
"//bindings/pydrake:test_util_pybind",
],
)

Expand All @@ -391,6 +396,7 @@ drake_pybind_cc_googletest(
name = "type_safe_index_pybind_test",
cc_deps = [
":type_safe_index_pybind",
"//bindings/pydrake:test_util_pybind",
"//common:nice_type_name",
],
)
Expand Down
5 changes: 3 additions & 2 deletions bindings/pydrake/common/test/cpp_param_pybind_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include "pybind11/eval.h"
#include "pybind11/pybind11.h"

#include "drake/bindings/pydrake/test/test_util_pybind.h"

using std::string;

namespace drake {
Expand Down Expand Up @@ -86,8 +88,7 @@ int main(int argc, char** argv) {
// Define custom class only once here.
py::class_<CustomCppType>(m, "CustomCppType");

// For Python3
py::globals().attr("update")(m.attr("__dict__"));
test::UpdateGlobalsFromModule(m);
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
Expand Down
17 changes: 7 additions & 10 deletions bindings/pydrake/common/test/cpp_template_pybind_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "pybind11/eval.h"
#include "pybind11/pybind11.h"

#include "drake/bindings/pydrake/test/test_util_pybind.h"
#include "drake/common/nice_type_name.h"
#include "drake/common/test_utilities/expect_throws_message.h"

Expand All @@ -22,6 +23,8 @@ namespace drake {
namespace pydrake {
namespace {

using test::UpdateGlobalsFromModule;

template <typename... Ts>
struct SimpleTemplate {
vector<string> GetNames() {
Expand All @@ -45,12 +48,6 @@ void CheckValue(const string& expr, const T& expected) {
EXPECT_EQ(py::eval(expr).cast<T>(), expected);
}

// TODO(eric.cousineau): Figure out why this is necessary.
// Necessary for Python3.
void sync(py::module m) {
py::globals().attr("update")(m.attr("__dict__"));
}

GTEST_TEST(CppTemplateTest, TemplateClass) {
py::module m("__main__");

Expand All @@ -59,14 +56,14 @@ GTEST_TEST(CppTemplateTest, TemplateClass) {

const vector<string> expected_1 = {"int"};
const vector<string> expected_2 = {"int", "double"};
sync(m);
UpdateGlobalsFromModule(m);

CheckValue("DefaultInst().GetNames()", expected_1);
CheckValue("SimpleTemplate[int]().GetNames()", expected_1);
CheckValue("SimpleTemplate[int, float]().GetNames()", expected_2);

m.def("simple_func", [](const SimpleTemplate<int>&) {});
sync(m);
UpdateGlobalsFromModule(m);

// Check error message if a function is called with the incorrect arguments.
// N.B. We use `[^\0]` because C++ regex does not have an equivalent of
Expand All @@ -93,7 +90,7 @@ GTEST_TEST(CppTemplateTest, TemplateFunction) {

const vector<string> expected_1 = {"int"};
const vector<string> expected_2 = {"int", "double"};
sync(m);
UpdateGlobalsFromModule(m);
CheckValue("SimpleFunction[int]()", expected_1);
CheckValue("SimpleFunction[int, float]()", expected_2);
}
Expand All @@ -120,7 +117,7 @@ GTEST_TEST(CppTemplateTest, TemplateMethod) {

const vector<string> expected_1 = {"int"};
const vector<string> expected_2 = {"int", "double"};
sync(m);
UpdateGlobalsFromModule(m);
CheckValue("SimpleType().SimpleMethod[int]()", expected_1);
CheckValue("SimpleType().SimpleMethod[int, float]()", expected_2);
}
Expand Down
5 changes: 4 additions & 1 deletion bindings/pydrake/common/test/drake_variant_pybind_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
#include "pybind11/pybind11.h"

#include "drake/bindings/pydrake/pydrake_pybind.h"
#include "drake/bindings/pydrake/test/test_util_pybind.h"

using std::string;

namespace drake {
namespace pydrake {
namespace {

using test::UpdateGlobalsFromModule;

string VariantToString(const variant<int, double, string>& value) {
const bool is_int = holds_alternative<int>(value);
const bool is_double = holds_alternative<double>(value);
Expand All @@ -36,7 +39,7 @@ GTEST_TEST(VariantTest, CheckCasting) {
py::module m("__main__");

m.def("VariantToString", &VariantToString, py::arg("value"));
py::globals().attr("update")(m.attr("__dict__"));
UpdateGlobalsFromModule(m);
ExpectString("VariantToString(1)", "int(1)");
ExpectString("VariantToString(0.5)", "double(0.5)");
ExpectString("VariantToString('foo')", "string(foo)");
Expand Down
8 changes: 6 additions & 2 deletions bindings/pydrake/common/test/type_safe_index_pybind_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
#include "pybind11/eval.h"
#include "pybind11/pybind11.h"

#include "drake/bindings/pydrake/test/test_util_pybind.h"

using std::string;
using std::vector;

namespace drake {
namespace pydrake {
namespace {

using test::UpdateGlobalsFromModule;

template <typename T>
void CheckValue(const string& expr, const T& expected) {
EXPECT_EQ(py::eval(expr).cast<T>(), expected);
Expand All @@ -34,7 +38,7 @@ GTEST_TEST(TypeSafeIndexTest, CheckCasting) {
EXPECT_EQ(x, 10);
return x;
});
py::globals().attr("update")(m.attr("__dict__")); // For Python3
UpdateGlobalsFromModule(m);
CheckValue("pass_thru_int(10)", 10);
CheckValue("pass_thru_int(Index(10))", 10);
// TypeSafeIndex<> is not implicitly constructible from an int.
Expand All @@ -46,7 +50,7 @@ GTEST_TEST(TypeSafeIndexTest, CheckCasting) {
return x;
});

py::globals().attr("update")(m.attr("__dict__")); // For Python3
UpdateGlobalsFromModule(m);

// TypeSafeIndex<> is not implicitly constructible from an int.
// TODO(eric.cousineau): Consider relaxing this to *only* accept `int`s, and
Expand Down
22 changes: 22 additions & 0 deletions bindings/pydrake/test/test_util_pybind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#pragma once

#include "pybind11/pybind11.h"

#include "drake/bindings/pydrake/pydrake_pybind.h"

namespace drake {
namespace pydrake {
namespace test {

// TODO(eric.cousineau): Figure out if there is a better solution than this
// hack.
/// pybind11's Python3 implementation seems to disconnect the `globals()` from
/// a embedded interpreter's `__main__` module. To remedy this, we must
/// manually synchronize these variables.
inline void UpdateGlobalsFromModule(py::module m) {
py::globals().attr("update")(m.attr("__dict__"));
}

} // namespace test
} // namespace pydrake
} // namespace drake

0 comments on commit d774035

Please sign in to comment.