diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f91c35e..cd59e0b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# 0.7.1 +- [x] [fix](https://github.com/Vernacular-ai/dialogy/issues/60): Entity scoring within `EntityExtractor` and `DucklingPlugin`. +- [x] [fix](https://github.com/Vernacular-ai/dialogy/issues/58): CurrencyEntity added to operate on `amount-of-money` dimension. +- [x] add: TimeIntervalEntities sometimes may contain a hybrid structure that resembles some values as `TimeEntities`. + # 0.7.0 - [x] add: `KeywordEntity` entity-type class. - [x] refactor: `ListEntityPlugin` doesn't need an entity map. Uses `KeywordEntity` instead. diff --git a/dialogy/base/entity_extractor.py b/dialogy/base/entity_extractor.py index 7bc90d55..7839fed4 100644 --- a/dialogy/base/entity_extractor.py +++ b/dialogy/base/entity_extractor.py @@ -79,18 +79,17 @@ def aggregate_entities( def apply_filters(self, entities: List[BaseEntity]) -> List[BaseEntity]: """ - [summary] + Conditionally remove entities. - :param entities: [description] + :param entities: A list of entities. :type entities: List[BaseEntity] - :return: [description] + :return: A list of entities. This can be at most the same length as `entities`. :rtype: List[BaseEntity] """ return self.remove_low_scoring_entities(entities) - @staticmethod def entity_consensus( - entities: List[BaseEntity], input_size: int + self, entities: List[BaseEntity], input_size: int ) -> List[BaseEntity]: """ Combine entities by type and value. @@ -108,4 +107,7 @@ def entity_consensus( entity_type_value_group = py_.group_by( entities, lambda entity: (entity.type, entity.get_value()) ) - return EntityExtractor.aggregate_entities(entity_type_value_group, input_size) + aggregate_entities = EntityExtractor.aggregate_entities( + entity_type_value_group, input_size + ) + return self.apply_filters(aggregate_entities) diff --git a/dialogy/plugins/preprocess/text/duckling_plugin.py b/dialogy/plugins/preprocess/text/duckling_plugin.py index ee92c478..d3223b82 100644 --- a/dialogy/plugins/preprocess/text/duckling_plugin.py +++ b/dialogy/plugins/preprocess/text/duckling_plugin.py @@ -246,7 +246,7 @@ def select_datetime( def apply_filters(self, entities: List[BaseEntity]) -> List[BaseEntity]: """ - Filter entities by configurable criteria. + Conditionally remove entities. The utility of this method is tracked here: https://github.com/Vernacular-ai/dialogy/issues/42 @@ -261,7 +261,9 @@ def apply_filters(self, entities: List[BaseEntity]) -> List[BaseEntity]: :rtype: List[BaseEntity] """ if self.datetime_filters: - entities = self.select_datetime(entities, self.datetime_filters) + return self.select_datetime(entities, self.datetime_filters) + + # We call the filters that exist on the EntityExtractor class like threshold filtering. return super().apply_filters(entities) @dbg(log) @@ -396,7 +398,10 @@ def utility(self, *args: Any) -> List[BaseEntity]: for (alternative_index, entities) in enumerate(list_of_entities): shaped_entities.append(self._reshape(entities, alternative_index)) - filtered_entities = self.apply_filters(py_.flatten(shaped_entities)) - return EntityExtractor.entity_consensus(filtered_entities, input_size) + shaped_entities_flattened = py_.flatten(shaped_entities) + aggregate_entities = self.entity_consensus( + shaped_entities_flattened, input_size + ) + return self.apply_filters(aggregate_entities) except ValueError as value_error: raise ValueError(str(value_error)) from value_error diff --git a/dialogy/plugins/preprocess/text/list_entity_plugin.py b/dialogy/plugins/preprocess/text/list_entity_plugin.py index 13d2db0a..6eebc38b 100644 --- a/dialogy/plugins/preprocess/text/list_entity_plugin.py +++ b/dialogy/plugins/preprocess/text/list_entity_plugin.py @@ -201,8 +201,8 @@ def get_entities(self, transcripts: List[str]) -> List[BaseEntity]: log.debug("Parsed entities") log.debug(entities) - filtered_entities = self.apply_filters(entities) - return EntityExtractor.entity_consensus(filtered_entities, len(transcripts)) + aggregated_entities = self.entity_consensus(entities, len(transcripts)) + return self.apply_filters(aggregated_entities) @dbg(log) def utility(self, *args: Any) -> Any: diff --git a/dialogy/types/entity/__init__.py b/dialogy/types/entity/__init__.py index b7d14a29..c475e885 100644 --- a/dialogy/types/entity/__init__.py +++ b/dialogy/types/entity/__init__.py @@ -6,6 +6,7 @@ """ from dialogy.types.entity.base_entity import BaseEntity, entity_synthesis +from dialogy.types.entity.currency_entity import CurrencyEntity from dialogy.types.entity.duration_entity import DurationEntity from dialogy.types.entity.keyword_entity import KeywordEntity from dialogy.types.entity.location_entity import LocationEntity @@ -19,4 +20,5 @@ "people": {"value": PeopleEntity}, "time": {"value": TimeEntity, "interval": TimeIntervalEntity}, "duration": {"value": DurationEntity}, + "amount-of-money": {"value": CurrencyEntity}, } diff --git a/dialogy/types/entity/currency_entity.py b/dialogy/types/entity/currency_entity.py new file mode 100644 index 00000000..1600c3b3 --- /dev/null +++ b/dialogy/types/entity/currency_entity.py @@ -0,0 +1,54 @@ +""" +.. _currency_entity: +Module provides access to entity types that can be parsed to currencies and their value. + +Import classes: + - CurrencyEntity +""" +from typing import Any, Dict + +import attr + +from dialogy import constants as const +from dialogy.types.entity.numerical_entity import NumericalEntity + + +@attr.s +class CurrencyEntity(NumericalEntity): + """ + Numerical Entity Type + + Use this type for handling all entities that can be parsed to obtain: + - numbers + - date + - time + - datetime + + Attributes: + - `dim` dimension of the entity from duckling parser + - `type` is the type of the entity which can have values in ["value", "interval"] + """ + unit = attr.ib( + type=str, validator=attr.validators.instance_of(str), kw_only=True + ) + + @classmethod + def reshape(cls, dict_: Dict[str, Any]) -> Dict[str, Any]: + unit = dict_[const.EntityKeys.VALUE].get(const.EntityKeys.UNIT) + dict_ = super().reshape(dict_) + dict_[const.EntityKeys.UNIT] = unit + return dict_ + + def get_value(self, reference: Any = None) -> Any: + """ + Getter for CurrencyEntity. + + We are yet to decide the pros and cons of the output. It seems retaining {"value": float, "unit": } + + :param reference: [description], defaults to None + :type reference: Any, optional + :return: [description] + :rtype: Any + """ + value = super().get_value(reference=reference) + return f"{self.unit}{value:.2f}" diff --git a/dialogy/types/entity/time_interval_entity.py b/dialogy/types/entity/time_interval_entity.py index cd610e5f..e1d98db5 100644 --- a/dialogy/types/entity/time_interval_entity.py +++ b/dialogy/types/entity/time_interval_entity.py @@ -34,16 +34,15 @@ class TimeIntervalEntity(TimeEntity): @classmethod def reshape(cls, dict_: Dict[str, Any]) -> Dict[str, Any]: dict_ = super(TimeIntervalEntity, cls).reshape(dict_) - if all( - value[const.EntityKeys.TYPE] == const.EntityKeys.INTERVAL - for value in dict_[const.EntityKeys.VALUES] - ): - date_range = dict_[const.EntityKeys.VALUES][0].get( - const.EntityKeys.FROM - ) or dict_[const.EntityKeys.VALUES][0].get(const.EntityKeys.TO) - if not date_range: - raise TypeError(f"{dict_} does not match TimeIntervalEntity format") - dict_[const.EntityKeys.GRAIN] = date_range[const.EntityKeys.GRAIN] + for value in dict_[const.EntityKeys.VALUES]: + if value[const.EntityKeys.TYPE] == const.EntityKeys.INTERVAL: + date_range = dict_[const.EntityKeys.VALUES][0].get( + const.EntityKeys.FROM + ) or dict_[const.EntityKeys.VALUES][0].get(const.EntityKeys.TO) + if not date_range: + raise TypeError(f"{dict_} does not match TimeIntervalEntity format") + dict_[const.EntityKeys.GRAIN] = date_range[const.EntityKeys.GRAIN] + break return dict_ def get_value(self, reference: Any = None) -> Any: @@ -76,6 +75,8 @@ def get_value(self, reference: Any = None) -> Any: if date_dict: return datetime.fromisoformat(date_dict.get(const.EntityKeys.VALUE)) + elif reference.get(const.EntityKeys.VALUE): + return datetime.fromisoformat(reference.get(const.EntityKeys.VALUE)) else: raise KeyError( f"Expected at least 1 of `from` or `to` in {self.values} for {self}" diff --git a/docs/_modules/dialogy/base/entity_extractor.html b/docs/_modules/dialogy/base/entity_extractor.html index 4e27da52..46db8bec 100644 --- a/docs/_modules/dialogy/base/entity_extractor.html +++ b/docs/_modules/dialogy/base/entity_extractor.html @@ -5,7 +5,7 @@ -