From 2dbd9858252ecc5c3dee440a3ae213cac01c7582 Mon Sep 17 00:00:00 2001 From: Michael Demoret <42954918+mdemoret-nv@users.noreply.github.com> Date: Thu, 7 Mar 2024 10:38:46 -0500 Subject: [PATCH] Adding RoundRobinRouter node type for distributing values to downstream nodes (#449) 1. Adds a new C++ type `RoundRobinRouterTypeless` which is very similar to `BroadcastTypeless` except it only pushes values to one of the downstream connections instead of copying 2. Adds a new Python type `RoundRobinRouter` which allows using the `RoundRobinRouterTypeless` from python 3. Adds a C++ test to confirm connectivity 4. Adds Python tests to verify output Authors: - Michael Demoret (https://github.com/mdemoret-nv) Approvers: - Devin Robison (https://github.com/drobison00) URL: https://github.com/nv-morpheus/MRC/pull/449 --- .../operators/round_robin_router_typeless.hpp | 144 +++++++++++++++++ cpp/mrc/tests/test_edges.cpp | 18 ++- python/mrc/core/node.cpp | 12 +- python/tests/test_edges.py | 148 ++++++++++++++---- 4 files changed, 287 insertions(+), 35 deletions(-) create mode 100644 cpp/mrc/include/mrc/node/operators/round_robin_router_typeless.hpp diff --git a/cpp/mrc/include/mrc/node/operators/round_robin_router_typeless.hpp b/cpp/mrc/include/mrc/node/operators/round_robin_router_typeless.hpp new file mode 100644 index 000000000..0eafd8572 --- /dev/null +++ b/cpp/mrc/include/mrc/node/operators/round_robin_router_typeless.hpp @@ -0,0 +1,144 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "mrc/edge/deferred_edge.hpp" + +#include +#include +#include + +namespace mrc::node { + +class RoundRobinRouterTypeless : public edge::IWritableProviderBase, public edge::IWritableAcceptorBase +{ + public: + std::shared_ptr get_writable_edge_handle() const override + { + auto* self = const_cast(this); + + // Create a new upstream edge. On connection, have it attach to any downstreams + auto deferred_ingress = std::make_shared( + [self](std::shared_ptr deferred_edge) { + // Set the broadcast indices function + deferred_edge->set_indices_fn([self](edge::DeferredWritableMultiEdgeBase& deferred_edge) { + // Increment the index and return the key for that index + auto next_idx = self->m_current_idx++; + + auto current_keys = deferred_edge.edge_connection_keys(); + + return std::vector{current_keys[next_idx % current_keys.size()]}; + }); + + // Need to work with weak ptr here otherwise we will keep it from closing + std::weak_ptr weak_deferred_edge = deferred_edge; + + // Use a connector here in case the object never gets set to an edge + deferred_edge->add_connector([self, weak_deferred_edge]() { + // Lock whenever working on the handles + std::unique_lock lock(self->m_mutex); + + // Save to the upstream handles + self->m_upstream_handles.emplace_back(weak_deferred_edge); + + auto deferred_edge = weak_deferred_edge.lock(); + + CHECK(deferred_edge) << "Edge was destroyed before making connection."; + + for (const auto& downstream : self->m_downstream_handles) + { + auto count = deferred_edge->edge_connection_count(); + + // Connect + deferred_edge->set_writable_edge_handle(count, downstream); + } + + // Now add a disconnector that will remove it from the list + deferred_edge->add_disconnector([self, weak_deferred_edge]() { + // Need to lock here since this could be driven by different progress engines + std::unique_lock lock(self->m_mutex); + + bool is_expired = weak_deferred_edge.expired(); + + // Cull all expired ptrs from the list + auto iter = self->m_upstream_handles.begin(); + + while (iter != self->m_upstream_handles.end()) + { + if ((*iter).expired()) + { + iter = self->m_upstream_handles.erase(iter); + } + else + { + ++iter; + } + } + + // If there are no more upstream handles, then delete the downstream + if (self->m_upstream_handles.empty()) + { + self->m_downstream_handles.clear(); + } + }); + }); + }); + + return deferred_ingress; + } + + edge::EdgeTypeInfo writable_provider_type() const override + { + return edge::EdgeTypeInfo::create_deferred(); + } + + void set_writable_edge_handle(std::shared_ptr ingress) override + { + // Lock whenever working on the handles + std::unique_lock lock(m_mutex); + + // We have a new downstream object. Hold onto it + m_downstream_handles.push_back(ingress); + + // If we have an upstream object, try to make a connection now + for (auto& upstream_weak : m_upstream_handles) + { + auto upstream = upstream_weak.lock(); + + CHECK(upstream) << "Upstream edge went out of scope before downstream edges were connected"; + + auto count = upstream->edge_connection_count(); + + // Connect + upstream->set_writable_edge_handle(count, ingress); + } + } + + edge::EdgeTypeInfo writable_acceptor_type() const override + { + return edge::EdgeTypeInfo::create_deferred(); + } + + private: + std::mutex m_mutex; + std::atomic_size_t m_current_idx{0}; + std::vector> m_upstream_handles; + std::vector> m_downstream_handles; +}; + +} // namespace mrc::node diff --git a/cpp/mrc/tests/test_edges.cpp b/cpp/mrc/tests/test_edges.cpp index 91c6d4e09..6d79f37dd 100644 --- a/cpp/mrc/tests/test_edges.cpp +++ b/cpp/mrc/tests/test_edges.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -29,6 +29,7 @@ #include "mrc/node/operators/broadcast.hpp" #include "mrc/node/operators/combine_latest.hpp" #include "mrc/node/operators/node_component.hpp" +#include "mrc/node/operators/round_robin_router_typeless.hpp" #include "mrc/node/operators/router.hpp" #include "mrc/node/rx_node.hpp" #include "mrc/node/sink_channel_owner.hpp" @@ -666,6 +667,21 @@ TEST_F(TestEdges, SourceToRouterToDifferentSinks) sink1->run(); } +TEST_F(TestEdges, SourceToRoundRobinRouterTypelessToDifferentSinks) +{ + auto source = std::make_shared>(); + auto router = std::make_shared(); + auto sink1 = std::make_shared>(); + auto sink2 = std::make_shared>(); + + mrc::make_edge(*source, *router); + mrc::make_edge(*router, *sink1); + mrc::make_edge(*router, *sink2); + + source->run(); + sink1->run(); +} + TEST_F(TestEdges, SourceToBroadcastToSink) { auto source = std::make_shared>(); diff --git a/python/mrc/core/node.cpp b/python/mrc/core/node.cpp index bbbdfe658..cc1a43d1d 100644 --- a/python/mrc/core/node.cpp +++ b/python/mrc/core/node.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +20,7 @@ #include "pymrc/utils.hpp" #include "mrc/node/operators/broadcast.hpp" +#include "mrc/node/operators/round_robin_router_typeless.hpp" #include "mrc/segment/builder.hpp" #include "mrc/segment/object.hpp" #include "mrc/utils/string_utils.hpp" @@ -58,6 +59,15 @@ PYBIND11_MODULE(node, py_mod) return node; })); + py::class_, + mrc::segment::ObjectProperties, + std::shared_ptr>>(py_mod, "RoundRobinRouter") + .def(py::init<>([](mrc::segment::IBuilder& builder, std::string name) { + auto node = builder.construct_object(name); + + return node; + })); + py_mod.attr("__version__") = MRC_CONCAT_STR(mrc_VERSION_MAJOR << "." << mrc_VERSION_MINOR << "." << mrc_VERSION_PATCH); } diff --git a/python/tests/test_edges.py b/python/tests/test_edges.py index 98ed11d0e..4dca8cc6f 100644 --- a/python/tests/test_edges.py +++ b/python/tests/test_edges.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -252,6 +252,16 @@ def add_broadcast(seg: mrc.Builder, *upstream: mrc.SegmentObject): return node +def add_round_robin_router(seg: mrc.Builder, *upstream: mrc.SegmentObject): + + node = mrc.core.node.RoundRobinRouter(seg, "RoundRobinRouter") + + for u in upstream: + seg.make_edge(u, node) + + return node + + # THIS TEST IS CAUSING ISSUES WHEN RUNNING ALL TESTS TOGETHER # @dataclasses.dataclass @@ -431,14 +441,15 @@ def fail_if_more_derived_type(combo: typing.Tuple): @pytest.mark.parametrize("source_cpp", [True, False], ids=["source_cpp", "source_py"]) @pytest.mark.parametrize("sink1_cpp", [True, False], ids=["sink1_cpp", "sink2_py"]) @pytest.mark.parametrize("sink2_cpp", [True, False], ids=["sink2_cpp", "sink2_py"]) -@pytest.mark.parametrize("source_type,sink1_type,sink2_type", - gen_parameters("source", - "sink1", - "sink2", - is_fail_fn=fail_if_more_derived_type, - values={ - "base": m.Base, "derived": m.DerivedA - })) +@pytest.mark.parametrize( + "source_type,sink1_type,sink2_type", + gen_parameters("source", + "sink1", + "sink2", + is_fail_fn=fail_if_more_derived_type, + values={ + "base": m.Base, "derived": m.DerivedA + })) def test_source_to_broadcast_to_sinks(run_segment, sink1_component: bool, sink2_component: bool, @@ -503,13 +514,84 @@ def segment_init(seg: mrc.Builder): assert results == expected_node_counts +@pytest.mark.parametrize("sink1_component,sink2_component", + gen_parameters("sink1", "sink2", is_fail_fn=lambda x: False)) +@pytest.mark.parametrize("source_cpp", [True, False], ids=["source_cpp", "source_py"]) +@pytest.mark.parametrize("sink1_cpp", [True, False], ids=["sink1_cpp", "sink2_py"]) +@pytest.mark.parametrize("sink2_cpp", [True, False], ids=["sink2_cpp", "sink2_py"]) +@pytest.mark.parametrize( + "source_type,sink1_type,sink2_type", + gen_parameters("source", + "sink1", + "sink2", + is_fail_fn=fail_if_more_derived_type, + values={ + "base": m.Base, "derived": m.DerivedA + })) +def test_source_to_round_robin_router_to_sinks(run_segment, + sink1_component: bool, + sink2_component: bool, + source_cpp: bool, + sink1_cpp: bool, + sink2_cpp: bool, + source_type: type, + sink1_type: type, + sink2_type: type): + + def segment_init(seg: mrc.Builder): + + source = add_source(seg, is_cpp=source_cpp, data_type=source_type, is_component=False) + broadcast = add_round_robin_router(seg, source) + add_sink(seg, + broadcast, + is_cpp=sink1_cpp, + data_type=sink1_type, + is_component=sink1_component, + suffix="1", + count=3) + add_sink(seg, + broadcast, + is_cpp=sink2_cpp, + data_type=sink2_type, + is_component=sink2_component, + suffix="2", + count=2) + + results = run_segment(segment_init) + + assert results == expected_node_counts + + +@pytest.mark.parametrize("sink1_component,sink2_component", + gen_parameters("sink1", "sink2", is_fail_fn=lambda x: False)) +@pytest.mark.parametrize("source_cpp", [True, False], ids=["source_cpp", "source_py"]) +@pytest.mark.parametrize("sink1_cpp", [True, False], ids=["sink1_cpp", "sink1_py"]) +@pytest.mark.parametrize("sink2_cpp", [True, False], ids=["sink2_cpp", "sink2_py"]) +def test_multi_source_to_round_robin_router_to_multi_sink(run_segment, + sink1_component: bool, + sink2_component: bool, + source_cpp: bool, + sink1_cpp: bool, + sink2_cpp: bool): + + def segment_init(seg: mrc.Builder): + + source1 = add_source(seg, is_cpp=source_cpp, data_type=m.Base, is_component=False, suffix="1") + source2 = add_source(seg, is_cpp=source_cpp, data_type=m.Base, is_component=False, suffix="2") + broadcast = add_round_robin_router(seg, source1, source2) + add_sink(seg, broadcast, is_cpp=sink1_cpp, data_type=m.Base, is_component=sink1_component, suffix="1") + add_sink(seg, broadcast, is_cpp=sink2_cpp, data_type=m.Base, is_component=sink2_component, suffix="2") + + results = run_segment(segment_init) + + assert results == expected_node_counts + + @pytest.mark.parametrize("source_cpp", [True, False], ids=["source_cpp", "source_py"]) -@pytest.mark.parametrize("source_type", - gen_parameters("source", - is_fail_fn=lambda _: False, - values={ - "base": m.Base, "derived": m.DerivedA - })) +@pytest.mark.parametrize( + "source_type", gen_parameters("source", is_fail_fn=lambda _: False, values={ + "base": m.Base, "derived": m.DerivedA + })) def test_source_to_null(run_segment, source_cpp: bool, source_type: type): def segment_init(seg: mrc.Builder): @@ -522,24 +604,24 @@ def segment_init(seg: mrc.Builder): assert results == expected_node_counts -@pytest.mark.parametrize("source_cpp,node_cpp", - gen_parameters("source", "node", is_fail_fn=lambda _: False, values={ - "cpp": True, "py": False - })) -@pytest.mark.parametrize("source_type,node_type", - gen_parameters("source", - "node", - is_fail_fn=fail_if_more_derived_type, - values={ - "base": m.Base, "derived": m.DerivedA - })) -@pytest.mark.parametrize("source_component,node_component", - gen_parameters("source", - "node", - is_fail_fn=lambda x: x[0] and x[1], - values={ - "run": False, "com": True - })) +@pytest.mark.parametrize( + "source_cpp,node_cpp", + gen_parameters("source", "node", is_fail_fn=lambda _: False, values={ + "cpp": True, "py": False + })) +@pytest.mark.parametrize( + "source_type,node_type", + gen_parameters("source", + "node", + is_fail_fn=fail_if_more_derived_type, + values={ + "base": m.Base, "derived": m.DerivedA + })) +@pytest.mark.parametrize( + "source_component,node_component", + gen_parameters("source", "node", is_fail_fn=lambda x: x[0] and x[1], values={ + "run": False, "com": True + })) def test_source_to_node_to_null(run_segment, source_cpp: bool, node_cpp: bool,