From 76ecf9d2796ab64e1cd8c6e62216e4826ebc343d Mon Sep 17 00:00:00 2001 From: Eric Cousineau Date: Mon, 12 Nov 2018 22:45:29 -0500 Subject: [PATCH] wip initial numpy --- BUILD.bazel | 2 +- bindings/BUILD.bazel | 10 + bindings/pybind11_ext/BUILD.bazel | 49 + bindings/pybind11_ext/README.md | 3 + bindings/pybind11_ext/numpy_dtypes_user.h | 850 ++++++++++++++++++ bindings/pybind11_ext/numpy_ufunc.h | 363 ++++++++ .../test/numpy_dtypes_user_test.py | 427 +++++++++ .../test/numpy_dtypes_user_test_util_py.cc | 339 +++++++ bindings/pydrake/util/BUILD.bazel | 15 + bindings/pydrake/util/function_inference.h | 116 +++ .../util/test/function_inference_test.cc | 93 ++ bindings/pydrake/util/test/type_pack_test.cc | 19 + .../pydrake/util/test/wrap_function_test.cc | 8 +- bindings/pydrake/util/type_pack.h | 12 + bindings/pydrake/util/wrap_function.h | 91 +- tools/BUILD.bazel | 3 + tools/skylark/pybind.bzl | 4 + tools/workspace/BUILD.bazel | 1 + tools/workspace/default.bzl | 6 +- tools/workspace/numpy/BUILD.bazel | 23 +- tools/workspace/numpy/package.BUILD.bazel | 35 + tools/workspace/numpy/repository.bzl | 112 ++- .../numpy/test/numpy_install_test.py | 20 + tools/workspace/numpy/test/numpy_test.py | 18 + tools/workspace/pybind11/package.BUILD.bazel | 1 - tools/workspace/pybind11/repository.bzl | 4 +- 26 files changed, 2470 insertions(+), 154 deletions(-) create mode 100644 bindings/pybind11_ext/BUILD.bazel create mode 100644 bindings/pybind11_ext/README.md create mode 100644 bindings/pybind11_ext/numpy_dtypes_user.h create mode 100644 bindings/pybind11_ext/numpy_ufunc.h create mode 100644 bindings/pybind11_ext/test/numpy_dtypes_user_test.py create mode 100644 bindings/pybind11_ext/test/numpy_dtypes_user_test_util_py.cc create mode 100644 bindings/pydrake/util/function_inference.h create mode 100644 bindings/pydrake/util/test/function_inference_test.cc create mode 100644 tools/workspace/numpy/package.BUILD.bazel create mode 100755 tools/workspace/numpy/test/numpy_install_test.py create mode 100644 tools/workspace/numpy/test/numpy_test.py diff --git a/BUILD.bazel b/BUILD.bazel index 87c32bfe53ff..ede169d1641e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -56,7 +56,7 @@ install( docs = ["LICENSE.TXT"], deps = [ "//automotive/models:install_data", - "//bindings/pydrake:install", + "//bindings:install", "//common:install", "//common/proto:install", "//examples:install", diff --git a/bindings/BUILD.bazel b/bindings/BUILD.bazel index 2799004f4e66..ad01807bc70a 100644 --- a/bindings/BUILD.bazel +++ b/bindings/BUILD.bazel @@ -1,6 +1,7 @@ # -*- python -*- load("//tools/lint:lint.bzl", "add_lint_tests") +load("@drake//tools/install:install.bzl", "install") load( "@drake//tools/skylark:pybind.bzl", "drake_pybind_library", @@ -26,4 +27,13 @@ drake_pybind_library( package_info = get_bazel_workaround_4594_libdrake_package_info(), ) +install( + name = "install", + visibility = ["//:__pkg__"], + deps = [ + "//bindings/pybind11_ext:install", + "//bindings/pydrake:install", + ], +) + add_lint_tests() diff --git a/bindings/pybind11_ext/BUILD.bazel b/bindings/pybind11_ext/BUILD.bazel new file mode 100644 index 000000000000..ad2f150d9089 --- /dev/null +++ b/bindings/pybind11_ext/BUILD.bazel @@ -0,0 +1,49 @@ +# -*- python -*- + +load("@drake//tools/install:install.bzl", "install") +load("//tools/lint:lint.bzl", "add_lint_tests") +load("@drake//tools/skylark:drake_cc.bzl", "drake_cc_library") +load("@drake//tools/skylark:drake_py.bzl", "drake_py_unittest") +load("@drake//tools/skylark:pybind.bzl", "pybind_py_library") + +package(default_visibility = [ + "//bindings:__subpackages__", +]) + +drake_cc_library( + name = "numpy_dtypes_user", + hdrs = [ + "numpy_dtypes_user.h", + "numpy_ufunc.h", + ], + deps = [ + "//bindings/pydrake/util:type_pack", + "//bindings/pydrake/util:wrap_function", + "@pybind11", + ], +) + +install( + name = "install", + targets = ["numpy_dtypes_user"], + visibility = ["//visibility:public"], +) + +pybind_py_library( + name = "numpy_dtypes_user_test_util_py", + testonly = 1, + cc_deps = [ + ":numpy_dtypes_user", + "@fmt", + ], + cc_so_name = "numpy_dtypes_user_test_util", + cc_srcs = ["test/numpy_dtypes_user_test_util_py.cc"], + py_imports = ["."], +) + +drake_py_unittest( + name = "numpy_dtypes_user_test", + deps = [":numpy_dtypes_user_test_util_py"], +) + +add_lint_tests() diff --git a/bindings/pybind11_ext/README.md b/bindings/pybind11_ext/README.md new file mode 100644 index 000000000000..e0ac79c5be63 --- /dev/null +++ b/bindings/pybind11_ext/README.md @@ -0,0 +1,3 @@ +# pybind11 extensions + +Provides the ability to provide uesr-defined dtypes in NumPy. diff --git a/bindings/pybind11_ext/numpy_dtypes_user.h b/bindings/pybind11_ext/numpy_dtypes_user.h new file mode 100644 index 000000000000..629342b3ea5b --- /dev/null +++ b/bindings/pybind11_ext/numpy_dtypes_user.h @@ -0,0 +1,850 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "drake/bindings/pybind11_ext/numpy_ufunc.h" + +// N.B. For NumPy dtypes, `custom` tends to mean record-like structures, while +// `user-defined` means teaching NumPy about previously opaque C structures. + +// TODO(eric.cousineau): Figure out how to make this automatically hidden. +#pragma GCC visibility push(hidden) + +namespace pybind11 { +namespace detail { + +// The following code effectively creates a separate instance system than what +// pybind11 nominally has. This is done because, at present, it's difficult to +// have pybind11 extend other python types, in this case, `np.generic` / +// `PyGenericArrType_Type` (#1170). + +// TODO(eric.cousineau): Get rid of this structure if #1170 can be resolved. + +typedef PyObject* (*nb_conversion_t)(PyObject*); + +// Stores dtype-specific information, as well as static access to relevant +// internals. +// Effectively a watered down version of `detail::type_info`. +struct dtype_info { + handle cls; + int dtype_num{-1}; + std::map instance_to_py; + std::vector implicit_conversions; + std::map nb_implicit_conversions; + + // Provides mutable entry for a registered type, with option to create. + template + static dtype_info& get_mutable_entry(bool is_new = false) { + auto& internals = get_mutable_internals(); + std::type_index id(typeid(T)); + if (is_new) { + if (internals.find(id) != internals.end()) + pybind11_fail("Class already registered"); + return internals[id]; + } else { + return internals.at(id); + } + } + + // Provides immutable entry for a registered type. + template + static const dtype_info& get_entry() { + return get_mutable_internals().at(std::type_index(typeid(T))); + } + + // Provides immutable entry for a registered type, given the typeid. + static const dtype_info& get_entry(std::type_index id) { + return get_mutable_internals().at(id); + } + + // Provides immutable entry for a registered type, or nullptr. + static const dtype_info* maybe_get_entry(std::type_index id) { + const auto& internals = get_mutable_internals(); + auto iter = internals.find(id); + if (iter != internals.end()) { + return &iter->second; + } else { + return nullptr; + } + } + + // Finds the corresponding typeid for a `cls`, return `nullptr` if nothing is + // found. + static const std::type_index* find_entry(object cls) { + auto& map = get_internals(); + for (auto& iter : map) { + auto& entry = iter.second; + if (cls.ptr() == entry.cls.ptr()) + return &iter.first; + } + return nullptr; + } + + private: + using internals = std::map; + static const internals& get_internals() { + return get_mutable_internals(); + } + + // TODO(eric.cousineau): Store in internals. + static internals& get_mutable_internals() { + static internals* ptr = + &get_or_create_shared_data("_numpy_dtype_user_internals"); + return *ptr; + } +}; + +// CPython extension of `PyObject` for per-instance information of a +// user-defind dtype. Akin to `detail::instance`. +template +struct dtype_user_instance { + PyObject_HEAD + // TODO(eric.cousineau): Consider storing a unique_ptr to reduce the number + // of temporaries. + Class value; + + // Extracts C++ pointer from a given python object. No type checking is done. + static Class* load_raw(PyObject* src) { + dtype_user_instance* obj = reinterpret_cast(src); + return &obj->value; + } + + // Allocates an instance. + static dtype_user_instance* alloc_py() { + auto cls = dtype_info::get_entry().cls; + PyTypeObject* cls_raw = reinterpret_cast(cls.ptr()); + auto obj = reinterpret_cast( + cls_raw->tp_alloc(cls_raw, 0)); + // Ensure we clear out the memory. + memset(&obj->value, 0, sizeof(Class)); + return obj; + } + + // Implementation for `tp_new` slot. + static PyObject* tp_new( + PyTypeObject* /*type*/, PyObject* /*args*/, PyObject* /*kwds*/) { + // N.B. `__init__` should call the in-place constructor. + auto obj = alloc_py(); + // // Register. + auto& entry = dtype_info::get_mutable_entry(); + PyObject* pyobj = reinterpret_cast(obj); + entry.instance_to_py[&obj->value] = pyobj; + return pyobj; + } + + // Implementation for `tp_dealloc` slot. + static void tp_dealloc(PyObject* pself) { + Class* value = load_raw(pself); + // Call destructor. + value->~Class(); + // Deregister. + auto& entry = dtype_info::get_mutable_entry(); + entry.instance_to_py.erase(value); + } + + // Instance finding. Returns empty `object` if nothing is found. + static object find_existing(const Class* value) { + auto& entry = dtype_info::get_entry(); + void* raw = const_cast(value); + auto it = entry.instance_to_py.find(raw); + if (it == entry.instance_to_py.end()) { + return {}; + } else { + return reinterpret_borrow(it->second); + } + } +}; + +// Implementation of `type_caster` to interface `dtype_user_instance<>`s. +template +struct dtype_user_caster { + static constexpr auto name = detail::_(); + using DTypePyObject = dtype_user_instance; + + // Casts a const lvalue reference to a Python object. + static handle cast(const Class& src, return_value_policy, handle) { + object h = DTypePyObject::find_existing(&src); + // TODO(eric.cousineau): Handle parenting? + if (!h) { + // Make new instance. + DTypePyObject* obj = DTypePyObject::alloc_py(); + obj->value = src; + h = reinterpret_borrow(reinterpret_cast(obj)); + return h.release(); + } + return h.release(); + } + + // Casts a pointer to a Python object. + static handle cast(const Class* src, return_value_policy policy, handle) { + object h = DTypePyObject::find_existing(src); + if (h) { + return h.release(); + } else { + if (policy == return_value_policy::automatic_reference || + policy == return_value_policy::reference) { + throw cast_error("Cannot find existing instance"); + } else { + // Copy the instance. + DTypePyObject* obj = DTypePyObject::alloc_py(); + obj->value = *src; + delete src; + h = reinterpret_borrow(reinterpret_cast(obj)); + return h.release(); + } + } + } + + // Load from Python to C++. The result will be retrieved by the casting + // operators below. + bool load(handle src, bool convert) { + auto& entry = dtype_info::get_entry(); + auto cls = entry.cls; + object obj; + if (!isinstance(src, cls)) { + // Check if it's an `np.array` with matching dtype. + handle array = reinterpret_cast( + npy_api::get().PyArray_Type_); + if (isinstance(src, array)) { + tuple shape = src.attr("shape"); + if (shape.size() == 0) { + obj = src.attr("item")(); + } + } + if (!obj && convert) { + // Try implicit conversions. + for (auto& converter : entry.implicit_conversions) { + auto temp = converter( + src.ptr(), reinterpret_cast(cls.ptr())); + if (temp) { + obj = reinterpret_steal(temp); + loader_life_support::add_patient(obj); + break; + } + } + } + } else { + obj = reinterpret_borrow(src); + } + if (!obj) { + return false; + } else { + ptr_ = DTypePyObject::load_raw(obj.ptr()); + return true; + } + } + + // Copy `type_caster_base`. + template using cast_op_type = + pybind11::detail::cast_op_type; + + // Retrieves result after `load()`. + operator Class&() { return *ptr_; } + // Retrieves result after `load()`. + operator Class*() { return ptr_; } + + private: + // Stores result after `load()`. + Class* ptr_{}; +}; + +// Ensures that `dtype_user_caster` can cast pointers. See `cast.h`. +template +struct cast_is_known_safe>, make_caster>::value>> + : public std::true_type {}; + +// Maps a common Python function name to a NumPy UFunc name, or just returns +// the original name (for trigonometric functions). +inline std::string get_ufunc_name(std::string name) { + static const std::map m = { + // Anything that is mapped to `nullptr` implies that NumPy does not support + // this ufunc. + // https://docs.python.org/2.7/reference/datamodel.html#emulating-numeric-types // NOLINT(whitespace/line_length) + // Use nominal ordering (e.g. `__add__`, not `__radd__`) as ordering will + // be handled by ufunc registration. + // Use Python 3 operator names (e.g. `__truediv__`) + // https://docs.scipy.org/doc/numpy/reference/routines.math.html + {"__add__", "add"}, + {"__iadd__", nullptr}, + {"__neg__", "negative"}, + {"__pos__", nullptr}, + {"__mul__", "multiply"}, + {"__imul__", nullptr}, + // https://docs.scipy.org/doc/numpy/reference/routines.bitwise.html + {"__and__", "bitwise_and"}, + {"__iand__", nullptr}, + {"__or__", "bitwise_or"}, + {"__ior__", nullptr}, + {"__xor__", "bitwise_xor"}, + {"__ixor__", nullptr}, + // TODO(eric.cousineau): Figure out how to appropriately map `true_divide` + // vs. `divide` when the output type is adjusted? + {"__truediv__", "divide"}, + {"__itruediv__", nullptr}, + {"__pow__", "power"}, + {"__sub__", "subtract"}, + {"__isub__", nullptr}, + {"__abs__", "absolute"}, + // https://docs.scipy.org/doc/numpy/reference/routines.logic.html + {"__gt__", "greater"}, + {"__ge__", "greater_equal"}, + {"__lt__", "less"}, + {"__le__", "less_equal"}, + {"__eq__", "equal"}, + {"__ne__", "not_equal"}, + {"__bool__", "nonzero"}, // Python3 + {"__nonzero__", "nonzero"}, // Python2.7 + {"__invert__", "logical_not"}, + // Are these necessary? + {"min", "fmin"}, + {"max", "fmax"}, + // TODO(eric.cousineau): Add something for junction-style logic? + }; + auto iter = m.find(name); + if (iter != m.end()) { + if (!iter->second) { + throw std::runtime_error("Invalid NumPy operator: " + name); + } + return iter->second; + } else { + return name; + } +} + +// Provides implementation of `npy_format_decsriptor` for a user-defined dtype. +template +struct dtype_user_npy_format_descriptor { + static constexpr auto name = detail::_(); + static pybind11::dtype dtype() { + int dtype_num = dtype_info::get_entry().dtype_num; + if (auto ptr = detail::npy_api::get().PyArray_DescrFromType_(dtype_num)) + return reinterpret_borrow(ptr); + pybind11_fail("Unsupported buffer format!"); + } +}; + +// Stores information about a conversion. +template +struct dtype_conversion_t { + Func func; + bool allow_implicit_coercion{}; +}; + +// Infers the correct signature for `dtype_conversion_t` from a function. +template +static auto dtype_conversion_impl( + FuncIn&& func_in, bool allow_implicit_coercion) { + auto func_infer = detail::infer_function_info(func_in); + using FuncInfer = decltype(func_infer); + using From = detail::intrinsic_t< + typename FuncInfer::Args::template type_at<0>>; + using To = detail::intrinsic_t; + using Func = typename FuncInfer::Func; + return dtype_conversion_t{ + std::forward(func_infer.func), allow_implicit_coercion}; +} + +} // namespace detail + +/// Provides user control over definition of UFuncs. +struct dtype_method { + /// Defines `np.dot` for a given type. + struct dot {}; + + /// Uses constructor / casting for explicit conversion. + template + static auto explicit_conversion() { + return detail::dtype_conversion_impl([](const From& in) { + return To(in); + }, false); + } + + /// Provides function for explicit conversion. + template + static auto explicit_conversion(Func&& func) { + return detail::dtype_conversion_impl(std::forward(func), false); + } + + /// Uses constructor / casting for implicit conversion. + template + static auto implicit_conversion() { + return detail::dtype_conversion_impl( + [](const From& in) -> To { return in; }, true); + } + + /// Provides function for implicit conversion. + template + static auto implicit_conversion(Func&& func) { + return detail::dtype_conversion_impl(std::forward(func), true); + } + + /// Implies that only a `ufunc` should be defined, and the corresponding class + /// method should not be defined. + struct ufunc_only {}; +}; + +/** +Defines a user-defined dtype. + +Constraints: +* The type must be copy-constructible and assignable. +* The type *may* not have its constructor called; however, its memory *will* be +initialized to zero, so it's assignment should be robust against being assigned +from zero memory. +* This type's instance *won't* always be destroyed, because NumPy does not have +slots to define this yet. + */ +template +class dtype_user : public object { + public: + static_assert( + !std::is_polymorphic::value, + "Cannot define NumPy dtypes for polymorphics classes."); + + using PyClass = class_; + using Class = Class_; + using DTypePyObject = detail::dtype_user_instance; + + dtype_user(handle scope, const char* name, const char* doc = "") + : cls_(none()) { + register_type(name, doc); + scope.attr(name) = self(); + auto& entry = detail::dtype_info::get_mutable_entry(true); + entry.cls = self(); + // Register numpy type. + // (Note that not registering the type will result in infinte recursion). + entry.dtype_num = register_numpy(); + + // Register default ufunc cast to `object`. + // N.B. Given how general this is, it should *NEVER* be implicit, as it + // would interfere with more meaningful casts. + // N.B. This works because `object` is defined to have the same memory + // layout as `PyObject*`, thus can be registered in lieu of `PyObject*` - + // this also effectively increases the refcount and releases the object. + this->def_loop(dtype_method::explicit_conversion( + [](const Class& self) -> object { return pybind11::cast(self); })); + object cls = self(); + auto object_to_cls = [cls](object obj) -> Class { + // N.B. We use the *constructor* rather than implicit conversions because + // implicit conversions may not be sufficient when dealing with `object` + // dtypes. As an example, a class can only explicitly cast to float, but + // the array is constructed as `np.array([1., Class(2)])`. The inferred + // dtype in this case will be `object`. + if (!isinstance(obj, cls)) { + // This will catch type mismatch errors. + // TODO(eric.cousineau): Not having the correct constructor registered + // can causes segfaults when the error is begin printed out, due to the + // indirection of `_dtype_init`. Consider changing this... + obj = cls(obj); + } + return obj.cast(); + }; + this->def_loop(dtype_method::explicit_conversion(object_to_cls)); + } + + ~dtype_user() { + // This will be called once the `pybind11` module ends, and thus should + // warn the user if they've left this class in a bad state. + check(); + } + + /// Forwards method definition to `py::class_`. + template + dtype_user& def(const char* name, Args&&... args) { + cls().def(name, std::forward(args)...); + return *this; + } + + /// Defines a constructor. + template + dtype_user& def( + detail::initimpl::constructor&&, const char* doc = "") { + // See notes in `add_init`. + // N.B. Do NOT use `Class*` as the argument, since that may incur recursion. + add_init([](object py_self, Args... args) { + // Old-style. No factories for now. + Class* self = DTypePyObject::load_raw(py_self.ptr()); + new (self) Class(std::forward(args)...); + }, doc); + return *this; + } + + /// Defines UFunc loop operator. + template + dtype_user& def_loop( + const detail::op_&, dtype_method::ufunc_only) { + // Register ufunction with builtin name. + // Use `op_l`. Mapping `__radd__` to `add` would require remapping argument + // order, and screw that. We can just use the fact that `op_impl` is + // generic. + constexpr auto ot_norm = (ot == detail::op_r) ? detail::op_l : ot; + using op_norm_ = detail::op_; + using op_norm_impl = typename op_norm_::template info::op; + std::string ufunc_name = detail::get_ufunc_name(op_norm_impl::name()); + ufunc::get_builtin(ufunc_name.c_str()) // BR + .def_loop(&op_norm_impl::execute); + if (ufunc_name == "divide") { + ufunc::get_builtin("true_divide").def_loop(&op_norm_impl::execute); + } + return *this; + } + + /// Defines Python and UFunc loop operator. + template + dtype_user& def_loop(const detail::op_& op) { + // Define Python class operator. + using op_ = detail::op_; + using op_impl = typename op_::template info::op; + this->def(op_impl::name(), &op_impl::execute, is_operator()); + // Define dtype operators. + return def_loop(op, dtype_method::ufunc_only()); + } + + /// Defines a scalar function and overloads an existing NumPy UFunc loop, + /// mapping to a buitlin name in `numpy`. + template + dtype_user& def_loop(const char* name, const Func& func) { + cls().def(name, func); + std::string ufunc_name = detail::get_ufunc_name(name); + ufunc::get_builtin(ufunc_name.c_str()).def_loop(func); + return *this; + } + + /// Forwards operator defintiion to `py::class_`. + template + dtype_user& def( + const detail::op_& op, const Extra&... extra) { + cls().def(op, extra...); + return *this; + } + + /// Defines loop cast, and optionally permit implicit conversions. + template + dtype_user& def_loop( + detail::dtype_conversion_t conv, + dtype from = {}, dtype to = {}) { + detail::ufunc_register_cast( + conv.func, conv.allow_implicit_coercion, from, to); + // Define implicit conversion on the class. + if (conv.allow_implicit_coercion) { + if (std::is_same::value) { + // TODO(eric.cousineau): Is this a good idea? It's quite confusing to + // discard the function here. + auto& entry = detail::dtype_info::get_mutable_entry(); + entry.implicit_conversions.push_back( + detail::create_implicit_caster()); + } else { + auto enabled = std::is_same{}; + register_nb_conversion(enabled, conv.func); + } + } + return *this; + } + + /// Defines dot product. + template + dtype_user& def_loop(dtype_method::dot) { + // TODO(eric.cousineau): See if there is a way to define `dot` for an + // algebra that is not closed under addition / multiplication (e.g. + // symbolic variable -> symbolic expression). + if (arrfuncs_->dotfunc) + pybind11_fail("dtype: Cannot redefine `dot`"); + using detail::npy_intp; + arrfuncs_->dotfunc = reinterpret_cast(+[]( + void* ip0_, npy_intp is0, void* ip1_, npy_intp is1, + void* op, npy_intp n, void* /*arr*/) { + const char *ip0 = reinterpret_cast(ip0_); + const char *ip1 = reinterpret_cast(ip1_); + Class r{}; + for (npy_intp i = 0; i < n; i++) { + const Class& v1 = *reinterpret_cast(ip0); + const Class& v2 = *reinterpret_cast(ip1); + r += v1 * v2; + ip0 += is0; + ip1 += is1; + } + *reinterpret_cast(op) = r; + }); + return *this; + } + + /// Access a `py::class_` view of the type. Please be careful when adding + /// methods or attributes, as they may conflict with how NumPy works. + PyClass& cls() { return cls_; } + + private: + // Provides mutable explicit upcast reference (for assignment). + object& self() { return *this; } + const object& self() const { return *this; } + + // Checks definition invariants. + void check() const { + auto warn = [](const std::string& msg) { + // TODO(eric.cousineau): Figure out better warning type. + PyErr_WarnEx(PyExc_UserWarning, msg.c_str(), 0); + }; + // This `dict` should indicate whether we've directly overridden methods. + dict d = self().attr("__dict__"); + // Without these, numpy goes into infinite recursion. Haven't bothered to + // figure out exactly why. + if (!d.contains("__repr__")) + warn("dtype: Class is missing explicit __repr__!"); + if (!d.contains("__str__")) + warn("dtype: Class is missing explicit __str__!"); + } + + // Adds constructor. See comments within for explanation. + template + void add_init(Func&& f, const char* doc) { + // Do not construct this with the name `__init__` as `pybind11`s + // constructor implementations via `cpp_function` are rigidly fixed to + // its instance registration system (which we don't want). + // Because of this, if there is an error in constructors when testing + // overloads, `repr()` may be called on an object in an invalid state. + this->def("_dtype_init", std::forward(f)); + // Ensure that this is called by a non-pybind11-instance `__init__`. + dict d = self().attr("__dict__"); + if (!d.contains("__init__")) { + auto init = self().attr("_dtype_init"); + auto func = cpp_function( + [init](handle self, args args, kwargs kwargs) { + // Dispatch. + init(self, *args, **kwargs); + }, is_method(self()), doc); + self().attr("__init__") = func; + } + } + + // Handles conversions from `nb_*` methods in Python type objects. Uses + // available conversions if possible, otherwise will cause an error to be + // thrown when called. + template + static PyObject* handle_nb_conversion(PyObject* from) { + auto& entry = detail::dtype_info::get_entry(); + auto& conversions = entry.nb_implicit_conversions; + // Check for available conversions. + std::type_index id(typeid(T)); + auto iter = conversions.find(id); + if (iter != conversions.end()) { + return iter->second(from); + } else { + PyErr_SetString( + PyExc_TypeError, + "dtype_user: Direct casting via Python not supported"); + return nullptr; + } + } + + // Handles registering native `nb_*` type conversions. + template + void register_nb_conversion(std::true_type, const Func& func) { + auto& entry = detail::dtype_info::get_mutable_entry(); + std::type_index id(typeid(To)); + auto& conversions = entry.nb_implicit_conversions; + assert(conversions.find(id) == conversions.end()); + static Func func_static = func; + detail::nb_conversion_t nb_conversion = + +[](PyObject* from_py) -> PyObject* { + Class* from = pybind11::cast(from_py); + To to = func_static(*from); + return pybind11::cast(to).release().ptr(); + }; + conversions[id] = nb_conversion; + } + + template + void register_nb_conversion(std::false_type, const Func&) {} + + // Disables `nb_coerce`. + static int disable_nb_coerce(PyObject**, PyObject**) { + PyErr_SetString( + PyExc_TypeError, + "dtype_user: Direct coercion via Python not supported"); + return 1; + } + + // Registers Python type. + void register_type(const char* name, const char* doc) { + // Ensure we initialize NumPy before accessing `PyGenericArrType_Type`. + auto& api = detail::npy_api::get(); + // Loosely uses https://stackoverflow.com/a/12505371/7829525 as well. + auto heap_type = reinterpret_cast( + PyType_Type.tp_alloc(&PyType_Type, 0)); + if (!heap_type) + pybind11_fail("dtype_user: Could not register heap type"); + heap_type->ht_name = pybind11::str(name).release().ptr(); + // It's painful to inherit from `np.generic`, because it has no `tp_new`. + auto& ClassObject_Type = heap_type->ht_type; + ClassObject_Type.tp_base = api.PyGenericArrType_Type_; + ClassObject_Type.tp_new = &DTypePyObject::tp_new; + ClassObject_Type.tp_dealloc = &DTypePyObject::tp_dealloc; + ClassObject_Type.tp_name = name; // Er... scope? + ClassObject_Type.tp_basicsize = sizeof(DTypePyObject); + ClassObject_Type.tp_getset = 0; + ClassObject_Type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE; + ClassObject_Type.tp_doc = doc; + if (PyType_Ready(&ClassObject_Type) != 0) + pybind11_fail("dtype_user: Unable to initialize class"); + // TODO(eric.cousineau): Figure out how to catch recursions with + // `tp_as_number` and casting, when it's not defined + static auto tp_as_number = *ClassObject_Type.tp_as_number; + ClassObject_Type.tp_as_number = &tp_as_number; + // TODO(eric.cousineau): Figure out how to use more generic dispatch on + // this object. If we use the `np.generic` stuff, we end up getting + // recursive loops. + tp_as_number.nb_float = &handle_nb_conversion; + tp_as_number.nb_int = &handle_nb_conversion; +#if PY_VERSION_HEX < 0x03000000 + tp_as_number.nb_long = &handle_nb_conversion; + tp_as_number.nb_coerce = &disable_nb_coerce; +#endif + // Create views into created type. + self() = reinterpret_borrow( + reinterpret_cast(&ClassObject_Type)); + cls_ = self(); + } + + // Registers NumPy dtype entry. + int register_numpy() { + using detail::npy_api; + // Adapted from `numpy/core/multiarray/src/test_rational.c.src`. + // Define NumPy description. + auto type = reinterpret_cast(self().ptr()); + struct align_test { char c; Class r; }; + + static detail::PyArray_ArrFuncs arrfuncs; + static detail::PyArray_Descr descr = { + PyObject_HEAD_INIT(0) + type, /* typeobj */ + 'V', /* kind (V = arbitrary) */ + 'r', /* type */ + '=', /* byteorder */ + npy_api::NPY_NEEDS_PYAPI_ | npy_api::NPY_USE_GETITEM_ | + npy_api::NPY_USE_SETITEM_ | + npy_api::NPY_NEEDS_INIT_, /* flags */ + 0, /* type_num */ + sizeof(Class), /* elsize */ + offsetof(align_test, r), /* alignment */ + 0, /* subarray */ + 0, /* fields */ + 0, /* names */ + &arrfuncs, /* f */ + }; + + auto& api = npy_api::get(); + Py_TYPE(&descr) = api.PyArrayDescr_Type_; + + api.PyArray_InitArrFuncs_(&arrfuncs); + + using detail::npy_intp; + + // https://docs.scipy.org/doc/numpy/reference/c-api.types-and-structures.html + arrfuncs.getitem = reinterpret_cast( + +[](void* in, void* /*arr*/) -> PyObject* { + auto item = reinterpret_cast(in); + return pybind11::cast(*item).release().ptr(); + }); + arrfuncs.setitem = reinterpret_cast( + +[](PyObject* in, void* out, void* /*arr*/) { + detail::loader_life_support guard{}; + detail::dtype_user_caster caster; + if (!caster.load(in, true)) { + PyErr_SetString( + PyExc_TypeError, + "dtype_user: Could not convert during `setitem`"); + return -1; + } + *reinterpret_cast(out) = caster; + return 0; + }); + arrfuncs.copyswap = reinterpret_cast( + +[](void* dst, void* src, int swap, void* /*arr*/) { + if (!src) return; + Class* r_dst = reinterpret_cast(dst); + Class* r_src = reinterpret_cast(src); + if (swap) { + PyErr_SetString( + PyExc_NotImplementedError, + "dtype_user: `swap` not implemented"); + } else { + *r_dst = *r_src; + } + }); + arrfuncs.copyswapn = reinterpret_cast( + +[](void* dst, npy_intp dstride, void* src, + npy_intp sstride, npy_intp n, int swap, void*) { + if (!src) return; + if (swap) { + PyErr_SetString( + PyExc_NotImplementedError, + "dtype_user: `swap` not implemented"); + } else { + char* c_dst = reinterpret_cast(dst); + char* c_src = reinterpret_cast(src); + for (int k = 0; k < n; k++) { + Class* r_dst = reinterpret_cast(c_dst); + Class* r_src = reinterpret_cast(c_src); + *r_dst = *r_src; + c_dst += dstride; + c_src += sstride; + } + } + }); + // - Ensure this doesn't overwrite our `equal` unfunc. + arrfuncs.compare = reinterpret_cast( + +[](const void* /*d1*/, const void* /*d2*/, void* /*arr*/) { + pybind11_fail( + "dtype: `compare` should not be called for pybind11 " + "custom dtype"); + }); + arrfuncs.fillwithscalar = reinterpret_cast( + +[](void* buffer_raw, npy_intp length, void* value_raw, + void* /*arr*/) { + const Class* value = reinterpret_cast(value_raw); + Class* buffer = reinterpret_cast(buffer_raw); + for (int k = 0; k < length; k++) { + buffer[k] = *value; + } + return 0; + }); + int dtype_num = api.PyArray_RegisterDataType_(&descr); + if (dtype_num < 0) { + pybind11_fail("dtype_user: Could not register!"); + } + self().attr("dtype") = + reinterpret_borrow(reinterpret_cast(&descr)); + arrfuncs_ = &arrfuncs; + return dtype_num; + } + + PyClass cls_; + detail::PyArray_ArrFuncs* arrfuncs_{}; +}; + +} // namespace pybind11 + +// Ensures that we can (a) cast the type (semi) natively, and (b) integrate +// with NumPy functionality. +#define PYBIND11_NUMPY_DTYPE_USER(Type) \ + namespace pybind11 { namespace detail { \ + template <> \ + struct type_caster : public dtype_user_caster {}; \ + template <> \ + struct npy_format_descriptor \ + : public dtype_user_npy_format_descriptor {}; \ + }} + +#pragma GCC visibility pop diff --git a/bindings/pybind11_ext/numpy_ufunc.h b/bindings/pybind11_ext/numpy_ufunc.h new file mode 100644 index 000000000000..9226a70c2381 --- /dev/null +++ b/bindings/pybind11_ext/numpy_ufunc.h @@ -0,0 +1,363 @@ +#pragma once + +/// @file +/// Simple glue for NumPy UFuncs. + +#include +#include +#include + +#include + +#include "drake/bindings/pydrake/util/function_inference.h" +#include "drake/bindings/pydrake/util/type_pack.h" + +// TODO(eric.cousineau): Figure out how to make this automatically hidden. +#pragma GCC visibility push(hidden) + +namespace pybind11 { +namespace detail { + +// Since this code lives in Drake, use existing headers. +using drake::type_pack; +using drake::type_pack_apply; +using drake::type_pack_concat; +using drake::pydrake::detail::infer_function_info; + +// Utilities + +// Builtins registered using +// numpy/build/{...}/numpy/core/include/numpy/__umath_generated.c + +template +struct ufunc_ptr { + PyUFuncGenericFunction func{}; + void* data{}; +}; + +// Unary ufunc. +template +auto ufunc_to_ptr(Func func, type_pack) { + auto ufunc = []( + char** args, npy_intp* dimensions, npy_intp* steps, void* data) { + Func& func_inner = *reinterpret_cast(data); + npy_intp step_0 = steps[0]; + npy_intp step_out = steps[1]; + npy_intp n = *dimensions; + char *in_0 = args[0], *out = args[1]; + for (npy_intp k = 0; k < n; k++) { + // TODO(eric.cousineau): Support pointers being changed. + *reinterpret_cast(out) = func_inner(*reinterpret_cast(in_0)); + in_0 += step_0; + out += step_out; + } + }; + // N.B. `new Func(...)` will never be destroyed. + return ufunc_ptr{ufunc, new Func(func)}; +} + +// Binary ufunc. +template +auto ufunc_to_ptr(Func func, type_pack) { + auto ufunc = []( + char** args, npy_intp* dimensions, npy_intp* steps, void* data) { + Func& func_inner = *reinterpret_cast(data); + npy_intp step_0 = steps[0]; + npy_intp step_1 = steps[1]; + npy_intp step_out = steps[2]; + npy_intp n = *dimensions; + char *in_0 = args[0], *in_1 = args[1], *out = args[2]; + for (npy_intp k = 0; k < n; k++) { + // TODO(eric.cousineau): Support pointers being fed in. + *reinterpret_cast(out) = func_inner( + *reinterpret_cast(in_0), *reinterpret_cast(in_1)); + in_0 += step_0; + in_1 += step_1; + out += step_out; + } + }; + // N.B. `new Func(...)` will never be destroyed. + return ufunc_ptr{ufunc, new Func(func)}; +} + +// Generic dispatch. +template +auto ufunc_to_ptr(Func func) { + auto info = detail::infer_function_info(func); + using Info = decltype(info); + auto type_args = type_pack_apply( + type_pack_concat( + typename Info::Args{}, + type_pack{})); + return ufunc_to_ptr(func, type_args); +} + +template +void assert_ufunc_dtype_valid() { + auto T_dtype = dtype::of(); + int num = T_dtype.num(); + bool is_object = std::is_same::value; + if (num == npy_api::NPY_OBJECT_ && !is_object) { + std::string message = + "ufunc: Cannot handle `dtype=object` when T != `py::object` "; + message += "(where T = " + type_id() + "). "; + message += "Please register function using `py::object`"; + pybind11_fail(message.c_str()); + } +} + +template +void ufunc_register_cast( + Func&& func, bool allow_coercion, + dtype from = {}, dtype to = {}, type_pack = {}) { + assert_ufunc_dtype_valid(); + assert_ufunc_dtype_valid(); + static auto cast_lambda = detail::infer_function_info(func).func; + auto cast_func = +[]( + void* from_, void* to_, npy_intp n, + void* /*fromarr*/, void* /*toarr*/) { + const From* from_inner = reinterpret_cast(from_); + To* to_inner = reinterpret_cast(to_); + for (npy_intp i = 0; i < n; i++) + to_inner[i] = cast_lambda(from_inner[i]); + }; + auto& api = npy_api::get(); + if (!from) { + from = npy_format_descriptor::dtype(); + } + if (!to) { + to = npy_format_descriptor::dtype(); + } + int to_num = to.num(); + auto from_raw = reinterpret_cast(from.ptr()); + if (from.num() == npy_api::NPY_OBJECT_ && !std::is_same::value) + pybind11_fail( + "ufunc: Registering conversion from `dtype=object` with " + "From != `py::object` is not supported"); + if (api.PyArray_RegisterCastFunc_(from_raw, to_num, cast_func) < 0) { + pybind11_fail("ufunc: Cannot register cast"); + } + if (allow_coercion) { + if (api.PyArray_RegisterCanCast_( + from_raw, to_num, npy_api::NPY_NOSCALAR_) < 0) + pybind11_fail( + "ufunc: Cannot register implicit / coercion cast capability"); + } +} + +} // namespace detail + +/** +Defines a UFunc in NumPy. +Handles either 1 or 2 arguments (for unary or binary arguments). +@pre All classes used must have corresponding dtypes in NumPy. +@note You must specify the class that is going to own the UFunc, due to +NumPy's API design. + +Example: + + py::ufunc(m, "custom_ufunc") + .def_loop([](const CustomClass& a) { ... }); + */ +class ufunc : public object { + public: + /// Enables defining a new UFunc `name` in `scope`. + ufunc(handle scope, const char* name) : scope_{scope} { + entries.reset(new entries_t(name)); + } + + // Wraps a reference to an existing UFunc. + // NOLINTNEXTLINE(runtime/explicit) + ufunc(object ptr_in) : object(ptr_in) { + // TODO(eric.cousineau): Check type. + if (!self() || self().is_none()) { + pybind11_fail("ufunc: Cannot wrap from empty or None object"); + } + entries.reset(new entries_t(ptr())); + } + + /// Constructs from a raw pointer. + explicit ufunc(detail::PyUFuncObject* ptr_in) + : ufunc(reinterpret_borrow(reinterpret_cast(ptr_in))) + {} + + ufunc(const ufunc&) = default; + + /// "Flushes" queued UFunc definitions. + ~ufunc() { + if (entries) { + finalize(); + } + } + + /// Gets a NumPy builtin UFunc by name. + static ufunc get_builtin(const char* name) { + module numpy = module::import("numpy"); + return ufunc(numpy.attr(name)); + } + + /// Queues a function to be realized as a UFunc loop. + template + ufunc& def_loop(Func func_in) { + auto func = detail::infer_function_info(func_in).func; + do_register(detail::ufunc_to_ptr(func)); + return *this; + } + + /// Retrieves raw pointer. + detail::PyUFuncObject* ptr() const { + return reinterpret_cast(self().ptr()); + } + + private: + object& self() { return *this; } + const object& self() const { return *this; } + + // Create UFunc object with core type functions if needed, and register user + // functions. + void finalize() { + if (!entries) + pybind11_fail("Object already finalized"); + if (!self()) { + // Create object and register core functions. + auto* h = entries->create_core(); + self() = reinterpret_borrow(reinterpret_cast(h)); + scope_.attr(entries->name()) = self(); + } + // Register user type functions. + entries->create_user(ptr()); + // Leak memory for now so that data lives longer than UFunc. + // TODO(eric.cousineau): Embed this in a capsule and use `keep_alive`. + (void)new std::shared_ptr(entries); + } + + // Registers a function pointer as a UFunc, mapping types to dtype nums. + template + void do_register(detail::ufunc_ptr user) { + constexpr int N = sizeof...(Args); + constexpr int nin = N - 1; + constexpr int nout = 1; + entries->init_or_check_args(nin, nout); + + const int dtype = dtype::of().num(); + const int dummy[] = {(detail::assert_ufunc_dtype_valid(), 0)...}; + (void)dummy; + const std::vector dtype_args = {dtype::of().num()...}; + bool is_core = true; + for (int i = 0; i < N; ++i) { + const size_t ii = static_cast(i); + if (dtype_args[ii] >= detail::npy_api::constants::NPY_USERDEF_) + is_core = false; + } + if (is_core) { + // TODO(eric.cousineau): Consider supporting + // `PyUFunc_ReplaceLoopBySignature_`? + if (self()) + pybind11_fail( + "ufunc: Can't add/replace signatures for core types for an " + "existing ufunc"); + entries->queue_core(user.func, user.data, dtype_args); + } else { + entries->queue_user(user.func, user.data, dtype, dtype_args); + } + } + + // These are only used if we have something new. + handle scope_{}; + + // Contains UFunc entries to flush into actual UFunc registrations. + class entries_t { + public: + // Initialize from existing object. + explicit entries_t(detail::PyUFuncObject* h) { + nin_ = h->nin; + nout_ = h->nout; + } + + // Set up to create a new instance. + explicit entries_t(const char* name) : name_(name) {} + + void init_or_check_args(int nin, int nout) { + if (nin_ != -1 && nout_ != -1) { + if (nin_ != nin) + pybind11_fail("ufunc: Input count mismatch"); + if (nout_ != nout) + pybind11_fail("ufunc: Output count mismatch"); + } + nin_ = nin; + nout_ = nout; + } + + const char* name() const { return name_.c_str(); } + + void queue_core( + detail::PyUFuncGenericFunction func, void* data, + const std::vector& dtype_args) { + assert(nin_ != -1 && nout_ != -1); + assert(static_cast(dtype_args.size()) == nin_ + nout_); + // Store core functionn. + core_funcs_.push_back(func); + core_data_.push_back(data); + const size_t ncore = core_funcs_.size(); + size_t t_index = core_type_args_.size(); + int nargs = nin_ + nout_; + core_type_args_.resize(ncore * static_cast(nargs)); + for (size_t i = 0; i < dtype_args.size(); ++i) { + core_type_args_.at(t_index++) = static_cast(dtype_args[i]); + } + } + + void queue_user( + detail::PyUFuncGenericFunction func, void* data, int dtype, + const std::vector& dtype_args) { + assert(nin_ != -1 && nout_ != -1); + assert(static_cast(dtype_args.size()) == nin_ + nout_); + user_funcs_.push_back(func); + user_data_.push_back(data); + user_types_.push_back(dtype); + user_type_args_.push_back(dtype_args); + } + + detail::PyUFuncObject* create_core() { + const int ncore = static_cast(core_funcs_.size()); + char* name_raw = const_cast(name()); + return reinterpret_cast( + detail::npy_api::get().PyUFunc_FromFuncAndData_( + core_funcs_.data(), core_data_.data(), core_type_args_.data(), + ncore, nin_, nout_, + detail::npy_api::constants::PyUFunc_None_, name_raw, nullptr, 0)); + } + + void create_user(detail::PyUFuncObject* h) { + const size_t nuser = user_funcs_.size(); + for (size_t i = 0; i < nuser; ++i) { + if (detail::npy_api::get().PyUFunc_RegisterLoopForType_( + h, user_types_[i], user_funcs_[i], + user_type_args_[i].data(), user_data_[i]) < 0) + pybind11_fail("ufunc: Failed to register custom ufunc"); + } + } + + private: + int nin_{-1}; + int nout_{-1}; + std::string name_{}; + + // Core. + std::vector core_funcs_; + std::vector core_data_; + std::vector core_type_args_; + + // User. + std::vector user_funcs_; + std::vector user_data_; + std::vector user_types_; + std::vector> user_type_args_; + }; + + std::shared_ptr entries; +}; + +} // namespace pybind11 + +#pragma GCC visibility pop diff --git a/bindings/pybind11_ext/test/numpy_dtypes_user_test.py b/bindings/pybind11_ext/test/numpy_dtypes_user_test.py new file mode 100644 index 000000000000..9cf83a50132e --- /dev/null +++ b/bindings/pybind11_ext/test/numpy_dtypes_user_test.py @@ -0,0 +1,427 @@ +import copy +import unittest + +import numpy as np + +import numpy_dtypes_user_test_util as mut + + +class TestNumpyDtypesUser(unittest.TestCase): + def test_scalar_meta(self): + """Tests basic metadata.""" + self.assertTrue(issubclass(mut.Symbol, np.generic)) + self.assertIsInstance(np.dtype(mut.Symbol), np.dtype) + + def check_scalar(self, actual, expected): + accepted = (mut.Symbol, mut.StrValueExplicit, mut.LengthValueImplicit) + if isinstance(actual, accepted): + self.assertEqual(actual.value(), expected) + else: + raise RuntimeError("Invalid scalar: {}".format(repr(expected))) + + def check_array(self, value, expected): + expected = np.array(expected, dtype=np.object) + self.assertEqual(value.shape, expected.shape) + for a, b in zip(value.flat, expected.flat): + self.check_scalar(a, b) + + def test_scalar_basics(self): + """ + Tests basics for scalars. + Important to do since we had to redo the instance registry to inherit + from `np.generic` :( + """ + # TODO(eric.cousineau): Consider using `pybind11`s `ConstructorStats` + # to do instance tracking. + c1 = mut.Symbol() + c2 = mut.Symbol() + self.assertIsNot(c1, c2) + self.assertIs(c1, c1.self_reference()) + # Test functions. + a = mut.Symbol("a") + self.assertEqual(repr(a), "") + self.assertEqual(str(a), "a") + self.assertEqual(a.value(), "a") + # Copying. + # N.B. Normally, `pybind11` does not implicitly define copy semantics. + # However, for these NumPy dtypes it is made implicit (relying on the + # copy constructor). + b = copy.copy(a) + self.assertIsNot(a, b) + self.assertEqual(a.value(), b.value()) + b = copy.deepcopy(a) + self.assertIsNot(a, b) + self.assertEqual(a.value(), b.value()) + + def test_array_creation_basics(self): + # Uniform creation. + A = np.array([mut.Symbol("a")]) + self.assertEqual(A.dtype, mut.Symbol) + self.assertEqual(A[0].value(), "a") + + def test_array_cast_explicit(self): + # Check idempotent round-trip casts. + A = np.array([mut.Symbol("a")]) + for dtype in (mut.Symbol, np.object, mut.StrValueExplicit): + B = A.astype(dtype) + self.assertEqual(B.dtype, dtype) + C = B.astype(mut.Symbol) + self.assertEqual(C.dtype, mut.Symbol) + self.check_scalar(C[0], "a") + # Check registered explicit casts. + # - From. + from_float = np.array([1.]).astype(mut.Symbol) + self.check_array(from_float, ["float(1)"]) + from_str = np.array([mut.StrValueExplicit("abc")]).astype(mut.Symbol) + self.check_array(from_str, ["abc"]) + from_length = np.array([mut.LengthValueImplicit(1)]).astype(mut.Symbol) + self.check_array(from_length, ["length(1)"]) + # - To. + # N.B. `np.int` may not be the same as `np.int32`; C++ uses `np.int32`. + to_int = A.astype(np.int32) + self.assertEqual(to_int[0], 1) + to_str = A.astype(mut.StrValueExplicit) + self.check_array(to_str, ["a"]) + to_length = A.astype(mut.LengthValueImplicit) + self.check_array(to_length, [1]) + + def test_array_cast_implicit(self): + # By assignment. + a = mut.Symbol("a") + A = np.array([a]) + + def reset(): + A[:] = a + + b_length = mut.LengthValueImplicit(1) + # - Implicitly convertible types. + A[0] = b_length + self.check_array(A, ["length(1)"]) + A[:] = b_length + self.check_array(A, ["length(1)"]) + # - Permitted as in place operation. + reset() + A += mut.LengthValueImplicit(1) + self.check_array(A, ["(a) + (length(1))"]) + # Explicit: Scalar assignment not permitted. + b_str = mut.StrValueExplicit("b") + with self.assertRaises(TypeError): + A[0] = b_str + # N.B. For some reason, NumPy considers this explicit coercion... + A[:] = b_str + self.check_array(A, ["b"]) + # - Permitted as in place operation. + reset() + A += mut.StrValueExplicit("b") + self.check_array(A, ["(a) + (b)"]) + reset() + + def test_array_creation_mixed(self): + # Mixed creation with implicitly convertible types. + with self.assertRaises(TypeError): + # No type specified, NumPy gets confused. + O_ = np.array([mut.Symbol(), mut.LengthValueImplicit(1)]) + A = np.array([ + mut.Symbol(), mut.LengthValueImplicit(1)], dtype=mut.Symbol) + self.check_array(A, ["", "length(1)"]) + + # Mixed creation without implicit casts, yields dtype=object. + O_ = np.array([mut.Symbol(), 1.]) + self.assertEqual(O_.dtype, np.object) + # - Explicit Cast. + A = O_.astype(mut.Symbol) + self.assertEqual(A.dtype, mut.Symbol) + self.check_array(A, ["", "float(1)"]) + + # Mixed creation with explicitly convertible types - does not work. + with self.assertRaises(TypeError): + A = np.array([ + mut.Symbol(), mut.StrValueExplicit("a")], dtype=mut.Symbol) + + def test_array_creation_constants(self): + # Zeros: More so an `empty` array. + Z = np.full((2,), mut.Symbol()) + self.assertEqual(Z.dtype, mut.Symbol) + self.check_array(Z, 2 * [""]) + + # Zeros: For making an "empty" array, but using float conversion. + Z_from_float = np.zeros((2,)).astype(mut.Symbol) + self.check_array(Z_from_float, 2 * ["float(0)"]) + + # Ones: Uses float conversion. + O_from_float = np.ones((2,)).astype(mut.Symbol) + self.check_array(O_from_float, 2 * ["float(1)"]) + + # Linear algebra. + I_from_float = np.eye(2).astype(mut.Symbol) + self.check_array( + I_from_float, + [["float(1)", "float(0)"], ["float(0)", "float(1)"]]) + self.check_array(np.diag(I_from_float), 2 * ["float(1)"]) + + def test_array_creation_constants_bad(self): + """ + WARNING: The following are all BAD. AVOID THEM (as of NumPy v1.15.2). + """ + # BAD Memory: `np.empty` works with uninitialized memory. + # Printing will most likely cause a segfault. + E = np.empty((2,), dtype=mut.Symbol) + self.assertEqual(E.dtype, mut.Symbol) + # BAD Memory: `np.zeros` works by using `memzero`. + # Printing will most likely cause a segfault. + Z = np.zeros((2,), dtype=mut.Symbol) + self.assertEqual(Z.dtype, mut.Symbol) + # BAD Semantics: This requires that `np.long` be added as an implicit + # conversion. + # Could add implicit conversion, but that may wreak havoc. + with self.assertRaises(ValueError): + I_ = np.ones((2,), dtype=mut.Symbol) + + def test_array_ufunc(self): + # - Symbol + a = mut.Symbol("a") + b = mut.Symbol("b") + self.check_scalar( + mut.custom_binary_ufunc(a, b), "custom-symbol(a, b)") + A = [a, a] + B = [b, b] + self.check_array( + mut.custom_binary_ufunc(A, B), ["custom-symbol(a, b)"] * 2) + + # Duplicating values for other tests. + # - LengthValueImplicit + x_length = mut.LengthValueImplicit(10) + self.check_scalar(mut.custom_binary_ufunc(x_length, x_length), 20) + X_length = [x_length, x_length] + self.check_array(mut.custom_binary_ufunc(X_length, X_length), 2 * [20]) + # - StrValueExplicit + x_str = mut.StrValueExplicit("x") + self.check_scalar( + mut.custom_binary_ufunc(x_str, x_str), "custom-str(x, x)") + X_str = [x_str, x_str] + self.check_array( + mut.custom_binary_ufunc(X_str, X_str), 2 * ["custom-str(x, x)"]) + + # - Mixing. + # N.B. For UFuncs, order affects the resulting output when implicit or + # explicit convesions are present. + # - - Symbol + LengthValueImplicit + self.check_scalar( + mut.custom_binary_ufunc(x_length, a), 11) + self.check_array( + mut.custom_binary_ufunc(X_length, A), 2 * [11]) + self.check_scalar( + mut.custom_binary_ufunc(a, x_length), + "custom-symbol(a, length(10))") + self.check_array( + mut.custom_binary_ufunc(A, X_length), + 2 * ["custom-symbol(a, length(10))"]) + # - - Symbol + StrValueExplicit + self.check_scalar( + mut.custom_binary_ufunc(x_str, a), "custom-str(x, a)") + self.check_array( + mut.custom_binary_ufunc(X_str, A), 2 * ["custom-str(x, a)"]) + self.check_scalar( + mut.custom_binary_ufunc(a, x_str), + "custom-symbol(a, x)") + self.check_array( + mut.custom_binary_ufunc(A, X_str), + 2 * ["custom-symbol(a, x)"]) + # - - Symbol + OperandExplicit + x_order = mut.OperandExplicit() + X_order = [x_order, x_order] + self.check_scalar( + mut.custom_binary_ufunc(x_order, a), "custom-operand-lhs(a)") + self.check_array( + mut.custom_binary_ufunc(X_order, A), 2 * ["custom-operand-lhs(a)"]) + self.check_scalar( + mut.custom_binary_ufunc(a, x_order), "custom-operand-rhs(a)") + self.check_array( + mut.custom_binary_ufunc(A, X_order), 2 * ["custom-operand-rhs(a)"]) + + def test_eigen_aliases(self): + a = mut.Symbol("a") + A = np.array([a, a]) + mut.add_one(A) + self.check_scalar(A[0], "(a) + (float(1))") + # Check reference to live stuff. + c = mut.SymbolContainer(2, 2) + mut.add_one(c.symbols()) + self.check_array(c.symbols(), 2 * [2 * ["() + (float(1))"]]) + + def check_binary(self, a, b, fop, value): + """Checks a binary operator for both scalar and array cases.""" + self.check_scalar(fop(a, b), value) + A = np.array([a, a]) + B = np.array([b, b]) + c1, c2 = fop(A, B) + self.check_scalar(c1, value) + self.check_scalar(c2, value) + + def check_binary_with_inplace( + self, a, b, fop, fiop, value, inplace_same=True): + """ + Args: + a: Left-hand operand. + b: Right-hand operand. + fop: Binary operator function function (x, y) -> z. + fiop: Binary operator inplace function (x, y). Must return `x`. + value: Expected value. + inplace_same: + For the scalar case, expects that `a += b` will not implicitly + create a new instance (per Python's math rules). If False, a + new instance must be created. + """ + # Scalar. + self.check_scalar(fop(a, b), value) + c = mut.Symbol(a) + d = fiop(c, b) + if inplace_same: + self.assertIs(c, d) + else: + self.assertIsNot(c, d) + self.check_scalar(d, value) + + # Array. + A = np.array([a, a]) + B = np.array([b, b]) + c1, c2 = fop(A, B) + self.check_scalar(c1, value) + self.check_scalar(c2, value) + C = np.array(A) + D = fiop(C, B) + # Regardless of the operation, numpy arrays should not generate + # temporaries for inplace operations. + self.assertIs(C, D) + c1, c2 = C + self.check_scalar(c1, value) + self.check_scalar(c2, value) + + def test_algebra_closed(self): + """Tests scalar and array algebra with implicit conversions.""" + a = mut.Symbol("a") + b = mut.Symbol("b") + + # Operators. + def fop(x, y): return x + y + + def fiop(x, y): + x += y + return x + self.check_binary_with_inplace(a, a, fop, fiop, "(a) + (a)") + self.check_binary_with_inplace(a, b, fop, fiop, "(a) + (b)") + + def fop(x, y): return x - y + + def fiop(x, y): + x -= y + return x + self.check_binary_with_inplace(a, a, fop, fiop, "(a) - (a)") + self.check_binary_with_inplace(a, b, fop, fiop, "(a) - (b)") + + def fop(x, y): return x * y + + def fiop(x, y): + x *= y + return x + self.check_binary_with_inplace(a, a, fop, fiop, "(a) * (a)") + self.check_binary_with_inplace(a, b, fop, fiop, "(a) * (b)") + + def fop(x, y): + return x / y + + def fiop(x, y): + x /= y + return x + self.check_binary_with_inplace(a, a, fop, fiop, "(a) / (a)") + self.check_binary_with_inplace(a, b, fop, fiop, "(a) / (b)") + + def fop(x, y): return x & y + + def fiop(x, y): + x &= y + return x + self.check_binary_with_inplace(a, a, fop, fiop, "(a) & (a)") + self.check_binary_with_inplace(a, b, fop, fiop, "(a) & (b)") + + def fop(x, y): return x | y + + def fiop(x, y): + x |= y + return x + self.check_binary_with_inplace(a, a, fop, fiop, "(a) | (a)") + self.check_binary_with_inplace(a, b, fop, fiop, "(a) | (b)") + + # Logical. + def fop(x, y): return x == y + self.check_binary(a, a, fop, "(a) == (a)") + self.check_binary(a, b, fop, "(a) == (b)") + + def fop(x, y): return x != y + self.check_binary(a, a, fop, "(a) != (a)") + self.check_binary(a, b, fop, "(a) != (b)") + + def fop(x, y): return x < y + self.check_binary(a, a, fop, "(a) < (a)") + self.check_binary(a, b, fop, "(a) < (b)") + + def fop(x, y): return x <= y + self.check_binary(a, a, fop, "(a) <= (a)") + self.check_binary(a, b, fop, "(a) <= (b)") + + def fop(x, y): return x > y + self.check_binary(a, a, fop, "(a) > (a)") + self.check_binary(a, b, fop, "(a) > (b)") + + def fop(x, y): return x >= y + self.check_binary(a, a, fop, "(a) >= (a)") + self.check_binary(a, b, fop, "(a) >= (b)") + + def test_linear_algebra(self): + a = mut.Symbol("a") + b = mut.Symbol("b") + L = np.array([a, b]) + R = np.array([b, a]) + self.check_scalar(np.dot(L, R), "(() + ((a) * (b))) + ((b) * (a))") + # Vector. + L.shape = (1, 2) + R.shape = (2, 1) + Y = np.dot(L, R) + self.assertEqual(Y.shape, (1, 1)) + self.check_scalar(Y[0, 0], "(() + ((a) * (b))) + ((b) * (a))") + + def test_algebra_order_check(self): + # By construction, `OperandExplicit` only interfaces with `Symbol` by + # explicit operator overloads; no casting / construction is done. + a = mut.Symbol("a") + operand = mut.OperandExplicit() + + def fop(x, y): return x + y + + def fiop(x, y): + x += y + return x + self.check_binary_with_inplace(a, operand, fop, fiop, "(a) + operand") + self.check_binary(operand, a, fop, "operand + (a)") + + def test_algebra_implicit_casting(self): + # N.B. Only tested on a single operator, `__add__` and `__iadd__`. + a = mut.Symbol("a") + + def fop(x, y): return x + y + + def fiop(x, y): + x += y + return x + + # N.B. Implicitly convertible types will enable true in-place + # operations. Explicitly convertible types requires a new value. + b_length = mut.LengthValueImplicit(1) + self.check_binary_with_inplace( + a, b_length, fop, fiop, "(a) + (length(1))", inplace_same=True) + + b_str = mut.StrValueExplicit("b") + self.check_binary_with_inplace( + a, b_str, fop, fiop, "(a) + (b)", inplace_same=False) + + # TODO(eric.cousineau): Check trigonometric UFuncs. diff --git a/bindings/pybind11_ext/test/numpy_dtypes_user_test_util_py.cc b/bindings/pybind11_ext/test/numpy_dtypes_user_test_util_py.cc new file mode 100644 index 000000000000..6431a363d551 --- /dev/null +++ b/bindings/pybind11_ext/test/numpy_dtypes_user_test_util_py.cc @@ -0,0 +1,339 @@ +/// @file +/// Tests NumPy user dtypes, using `pybind11`s C++ testing infrastructure. + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "drake/bindings/pybind11_ext/numpy_dtypes_user.h" + +using std::string; +using std::to_string; +using std::unique_ptr; + +namespace py = pybind11; + +namespace { + +template +using MatrixX = Eigen::Matrix; + +/* +Goals: + + * Show API + * Show operator overloads + * Exercise memory bugs (allocation, etc.) + * Show implicit / explicit conversions. + +The simplest mechanism is to do super simple symbolics. +*/ + +// Captures length of a given `Symbol`. Can implicitly convert to and from a +// `Symbol. +class LengthValueImplicit { + public: + // NOLINTNEXTLINE(runtime/explicit) + LengthValueImplicit(int value) : value_(value) {} + int value() const { return value_; } + + bool operator==(const LengthValueImplicit& other) const { + return value_ == other.value_; + } + private: + int value_{}; +}; + +// Captures value of a given `Symbol`. Can explicitly convert to and from a +// `Symbol. +class StrValueExplicit { + public: + explicit StrValueExplicit(const string& value) : value_(new string(value)) {} + const string& value() const { return *value_; } + private: + std::shared_ptr value_{}; +}; + +// No construction possible from this; purely for testing explicit operator +// overloads. +struct OperandExplicit {}; + +class Symbol { + public: + Symbol() : Symbol("") {} + + Symbol(const Symbol& other) : Symbol(other.value()) {} + // `operator=` must be overloaded so that we do not copy the underyling + // `shared_ptr` (when creating an array repeating from the same scalar). + Symbol& operator=(const Symbol& other) { + // WARNING: Because NumPy can assign from `memzero` memory, we must handle + // this case. + if (!str_) str_.reset(new string()); + *str_ = *other.str_; + return *this; + } + + explicit Symbol(string str) : str_(new string(str)) {} + + // Explicit conversion. + explicit Symbol(const StrValueExplicit& other) : Symbol(other.value()) {} + + // Implicit conversion. + // NOLINTNEXTLINE(runtime/explicit) + Symbol(const LengthValueImplicit& other) + : Symbol(fmt::format("length({})", other.value())) {} + explicit Symbol(double value) : Symbol(fmt::format("float({})", value)) {} + + // N.B. Due to constraints of `pybind11`s architecture, we must try to handle + // `str` conversion from an invalid state. See `add_init`. + // WARNING: If the user encounters `memzero` memory, this case must handled. + string value() const { return str_ ? *str_ : ""; } + + // To be explicit. + operator int() const { return str_->size(); } + operator StrValueExplicit() const { return StrValueExplicit(*str_); } + // To be implicit. + operator LengthValueImplicit() const { return str_->size(); } + + template + static Symbol format(string pattern, const Args&... args) { + return Symbol(fmt::format(pattern, args...)); + } + + // Closed under the following operators: + Symbol& operator+=(const Symbol& rhs) { return inplace_binary("+", rhs); } + Symbol operator+(const Symbol& rhs) const { return binary("+", rhs); } + Symbol& operator-=(const Symbol& rhs) { return inplace_binary("-", rhs); } + Symbol operator-(const Symbol& rhs) const { return binary("-", rhs); } + Symbol& operator*=(const Symbol& rhs) { return inplace_binary("*", rhs); } + Symbol operator*(const Symbol& rhs) const { return binary("*", rhs); } + Symbol& operator/=(const Symbol& rhs) { return inplace_binary("/", rhs); } + Symbol operator/(const Symbol& rhs) const { return binary("/", rhs); } + Symbol& operator&=(const Symbol& rhs) { return inplace_binary("&", rhs); } + Symbol operator&(const Symbol& rhs) const { return binary("&", rhs); } + Symbol& operator|=(const Symbol& rhs) { return inplace_binary("|", rhs); } + Symbol operator|(const Symbol& rhs) const { return binary("|", rhs); } + Symbol operator==(const Symbol& rhs) const { return binary("==", rhs); } + Symbol operator!=(const Symbol& rhs) const { return binary("!=", rhs); } + Symbol operator<(const Symbol& rhs) const { return binary("<", rhs); } + Symbol operator<=(const Symbol& rhs) const { return binary("<=", rhs); } + Symbol operator>(const Symbol& rhs) const { return binary(">", rhs); } + Symbol operator>=(const Symbol& rhs) const { return binary(">=", rhs); } + Symbol operator&&(const Symbol& rhs) const { return binary("&&", rhs); } + Symbol operator||(const Symbol& rhs) const { return binary("||", rhs); } + + // - Not closed. + Symbol& operator+=(const OperandExplicit&) { + *str_ = fmt::format("({}) + operand", *this); + return *this; + } + Symbol operator+(const OperandExplicit&) const { + Symbol lhs(*this); + lhs += OperandExplicit{}; + return lhs; + } + + private: + Symbol binary(const char* op, const Symbol& rhs) const { + Symbol lhs(*this); + lhs.inplace_binary(op, rhs); + return lhs; + } + Symbol& inplace_binary(const char* op, const Symbol& rhs) { + *str_ = fmt::format("({}) {} ({})", value(), op, rhs.value()); + return *this; + } + + // Data member to ensure that we do not get segfaults when carrying around + // `shared_ptr`s, and to ensure that the data is memcpy-moveable. + // N.B. This is not used for Copy-on-Write optimizations. + std::shared_ptr str_; +}; + +Symbol operator+(const OperandExplicit&, const Symbol& rhs) { + return Symbol::format("operand + ({})", rhs); +} + +std::ostream& operator<<(std::ostream& os, const Symbol& s) { + return os << s.value(); +} + +namespace math { + +Symbol abs(const Symbol& s) { return Symbol::format("abs({})", s); } +Symbol cos(const Symbol& s) { return Symbol::format("cos({})", s); } +Symbol sin(const Symbol& s) { return Symbol::format("sin({})", s); } +Symbol pow(const Symbol& a, const Symbol& b) { + return Symbol::format("({}) ^ ({})", a, b); +} + +} // namespace math + +template +auto MakeRepr(const string& name, Return (Class::*method)() const) { + return [name, method](Class* self) { + return py::str("<{} '{}'>").format(name, (self->*method)()); + }; +} + +template +auto MakeStr(Return (Class::*method)() const) { + return [method](Class* self) { + return py::str("{}").format((self->*method)()); + }; +} + +// Simple container to check referencing of symbols. +class SymbolContainer { + public: + SymbolContainer(int rows, int cols) : symbols_(rows, cols) {} + Eigen::Ref> symbols() { return symbols_; } + + private: + MatrixX symbols_; +}; + +} // namespace + +PYBIND11_NUMPY_DTYPE_USER(LengthValueImplicit); +PYBIND11_NUMPY_DTYPE_USER(StrValueExplicit); +PYBIND11_NUMPY_DTYPE_USER(OperandExplicit); +PYBIND11_NUMPY_DTYPE_USER(Symbol); + +namespace { + +PYBIND11_MODULE(numpy_dtypes_user_test_util, m) { + // N.B. You must pre-declare all types that must interact using UFuncs, as + // they must already be registered at that point of defining the UFunc. + py::dtype_user length(m, "LengthValueImplicit"); + py::dtype_user str(m, "StrValueExplicit"); + py::dtype_user operand(m, "OperandExplicit"); + py::dtype_user sym(m, "Symbol"); + + length // BR + .def(py::init()) + .def("value", &LengthValueImplicit::value) + .def("__repr__", + MakeRepr("LengthValueImplicit", &LengthValueImplicit::value)) + .def("__str__", MakeStr(&LengthValueImplicit::value)) + .def_loop(py::self == py::self); + + str // BR + .def(py::init()) + .def("value", &StrValueExplicit::value) + .def("__repr__", MakeRepr("StrValueExplicit", &StrValueExplicit::value)) + .def("__str__", MakeStr(&StrValueExplicit::value)); + + operand // BR + .def(py::init()) + .def("__repr__", + [](const OperandExplicit&) { return ""; }) + .def("__str__", + [](const OperandExplicit&) { return ""; }); + + sym // BR + // Nominal definitions. + .def(py::init()) + .def(py::init()) + .def(py::init()) + // N.B. Constructing `StrValueExplicit` only matters for user experience, + // since it's explicit. However, implicit conversions *must* have an + // accompanying constructor. + .def(py::init()) + .def(py::init()) + .def("__repr__", MakeRepr("Symbol", &Symbol::value)) + .def("__str__", MakeStr(&Symbol::value)) + .def("value", &Symbol::value) + // - Test referencing. + .def("self_reference", + [](const Symbol& self) { return &self; }, + py::return_value_policy::reference) + // Casting. + // - From + // WARNING: See above about implicit conversions + constructors. + .def_loop(py::dtype_method::explicit_conversion()) + .def_loop(py::dtype_method::explicit_conversion< + StrValueExplicit, Symbol>()) + .def_loop(py::dtype_method::implicit_conversion< + LengthValueImplicit, Symbol>()) + // - To + .def_loop(py::dtype_method::explicit_conversion()) + .def_loop(py::dtype_method::explicit_conversion< + Symbol, StrValueExplicit>()) + .def_loop(py::dtype_method::implicit_conversion< + Symbol, LengthValueImplicit>()) + // Operators. + // N.B. Inplace operators do not have UFuncs in NumPy. + // - Math. + .def_loop(py::self + py::self) + .def(py::self += py::self) + .def_loop(py::self - py::self) + .def(py::self -= py::self) + .def_loop(py::self * py::self) + .def(py::self *= py::self) + .def_loop(py::self / py::self) + .def(py::self /= py::self) + // - Bitwise. + .def_loop(py::self & py::self) + .def(py::self &= py::self) + .def_loop(py::self | py::self) + .def(py::self |= py::self) + // - Logical. + .def_loop(py::self == py::self) + .def_loop(py::self != py::self) + .def_loop(py::self < py::self) + .def_loop(py::self <= py::self) + .def_loop(py::self > py::self) + .def_loop(py::self >= py::self) + // - Not closed. + .def_loop(py::self + OperandExplicit{}) + // NOLINTNEXTLINE(whitespace/braces) + .def_loop(OperandExplicit{} + py::self) + .def(py::self += OperandExplicit{}) + // .def_loop(py::self && py::self) + // .def_loop(py::self || py::self) + // Explicit UFunc. + .def_loop(py::dtype_method::dot()) + .def_loop("__pow__", &math::pow) + .def_loop("abs", &math::abs) + .def_loop("cos", &math::cos) + .def_loop("sin", &math::sin); + + py::ufunc(m, "custom_binary_ufunc") + .def_loop([](const Symbol& lhs, const Symbol& rhs) { + return Symbol::format("custom-symbol({}, {})", lhs, rhs); + }) + .def_loop([](const Symbol& lhs, const OperandExplicit& rhs) { + return Symbol::format("custom-operand-rhs({})", lhs); + }) + .def_loop([](const OperandExplicit& lhs, const Symbol& rhs) { + return Symbol::format("custom-operand-lhs({})", rhs); + }) + .def_loop( + [](const LengthValueImplicit& lhs, const LengthValueImplicit& rhs) { + return LengthValueImplicit(lhs.value() + rhs.value()); + }) + .def_loop( + [](const StrValueExplicit& lhs, const StrValueExplicit& rhs) { + return StrValueExplicit(fmt::format( + "custom-str({}, {})", lhs.value(), rhs.value())); + }); + + m.def("add_one", + [](Eigen::Ref> value) { + value.array() += Symbol(1.); + }); + + py::class_(m, "SymbolContainer") + .def(py::init(), py::arg("rows"), py::arg("cols")) + .def("symbols", &SymbolContainer::symbols, + py::return_value_policy::reference_internal); +} + +} // namespace diff --git a/bindings/pydrake/util/BUILD.bazel b/bindings/pydrake/util/BUILD.bazel index 39941da963ae..90e267e4dce0 100644 --- a/bindings/pydrake/util/BUILD.bazel +++ b/bindings/pydrake/util/BUILD.bazel @@ -39,10 +39,18 @@ drake_cc_library( visibility = ["//visibility:public"], ) +drake_cc_library( + name = "function_inference", + hdrs = ["function_inference.h"], + visibility = ["//visibility:public"], + deps = [":type_pack"], +) + drake_cc_library( name = "wrap_function", hdrs = ["wrap_function.h"], visibility = ["//visibility:public"], + deps = [":function_inference"], ) drake_cc_library( @@ -235,6 +243,13 @@ drake_cc_googletest( ], ) +drake_cc_googletest( + name = "function_inference_test", + deps = [ + ":function_inference", + ], +) + drake_cc_googletest( name = "wrap_function_test", deps = [ diff --git a/bindings/pydrake/util/function_inference.h b/bindings/pydrake/util/function_inference.h new file mode 100644 index 000000000000..7d3da2f3b1c5 --- /dev/null +++ b/bindings/pydrake/util/function_inference.h @@ -0,0 +1,116 @@ +#pragma once + +#include +#include + +#include "drake/bindings/pydrake/util/type_pack.h" + +// TODO(eric.cousineau): Figure out how to make this automatically hidden. +#pragma GCC visibility push(hidden) + +namespace drake { +namespace pydrake { +namespace detail { + +struct function_inference { + // Collects both a functor object and its signature for ease of inference. + template + struct info { + // TODO(eric.cousineau): Ensure that this permits copy elision when combined + // with `std::forward(func)`, while still behaving well with primitive + // types. + using Return = ReturnT; + using Args = type_pack; + using Func = std::decay_t; + Func func; + }; + + // Factory method for `info<>`, to be used by `run`. + template + static auto make_inferred_info( + Func&& func, Return (*infer)(Args...) = nullptr) { + (void)infer; + return info{std::forward(func)}; + } + + // Infers `info<>` from a function pointer. + template + static auto run(Return (*func)(Args...)) { + return make_inferred_info(func); + } + + // Infers `info<>` from a mutable method pointer. + template + static auto run(Return (Class::*method)(Args...)) { + auto func = [method](Class& self, Args... args) { + return (self.*method)(std::forward(args)...); + }; + return make_inferred_info(func); + } + + // Infers `info<>` from a const method pointer. + template + static auto run(Return (Class::*method)(Args...) const) { + auto func = [method](const Class& self, Args... args) { + return (self.*method)(std::forward(args)...); + }; + return make_inferred_info(func); + } + + // Helpers for general functor objects. + struct infer_helper { + // Removes class from mutable method pointer for inferring signature + // of functor. + template + static auto remove_class_from_ptr(Return (Class::*)(Args...)) { + using Ptr = Return (*)(Args...); + return Ptr{}; + } + + // Removes class from const method pointer for inferring signature of + // functor. + template + static auto remove_class_from_ptr(Return (Class::*)(Args...) const) { + using Ptr = Return (*)(Args...); + return Ptr{}; + } + + // Infers funtion pointer from functor. + // @pre `Func` must have only *one* overload of `operator()`. + template + static auto infer_function_ptr() { + return remove_class_from_ptr(&Func::operator()); + } + }; + + // SFINAE for functors. + // N.B. This *only* distinguished between function / method pointers and + // lambda objects. It does *not* distinguish among other types. + template + using enable_if_lambda_t = + std::enable_if_t>::value, T>; + + // Infers `info<>` from a generic functor. + template > + static auto run(Func&& func) { + return make_inferred_info( + std::forward(func), + infer_helper::infer_function_ptr>()); + } +}; + +// Infers `function_inference::info<>` from a generic functor, converting form +// the left form to the right form: +// -> Return(*)(Args...) +// Return (Class::*)(Args...) -> Return (*)(Class&, Args...) +// Return (Class::*)(Args...) const -> Return (*)(const Class&, Args...) +template +auto infer_function_info(Func&& func) { + return function_inference::run(std::forward(func)); +} + +} // namespace detail +} // namespace pydrake +} // namespace drake + +#pragma GCC visibility pop diff --git a/bindings/pydrake/util/test/function_inference_test.cc b/bindings/pydrake/util/test/function_inference_test.cc new file mode 100644 index 000000000000..70b35b58cd75 --- /dev/null +++ b/bindings/pydrake/util/test/function_inference_test.cc @@ -0,0 +1,93 @@ +#include "drake/bindings/pydrake/util/function_inference.h" + +#include +#include + +#include + +namespace drake { +namespace pydrake { + +using detail::infer_function_info; + +// Compares types, generating a static_assert that should have a helpful +// context of which types do not match. +template +void check_type() { + // Use this function to inspect types when failure is encountered. + static_assert(std::is_same::value, "Mismatch"); +} + +// Checks signature of a generic functor. +template +void check_pack(type_pack, type_pack) { + using Dummy = int[]; + (void)Dummy{(check_type(), 0)...}; +} + +template +void check_signature(InfoT&& info) { + using Info = std::decay_t; + check_type(); + check_pack(typename Info::Args{}, type_pack{}); +} + +// Simple function. +void IntToVoid(int) {} + +GTEST_TEST(WrapFunction, FunctionPointer) { + auto info = infer_function_info(IntToVoid); + check_signature(info); + int value = 1; + info.func(value); +} + +// Lambdas / basic functors. +GTEST_TEST(WrapFunction, Lambda) { + int value{0}; + auto lambda = [](int) {}; + { + auto info = infer_function_info(lambda); + check_signature(info); + info.func(value); + } + + { + std::function func = lambda; + auto info = infer_function_info(func); + info.func(value); + } +} + +// Class methods. +class MyClass { + public: + static int MethodStatic(int value) { return value; } + int MethodMutable(int value) { return value + value_; } + int MethodConst(int value) const { return value * value_; } + + private: + int value_{10}; +}; + +GTEST_TEST(WrapFunction, Methods) { + int value = 2; + + MyClass c; + const MyClass& c_const{c}; + + auto info1 = infer_function_info(&MyClass::MethodStatic); + check_signature(info1); + EXPECT_EQ(info1.func(value), 2); + + auto info2 = infer_function_info(&MyClass::MethodMutable); + check_signature(info2); + EXPECT_EQ(info2.func(c, value), 12); + + auto info3 = infer_function_info(&MyClass::MethodConst); + check_signature(info3); + EXPECT_EQ(info3.func(c_const, value), 20); +} + +} // namespace pydrake +} // namespace drake diff --git a/bindings/pydrake/util/test/type_pack_test.cc b/bindings/pydrake/util/test/type_pack_test.cc index 7b32ea82d8eb..bbd9313a36f8 100644 --- a/bindings/pydrake/util/test/type_pack_test.cc +++ b/bindings/pydrake/util/test/type_pack_test.cc @@ -49,6 +49,25 @@ GTEST_TEST(TypeUtilTest, TypeTags) { decltype(pack_check), type_pack>::value)); } +GTEST_TEST(TypeUtilTest, Concat) { + using A = type_pack; + using B = type_pack; + using AB = type_pack; + EXPECT_TRUE((std::is_same< + decltype(type_pack_concat(A{}, B{})), AB>::value)); +} + +// Adds a pointer to a given type. +template +using Ptr = T*; + +GTEST_TEST(TypeUtilTest, Apply) { + using A = type_pack; + using B = type_pack; + EXPECT_TRUE((std::is_same< + decltype(type_pack_apply(A{})), B>::value)); +} + GTEST_TEST(TypeUtilTest, Bind) { using T_0 = Pack::bind; EXPECT_TRUE((std::is_same< diff --git a/bindings/pydrake/util/test/wrap_function_test.cc b/bindings/pydrake/util/test/wrap_function_test.cc index 48f74bfe23ed..366535b9b1f6 100644 --- a/bindings/pydrake/util/test/wrap_function_test.cc +++ b/bindings/pydrake/util/test/wrap_function_test.cc @@ -10,6 +10,7 @@ namespace drake { namespace pydrake { +namespace { // N.B. Anonymous namespace not used as it makes failure messages // (static_assert) harder to interpret. @@ -67,10 +68,10 @@ GTEST_TEST(WrapFunction, Methods) { EXPECT_EQ(WrapIdentity(&MyClass::MethodStatic)(value), 2); // Wrapped signature: int (MyClass*, int) auto method_mutable = WrapIdentity(&MyClass::MethodMutable); - EXPECT_EQ(method_mutable(&c, value), 12); + EXPECT_EQ(method_mutable(c, value), 12); // method_mutable(&c_const, value); // Should fail. // Wrapped signature: int (const MyClass*, int) - EXPECT_EQ(WrapIdentity(&MyClass::MethodConst)(&c_const, value), 20); + EXPECT_EQ(WrapIdentity(&MyClass::MethodConst)(c_const, value), 20); } // Move-only arguments. @@ -245,7 +246,7 @@ struct check_signature { template static void run_impl( - const detail::function_info< + const detail::function_inference::info< FuncActual, ReturnActual, ArgsActual...>& info) { check_type(); using Dummy = int[]; @@ -390,5 +391,6 @@ GTEST_TEST(WrapFunction, ChangeCallbackOnly) { check_expected::run(wrapped); } +} // namespace } // namespace pydrake } // namespace drake diff --git a/bindings/pydrake/util/type_pack.h b/bindings/pydrake/util/type_pack.h index d422c4448194..920766a9dc85 100644 --- a/bindings/pydrake/util/type_pack.h +++ b/bindings/pydrake/util/type_pack.h @@ -111,6 +111,18 @@ struct type_pack { using type_at = typename drake::type_at::type; }; +/// Concatenates two packs. +template +auto type_pack_concat(type_pack = {}, type_pack = {}) { + return type_pack{}; +} + +/// Applys a template to the parameters of a pack. +template