diff --git a/dialogy/types/entity/base_entity.py b/dialogy/types/entity/base_entity.py index 7feb3d3d..e9c4117b 100644 --- a/dialogy/types/entity/base_entity.py +++ b/dialogy/types/entity/base_entity.py @@ -101,6 +101,9 @@ class BaseEntity: __properties_map = const.BASE_ENTITY_PROPS + def __attrs_post_init__(self) -> None: + self.entity_type = self.type + @classmethod def validate(cls, dict_: Dict[str, Any]) -> None: """ diff --git a/tests/types/entity/test_entities.py b/tests/types/entity/test_entities.py index b1f7b953..899d9a9c 100644 --- a/tests/types/entity/test_entities.py +++ b/tests/types/entity/test_entities.py @@ -318,3 +318,15 @@ def test_entity_grain_to_type() -> None: ) assert entity.entity_type == "hour" assert entity.type == "hour" + + +def test_both_entity_type_attributes_match() -> None: + body = "4 things" + value = {"value": 4} + entity = BaseEntity( + range={"from": 0, "to": len(body)}, + body=body, + type="base", + values=[value], + ) + assert entity.type == entity.entity_type