Skip to content

Commit

Permalink
pydrake systems: Update semantics to leverage dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
EricCousineau-TRI committed May 11, 2018
1 parent 7ab9719 commit 38d3d38
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 105 deletions.
3 changes: 2 additions & 1 deletion bindings/pydrake/systems/framework_py_semantics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ void DefineFrameworkPySemantics(py::module m) {
// `AddValueInstantiation` for more information.
// Keep alive, ownership: `value` keeps `self` alive.
py::keep_alive<2, 1>(), py::arg("abstract_params"))
.def("SetFrom", &Parameters<T>::SetFrom);
.def("SetFrom", &Parameters<T>::SetFrom)
.def("CopyFrom", &Parameters<T>::CopyFrom);

// State.
DefineTemplateClassWithDefault<State<T>>(m, "State", GetPyParam<T>())
Expand Down
14 changes: 7 additions & 7 deletions bindings/pydrake/systems/framework_py_values.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "drake/bindings/pydrake/pydrake_pybind.h"
#include "drake/bindings/pydrake/systems/systems_pybind.h"
#include "drake/bindings/pydrake/util/eigen_pybind.h"
#include "drake/bindings/pydrake/util/wrap_pybind.h"
#include "drake/systems/framework/basic_vector.h"
#include "drake/systems/framework/subvector.h"
#include "drake/systems/framework/supervector.h"
Expand Down Expand Up @@ -43,18 +44,17 @@ void DefineFrameworkPyValues(py::module m) {
// N.B. Place `init<VectorX<T>>` `init<int>` so that we do not implicitly
// convert scalar-size `np.array` objects to `int` (since this is normally
// permitted).
.def(py::init<VectorX<T>>())
// N.B. Also ensure that we use `greedy_arg` to prevent ambiguous
// overloads when using scalars vs. lists vs. numpy arrays. See
// `greedy_arg` for more information.
.def(py::init([](greedy_arg<VectorX<T>> in) {
return new BasicVector<T>(*in);
}))
.def(py::init<int>())
.def("get_value",
[](const BasicVector<T>* self) -> Eigen::Ref<const VectorX<T>> {
return self->get_value();
}, py_reference_internal)
// TODO(eric.cousineau): Remove this once `get_value` is changed, or
// reference semantics are changed for custom dtypes.
.def("_get_value_copy",
[](const BasicVector<T>* self) -> VectorX<T> {
return self->get_value();
})
.def("get_mutable_value",
[](BasicVector<T>* self) -> Eigen::Ref<VectorX<T>> {
return self->get_mutable_value();
Expand Down
151 changes: 79 additions & 72 deletions bindings/pydrake/systems/test/test_util_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ using systems::ConstantVectorSource;
namespace pydrake {
namespace {

using T = double;

// Informs listener when this class is deleted.
template <typename T>
class DeleteListenerSystem : public ConstantVectorSource<T> {
public:
using Base = ConstantVectorSource<T>;

explicit DeleteListenerSystem(std::function<void()> delete_callback)
: ConstantVectorSource(VectorX<T>::Constant(1, 0.)),
: Base(VectorX<T>::Constant(1, 0.)),
delete_callback_(delete_callback) {}

~DeleteListenerSystem() override {
Expand All @@ -34,10 +35,11 @@ class DeleteListenerSystem : public ConstantVectorSource<T> {
std::function<void()> delete_callback_;
};

template <typename T>
class DeleteListenerVector : public BasicVector<T> {
public:
explicit DeleteListenerVector(std::function<void()> delete_callback)
: BasicVector(VectorX<T>::Constant(1, 0.)),
: BasicVector<T>(VectorX<T>::Constant(1, 0.)),
delete_callback_(delete_callback) {}

~DeleteListenerVector() override {
Expand Down Expand Up @@ -89,10 +91,10 @@ PYBIND11_MODULE(test_util, m) {
py::module::import("pydrake.systems.framework");
py::module::import("pydrake.systems.primitives");

py::class_<DeleteListenerSystem, ConstantVectorSource<T>>(
py::class_<DeleteListenerSystem<double>, ConstantVectorSource<double>>(
m, "DeleteListenerSystem")
.def(py::init<std::function<void()>>());
py::class_<DeleteListenerVector, BasicVector<T>>(
py::class_<DeleteListenerVector<double>, BasicVector<double>>(
m, "DeleteListenerVector")
.def(py::init<std::function<void()>>());

Expand All @@ -104,80 +106,85 @@ PYBIND11_MODULE(test_util, m) {
pysystems::AddValueInstantiation<MoveOnlyType>(m);

// A 2-dimensional subclass of BasicVector.
py::class_<MyVector2<T>, BasicVector<T>>(m, "MyVector2")
py::class_<MyVector2<double>, BasicVector<double>>(m, "MyVector2")
.def(py::init<const Eigen::Vector2d&>(), py::arg("data"));

m.def("make_unknown_abstract_value", []() {
return AbstractValue::Make(UnknownType{});
});

// Call overrides to ensure a custom Python class can override these methods.

auto clone_vector = [](const VectorBase<T>& vector) {
auto copy = std::make_unique<BasicVector<T>>(vector.size());
copy->SetFrom(vector);
return copy;
auto bind_common_scalar_types = [&m](auto dummy) {
using T = decltype(dummy);

auto clone_vector = [](const VectorBase<T>& vector) {
auto copy = std::make_unique<BasicVector<T>>(vector.size());
copy->SetFrom(vector);
return copy;
};

m.def("call_leaf_system_overrides", [clone_vector](
const LeafSystem<T>& system) {
py::dict results;
auto context = system.AllocateContext();
{
// Call `Publish` to test `DoPublish`.
auto events =
LeafEventCollection<PublishEvent<T>>::MakeForcedEventCollection();
system.Publish(*context, *events);
}
{
// Call `HasDirectFeedthrough` to test `DoHasDirectFeedthrough`.
results["has_direct_feedthrough"] = system.HasDirectFeedthrough(0, 0);
}
{
// Call `CalcDiscreteVariableUpdates` to test
// `DoCalcDiscreteVariableUpdates`.
auto& state = context->get_mutable_discrete_state();
DiscreteValues<T> state_copy(clone_vector(state.get_vector()));
system.CalcDiscreteVariableUpdates(*context, &state_copy);

// From t=0, return next update time for testing discrete time.
// If there is an abstract / unrestricted update, this assumes that
// `dt_discrete < dt_abstract`.
systems::LeafCompositeEventCollection<T> events;
results["discrete_next_t"] = system.CalcNextUpdateTime(
*context, &events);
}
return results;
});

m.def("call_vector_system_overrides", [clone_vector](
const VectorSystem<T>& system, Context<T>* context,
bool is_discrete, double dt) {
// While this is not convention, update state first to ensure that our
// output incorporates it correctly, for testing purposes.
// TODO(eric.cousineau): Add (Continuous|Discrete)State::Clone().
if (is_discrete) {
auto& state = context->get_mutable_discrete_state();
DiscreteValues<T> state_copy(
clone_vector(state.get_vector()));
system.CalcDiscreteVariableUpdates(
*context, &state_copy);
state.CopyFrom(state_copy);
} else {
auto& state = context->get_mutable_continuous_state();
ContinuousState<T> state_dot(
clone_vector(state.get_vector()),
state.get_generalized_position().size(),
state.get_generalized_velocity().size(),
state.get_misc_continuous_state().size());
system.CalcTimeDerivatives(*context, &state_dot);
state.SetFromVector(
state.CopyToVector() + dt * state_dot.CopyToVector());
}
// Calculate output.
auto output = system.AllocateOutput(*context);
system.CalcOutput(*context, output.get());
return output;
});
};

m.def("call_leaf_system_overrides", [clone_vector](
const LeafSystem<T>& system) {
py::dict results;
auto context = system.AllocateContext();
{
// Call `Publish` to test `DoPublish`.
auto events =
LeafEventCollection<PublishEvent<T>>::MakeForcedEventCollection();
system.Publish(*context, *events);
}
{
// Call `HasDirectFeedthrough` to test `DoHasDirectFeedthrough`.
results["has_direct_feedthrough"] = system.HasDirectFeedthrough(0, 0);
}
{
// Call `CalcDiscreteVariableUpdates` to test
// `DoCalcDiscreteVariableUpdates`.
auto& state = context->get_mutable_discrete_state();
DiscreteValues<T> state_copy(clone_vector(state.get_vector()));
system.CalcDiscreteVariableUpdates(*context, &state_copy);

// From t=0, return next update time for testing discrete time.
// If there is an abstract / unrestricted update, this assumes that
// `dt_discrete < dt_abstract`.
systems::LeafCompositeEventCollection<double> events;
results["discrete_next_t"] = system.CalcNextUpdateTime(*context, &events);
}
return results;
});

m.def("call_vector_system_overrides", [clone_vector](
const VectorSystem<T>& system, Context<T>* context,
bool is_discrete, double dt) {
// While this is not convention, update state first to ensure that our
// output incorporates it correctly, for testing purposes.
// TODO(eric.cousineau): Add (Continuous|Discrete)State::Clone().
if (is_discrete) {
auto& state = context->get_mutable_discrete_state();
DiscreteValues<T> state_copy(
clone_vector(state.get_vector()));
system.CalcDiscreteVariableUpdates(
*context, &state_copy);
state.SetFrom(state_copy);
} else {
auto& state = context->get_mutable_continuous_state();
ContinuousState<T> state_dot(
clone_vector(state.get_vector()),
state.get_generalized_position().size(),
state.get_generalized_velocity().size(),
state.get_misc_continuous_state().size());
system.CalcTimeDerivatives(*context, &state_dot);
state.SetFromVector(
state.CopyToVector() + dt * state_dot.CopyToVector());
}
// Calculate output.
auto output = system.AllocateOutput(*context);
system.CalcOutput(*context, output.get());
return output;
});
type_visit(bind_common_scalar_types, pysystems::CommonScalarPack{});
}

} // namespace pydrake
Expand Down
67 changes: 42 additions & 25 deletions bindings/pydrake/systems/test/value_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
import unittest
import numpy as np

from pydrake.autodiffutils import AutoDiffXd
from pydrake.symbolic import Expression
from pydrake.systems.framework import (
AbstractValue,
BasicVector,
Parameters,
BasicVector, BasicVector_,
Parameters, Parameters_,
Value,
VectorBase,
)
Expand All @@ -23,45 +25,53 @@ def pass_through(x):
return x


# TODO(eric.cousineau): Add negative (or positive) test cases for AutoDiffXd
# and Symbolic once they are in the bindings.
def int_list(x):
return [int(xi) for xi in x]


class TestValue(unittest.TestCase):
def test_basic_vector_double(self):
def assertArrayEqual(self, lhs, rhs):
# TODO(eric.cousineau): Place in `pydrake.test.unittest_mixins`.
lhs, rhs = np.array(lhs), np.array(rhs)
if lhs.dtype == Expression or rhs.dtype == Expression:
lhs, rhs = lhs.astype(Expression), rhs.astype(Expression)
self.assertTrue(Expression.equal_to(lhs, rhs).all())
else:
self.assertTrue(np.allclose(lhs, rhs))

def test_basic_vector(self):
map(self._check_basic_vector, (float, AutoDiffXd, Expression))

def _check_basic_vector(self, T):
# Test constructing vectors of sizes [0, 1, 2], and ensure that we can
# construct from both lists and `np.array` objects with no ambiguity.
for n in [0, 1, 2]:
for wrap in [pass_through, np.array]:
for wrap in [pass_through, int_list, np.array]:
# Ensure that we can get vectors templated on double by
# reference.
expected_init = wrap(map(float, range(n)))
expected_add = wrap([x + 1 for x in expected_init])
expected_set = wrap([x + 10 for x in expected_init])

value_data = BasicVector(expected_init)
value_data = BasicVector_[T](expected_init)
value = value_data.get_mutable_value()
self.assertTrue(np.allclose(value, expected_init))
self.assertArrayEqual(value, expected_init)

# Add value directly.
# TODO(eric.cousineau): Determine if there is a way to extract
# the pointer referred to by the buffer (e.g. `value.data`).
value[:] += 1
self.assertTrue(np.allclose(value, expected_add))
self.assertTrue(
np.allclose(value_data.get_value(), expected_add))
self.assertTrue(
np.allclose(value_data.get_mutable_value(), expected_add))
self.assertArrayEqual(value, expected_add)
self.assertArrayEqual(value_data.get_value(), expected_add)
self.assertArrayEqual(
value_data.get_mutable_value(), expected_add)

# Set value from `BasicVector`.
value_data.SetFromVector(expected_set)
self.assertTrue(np.allclose(value, expected_set))
self.assertTrue(
np.allclose(value_data.get_value(), expected_set))
self.assertTrue(
np.allclose(value_data.get_mutable_value(), expected_set))
self.assertArrayEqual(value, expected_set)
self.assertArrayEqual(value_data.get_value(), expected_set)
self.assertArrayEqual(
value_data.get_mutable_value(), expected_set)
# Ensure we can construct from size.
value_data = BasicVector(n)
value_data = BasicVector_[T](n)
self.assertEquals(value_data.size(), n)
# Ensure we can clone.
value_copies = [
Expand Down Expand Up @@ -142,13 +152,20 @@ def test_abstract_value_unknown(self):
]), cm.exception.message)

def test_parameters_api(self):
map(self._check_parameters_api, (float, AutoDiffXd, Expression))

def _check_parameters_api(self, T):
Parameters = Parameters_[T]
BasicVector = BasicVector_[T]

def compare(actual, expected):
self.assertEquals(type(actual), type(expected))
if isinstance(actual, VectorBase):
self.assertTrue(
np.allclose(actual.get_value(), expected.get_value()))
if isinstance(actual, BasicVector):
self.assertArrayEqual(actual.get_value(), expected.get_value())
else:
assert isinstance(actual, Value[str])
# Strings getting converted to numpy arrays is no bueno. Do
# scalar comparison.
self.assertEquals(actual.get_value(), expected.get_value())

model_numeric = BasicVector([0.])
Expand All @@ -170,7 +187,7 @@ def compare(actual, expected):
params.set_abstract_parameters(
params.get_abstract_parameters().Clone())
# WARNING: This may invalidate old references!
params.SetFrom(copy.deepcopy(params))
params.CopyFrom(copy.deepcopy(params))

# Test alternative constructors.
ctor_test = [
Expand Down
Loading

0 comments on commit 38d3d38

Please sign in to comment.