diff --git a/launch_testing_ros/launch_testing_ros/wait_for_topics.py b/launch_testing_ros/launch_testing_ros/wait_for_topics.py index 8e99423d..748ab6cf 100644 --- a/launch_testing_ros/launch_testing_ros/wait_for_topics.py +++ b/launch_testing_ros/launch_testing_ros/wait_for_topics.py @@ -19,6 +19,8 @@ from threading import Thread import rclpy +from rclpy.event_handler import QoSSubscriptionMatchedInfo +from rclpy.event_handler import SubscriptionEventCallbacks from rclpy.executors import SingleThreadedExecutor from rclpy.node import Node @@ -50,12 +52,28 @@ def method_2(): print(wait_for_topics.topics_received()) # Should be {'topic_1', 'topic_2'} print(wait_for_topics.messages_received('topic_1')) # Should be [message_1, ...] wait_for_topics.shutdown() + + # Method3, using a callback + def callback_function(arg): + print(f'Callback function called with argument: {arg}') + + def method_3(): + topic_list = [('topic_1', String), ('topic_2', String)] + with WaitForTopics(topic_list, callback=callback_function, callback_arguments="Hello"): + print('Given topics are receiving messages !') """ - def __init__(self, topic_tuples, timeout=5.0, messages_received_buffer_length=10): + def __init__(self, topic_tuples, timeout=5.0, messages_received_buffer_length=10, + callback=None, callback_arguments=None): self.topic_tuples = topic_tuples self.timeout = timeout self.messages_received_buffer_length = messages_received_buffer_length + self.callback = callback + if self.callback is not None and not callable(self.callback): + raise TypeError('The passed callback is not callable') + self.callback_arguments = ( + callback_arguments if callback_arguments is not None else [] + ) self.__ros_context = rclpy.Context() rclpy.init(context=self.__ros_context) self.__ros_executor = SingleThreadedExecutor(context=self.__ros_context) @@ -85,6 +103,14 @@ def _prepare_ros_node(self): def wait(self): self.__ros_node.start_subscribers(self.topic_tuples) + if self.callback: + if isinstance(self.callback_arguments, dict): + self.callback(**self.callback_arguments) + elif isinstance(self.callback_arguments, (list, set, tuple)): + self.callback(*self.callback_arguments) + else: + self.callback(self.callback_arguments) + self.__ros_node._any_publisher_connected.wait() return self.__ros_node.msg_event_object.wait(self.timeout) def shutdown(self): @@ -131,6 +157,13 @@ def __init__( self.expected_topics = set() self.received_topics = set() self.received_messages_buffer = {} + self._any_publisher_connected = Event() + + def _sub_matched_event_callback(self, info: QoSSubscriptionMatchedInfo): + if info.current_count != 0: + self._any_publisher_connected.set() + else: + self._any_publisher_connected.clear() def _reset(self): self.msg_event_object.clear() @@ -149,12 +182,16 @@ def start_subscribers(self, topic_tuples): maxlen=self.messages_received_buffer_length ) # Create a subscriber + sub_event_callback = SubscriptionEventCallbacks( + matched=self._sub_matched_event_callback + ) self.subscriber_list.append( self.create_subscription( topic_type, topic_name, self.callback_template(topic_name), - 10 + 10, + event_callbacks=sub_event_callback, ) ) diff --git a/launch_testing_ros/test/examples/repeater.py b/launch_testing_ros/test/examples/repeater.py new file mode 100644 index 00000000..10009b1e --- /dev/null +++ b/launch_testing_ros/test/examples/repeater.py @@ -0,0 +1,53 @@ +# Copyright 2019 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import rclpy +from rclpy.node import Node + +from std_msgs.msg import String + + +class Repeater(Node): + + def __init__(self): + super().__init__('repeater') + self.count = 0 + self.subscription = self.create_subscription( + String, 'input', self.callback, 10 + ) + self.publisher = self.create_publisher(String, 'output', 10) + + def callback(self, input_msg): + self.get_logger().info('I heard: [%s]' % input_msg.data) + output_msg_data = input_msg.data + self.get_logger().info('Publishing: "{0}"'.format(output_msg_data)) + self.publisher.publish(String(data=output_msg_data)) + + +def main(args=None): + rclpy.init(args=args) + + node = Repeater() + + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() + + +if __name__ == '__main__': + main() diff --git a/launch_testing_ros/test/examples/wait_for_topic_inject_callback_test.py b/launch_testing_ros/test/examples/wait_for_topic_inject_callback_test.py new file mode 100644 index 00000000..08006b88 --- /dev/null +++ b/launch_testing_ros/test/examples/wait_for_topic_inject_callback_test.py @@ -0,0 +1,78 @@ +# Copyright 2021 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import unittest + +import launch +import launch.actions +import launch_ros.actions +import launch_testing.actions +import launch_testing.markers +from launch_testing_ros import WaitForTopics +import pytest +import rclpy +from std_msgs.msg import String + + +def generate_node(): + """Return node and remap the topic based on the index provided.""" + path_to_test = os.path.dirname(__file__) + return launch_ros.actions.Node( + executable=sys.executable, + arguments=[os.path.join(path_to_test, 'repeater.py')], + name='demo_node', + additional_env={'PYTHONUNBUFFERED': '1'}, + ) + + +def trigger_callback(): + rclpy.init() + node = rclpy.create_node('trigger') + publisher = node.create_publisher(String, 'input', 10) + while publisher.get_subscription_count() == 0: + rclpy.spin_once(node, timeout_sec=0.1) + msg = String() + msg.data = 'Hello World' + publisher.publish(msg) + print('Published message') + node.destroy_node() + rclpy.shutdown() + + +@pytest.mark.launch_test +@launch_testing.markers.keep_alive +def generate_test_description(): + description = [generate_node(), launch_testing.actions.ReadyToTest()] + return launch.LaunchDescription(description) + + +# TODO: Test cases fail on Windows debug builds +# https://github.com/ros2/launch_ros/issues/292 +if os.name != 'nt': + + class TestFixture(unittest.TestCase): + + def test_topics_successful(self): + """All the supplied topics should be read successfully.""" + topic_list = [('output', String)] + expected_topics = {'output'} + + # Method 1 : Using the magic methods and 'with' keyword + with WaitForTopics( + topic_list, timeout=10.0, callback=trigger_callback + ) as wait_for_node_object_1: + assert wait_for_node_object_1.topics_received() == expected_topics + assert wait_for_node_object_1.topics_not_received() == set() diff --git a/launch_testing_ros/test/examples/wait_for_topic_launch_test.py b/launch_testing_ros/test/examples/wait_for_topic_launch_test.py index c32ca361..ac445d5b 100644 --- a/launch_testing_ros/test/examples/wait_for_topic_launch_test.py +++ b/launch_testing_ros/test/examples/wait_for_topic_launch_test.py @@ -104,3 +104,20 @@ def test_topics_unsuccessful(self, count: int): assert wait_for_node_object.topics_received() == expected_topics assert wait_for_node_object.topics_not_received() == {'invalid_topic'} wait_for_node_object.shutdown() + + def test_callback(self, count): + topic_list = [('chatter_' + str(i), String) for i in range(count)] + expected_topics = {'chatter_' + str(i) for i in range(count)} + + # Method 1 : Using the magic methods and 'with' keyword + + is_callback_called = [[False]] + + def callback(arg): + arg[0] = True + + with WaitForTopics(topic_list, timeout=2.0, callback=callback, + callback_arguments=is_callback_called) as wait_for_node_object_1: + assert wait_for_node_object_1.topics_received() == expected_topics + assert wait_for_node_object_1.topics_not_received() == set() + assert is_callback_called[0]