Skip to content

Commit

Permalink
Implement callback number 1
Browse files Browse the repository at this point in the history
Signed-off-by: Simone Orru <simone.orru@secomind.com>
  • Loading branch information
sorru94 committed Sep 7, 2023
1 parent b67a4fd commit 3e67ed9
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 15 deletions.
11 changes: 8 additions & 3 deletions astarte/device/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,18 @@ class Device(ABC):
"""

@abstractmethod
def __init__(self, *args):
def __init__(self, loop):
"""
Parameters
----------
args :
TODO.
loop : asyncio.loop (optional)
An optional loop which will be used for invoking callbacks. When this is not none,
device will call any specified callback through loop.call_soon_threadsafe, ensuring
that the callbacks will be run in thread the loop belongs to. Usually, you want
to set this to get_running_loop(). When not sent, callbacks will be invoked as a
standard function - keep in mind this means your callbacks might create deadlocks.
"""
self._loop = loop
self._introspection = Introspection()

@abstractmethod
Expand Down
113 changes: 110 additions & 3 deletions astarte/device/device_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from collections.abc import Callable
import threading
import grpc
import logging
import asyncio

# pylint: disable=no-name-in-module
from google.protobuf.timestamp_pb2 import Timestamp
Expand All @@ -45,6 +47,7 @@ def __init__(
self,
server_addr: str,
node_uuid: str,
loop: asyncio.AbstractEventLoop | None = None,
):
"""
Parameters
Expand All @@ -54,7 +57,7 @@ def __init__(
node_uuid : str
Unique identifier for this node.
"""
super().__init__()
super().__init__(loop)

self.server_addr = server_addr
self.node_uuid = node_uuid
Expand Down Expand Up @@ -174,8 +177,112 @@ def _rx_stream_handler(self, stream): # pylint: disable=no-self-use
The GRPC receive stream.
"""
for event in stream:
print(event)

interface_name = event.interface_name
path = event.path
payload = None
timestamp = None
if event.HasField("astarte_data"):
if event.astarte_data.HasField("astarte_individual"):
if event.astarte_data.astarte_individual.HasField('astarte_double'):
payload = event.astarte_data.astarte_individual.astarte_double
elif event.astarte_data.astarte_individual.HasField('astarte_double_array'):
payload = event.astarte_data.astarte_individual.astarte_double_array.values
payload = [e for e in payload]
elif event.astarte_data.astarte_individual.HasField('astarte_integer'):
payload = event.astarte_data.astarte_individual.astarte_integer
elif event.astarte_data.astarte_individual.HasField('astarte_integer_array'):
payload = event.astarte_data.astarte_individual.astarte_integer_array.values
payload = [e for e in payload]
elif event.astarte_data.astarte_individual.HasField('astarte_boolean'):
payload = event.astarte_data.astarte_individual.astarte_boolean
elif event.astarte_data.astarte_individual.HasField('astarte_boolean_array'):
payload = event.astarte_data.astarte_individual.astarte_boolean_array.values
payload = [e for e in payload]
elif event.astarte_data.astarte_individual.HasField('astarte_long_integer'):
payload = event.astarte_data.astarte_individual.astarte_long_integer
elif event.astarte_data.astarte_individual.HasField('astarte_long_integer_array'):
payload = event.astarte_data.astarte_individual.astarte_long_integer_array.values
payload = [e for e in payload]
elif event.astarte_data.astarte_individual.HasField('astarte_string'):
payload = event.astarte_data.astarte_individual.astarte_string
elif event.astarte_data.astarte_individual.HasField('astarte_string_array'):
payload = event.astarte_data.astarte_individual.astarte_string_array.values
payload = [e for e in payload]
elif event.astarte_data.astarte_individual.HasField('astarte_binary_blob'):
payload = event.astarte_data.astarte_individual.astarte_binary_blob
elif event.astarte_data.astarte_individual.HasField('astarte_binary_blob_array'):
payload = event.astarte_data.astarte_individual.astarte_binary_blob_array.values
payload = [e for e in payload]
elif event.astarte_data.astarte_individual.HasField('astarte_date_time'):
payload = event.astarte_data.astarte_individual.astarte_date_time.ToDatetime()
elif event.astarte_data.astarte_individual.HasField('astarte_date_time_array'):
payload = event.astarte_data.astarte_individual.astarte_date_time_array.values
payload = [e for e in payload]
else:
# Handle event.astarte_data.astarte_object
pass
else:
# Handle event.astarte_unset
pass
if event.HasField("timestamp"):
# Handle event.timestamp
pass
self._on_message_checks(interface_name, path, payload)

def _on_message_checks(self, interface_name, path, payload):
# Check if interface name is correct
interface = self._introspection.get_interface(interface_name)
if not interface:
logging.warning(
"Received unexpected message for unregistered interface %s: %s, %s",
interface_name,
path,
payload,
)
return

# Check over ownership of the interface
if not interface.is_server_owned():
logging.warning(
"Received unexpected message for device owned interface %s: %s, %s",
interface_name,
path,
payload,
)
return

# Check the received path corresponds to the one in the interface
if interface.validate_path(path, payload):
logging.warning(
"Received message on incorrect endpoint for interface %s: %s, %s",
interface_name,
path,
payload,
)
return

# Check the payload matches with the interface
if payload:
if interface.validate_payload(path, payload):
logging.warning(
"Received incompatible payload for interface %s: %s, %s",
interface_name,
path,
payload,
)
return

if self._loop:
# Use threadsafe, as we're in a different thread here
self._loop.call_soon_threadsafe(
self.on_data_received,
self,
interface_name,
path,
payload,
)
else:
self.on_data_received(self, interface_name, path, payload)

def _parse_individual_payload(
interface: Interface, path: str, payload: object | collections.abc.Mapping | None
Expand Down
15 changes: 7 additions & 8 deletions astarte/device/device_mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(
PersistencyDirectoryNotFoundError
If the provided persistency directory does not exists.
"""
super().__init__()
super().__init__(loop)

if not os.path.isdir(persistency_dir):
raise PersistencyDirectoryNotFoundError(f"{persistency_dir} is not a directory")
Expand Down Expand Up @@ -157,7 +157,6 @@ def __init__(
# self.__jwt_token: str | None = None
self.__is_crypto_setup = False
self.__is_connected = False
self.__loop = loop
self.__ignore_ssl_errors = ignore_ssl_errors

self.on_connected: Callable[DeviceMqtt, None] | None = None
Expand Down Expand Up @@ -402,9 +401,9 @@ def __on_connect(self, _client, _userdata, flags: dict, rc):
self.__send_device_owned_properties()

if self.on_connected:
if self.__loop:
if self._loop:
# Use threadsafe, as we're in a different thread here
self.__loop.call_soon_threadsafe(self.on_connected, self)
self._loop.call_soon_threadsafe(self.on_connected, self)
else:
self.on_connected(self)

Expand All @@ -428,9 +427,9 @@ def __on_disconnect(self, _client, _userdata, rc):
self.__is_connected = False

if self.on_disconnected:
if self.__loop:
if self._loop:
# Use threadsafe, as we're in a different thread here
self.__loop.call_soon_threadsafe(self.on_disconnected, self, rc)
self._loop.call_soon_threadsafe(self.on_disconnected, self, rc)
else:
self.on_disconnected(self, rc)

Expand Down Expand Up @@ -547,9 +546,9 @@ def __on_message(self, _client, _userdata, msg):
interface.name, interface.version_major, interface_path, data_payload
)

if self.__loop:
if self._loop:
# Use threadsafe, as we're in a different thread here
self.__loop.call_soon_threadsafe(
self._loop.call_soon_threadsafe(
self.on_data_received,
self,
interface_name,
Expand Down
16 changes: 15 additions & 1 deletion examples/grpc/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,24 @@
from datetime import datetime, timezone
from pathlib import Path
import time
from termcolor import cprint

from astarte.device import DeviceGrpc

_ROOT_DIR = Path(__file__).parent.absolute()
_INTERFACES_DIR = _ROOT_DIR.joinpath("interfaces")

def on_data_received_cbk(device: DeviceGrpc, interface_name: str, path: str, payload: dict):
"""
Callback for a data reception event.
"""
cprint(
f"Received message for interface: {interface_name} and path: {path}.",
color="cyan",
flush=True,
)
cprint(f" Payload: {payload}", color="cyan", flush=True)

# If called as a script
if __name__ == "__main__":
# Instantiate the device
Expand All @@ -38,6 +50,8 @@
)
# Load all the interfaces
device.add_interfaces_from_dir(_INTERFACES_DIR)
# Set all the callback functions
device.on_data_received = on_data_received_cbk
# # Connect the device
device.connect()

Expand Down Expand Up @@ -141,6 +155,6 @@
datetime.now(tz=timezone.utc),
)

time.sleep(1)
time.sleep(60)

device.disconnect()

0 comments on commit 3e67ed9

Please sign in to comment.