Skip to content

Commit

Permalink
Improved validation
Browse files Browse the repository at this point in the history
Signed-off-by: Simone Orru <simone.orru@secomind.com>
  • Loading branch information
sorru94 committed Sep 28, 2023
1 parent 941b4a6 commit 172ff8c
Show file tree
Hide file tree
Showing 6 changed files with 1,389 additions and 641 deletions.
16 changes: 8 additions & 8 deletions astarte/device/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,7 @@ def send(
raise ValidationError("Payload should be different from None")
if isinstance(payload, collections.abc.Mapping):
raise ValidationError("Payload for individual interfaces should not be a dictionary")
validation_result = interface.validate(interface_path, payload, timestamp)
if validation_result:
raise validation_result
interface.validate_payload_and_timestamp(interface_path, payload, timestamp)

self._send_generic(
interface,
Expand Down Expand Up @@ -282,9 +280,7 @@ def send_aggregate(
raise ValidationError("Payload should be different from None")
if not isinstance(payload, collections.abc.Mapping):
raise ValidationError("Payload for aggregate interfaces should be a dictionary")
validation_result = interface.validate(interface_path, payload, timestamp)
if validation_result:
raise validation_result
interface.validate_payload_and_timestamp(interface_path, payload, timestamp)

self._send_generic(
interface,
Expand Down Expand Up @@ -401,7 +397,9 @@ def _on_message_generic(self, interface_name, path, payload):
return

# Check the received path corresponds to the one in the interface
if interface.validate_path(path, payload):
try:
interface.validate_path(path, payload)
except ValidationError:
logging.warning(
"Received message on incorrect endpoint for interface %s: %s, %s",
interface_name,
Expand All @@ -412,7 +410,9 @@ def _on_message_generic(self, interface_name, path, payload):

# Check the payload matches with the interface
if payload:
if interface.validate_payload(path, payload):
try:
interface.validate_payload(path, payload)
except ValidationError:
logging.warning(
"Received incompatible payload for interface %s: %s, %s",
interface_name,
Expand Down
211 changes: 125 additions & 86 deletions astarte/device/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from __future__ import annotations

import re
from datetime import datetime

from astarte.device.mapping import Mapping
Expand All @@ -30,6 +31,10 @@
DEVICE = "device"
SERVER = "server"

name_regex = re.compile(
r"^([a-zA-Z][a-zA-Z0-9]*\.([a-zA-Z0-9][a-zA-Z0-9-]*\.)*)?[a-zA-Z][a-zA-Z0-9]*$"
)


class Interface:
"""
Expand Down Expand Up @@ -74,36 +79,57 @@ def __init__(self, interface_definition: dict):
ValueError
if both version_major and version_minor numbers are set to 0
"""

self.name: str = interface_definition.get("interface_name")
if not isinstance(self.name, str):
raise InterfaceFileDecodeError(
"Interface name is a required interface field and should be a string."
)
if name_regex.match(self.name) is None:
raise InterfaceFileDecodeError(
f"Interface name is not correctly formatted: {self.name}"
)

self.version_major: int = interface_definition.get("version_major")
if not isinstance(self.version_major, int):
raise InterfaceFileDecodeError(
"Major version is a required interface field and should be an integer."
)

self.version_minor: int = interface_definition.get("version_minor")
self.type: str = interface_definition.get("type")
self.ownership = interface_definition.get("ownership")

if not (
isinstance(self.name, str)
and isinstance(self.version_major, int)
and isinstance(self.version_minor, int)
and self.type in {"datastream", "properties"}
and self.ownership in (DEVICE, SERVER)
):
if not isinstance(self.version_minor, int):
raise InterfaceFileDecodeError(
f"Error parsing the following interface definition: {interface_definition}"
"Minor version is a required interface field and should be an integer."
)

if not self.version_major and not self.version_minor:
if (not self.version_major) and (not self.version_minor):
raise InterfaceFileDecodeError(
f"Both Major and Minor versions set to 0 for interface {self.name}"
)

self.aggregation = interface_definition.get("aggregation", "individual")
self.type: str = interface_definition.get("type")
if self.type not in {"datastream", "properties"}:
raise InterfaceFileDecodeError(
"Interface type can be one of 'datastream' and 'properties."
)

self.ownership: str = interface_definition.get("ownership")
if self.ownership not in (DEVICE, SERVER):
raise InterfaceFileDecodeError(
f"Interface ownership can be one of '{DEVICE}' and '{SERVER}'."
)

self.aggregation: str = interface_definition.get("aggregation", "individual")
if self.aggregation not in {"individual", "object"}:
raise InterfaceFileDecodeError(f"Invalid aggregation type for interface {self.name}.")

self.mappings = []
if (self.type == "properties") and (self.aggregation == "object"):
raise InterfaceFileDecodeError("Properties can only be objects.")

self.mappings: list[Mapping] = []
endpoints = []
for mapping_definition in interface_definition.get("mappings", []):
mapping = Mapping(mapping_definition, self.type)
mapping = Mapping(mapping_definition, self.type == "datastream")
if mapping.endpoint in endpoints:
raise InterfaceFileDecodeError(
f"Duplicated mapping {mapping.endpoint} for interface {self.name}."
Expand All @@ -114,6 +140,14 @@ def __init__(self, interface_definition: dict):
if not self.mappings:
raise InterfaceFileDecodeError(f"No mappings in interface {self.name}.")

if self.aggregation == "object":
expl_ts_and_qos = [(m.explicit_timestamp, m.reliability) for m in self.mappings]
if len(set(expl_ts_and_qos)) != 1:
raise InterfaceFileDecodeError(
"All the mappings for objects should have the same explicit_timestamp and "
"reliability fields."
)

def is_aggregation_object(self) -> bool:
"""
Check if the current Interface is a datastream with aggregation object
Expand Down Expand Up @@ -177,8 +211,11 @@ def get_mapping(self, endpoint) -> Mapping | None:
The Mapping if found, None otherwise
"""
for mapping in self.mappings:
if not mapping.validate_path(endpoint):
try:
mapping.validate_path(endpoint)
return mapping
except ValidationError:
pass
return None

def get_reliability(self, endpoint: str) -> int:
Expand Down Expand Up @@ -207,7 +244,7 @@ def get_reliability(self, endpoint: str) -> int:
return mapping.reliability
return 2

def validate_path(self, path: str, payload) -> ValidationError | None:
def validate_path(self, path: str, payload):
"""
Validate that the provided path conforms to the interface.
Expand All @@ -220,23 +257,22 @@ def validate_path(self, path: str, payload) -> ValidationError | None:
payload: object
Payload used to extrapolate the remaining endpoints for aggregated interfaces.
Returns
-------
ValidationError or None
None in case of successful validation, ValidationError otherwise
Raises
------
ValidationError
When validation has failed.
"""
if not self.is_aggregation_object():
if not self.get_mapping(path):
return ValidationError(f"Path {path} not in the {self.name} interface.")
raise ValidationError(f"Path {path} not in the {self.name} interface.")
else:
for k in payload:
if not self.get_mapping(f"{path}/{k}"):
return ValidationError(f"Path {path}/{k} not in the {self.name} interface.")
return None
raise ValidationError(f"Path {path}/{k} not in the {self.name} interface.")

def validate_payload(self, path: str, payload) -> ValidationError | None:
def validate_payload(self, path: str, payload):
"""
Validate that the payload conforms to the interface.
Validate that the payload conforms to the interface definition.
Parameters
----------
Expand All @@ -247,54 +283,36 @@ def validate_payload(self, path: str, payload) -> ValidationError | None:
payload: object
Data to validate
Returns
-------
ValidationError or None
None in case of successful validation, ValidationError otherwise
Raises
------
ValidationError
When validation has failed.
"""

# Validate the payload for the individual mapping
if not self.is_aggregation_object():
return self.get_mapping(path).validate_payload(payload)
mapping: Mapping = self.get_mapping(path)
if mapping is None:
raise ValidationError(f"Mapping not found for path {path}.")
mapping.validate_payload(payload)
return

# Validate the payload for the aggregate mapping
if not isinstance(payload, dict):
return ValidationError(f"Payload not a dict for aggregated interface {self.name}.")
raise ValidationError(f"Payload not a dict for aggregated interface {self.name}.")
for k, v in payload.items():
payload_invalid = self.get_mapping(f"{path}/{k}").validate_payload(v)
if payload_invalid:
return payload_invalid
# Check all the interface endpoints are present in the payload
return self.validate_object_complete(path, payload)

def validate_object_complete(self, path: str, payload):
"""
Validate that the payload contains all the endpoints for an aggregated interface.
Shall only be used on device owned interfaces, as server interfaces could be sent
incomplete.
Parameters
----------
path: str
Path on which the payload has been received. This is assumed to correspond to a valid
partial mapping.
payload: object
Data to validate
mapping: Mapping = self.get_mapping(f"{path}/{k}")
if mapping is None:
raise ValidationError(f"Mapping not found for path {path}/{k}.")
mapping.validate_payload(v)

Returns
-------
ValidationError or None
None in case of successful validation, ValidationError otherwise
"""
# Check all the interface endpoints are present in the payload
if not self.is_server_owned():
path_segments = path.count("/") + 1
for endpoint in [m.endpoint for m in self.mappings]:
non_common_endpoint = "/".join(endpoint.split("/")[path_segments:])
if non_common_endpoint not in payload:
return ValidationError(
f"Path {endpoint} of {self.name} interface not in payload."
)
return None
self._validate_object_completeness(path, payload)

def validate(self, path: str, payload, timestamp: datetime | None) -> ValidationError | None:
def validate_payload_and_timestamp(self, path: str, payload, timestamp: datetime | None):
"""
Interface Data validation.
Validate that path, payload and timestamp conform to the interface definition.
Parameters
----------
Expand All @@ -305,35 +323,56 @@ def validate(self, path: str, payload, timestamp: datetime | None) -> Validation
timestamp: datetime or None
Timestamp associated to the payload
Returns
-------
ValidationError or None
None in case of successful validation, ValidationError otherwise
Raises
------
ValidationError
When validation has failed.
"""

# Validate the payload for the individual mapping
if not self.is_aggregation_object():
mapping = self.get_mapping(path)
if mapping:
p_err = mapping.validate_payload(payload)
t_err = mapping.validate_timestamp(timestamp)
return p_err if p_err else t_err
return ValidationError(f"Path {path} not in the {self.name} interface.")
if mapping is None:
raise ValidationError(f"Path {path} not in the {self.name} interface.")
mapping.validate_payload(payload)
mapping.validate_timestamp(timestamp)
return

# Validate the payload for the aggregate mapping
if not isinstance(payload, dict):
return ValidationError(
f"The interface {self.name} is aggregate, but the payload is not a dictionary."
)
raise ValidationError(f"Interface {self.name} is aggregate, payload not a dictionary.")
for k, v in payload.items():
mapping = self.get_mapping(f"{path}/{k}")
if mapping:
p_err = mapping.validate_payload(v)
t_err = mapping.validate_timestamp(timestamp)
if p_err or t_err:
return p_err if p_err else t_err
else:
return ValidationError(f"Path {path}/{k} not in the {self.name} interface.")
if mapping is None:
raise ValidationError(f"Path {path}/{k} not in the {self.name} interface.")
mapping.validate_payload(v)
mapping.validate_timestamp(timestamp)

# Check all the mappings are present in the payload
return self.validate_object_complete(path, payload)
if not self.is_server_owned():
self._validate_object_completeness(path, payload)

def _validate_object_completeness(self, path: str, payload):
"""
Validate that the payload contains all the endpoints for an aggregated interface.
Shall only be used on device owned interfaces, as server interfaces could be sent
incomplete.
Parameters
----------
path: str
Path on which the payload has been received. This is assumed to correspond to a valid
partial mapping.
payload: object
Data to validate
Raises
------
ValidationError
When validation has failed.
"""
path_segments = path.count("/") + 1
for endpoint in [m.endpoint for m in self.mappings]:
non_common_endpoint = "/".join(endpoint.split("/")[path_segments:])
if non_common_endpoint not in payload:
raise ValidationError(f"Path {endpoint} of {self.name} interface not in payload.")
Loading

0 comments on commit 172ff8c

Please sign in to comment.