Skip to content

Commit

Permalink
Refactor constraints so that each constraint is it's own class (#23753)
Browse files Browse the repository at this point in the history
  • Loading branch information
tehampson authored and pull[bot] committed Oct 2, 2023
1 parent d262c9b commit 3685942
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 110 deletions.
1 change: 1 addition & 0 deletions src/controller/python/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ chip_python_wheel_action("chip-core") {
"chip/utils/CommissioningBuildingBlocks.py",
"chip/utils/__init__.py",
"chip/yaml/__init__.py",
"chip/yaml/constraints.py",
"chip/yaml/data_model_lookup.py",
"chip/yaml/errors.py",
"chip/yaml/format_converter.py",
Expand Down
218 changes: 218 additions & 0 deletions src/controller/python/chip/yaml/constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
#
# Copyright (c) 2022 Project CHIP Authors
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from abc import ABC, abstractmethod
import chip.yaml.format_converter as Converter
from .variable_storage import VariableStorage


class ConstraintValidationError(Exception):
def __init__(self, message):
super().__init__(message)


class BaseConstraint(ABC):
'''Constrain Interface'''

@abstractmethod
def is_met(self, response) -> bool:
pass


class _LoadableConstraint(BaseConstraint):
'''Constraints where value might be stored in VariableStorage needing runtime load.'''

def __init__(self, value, field_type, variable_storage: VariableStorage):
self._variable_storage = variable_storage
# When not none _indirect_value_key is binding a name to the constraint value, and the
# actual value can only be looked-up dynamically, which is why this is a key name.
self._indirect_value_key = None
self._value = None

if value is None:
# Default values set above is all we need here.
return

if isinstance(value, str) and self._variable_storage.is_key_saved(value):
self._indirect_value_key = value
else:
self._value = Converter.convert_yaml_type(
value, field_type)

def get_value(self):
'''Gets the current value of the constraint.
This method accounts for getting the runtime saved value from DUT previous responses.
'''
if self._indirect_value_key:
return self._variable_storage.load(self._indirect_value_key)
return self._value


class _ConstraintHasValue(BaseConstraint):
def __init__(self, has_value):
self._has_value = has_value

def is_met(self, response) -> bool:
raise ConstraintValidationError('HasValue constraint currently not implemented')


class _ConstraintType(BaseConstraint):
def __init__(self, type):
self._type = type

def is_met(self, response) -> bool:
raise ConstraintValidationError('Type constraint currently not implemented')


class _ConstraintStartsWith(BaseConstraint):
def __init__(self, starts_with):
self._starts_with = starts_with

def is_met(self, response) -> bool:
return response.startswith(self._starts_with)


class _ConstraintEndsWith(BaseConstraint):
def __init__(self, ends_with):
self._ends_with = ends_with

def is_met(self, response) -> bool:
return response.endswith(self._ends_with)


class _ConstraintIsUpperCase(BaseConstraint):
def __init__(self, is_upper_case):
self._is_upper_case = is_upper_case

def is_met(self, response) -> bool:
return response.isupper() == self._is_upper_case


class _ConstraintIsLowerCase(BaseConstraint):
def __init__(self, is_lower_case):
self._is_lower_case = is_lower_case

def is_met(self, response) -> bool:
return response.islower() == self._is_lower_case


class _ConstraintMinValue(_LoadableConstraint):
def __init__(self, min_value, field_type, variable_storage: VariableStorage):
super().__init__(min_value, field_type, variable_storage)

def is_met(self, response) -> bool:
min_value = self.get_value()
return response >= min_value


class _ConstraintMaxValue(_LoadableConstraint):
def __init__(self, max_value, field_type, variable_storage: VariableStorage):
super().__init__(max_value, field_type, variable_storage)

def is_met(self, response) -> bool:
max_value = self.get_value()
return response <= max_value


class _ConstraintContains(BaseConstraint):
def __init__(self, contains):
self._contains = contains

def is_met(self, response) -> bool:
return set(self._contains).issubset(response)


class _ConstraintExcludes(BaseConstraint):
def __init__(self, excludes):
self._excludes = excludes

def is_met(self, response) -> bool:
return set(self._excludes).isdisjoint(response)


class _ConstraintHasMaskSet(BaseConstraint):
def __init__(self, has_masks_set):
self._has_masks_set = has_masks_set

def is_met(self, response) -> bool:
return all([(response & mask) == mask for mask in self._has_masks_set])


class _ConstraintHasMaskClear(BaseConstraint):
def __init__(self, has_masks_clear):
self._has_masks_clear = has_masks_clear

def is_met(self, response) -> bool:
return all([(response & mask) == 0 for mask in self._has_masks_clear])


class _ConstraintNotValue(_LoadableConstraint):
def __init__(self, not_value, field_type, variable_storage: VariableStorage):
super().__init__(not_value, field_type, variable_storage)

def is_met(self, response) -> bool:
not_value = self.get_value()
return response != not_value


def get_constraints(constraints, field_type,
variable_storage: VariableStorage) -> list[BaseConstraint]:
_constraints = []
if 'hasValue' in constraints:
_constraints.append(_ConstraintHasValue(constraints.get('hasValue')))

if 'type' in constraints:
_constraints.append(_ConstraintType(constraints.get('type')))

if 'startsWith' in constraints:
_constraints.append(_ConstraintStartsWith(constraints.get('startsWith')))

if 'endsWith' in constraints:
_constraints.append(_ConstraintEndsWith(constraints.get('endsWith')))

if 'isUpperCase' in constraints:
_constraints.append(_ConstraintIsUpperCase(constraints.get('isUpperCase')))

if 'isLowerCase' in constraints:
_constraints.append(_ConstraintIsLowerCase(constraints.get('isLowerCase')))

if 'minValue' in constraints:
_constraints.append(_ConstraintMinValue(
constraints.get('minValue'), field_type, variable_storage))

if 'maxValue' in constraints:
_constraints.append(_ConstraintMaxValue(
constraints.get('maxValue'), field_type, variable_storage))

if 'contains' in constraints:
_constraints.append(_ConstraintContains(constraints.get('contains')))

if 'excludes' in constraints:
_constraints.append(_ConstraintExcludes(constraints.get('excludes')))

if 'hasMasksSet' in constraints:
_constraints.append(_ConstraintHasMaskSet(constraints.get('hasMasksSet')))

if 'hasMasksClear' in constraints:
_constraints.append(_ConstraintHasMaskClear(constraints.get('hasMasksClear')))

if 'notValue' in constraints:
_constraints.append(_ConstraintNotValue(
constraints.get('notValue'), field_type, variable_storage))

return _constraints
116 changes: 6 additions & 110 deletions src/controller/python/chip/yaml/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from chip import ChipDeviceCtrl
from chip.clusters.Types import NullValue
from chip.tlv import float32
import yaml
import stringcase
Expand All @@ -30,6 +29,7 @@
from .data_model_lookup import *
import chip.yaml.format_converter as Converter
from .variable_storage import VariableStorage
from .constraints import get_constraints

_SUCCESS_STATUS_CODE = "SUCCESS"
_NODE_ID_DEFAULT = 0x12345
Expand All @@ -50,110 +50,6 @@ class _ExecutionContext:
config_values: dict = None


class _ConstraintValue:
'''Constraints that are numeric primitive data types'''

def __init__(self, value, field_type, context: _ExecutionContext):
self._variable_storage = context.variable_storage
# When not none _indirect_value_key is binding a name to the constraint value, and the
# actual value can only be looked-up dynamically, which is why this is a key name.
self._indirect_value_key = None
self._value = None

if value is None:
# Default values set above is all we need here.
return

if isinstance(value, str) and self._variable_storage.is_key_saved(value):
self._indirect_value_key = value
else:
self._value = Converter.convert_yaml_type(
value, field_type)

def get_value(self):
'''Gets the current value of the constraint.
This method accounts for getting the runtime saved value from DUT previous responses.
'''
if self._indirect_value_key:
return self._variable_storage.load(self._indirect_value_key)
return self._value


class _Constraints:
def __init__(self, constraints: dict, field_type, context: _ExecutionContext):
self._variable_storage = context.variable_storage
self._has_value = constraints.get('hasValue')
self._type = constraints.get('type')
self._starts_with = constraints.get('startsWith')
self._ends_with = constraints.get('endsWith')
self._is_upper_case = constraints.get('isUpperCase')
self._is_lower_case = constraints.get('isLowerCase')
self._min_value = _ConstraintValue(constraints.get('minValue'), field_type,
context)
self._max_value = _ConstraintValue(constraints.get('maxValue'), field_type,
context)
self._contains = constraints.get('contains')
self._excludes = constraints.get('excludes')
self._has_masks_set = constraints.get('hasMasksSet')
self._has_masks_clear = constraints.get('hasMasksClear')
self._not_value = _ConstraintValue(constraints.get('notValue'), field_type,
context)

def are_constrains_met(self, response) -> bool:
return_value = True

if self._has_value:
logger.warn(f'HasValue constraint currently not implemented, forcing failure')
return_value = False

if self._type:
logger.warn(f'Type constraint currently not implemented, forcing failure')
return_value = False

if self._starts_with and not response.startswith(self._starts_with):
return_value = False

if self._ends_with and not response.endswith(self._ends_with):
return_value = False

if self._is_upper_case and not response.isupper():
return_value = False

if self._is_lower_case and not response.islower():
return_value = False

min_value = self._min_value.get_value()
if response is not NullValue and min_value and response < min_value:
return_value = False

max_value = self._max_value.get_value()
if response is not NullValue and max_value and response > max_value:
return_value = False

if self._contains and not set(self._contains).issubset(response):
return_value = False

if self._excludes and not set(self._excludes).isdisjoint(response):
return_value = False

if self._has_masks_set:
for mask in self._has_masks_set:
if (response & mask) != mask:
return_value = False

if self._has_masks_clear:
for mask in self._has_masks_clear:
if (response & mask) != 0:
return_value = False

not_value = self._not_value.get_value()
if not_value and response == not_value:
return_value = False

return return_value


class _VariableToSave:
def __init__(self, variable_name: str, variable_storage: VariableStorage):
self._variable_name = variable_name
Expand Down Expand Up @@ -311,7 +207,7 @@ def __init__(self, item: dict, cluster: str, context: _ExecutionContext):
'''
super().__init__(item['label'])
self._attribute_name = stringcase.pascalcase(item['attribute'])
self._constraints = None
self._constraints = []
self._cluster = cluster
self._cluster_object = None
self._request_object = None
Expand Down Expand Up @@ -362,9 +258,9 @@ def __init__(self, item: dict, cluster: str, context: _ExecutionContext):

constraints = self._expected_raw_response.get('constraints')
if constraints:
self._constraints = _Constraints(constraints,
self._request_object.attribute_type.Type,
context)
self._constraints = get_constraints(constraints,
self._request_object.attribute_type.Type,
context.variable_storage)

def run_action(self, dev_ctrl: ChipDeviceCtrl, endpoint: int, node_id: int):
try:
Expand All @@ -391,7 +287,7 @@ def run_action(self, dev_ctrl: ChipDeviceCtrl, endpoint: int, node_id: int):
if self._variable_to_save is not None:
self._variable_to_save.save_response(parsed_resp)

if self._constraints and not self._constraints.are_constrains_met(parsed_resp):
if not all([constraint.is_met(parsed_resp) for constraint in self._constraints]):
logger.error(f'Constraints check failed')
# TODO how should we fail the test here?

Expand Down

0 comments on commit 3685942

Please sign in to comment.