From b4c7bd7fd106b082ae1c91e89d58b08e8ab3c015 Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Mon, 28 Nov 2022 16:51:51 -0700 Subject: [PATCH] Adds missing functionality to allow dynamic modules to forward their configuration to nested children. (#224) Adds missing functionality to allow dynamic modules to forward their configuration to nested children. Closes #223 Authors: - Devin Robison (https://github.com/drobison00) Approvers: - Michael Demoret (https://github.com/mdemoret-nv) URL: https://github.com/nv-morpheus/SRF/pull/224 --- include/srf/segment/builder.hpp | 6 ++++ python/srf/_pysrf/include/pysrf/segment.hpp | 32 ++------------------- python/srf/_pysrf/src/segment.cpp | 7 +++++ python/srf/core/segment.cpp | 2 ++ python/tests/test_module_registry.py | 29 +++++++++++++++---- src/public/segment/builder.cpp | 17 +++++++++++ 6 files changed, 58 insertions(+), 35 deletions(-) diff --git a/include/srf/segment/builder.hpp b/include/srf/segment/builder.hpp index 1e26bcd64..4cd9cc952 100644 --- a/include/srf/segment/builder.hpp +++ b/include/srf/segment/builder.hpp @@ -198,6 +198,12 @@ class Builder final */ void register_module_input(std::string input_name, std::shared_ptr object); + /** + * Get the json configuration for the current module under configuration. + * @return nlohmann::json object. + */ + nlohmann::json get_current_module_config(); + /** * Register an output port on the given module -- note: this in generally only necessary for dynamically * created modules that use an alternate initializer function independent of the derived class. diff --git a/python/srf/_pysrf/include/pysrf/segment.hpp b/python/srf/_pysrf/include/pysrf/segment.hpp index af2adb77e..cf589eff2 100644 --- a/python/srf/_pysrf/include/pysrf/segment.hpp +++ b/python/srf/_pysrf/include/pysrf/segment.hpp @@ -27,7 +27,6 @@ #include #include // IWYU pragma: keep -#include #include #include #include @@ -176,16 +175,6 @@ class BuilderProxy const std::string& name, std::function sub_fn); - static void make_py2cxx_edge_adapter(srf::segment::Builder& self, - std::shared_ptr source, - std::shared_ptr sink, - pybind11::object& sink_t); - - static void make_cxx2py_edge_adapter(srf::segment::Builder& self, - std::shared_ptr source, - std::shared_ptr sink, - pybind11::object& source_t); - static void make_edge(srf::segment::Builder& self, std::shared_ptr source, std::shared_ptr sink); @@ -210,26 +199,9 @@ class BuilderProxy std::string output_name, std::shared_ptr object); - static void init_module(srf::segment::Builder& self, std::shared_ptr module); - - static std::shared_ptr make_file_reader(srf::segment::Builder& self, - const std::string& name, - const std::string& filename); - - static std::shared_ptr debug_float_source(srf::segment::Builder& self, - const std::string& name, - std::size_t iterations); + static pybind11::dict get_current_module_config(srf::segment::Builder& self); - static std::shared_ptr debug_float_passthrough(srf::segment::Builder& self, - const std::string& name); - - static std::shared_ptr flatten_list(srf::segment::Builder& self, const std::string& name); - - static std::shared_ptr debug_string_passthrough(srf::segment::Builder& self, - const std::string& name); - - static std::shared_ptr debug_float_sink(srf::segment::Builder& self, - const std::string& name); + static void init_module(srf::segment::Builder& self, std::shared_ptr module); }; #pragma GCC visibility pop diff --git a/python/srf/_pysrf/src/segment.cpp b/python/srf/_pysrf/src/segment.cpp index 153890723..aec7685b2 100644 --- a/python/srf/_pysrf/src/segment.cpp +++ b/python/srf/_pysrf/src/segment.cpp @@ -319,6 +319,13 @@ void BuilderProxy::register_module_output(srf::segment::Builder& self, self.register_module_output(std::move(output_name), object); } +py::dict BuilderProxy::get_current_module_config(srf::segment::Builder& self) +{ + auto json_config = self.get_current_module_config(); + + return cast_from_json(json_config); +} + void BuilderProxy::make_edge(srf::segment::Builder& self, std::shared_ptr source, std::shared_ptr sink) diff --git a/python/srf/core/segment.cpp b/python/srf/core/segment.cpp index df349b268..34e3ede4c 100644 --- a/python/srf/core/segment.cpp +++ b/python/srf/core/segment.cpp @@ -184,6 +184,8 @@ PYBIND11_MODULE(segment, module) Builder.def( "register_module_output", &BuilderProxy::register_module_output, py::arg("output_name"), py::arg("object")); + Builder.def("get_current_module_config", &BuilderProxy::get_current_module_config); + Builder.def("make_node_full", &BuilderProxy::make_node_full, py::return_value_policy::reference_internal); /** Segment Module Interface Declarations **/ diff --git a/python/tests/test_module_registry.py b/python/tests/test_module_registry.py index 4f0fd1e91..4c5882762 100644 --- a/python/tests/test_module_registry.py +++ b/python/tests/test_module_registry.py @@ -135,18 +135,19 @@ def test_get_module_constructor(): def test_module_intitialize(): - module_name = "test_py_source_from_cpp" config = {"source_count": 42} registry = srf.ModuleRegistry def module_initializer(builder: srf.Builder): + local_config = builder.get_current_module_config() + assert ("source_count" in local_config) + assert (local_config["source_count"] == config["source_count"]) - source_mod = builder.load_module("SourceModule", "srf_unittest", "ModuleSourceTest_mod1", config) + source_mod = builder.load_module("SourceModule", "srf_unittest", "ModuleSourceTest_mod1", local_config) builder.register_module_output("source", source_mod.output_port("source")) def init_wrapper(builder: srf.Builder): - global packet_count packet_count = 0 @@ -200,6 +201,10 @@ def test_py_registered_nested_modules(): def module_initializer(builder: srf.Builder): global packet_count + local_config = builder.get_current_module_config() + assert (isinstance(local_config, type({}))) + assert (len(local_config.keys()) == 0) + def on_next(data): global packet_count packet_count += 1 @@ -246,6 +251,16 @@ def test_py_registered_nested_copied_modules(): global packet_count def module_initializer(builder: srf.Builder): + local_config = builder.get_current_module_config() + assert (isinstance(local_config, type({}))) + if ("test1" in local_config): + assert ("test2" not in local_config) + assert (local_config["test1"] == "module_1") + else: + assert ("test1" not in local_config) + assert ("test2" in local_config) + assert (local_config["test2"] == "module_2") + global packet_count def on_next(data): @@ -271,8 +286,12 @@ def on_complete(): def init_wrapper(builder: srf.Builder): global packet_count packet_count = 0 - builder.load_module("test_py_registered_nested_copied_module", "srf_unittests", "my_loaded_module!", {}) - builder.load_module("test_py_registered_nested_copied_module", "srf_unittests", "my_loaded_module_copy!", {}) + builder.load_module("test_py_registered_nested_copied_module", + "srf_unittests", + "my_loaded_module!", {"test1": "module_1"}) + builder.load_module("test_py_registered_nested_copied_module", + "srf_unittests", + "my_loaded_module_copy!", {"test2": "module_2"}) pipeline = srf.Pipeline() pipeline.make_segment("ModuleAsSource_Segment", init_wrapper) diff --git a/src/public/segment/builder.cpp b/src/public/segment/builder.cpp index 5e8e3ac9f..b49d14a6d 100644 --- a/src/public/segment/builder.cpp +++ b/src/public/segment/builder.cpp @@ -151,4 +151,21 @@ void Builder::register_module_output(std::string output_name, sp_obj_prop_t obje current_module->register_output_port(std::move(output_name), object); } +nlohmann::json Builder::get_current_module_config() +{ + if (m_module_stack.empty()) + { + std::stringstream sstream; + + sstream << "Failed to acquire module configuration -> no module context exists"; + VLOG(2) << sstream.str(); + + throw std::invalid_argument(sstream.str()); + } + + auto current_module = m_module_stack.back(); + + return current_module->config(); +} + } // namespace srf::segment