Skip to content

Commit

Permalink
use type hinting
Browse files Browse the repository at this point in the history
  • Loading branch information
ledmonster committed Dec 31, 2020
1 parent 8fb9bc7 commit 76be163
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 56 deletions.
64 changes: 19 additions & 45 deletions src/mqtt_bridge/bridge.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABCMeta, abstractmethod
from abc import ABCMeta
from typing import Optional, Type, Dict, Union

import inject
import paho.mqtt.client as mqtt
Expand All @@ -7,15 +8,10 @@
from .util import lookup_object, extract_values, populate_instance


def create_bridge(factory, msg_type, topic_from, topic_to, **kwargs):
""" bridge generator function
:param (str|class) factory: Bridge class
:param (str|class) msg_type: ROS message type
:param str topic_from: incoming topic path
:param str topic_to: outgoing topic path
:param (float|None) frequency: publish frequency
:return Bridge: bridge object
def create_bridge(factory: Union[str, "Bridge"], msg_type: Union[str, Type[rospy.Message]], topic_from: str,
topic_to: str, frequency: Optional[float] = None, **kwargs) -> "Bridge":
""" generate bridge instance using factory callable and arguments. if `factory` or `meg_type` is provided as string,
this function will convert it to a corresponding object.
"""
if isinstance(factory, str):
factory = lookup_object(factory)
Expand All @@ -28,17 +24,11 @@ def create_bridge(factory, msg_type, topic_from, topic_to, **kwargs):
"msg_type should be rospy.Message instance or its string"
"reprensentation")
return factory(
topic_from=topic_from, topic_to=topic_to, msg_type=msg_type, **kwargs)
topic_from=topic_from, topic_to=topic_to, msg_type=msg_type, frequency=frequency, **kwargs)


class Bridge(object, metaclass=ABCMeta):
""" Bridge base class
:param mqtt.Client _mqtt_client: MQTT client
:param _serialize: message serialize callable
:param _deserialize: message deserialize callable
"""

""" Bridge base class """
_mqtt_client = inject.attr(mqtt.Client)
_serialize = inject.attr('serializer')
_deserialize = inject.attr('deserializer')
Expand All @@ -48,43 +38,36 @@ class Bridge(object, metaclass=ABCMeta):
class RosToMqttBridge(Bridge):
""" Bridge from ROS topic to MQTT
:param str topic_from: incoming ROS topic path
:param str topic_to: outgoing MQTT topic path
:param class msg_type: subclass of ROS Message
:param (float|None) frequency: publish frequency
bridge ROS messages on `topic_from` to MQTT topic `topic_to`. expect `msg_type` ROS message type.
"""

def __init__(self, topic_from, topic_to, msg_type, frequency=None):
def __init__(self, topic_from: str, topic_to: str, msg_type: rospy.Message, frequency: Optional[float] = None):
self._topic_from = topic_from
self._topic_to = self._extract_private_path(topic_to)
self._last_published = rospy.get_time()
self._interval = 0 if frequency is None else 1.0 / frequency
rospy.Subscriber(topic_from, msg_type, self._callback_ros)

def _callback_ros(self, msg):
def _callback_ros(self, msg: rospy.Message):
rospy.logdebug("ROS received from {}".format(self._topic_from))
now = rospy.get_time()
if now - self._last_published >= self._interval:
self._publish(msg)
self._last_published = now

def _publish(self, msg):
def _publish(self, msg: rospy.Message):
payload = self._serialize(extract_values(msg))
self._mqtt_client.publish(topic=self._topic_to, payload=payload)


class MqttToRosBridge(Bridge):
""" Bridge from MQTT to ROS topic
:param str topic_from: incoming MQTT topic path
:param str topic_to: outgoing ROS topic path
:param class msg_type: subclass of ROS Message
:param (float|None) frequency: publish frequency
:param int queue_size: ROS publisher's queue size
bridge MQTT messages on `topic_from` to ROS topic `topic_to`. MQTT messages will be converted to `msg_type`.
"""

def __init__(self, topic_from, topic_to, msg_type, frequency=None,
queue_size=10):
def __init__(self, topic_from: str, topic_to: str, msg_type: Type[rospy.Message],
frequency: Optional[float] = None, queue_size: int = 10):
self._topic_from = self._extract_private_path(topic_from)
self._topic_to = topic_to
self._msg_type = msg_type
Expand All @@ -97,13 +80,8 @@ def __init__(self, topic_from, topic_to, msg_type, frequency=None,
self._publisher = rospy.Publisher(
self._topic_to, self._msg_type, queue_size=self._queue_size)

def _callback_mqtt(self, client, userdata, mqtt_msg):
""" callback from MQTT
:param mqtt.Client client: MQTT client used in connection
:param userdata: user defined data
:param mqtt.MQTTMessage mqtt_msg: MQTT message
"""
def _callback_mqtt(self, client: mqtt.Client, userdata: Dict, mqtt_msg: mqtt.MQTTMessage):
""" callback from MQTT """
rospy.logdebug("MQTT received from {}".format(mqtt_msg.topic))
now = rospy.get_time()

Expand All @@ -115,12 +93,8 @@ def _callback_mqtt(self, client, userdata, mqtt_msg):
except Exception as e:
rospy.logerr(e)

def _create_ros_message(self, mqtt_msg):
""" create ROS message from MQTT payload
:param mqtt.Message mqtt_msg: MQTT Message
:return rospy.Message: ROS Message
"""
def _create_ros_message(self, mqtt_msg: mqtt.MQTTMessage) -> rospy.Message:
""" create ROS message from MQTT payload """
# Hack to enable both, messagepack and json deserialization.
if self._serialize.__name__ == "packb":
msg_dict = self._deserialize(mqtt_msg.payload, raw=False)
Expand Down
12 changes: 5 additions & 7 deletions src/mqtt_bridge/mqtt_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import paho.mqtt.client as mqtt
from typing import Dict, Callable

import paho.mqtt.client as mqtt

def default_mqtt_client_factory(params):
""" MQTT Client factory

:param dict param: configuration parameters
:return mqtt.Client: MQTT Client
"""
def default_mqtt_client_factory(params: Dict) -> mqtt.Client:
""" MQTT Client factory """
# create client
client_params = params.get('client', {})
client = mqtt.Client(**client_params)
Expand Down Expand Up @@ -49,7 +47,7 @@ def default_mqtt_client_factory(params):
return client


def create_private_path_extractor(mqtt_private_path):
def create_private_path_extractor(mqtt_private_path: str) -> Callable[[str], str]:
def extractor(topic_path):
if topic_path.startswith('~/'):
return '{}/{}'.format(mqtt_private_path, topic_path[2:])
Expand Down
10 changes: 6 additions & 4 deletions src/mqtt_bridge/util.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from importlib import import_module
from typing import Any, Callable, Dict

import rospy
from rosbridge_library.internal import message_conversion


def lookup_object(object_path, package='mqtt_bridge'):
def lookup_object(object_path: str, package: str='mqtt_bridge') -> Any:
""" lookup object from a some.module:object_name specification. """
module_name, obj_name = object_path.split(":")
module = import_module(module_name, package)
obj = getattr(module, obj_name)
return obj

extract_values = message_conversion.extract_values
populate_instance = message_conversion.populate_instance

extract_values = message_conversion.extract_values # type: Callable[[rospy.Message], Dict]
populate_instance = message_conversion.populate_instance # type: Callable[[Dict, rospy.Message], rospy.Message]


__all__ = ['lookup_object', 'extract_values', 'populate_instance']

0 comments on commit 76be163

Please sign in to comment.