Skip to content

Commit

Permalink
Merge pull request #32 from agonzc34/main
Browse files Browse the repository at this point in the history
Lifecycle nodes
  • Loading branch information
mgonzs13 authored Apr 9, 2024
2 parents 5061786 + ef27143 commit 22cfb30
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 73 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ $ ros2 launch yolov8_bringup yolov8_3d.launch.py
- **target_frame**: frame to transform the 3D boxes (default: base_link)
- **maximum_detection_threshold**: maximum detection threshold in the z axis (default: 0.3)

## Lifecyle nodes

Previous updates add Lifecycle Nodes support to all the nodes available in the package.
This implementation tries to reduce the workload in the unconfigured and inactive states by only loading the models and activating the subscriber on the active state.

These are some resource comparisons using the 'yolov8m.pt' model on a 30fps video stream.

| State | CPU Usage (i7 12th Gen) | VRAM Usage | Bandwidth Usage |
|-------------|--------------------------|-------------|-----------------|
| Active | 40-50% in one core | 628 MB | Up to 200 Mbps |
| Inactive | ~5-7% in one core | 338 MB | 0-20 Kbps |

## Demos

## Object Detection
Expand Down
41 changes: 35 additions & 6 deletions yolov8_ros/yolov8_ros/debug_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from rclpy.qos import QoSHistoryPolicy
from rclpy.qos import QoSDurabilityPolicy
from rclpy.qos import QoSReliabilityPolicy
from rclpy.lifecycle import LifecycleNode
from rclpy.lifecycle import TransitionCallbackReturn
from rclpy.lifecycle import LifecycleState

import message_filters
from cv_bridge import CvBridge
Expand All @@ -41,7 +44,7 @@
from yolov8_msgs.msg import DetectionArray


class DebugNode(Node):
class DebugNode(LifecycleNode):

def __init__(self) -> None:
super().__init__("debug_node")
Expand All @@ -52,7 +55,11 @@ def __init__(self) -> None:
# params
self.declare_parameter("image_reliability",
QoSReliabilityPolicy.BEST_EFFORT)
image_qos_profile = QoSProfile(

self.get_logger().info("Debug node created")

def on_configure(self, state: LifecycleState) -> TransitionCallbackReturn:
self.image_qos_profile = QoSProfile(
reliability=self.get_parameter(
"image_reliability").get_parameter_value().integer_value,
history=QoSHistoryPolicy.KEEP_LAST,
Expand All @@ -67,16 +74,36 @@ def __init__(self) -> None:
self._kp_markers_pub = self.create_publisher(
MarkerArray, "dgb_kp_markers", 10)

return TransitionCallbackReturn.SUCCESS

def on_activate(self, state: LifecycleState) -> TransitionCallbackReturn:
# subs
image_sub = message_filters.Subscriber(
self, Image, "image_raw", qos_profile=image_qos_profile)
detections_sub = message_filters.Subscriber(
self.image_sub = message_filters.Subscriber(
self, Image, "image_raw", qos_profile=self.image_qos_profile)
self.detections_sub = message_filters.Subscriber(
self, DetectionArray, "detections", qos_profile=10)

self._synchronizer = message_filters.ApproximateTimeSynchronizer(
(image_sub, detections_sub), 10, 0.5)
(self.image_sub, self.detections_sub), 10, 0.5)
self._synchronizer.registerCallback(self.detections_cb)

return TransitionCallbackReturn.SUCCESS

def on_deactivate(self, state: LifecycleState) -> TransitionCallbackReturn:
self.destroy_subscription(self.image_sub.sub)
self.destroy_subscription(self.detections_sub.sub)

del self._synchronizer

return TransitionCallbackReturn.SUCCESS

def on_cleanup(self, state: LifecycleState) -> TransitionCallbackReturn:
self.destroy_publisher(self._dbg_pub)
self.destroy_publisher(self._bb_markers_pub)
self.destroy_publisher(self._kp_markers_pub)

return TransitionCallbackReturn.SUCCESS

def draw_box(self, cv_image: np.array, detection: Detection, color: Tuple[int]) -> np.array:

# get detection info
Expand Down Expand Up @@ -259,6 +286,8 @@ def detections_cb(self, img_msg: Image, detection_msg: DetectionArray) -> None:
def main():
rclpy.init()
node = DebugNode()
node.trigger_configure()
node.trigger_activate()
rclpy.spin(node)
node.destroy_node()
rclpy.shutdown()
74 changes: 52 additions & 22 deletions yolov8_ros/yolov8_ros/detect_3d_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from rclpy.qos import QoSHistoryPolicy
from rclpy.qos import QoSDurabilityPolicy
from rclpy.qos import QoSReliabilityPolicy
from rclpy.lifecycle import LifecycleNode
from rclpy.lifecycle import TransitionCallbackReturn
from rclpy.lifecycle import LifecycleState

import message_filters
from cv_bridge import CvBridge
Expand All @@ -39,66 +42,91 @@
from yolov8_msgs.msg import BoundingBox3D


class Detect3DNode(Node):
class Detect3DNode(LifecycleNode):

def __init__(self) -> None:
super().__init__("bbox3d_node")

# parameters
self.declare_parameter("target_frame", "base_link")
self.declare_parameter("maximum_detection_threshold", 0.3)
self.declare_parameter("depth_image_units_divisor", 1000)
self.declare_parameter("depth_image_reliability",
QoSReliabilityPolicy.BEST_EFFORT)
self.declare_parameter("depth_info_reliability",
QoSReliabilityPolicy.BEST_EFFORT)

# aux
self.tf_buffer = Buffer()
self.cv_bridge = CvBridge()


def on_configure(self, state: LifecycleState) -> TransitionCallbackReturn:
self.target_frame = self.get_parameter(
"target_frame").get_parameter_value().string_value

self.declare_parameter("maximum_detection_threshold", 0.3)
self.maximum_detection_threshold = self.get_parameter(
"maximum_detection_threshold").get_parameter_value().double_value

self.declare_parameter("depth_image_units_divisor", 1000)
self.depth_image_units_divisor = self.get_parameter(
"depth_image_units_divisor").get_parameter_value().integer_value

self.declare_parameter("depth_image_reliability",
QoSReliabilityPolicy.BEST_EFFORT)
depth_image_qos_profile = QoSProfile(
reliability=self.get_parameter(
"depth_image_reliability").get_parameter_value().integer_value,
dimg_reliability=self.get_parameter(
"depth_image_reliability").get_parameter_value().integer_value

self.depth_image_qos_profile = QoSProfile(
dimg_reliability,
history=QoSHistoryPolicy.KEEP_LAST,
durability=QoSDurabilityPolicy.VOLATILE,
depth=1
)

self.declare_parameter("depth_info_reliability",
QoSReliabilityPolicy.BEST_EFFORT)
depth_info_qos_profile = QoSProfile(
reliability=self.get_parameter(
"depth_info_reliability").get_parameter_value().integer_value,
dinfo_reliability=self.get_parameter(
"depth_info_reliability").get_parameter_value().integer_value

self.depth_info_qos_profile = QoSProfile(
dinfo_reliability,
history=QoSHistoryPolicy.KEEP_LAST,
durability=QoSDurabilityPolicy.VOLATILE,
depth=1
)

# aux
self.tf_buffer = Buffer()
self.tf_listener = TransformListener(self.tf_buffer, self)
self.cv_bridge = CvBridge()

# pubs
self._pub = self.create_publisher(DetectionArray, "detections_3d", 10)

return TransitionCallbackReturn.SUCCESS

def on_activate(self, state: LifecycleState) -> TransitionCallbackReturn:
# subs
self.depth_sub = message_filters.Subscriber(
self, Image, "depth_image",
qos_profile=depth_image_qos_profile)
qos_profile=self.depth_image_qos_profile)
self.depth_info_sub = message_filters.Subscriber(
self, CameraInfo, "depth_info",
qos_profile=depth_info_qos_profile)
qos_profile=self.depth_info_qos_profile)
self.detections_sub = message_filters.Subscriber(
self, DetectionArray, "detections")

self._synchronizer = message_filters.ApproximateTimeSynchronizer(
(self.depth_sub, self.depth_info_sub, self.detections_sub), 10, 0.5)
self._synchronizer.registerCallback(self.on_detections)

return TransitionCallbackReturn.SUCCESS

def on_deactivate(self, state: LifecycleState) -> TransitionCallbackReturn:
self.destroy_subscription(self.depth_sub.sub)
self.destroy_subscription(self.depth_info_sub.sub)
self.destroy_subscription(self.detections_sub.sub)

del self._synchronizer

return TransitionCallbackReturn.SUCCESS

def on_cleanup(self, state: LifecycleState) -> TransitionCallbackReturn:
del self.tf_listener

self.destroy_publisher(self._pub)

return TransitionCallbackReturn.SUCCESS

def on_detections(
self,
depth_msg: Image,
Expand Down Expand Up @@ -340,6 +368,8 @@ def qv_mult(q: np.ndarray, v: np.ndarray) -> np.ndarray:
def main():
rclpy.init()
node = Detect3DNode()
node.trigger_configure()
node.trigger_activate()
rclpy.spin(node)
node.destroy_node()
rclpy.shutdown()
55 changes: 42 additions & 13 deletions yolov8_ros/yolov8_ros/tracking_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from rclpy.qos import QoSHistoryPolicy
from rclpy.qos import QoSDurabilityPolicy
from rclpy.qos import QoSReliabilityPolicy
from rclpy.lifecycle import LifecycleNode
from rclpy.lifecycle import TransitionCallbackReturn
from rclpy.lifecycle import LifecycleState

import message_filters
from cv_bridge import CvBridge
Expand All @@ -37,32 +40,40 @@
from yolov8_msgs.msg import DetectionArray


class TrackingNode(Node):
class TrackingNode(LifecycleNode):

def __init__(self) -> None:
super().__init__("tracking_node")

# params
self.declare_parameter("tracker", "bytetrack.yaml")
tracker = self.get_parameter(
"tracker").get_parameter_value().string_value

self.declare_parameter("image_reliability",
QoSReliabilityPolicy.BEST_EFFORT)
image_qos_profile = QoSProfile(
reliability=self.get_parameter(
"image_reliability").get_parameter_value().integer_value,
history=QoSHistoryPolicy.KEEP_LAST,
durability=QoSDurabilityPolicy.VOLATILE,
depth=1
)

self.cv_bridge = CvBridge()
self.tracker = self.create_tracker(tracker)

# pubs

def on_configure(self, state: LifecycleState) -> TransitionCallbackReturn:
tracker_name = self.get_parameter(
"tracker").get_parameter_value().string_value

self.image_reliability = self.get_parameter(
"image_reliability").get_parameter_value().integer_value

self.tracker = self.create_tracker(tracker_name)
self._pub = self.create_publisher(DetectionArray, "tracking", 10)

return TransitionCallbackReturn.SUCCESS


def on_activate(self, state: LifecycleState) -> TransitionCallbackReturn:
image_qos_profile = QoSProfile(
reliability=self.image_reliability,
history=QoSHistoryPolicy.KEEP_LAST,
durability=QoSDurabilityPolicy.VOLATILE,
depth=1
)

# subs
image_sub = message_filters.Subscriber(
self, Image, "image_raw", qos_profile=image_qos_profile)
Expand All @@ -72,6 +83,22 @@ def __init__(self) -> None:
self._synchronizer = message_filters.ApproximateTimeSynchronizer(
(image_sub, detections_sub), 10, 0.5)
self._synchronizer.registerCallback(self.detections_cb)

return TransitionCallbackReturn.SUCCESS

def on_deactivate(self, state: LifecycleState) -> TransitionCallbackReturn:
self.destroy_subscription(self.image_sub.sub)
self.destroy_subscription(self.detections_sub.sub)

del self._synchronizer
self._synchronizer = None

return TransitionCallbackReturn.SUCCESS

def on_cleanup(self, state: LifecycleState) -> TransitionCallbackReturn:
del self.tracker

return TransitionCallbackReturn.SUCCESS

def create_tracker(self, tracker_yaml: str) -> BaseTrack:

Expand Down Expand Up @@ -153,6 +180,8 @@ def detections_cb(self, img_msg: Image, detections_msg: DetectionArray) -> None:
def main():
rclpy.init()
node = TrackingNode()
node.trigger_configure()
node.trigger_activate()
rclpy.spin(node)
node.destroy_node()
rclpy.shutdown()
Loading

0 comments on commit 22cfb30

Please sign in to comment.