Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve validation raising exceptions in place #121

Merged
merged 1 commit into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions astarte/device/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,7 @@ def send(
raise ValidationError("Payload should be different from None")
if isinstance(payload, collections.abc.Mapping):
raise ValidationError("Payload for individual interfaces should not be a dictionary")
validation_result = interface.validate(interface_path, payload, timestamp)
if validation_result:
raise validation_result
interface.validate_payload_and_timestamp(interface_path, payload, timestamp)

self._send_generic(
interface,
Expand Down Expand Up @@ -282,9 +280,7 @@ def send_aggregate(
raise ValidationError("Payload should be different from None")
if not isinstance(payload, collections.abc.Mapping):
raise ValidationError("Payload for aggregate interfaces should be a dictionary")
validation_result = interface.validate(interface_path, payload, timestamp)
if validation_result:
raise validation_result
interface.validate_payload_and_timestamp(interface_path, payload, timestamp)

self._send_generic(
interface,
Expand Down Expand Up @@ -401,7 +397,9 @@ def _on_message_generic(self, interface_name, path, payload):
return

# Check the received path corresponds to the one in the interface
if interface.validate_path(path, payload):
try:
interface.validate_path(path, payload)
except ValidationError:
logging.warning(
"Received message on incorrect endpoint for interface %s: %s, %s",
interface_name,
Expand All @@ -412,7 +410,9 @@ def _on_message_generic(self, interface_name, path, payload):

# Check the payload matches with the interface
if payload:
if interface.validate_payload(path, payload):
try:
interface.validate_payload(path, payload)
except ValidationError:
logging.warning(
"Received incompatible payload for interface %s: %s, %s",
interface_name,
Expand Down
213 changes: 127 additions & 86 deletions astarte/device/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from __future__ import annotations

import re
from datetime import datetime

from astarte.device.mapping import Mapping
Expand All @@ -30,6 +31,10 @@
DEVICE = "device"
SERVER = "server"

name_regex = re.compile(
r"^([a-zA-Z][a-zA-Z0-9]*\.([a-zA-Z0-9][a-zA-Z0-9-]*\.)*)?[a-zA-Z][a-zA-Z0-9]*$"
)


class Interface:
"""
Expand Down Expand Up @@ -74,36 +79,59 @@ def __init__(self, interface_definition: dict):
ValueError
if both version_major and version_minor numbers are set to 0
"""

self.name: str = interface_definition.get("interface_name")
if not isinstance(self.name, str):
raise InterfaceFileDecodeError(
"Interface name is a required interface field and should be a string."
)
if name_regex.match(self.name) is None:
raise InterfaceFileDecodeError(
f"Interface name is not correctly formatted: {self.name}"
)

self.version_major: int = interface_definition.get("version_major")
if not isinstance(self.version_major, int):
raise InterfaceFileDecodeError(
"Major version is a required interface field and should be an integer."
)

self.version_minor: int = interface_definition.get("version_minor")
self.type: str = interface_definition.get("type")
self.ownership = interface_definition.get("ownership")

if not (
isinstance(self.name, str)
and isinstance(self.version_major, int)
and isinstance(self.version_minor, int)
and self.type in {"datastream", "properties"}
and self.ownership in (DEVICE, SERVER)
):
if not isinstance(self.version_minor, int):
raise InterfaceFileDecodeError(
f"Error parsing the following interface definition: {interface_definition}"
"Minor version is a required interface field and should be an integer."
)

if not self.version_major and not self.version_minor:
if (not self.version_major) and (not self.version_minor):
sorru94 marked this conversation as resolved.
Show resolved Hide resolved
raise InterfaceFileDecodeError(
f"Both Major and Minor versions set to 0 for interface {self.name}"
)

self.aggregation = interface_definition.get("aggregation", "individual")
self.type: str = interface_definition.get("type")
if self.type not in {"datastream", "properties"}:
raise InterfaceFileDecodeError(
"Interface type can be one of 'datastream' and 'properties."
)

self.ownership: str = interface_definition.get("ownership")
if self.ownership not in (DEVICE, SERVER):
raise InterfaceFileDecodeError(
f"Interface ownership can be one of '{DEVICE}' and '{SERVER}'."
)

self.aggregation: str = interface_definition.get("aggregation", "individual")
if self.aggregation not in {"individual", "object"}:
raise InterfaceFileDecodeError(f"Invalid aggregation type for interface {self.name}.")

self.mappings = []
if (self.type == "properties") and (self.aggregation == "object"):
sorru94 marked this conversation as resolved.
Show resolved Hide resolved
raise InterfaceFileDecodeError(
"Invalid aggregation type 'object', properties can only be 'individual'."
)

self.mappings: list[Mapping] = []
endpoints = []
for mapping_definition in interface_definition.get("mappings", []):
mapping = Mapping(mapping_definition, self.type)
mapping = Mapping(mapping_definition, self.type == "datastream")
if mapping.endpoint in endpoints:
raise InterfaceFileDecodeError(
f"Duplicated mapping {mapping.endpoint} for interface {self.name}."
Expand All @@ -114,6 +142,14 @@ def __init__(self, interface_definition: dict):
if not self.mappings:
raise InterfaceFileDecodeError(f"No mappings in interface {self.name}.")

if self.aggregation == "object":
expl_ts_and_qos = [(m.explicit_timestamp, m.reliability) for m in self.mappings]
if len(set(expl_ts_and_qos)) != 1:
raise InterfaceFileDecodeError(
"All the mappings for objects should have the same explicit_timestamp and "
"reliability fields."
)

def is_aggregation_object(self) -> bool:
"""
Check if the current Interface is a datastream with aggregation object
Expand Down Expand Up @@ -177,8 +213,11 @@ def get_mapping(self, endpoint) -> Mapping | None:
The Mapping if found, None otherwise
"""
for mapping in self.mappings:
if not mapping.validate_path(endpoint):
try:
mapping.validate_path(endpoint)
return mapping
except ValidationError:
pass
return None

def get_reliability(self, endpoint: str) -> int:
Expand Down Expand Up @@ -207,7 +246,7 @@ def get_reliability(self, endpoint: str) -> int:
return mapping.reliability
return 2

def validate_path(self, path: str, payload) -> ValidationError | None:
def validate_path(self, path: str, payload):
"""
Validate that the provided path conforms to the interface.

Expand All @@ -220,23 +259,22 @@ def validate_path(self, path: str, payload) -> ValidationError | None:
payload: object
Payload used to extrapolate the remaining endpoints for aggregated interfaces.

Returns
-------
ValidationError or None
None in case of successful validation, ValidationError otherwise
Raises
------
ValidationError
When validation has failed.
"""
if not self.is_aggregation_object():
if not self.get_mapping(path):
return ValidationError(f"Path {path} not in the {self.name} interface.")
raise ValidationError(f"Path {path} not in the {self.name} interface.")
else:
for k in payload:
if not self.get_mapping(f"{path}/{k}"):
return ValidationError(f"Path {path}/{k} not in the {self.name} interface.")
return None
raise ValidationError(f"Path {path}/{k} not in the {self.name} interface.")

def validate_payload(self, path: str, payload) -> ValidationError | None:
def validate_payload(self, path: str, payload):
"""
Validate that the payload conforms to the interface.
Validate that the payload conforms to the interface definition.

Parameters
----------
Expand All @@ -247,54 +285,36 @@ def validate_payload(self, path: str, payload) -> ValidationError | None:
payload: object
Data to validate

Returns
-------
ValidationError or None
None in case of successful validation, ValidationError otherwise
Raises
------
ValidationError
When validation has failed.
"""

# Validate the payload for the individual mapping
if not self.is_aggregation_object():
return self.get_mapping(path).validate_payload(payload)
mapping: Mapping = self.get_mapping(path)
if mapping is None:
raise ValidationError(f"Mapping not found for path {path}.")
mapping.validate_payload(payload)
return

# Validate the payload for the aggregate mapping
if not isinstance(payload, dict):
return ValidationError(f"Payload not a dict for aggregated interface {self.name}.")
raise ValidationError(f"Payload not a dict for aggregated interface {self.name}.")
for k, v in payload.items():
payload_invalid = self.get_mapping(f"{path}/{k}").validate_payload(v)
if payload_invalid:
return payload_invalid
# Check all the interface endpoints are present in the payload
return self.validate_object_complete(path, payload)

def validate_object_complete(self, path: str, payload):
"""
Validate that the payload contains all the endpoints for an aggregated interface.
Shall only be used on device owned interfaces, as server interfaces could be sent
incomplete.

Parameters
----------
path: str
Path on which the payload has been received. This is assumed to correspond to a valid
partial mapping.
payload: object
Data to validate
mapping: Mapping = self.get_mapping(f"{path}/{k}")
if mapping is None:
raise ValidationError(f"Mapping not found for path {path}/{k}.")
mapping.validate_payload(v)

Returns
-------
ValidationError or None
None in case of successful validation, ValidationError otherwise
"""
# Check all the interface endpoints are present in the payload
if not self.is_server_owned():
path_segments = path.count("/") + 1
for endpoint in [m.endpoint for m in self.mappings]:
non_common_endpoint = "/".join(endpoint.split("/")[path_segments:])
if non_common_endpoint not in payload:
return ValidationError(
f"Path {endpoint} of {self.name} interface not in payload."
)
return None
self._validate_object_completeness(path, payload)

def validate(self, path: str, payload, timestamp: datetime | None) -> ValidationError | None:
def validate_payload_and_timestamp(self, path: str, payload, timestamp: datetime | None):
"""
Interface Data validation.
Validate that path, payload and timestamp conform to the interface definition.

Parameters
----------
Expand All @@ -305,35 +325,56 @@ def validate(self, path: str, payload, timestamp: datetime | None) -> Validation
timestamp: datetime or None
Timestamp associated to the payload

Returns
-------
ValidationError or None
None in case of successful validation, ValidationError otherwise
Raises
------
ValidationError
When validation has failed.
"""

# Validate the payload for the individual mapping
if not self.is_aggregation_object():
mapping = self.get_mapping(path)
if mapping:
p_err = mapping.validate_payload(payload)
t_err = mapping.validate_timestamp(timestamp)
return p_err if p_err else t_err
return ValidationError(f"Path {path} not in the {self.name} interface.")
if mapping is None:
raise ValidationError(f"Path {path} not in the {self.name} interface.")
mapping.validate_payload(payload)
mapping.validate_timestamp(timestamp)
return

# Validate the payload for the aggregate mapping
if not isinstance(payload, dict):
return ValidationError(
f"The interface {self.name} is aggregate, but the payload is not a dictionary."
)
raise ValidationError(f"Interface {self.name} is aggregate, payload not a dictionary.")
for k, v in payload.items():
mapping = self.get_mapping(f"{path}/{k}")
if mapping:
p_err = mapping.validate_payload(v)
t_err = mapping.validate_timestamp(timestamp)
if p_err or t_err:
return p_err if p_err else t_err
else:
return ValidationError(f"Path {path}/{k} not in the {self.name} interface.")
if mapping is None:
raise ValidationError(f"Path {path}/{k} not in the {self.name} interface.")
mapping.validate_payload(v)
mapping.validate_timestamp(timestamp)

# Check all the mappings are present in the payload
return self.validate_object_complete(path, payload)
if not self.is_server_owned():
self._validate_object_completeness(path, payload)

def _validate_object_completeness(self, path: str, payload):
"""
Validate that the payload contains all the endpoints for an aggregated interface.
Shall only be used on device owned interfaces, as server interfaces could be sent
incomplete.

Parameters
----------
path: str
Path on which the payload has been received. This is assumed to correspond to a valid
partial mapping.
payload: object
Data to validate

Raises
------
ValidationError
When validation has failed.
"""
path_segments = path.count("/") + 1
for endpoint in [m.endpoint for m in self.mappings]:
non_common_endpoint = "/".join(endpoint.split("/")[path_segments:])
if non_common_endpoint not in payload:
raise ValidationError(f"Path {endpoint} of {self.name} interface not in payload.")
Loading
Loading