diff --git a/dialogy/preprocess/text/duckling_plugin.py b/dialogy/preprocess/text/duckling_plugin.py index 1a493549..0b4fda6d 100644 --- a/dialogy/preprocess/text/duckling_plugin.py +++ b/dialogy/preprocess/text/duckling_plugin.py @@ -114,6 +114,7 @@ def __init__( url: str = "http://0.0.0.0:8000/parse", access: Optional[PluginFn] = None, mutate: Optional[PluginFn] = None, + custom_entity_map: Optional[Dict[str, Any]] = None, debug: bool = False, ) -> None: """ @@ -129,6 +130,11 @@ def __init__( "Content-Type": "application/x-www-form-urlencoded; charset=UTF-8" } + if isinstance(custom_entity_map, dict): + self.dimension_entity_map = {**dimension_entity_map, **custom_entity_map} + else: + self.dimension_entity_map = dimension_entity_map + def __set_timezone(self) -> Optional[BaseTzInfo]: """ Set timezone as BaseTzInfo from compatible timezone string. @@ -214,9 +220,9 @@ def _reshape( if entity[EntityKeys.VALUE][EntityKeys.TYPE] == EntityKeys.INTERVAL: # Duckling entities with interval type have a different structure for value(s). # They have a need to express units in "from", "to" format. - cls = dimension_entity_map[entity[EntityKeys.DIM]][EntityKeys.INTERVAL] # type: ignore + cls = self.dimension_entity_map[entity[EntityKeys.DIM]][EntityKeys.INTERVAL] # type: ignore else: - cls = dimension_entity_map[entity[EntityKeys.DIM]][EntityKeys.VALUE] # type: ignore + cls = self.dimension_entity_map[entity[EntityKeys.DIM]][EntityKeys.VALUE] # type: ignore # The most appropriate class is picked for making an object from the dict. duckling_entity = cls.from_dict(entity) # Depending on the type of entity, the value is searched and filled. diff --git a/tests/preprocess/text/test_duckling_plugin.py b/tests/preprocess/text/test_duckling_plugin.py index 1ba56c08..56114d17 100644 --- a/tests/preprocess/text/test_duckling_plugin.py +++ b/tests/preprocess/text/test_duckling_plugin.py @@ -4,11 +4,26 @@ import pytest from dialogy.preprocess.text.duckling_plugin import DucklingPlugin +from dialogy.types import BaseEntity from dialogy.workflow import Workflow from tests import EXCEPTIONS, load_tests, request_builder -# == Test missing i/o == +def test_plugin_with_custom_entity_map() -> None: + """ + Here we are checking if the plugin has access to workflow. + Since we haven't provided `access`, `mutate` to `DucklingPlugin` + we will receive a `TypeError`. + """ + parser = DucklingPlugin( + locale="en_IN", + timezone="Asia/Kolkata", + dimensions=["time"], + custom_entity_map={"number": {"value": BaseEntity}}, + ) + assert parser.dimension_entity_map["number"]["value"] == BaseEntity + + def test_plugin_io_missing() -> None: """ Here we are checking if the plugin has access to workflow.