Skip to content

Commit

Permalink
Support TF trajectories in ROS 2 bags (#672)
Browse files Browse the repository at this point in the history
Signed-off-by: Michel Hidalgo <michel@ekumenlabs.com>
Co-authored-by: Michael Grupp <grupp@magazino.eu>
  • Loading branch information
hidmic and MichaelGrupp authored Jun 13, 2024
1 parent 700247d commit d8270a4
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[MASTER]

ignore=fastentrypoints.py,transformations.py

max-branches=13
15 changes: 5 additions & 10 deletions evo/tools/file_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,17 +276,12 @@ def read_bag_trajectory(reader: typing.Union[Rosbag1Reader,
"or rosbags.rosbags2.reader.Reader - "
"rosbag.Bag() is not supported by evo anymore")

# TODO: Support TF also with ROS2 bags.
if tf_id.check_id(topic):
if isinstance(reader, Rosbag1Reader):
# Use TfCache instead if it's a TF transform ID.
from evo.tools import tf_cache
tf_tree_cache = (tf_cache.instance(reader.__hash__())
if cache_tf_tree else tf_cache.TfCache())
return tf_tree_cache.get_trajectory(reader, identifier=topic)
else:
raise FileInterfaceException(
"TF support for ROS2 bags is not implemented")
# Use TfCache instead if it's a TF transform ID.
from evo.tools import tf_cache
tf_tree_cache = (tf_cache.instance(reader.__hash__())
if cache_tf_tree else tf_cache.TfCache())
return tf_tree_cache.get_trajectory(reader, identifier=topic)

if topic not in reader.topics:
raise FileInterfaceException("no messages for topic '" + topic +
Expand Down
174 changes: 139 additions & 35 deletions evo/tools/tf_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,21 @@
along with evo. If not, see <http://www.gnu.org/licenses/>.
"""

import dataclasses
import logging
import math
import warnings
from collections import defaultdict
from typing import DefaultDict, List, Optional
from typing import (DefaultDict, List, Optional, Protocol, Union,
runtime_checkable)

import numpy as np
import rospy
import tf2_py
from geometry_msgs.msg import TransformStamped
from rosbags.rosbag1 import Reader as Rosbag1Reader
from rosbags.rosbag2 import Reader as Rosbag2Reader
from rosbags.typesys import get_typestore, get_types_from_msg, Stores
from rosbags.typesys.store import Typestore
from std_msgs.msg import Header

from evo import EvoException
from evo.core.trajectory import PoseTrajectory3D
Expand All @@ -48,13 +50,77 @@ class TfCacheException(EvoException):
pass


@runtime_checkable
class Ros1TimeLike(Protocol): # pylint: disable=too-few-public-methods
"""
Basic ROS 1 compatible time instance protocol.
"""
def to_sec(self) -> float:
"""
Gets scalar time, in seconds.
"""


@runtime_checkable
class Ros2TimeLike(Protocol): # pylint: disable=too-few-public-methods
"""
Basic ROS 2 compatible time instance protocol.
"""
@property
def nanoseconds(self) -> int:
"""
Gets underlying scalar timestamp, in nanoseconds.
"""


@runtime_checkable
class Ros2StampLike(Protocol): # pylint: disable=too-few-public-methods
"""
Basic ROS 2 compatible message stamp protocol.
"""
sec: int
nanosec: int


@dataclasses.dataclass
class TfDuration(Ros1TimeLike, Ros2TimeLike, Ros2StampLike):
"""
A duration representation that is TF compatible in ROS 1 and ROS 2.
"""
sec: int
nanosec: int

@classmethod
def from_sec(cls, sec: float) -> "TfDuration":
"""Instantiates a duration given a number of seconds."""
frac, whole = math.modf(sec)
return cls(sec=int(whole), nanosec=int(frac * 1e9))

@property
def nanoseconds(self) -> int:
return self.sec * 1000000000 + self.nanosec

def to_sec(self) -> float:
return self.sec + self.nanosec * 1e-9


def to_sec(
timestamp: Union[Ros1TimeLike, Ros2TimeLike, Ros2StampLike]) -> float:
"""Converts any given `timestamp` to a scalar time, in seconds."""
if isinstance(timestamp, Ros1TimeLike):
return timestamp.to_sec()
if isinstance(timestamp, Ros2TimeLike):
return timestamp.nanoseconds * 1e-9
return timestamp.sec + timestamp.nanosec * 1e-9


class TfCache(object):
"""
For caching TF messages and looking up trajectories of specific transforms.
"""
def __init__(self):
self.buffer = tf2_py.BufferCore(
rospy.Duration.from_sec(SETTINGS.tf_cache_max_time))
TfDuration.from_sec(SETTINGS.tf_cache_max_time))
self.topics = []
self.bags = []

Expand All @@ -68,18 +134,22 @@ def clear(self) -> None:
# update the ROS1 typestore with the interface definition from the bag.
# https://ternaris.gitlab.io/rosbags/examples/register_types.html
@staticmethod
def _setup_typestore(reader: Rosbag1Reader) -> Typestore:
typestore = get_typestore(Stores.ROS1_NOETIC)
for connection in reader.connections:
if connection.msgtype == SUPPORTED_TF_MSG:
typestore.register(
get_types_from_msg(connection.msgdef, connection.msgtype))
break
def _setup_typestore(
reader: Union[Rosbag1Reader, Rosbag2Reader]) -> Typestore:
if isinstance(reader, Rosbag1Reader):
typestore = get_typestore(Stores.ROS1_NOETIC)
for connection in reader.connections:
if connection.msgtype == SUPPORTED_TF_MSG:
typestore.register(
get_types_from_msg(connection.msgdef,
connection.msgtype))
break
else:
typestore = get_typestore(Stores.LATEST)
return typestore

# TODO: support also ROS2 bag reader.
def from_bag(self, reader: Rosbag1Reader, topic: str = "/tf",
static_topic: str = "/tf_static") -> None:
def from_bag(self, reader: Union[Rosbag1Reader, Rosbag2Reader],
topic: str = "/tf", static_topic: str = "/tf_static") -> None:
"""
Loads the TF topics from a bagfile into the buffer,
if it's not already cached.
Expand Down Expand Up @@ -113,24 +183,42 @@ def from_bag(self, reader: Rosbag1Reader, topic: str = "/tf",
raise TfCacheException(
f"Expected {SUPPORTED_TF_MSG} message type for topic "
f"{tf_topic}, got: {connection.msgtype}")
msg = typestore.deserialize_ros1(rawdata, connection.msgtype)
if isinstance(reader, Rosbag1Reader):
msg = typestore.deserialize_ros1(rawdata,
connection.msgtype)
else:
msg = typestore.deserialize_cdr(rawdata,
connection.msgtype)
for tf in msg.transforms: # type: ignore
# Convert from rosbags.typesys.types to native ROS.
# Related: https://gitlab.com/ternaris/rosbags/-/issues/13
stamp = rospy.Time()
stamp.secs = tf.header.stamp.sec
stamp.nsecs = tf.header.stamp.nanosec
tf = TransformStamped(Header(0, stamp, tf.header.frame_id),
tf.child_frame_id, tf.transform)
native_msg = TransformStamped()
if hasattr(native_msg.header.stamp, "nsecs"):
native_msg.header.stamp.secs = tf.header.stamp.sec
native_msg.header.stamp.nsecs = tf.header.stamp.nanosec
else:
native_msg.header.stamp.sec = tf.header.stamp.sec
native_msg.header.stamp.nanosec = tf.header.stamp.nanosec
native_msg.header.frame_id = tf.header.frame_id
native_msg.child_frame_id = tf.child_frame_id
native_msg.transform.translation.x = tf.transform.translation.x
native_msg.transform.translation.y = tf.transform.translation.y
native_msg.transform.translation.z = tf.transform.translation.z
native_msg.transform.rotation.x = tf.transform.rotation.x
native_msg.transform.rotation.y = tf.transform.rotation.y
native_msg.transform.rotation.z = tf.transform.rotation.z
native_msg.transform.rotation.w = tf.transform.rotation.w
if tf_topic == static_topic:
self.buffer.set_transform_static(tf, __name__)
self.buffer.set_transform_static(native_msg, __name__)
else:
self.buffer.set_transform(tf, __name__)
self.buffer.set_transform(native_msg, __name__)
self.topics.append(tf_topic)
self.bags.append(reader.path.name)

def lookup_trajectory(self, parent_frame: str, child_frame: str,
timestamps: List[rospy.Time]) -> PoseTrajectory3D:
def lookup_trajectory(
self, parent_frame: str, child_frame: str,
timestamps: Union[List[Ros1TimeLike], List[Ros2TimeLike]]
) -> PoseTrajectory3D:
"""
Look up the trajectory of a transform chain from the cache's TF buffer.
:param parent_frame, child_frame: TF transform frame IDs
Expand All @@ -147,7 +235,7 @@ def lookup_trajectory(self, parent_frame: str, child_frame: str,
child_frame, timestamp)
except tf2_py.ExtrapolationException:
continue
stamps.append(tf.header.stamp.to_sec())
stamps.append(to_sec(tf.header.stamp))
x, q = _get_xyz_quat_from_transform_stamped(tf)
xyz.append(x)
quat.append(q)
Expand All @@ -160,8 +248,10 @@ def lookup_trajectory(self, parent_frame: str, child_frame: str,
return trajectory

def get_trajectory(
self, reader: Rosbag1Reader, identifier: str,
timestamps: Optional[List[rospy.Time]] = None) -> PoseTrajectory3D:
self, reader: Union[Rosbag1Reader, Rosbag2Reader], identifier: str,
timestamps: Optional[Union[List[Ros1TimeLike],
List[Ros2TimeLike]]] = None
) -> PoseTrajectory3D:
"""
Get a TF trajectory from a bag file. Updates or uses the cache.
:param reader: opened bag reader (rosbags.rosbag1)
Expand All @@ -176,22 +266,36 @@ def get_trajectory(
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
self.from_bag(reader, topic, static_topic)
try:
latest_time = self.buffer.get_latest_common_time(parent, child)
except (tf2_py.LookupException, tf2_py.TransformException) as e:
raise TfCacheException("Could not load trajectory: " + str(e))
# rosbags Reader start_time is in nanoseconds.
start_time = rospy.Time.from_sec(reader.start_time * 1e-9)
if timestamps is None:
timestamps = []

try:
latest_time = self.buffer.get_latest_common_time(parent, child)
except (tf2_py.LookupException, tf2_py.TransformException) as e:
raise TfCacheException("Could not load trajectory: " + str(e))

if hasattr(latest_time, "nsecs"):
from rospy import Time, Duration # pylint: disable=import-outside-toplevel

# rosbags Reader start_time is in nanoseconds.
start_time = Time.from_sec(reader.start_time * 1e-9)
step = Duration.from_sec(1. /
SETTINGS.tf_cache_lookup_frequency)
else:
from rclpy.time import Time # pylint: disable=import-outside-toplevel
from rclpy.duration import Duration # pylint: disable=import-outside-toplevel

# rosbags Reader start_time is in nanoseconds.
start_time = Time(nanoseconds=reader.start_time)
step = Duration(seconds=1. /
SETTINGS.tf_cache_lookup_frequency)

# Static TF have zero timestamp in the buffer, which will be lower
# than the bag start time. Looking up a static TF is a valid request,
# so this should be possible.
if latest_time < start_time:
timestamps.append(latest_time)
else:
step = rospy.Duration.from_sec(
1. / SETTINGS.tf_cache_lookup_frequency)
time = start_time
while time <= latest_time:
timestamps.append(time)
Expand Down
39 changes: 32 additions & 7 deletions evo/tools/tf_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,46 @@

from evo import EvoException

ROS_NAME_REGEX = re.compile(r"([\/|_|0-9|a-z|A-Z]+)")
ROS_NAME_REGEX = re.compile(r"[\/|a-z|A-Z][\/|_|0-9|a-z|A-Z]+")


class TfIdException(EvoException):
pass


def split_id(identifier: str) -> tuple:
match = ROS_NAME_REGEX.findall(identifier)
# If a fourth component exists, it's interpreted as the static TF name.
if not len(match) in (3, 4):
tf_topic, _, identifier = identifier.partition(":")
if ":" in identifier:
identifier, _, tf_static_topic = identifier.rpartition(":")
else:
tf_static_topic = None
parent_frame_id, _, child_frame_id = identifier.partition(".")

if ROS_NAME_REGEX.match(tf_topic) is None:
raise TfIdException(
f"ID string malformed, {tf_topic} is not a valid topic name, "
"ID string should look like /tf:map.base_footprint(:/tf_static)")

if not parent_frame_id:
raise TfIdException(
"ID string malformed, parent frame ID is missing, ID string "
"should look like /tf:map.base_footprint(:/tf_static)")

if not child_frame_id:
raise TfIdException(
"ID string malformed, it should look similar to this: "
"/tf:map.base_footprint")
return tuple(match)
"ID string malformed, child frame ID is missing, ID string "
"should look like /tf:map.base_footprint(:/tf_static)")

if tf_static_topic:
if ROS_NAME_REGEX.match(tf_static_topic) is None:
raise TfIdException(
f"ID string malformed, {tf_static_topic} is not a valid topic name, "
"ID string should look like /tf:map.base_footprint(:/tf_static)"
)

return (tf_topic, parent_frame_id, child_frame_id, tf_static_topic)

return (tf_topic, parent_frame_id, child_frame_id)


def check_id(identifier: str) -> bool:
Expand Down

0 comments on commit d8270a4

Please sign in to comment.