diff --git a/changelog/7707.bugfix.md b/changelog/7707.bugfix.md new file mode 100644 index 000000000000..9d81c3ccd708 --- /dev/null +++ b/changelog/7707.bugfix.md @@ -0,0 +1,2 @@ +Add the option to configure whether extracted entities should be split by comma (`","`) or not to TEDPolicy. Fixes +crash when this parameter is accessed during extraction. \ No newline at end of file diff --git a/docs/docs/policies.mdx b/docs/docs/policies.mdx index 6dcd139907cf..e53c9eaf6caf 100644 --- a/docs/docs/policies.mdx +++ b/docs/docs/policies.mdx @@ -176,6 +176,27 @@ If you want to fine-tune your model, start by modifying the following parameters set `weight_sparsity` to 1 as this would result in all kernel weights being 0, i.e. the model is not able to learn. +* `split_entities_by_comma`: + This parameter defines whether adjacent entities separated by a comma should be treated as one, or split. For example, + entities with the type `ingredients`, like "apple, banana" can be split into "apple" and "banana". An entity with type + `address`, like "Schönhauser Allee 175, 10119 Berlin" should be treated as one. + + Can either be + `True`/`False` globally: + ```yaml-rasa title="config.yml" + policies: + - name: TEDPolicy + split_entities_by_comma: True + ``` + or set per entity type, such as: + ```yaml-rasa title="config.yml" + policies: + - name: TEDPolicy + split_entities_by_comma: + address: False + ingredients: True + ``` + The above configuration parameters are the ones you should configure to fit your model to your data. However, additional parameters exist that can be adapted. @@ -320,6 +341,15 @@ However, additional parameters exist that can be adapted. | entity_recognition | True | If 'True' entity recognition is trained and entities are | | | | extracted. | +---------------------------------------+------------------------+--------------------------------------------------------------+ +| split_entities_by_comma | True | Splits a list of extracted entities by comma to treat each | +| | | one of them as a single entity. Can either be `True`/`False` | +| | | globally, or set per entity type, such as: | +| | | ``` | +| | | - name: TEDPolicy | +| | | split_entities_by_comma: | +| | | address: True | +| | | ``` | ++---------------------------------------+------------------------+--------------------------------------------------------------+ ``` :::note diff --git a/rasa/core/policies/ted_policy.py b/rasa/core/policies/ted_policy.py index f06530b26a54..38b3088922e6 100644 --- a/rasa/core/policies/ted_policy.py +++ b/rasa/core/policies/ted_policy.py @@ -32,6 +32,8 @@ ENTITY_ATTRIBUTE_TYPE, ENTITY_TAGS, EXTRACTOR, + SPLIT_ENTITIES_BY_COMMA, + SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE, ) from rasa.shared.nlu.interpreter import NaturalLanguageInterpreter from rasa.core.policies.policy import Policy, PolicyPrediction @@ -272,6 +274,10 @@ class TEDPolicy(Policy): FEATURIZERS: [], # If set to true, entities are predicted in user utterances. ENTITY_RECOGNITION: True, + # Split entities by comma, this makes sense e.g. for a list of + # ingredients in a recipe, but it doesn't make sense for the parts of + # an address + SPLIT_ENTITIES_BY_COMMA: SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE, } @staticmethod @@ -292,6 +298,11 @@ def __init__( **kwargs: Any, ) -> None: """Declare instance variables with default values.""" + self.split_entities_config = rasa.utils.train_utils.init_split_entities( + kwargs.get(SPLIT_ENTITIES_BY_COMMA, SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE), + self.defaults[SPLIT_ENTITIES_BY_COMMA], + ) + if not featurizer: featurizer = self._standard_featurizer(max_history) @@ -662,7 +673,11 @@ def _create_optional_event_for_entities( parsed_message = interpreter.featurize_message(Message(data={TEXT: text})) tokens = parsed_message.get(TOKENS_NAMES[TEXT]) entities = EntityExtractor.convert_predictions_into_entities( - text, tokens, predicted_tags, confidences=confidence_values + text, + tokens, + predicted_tags, + self.split_entities_config, + confidences=confidence_values, ) # add the extractor name diff --git a/rasa/nlu/extractors/extractor.py b/rasa/nlu/extractors/extractor.py index 4d79b89ab622..7d51e3b40f17 100644 --- a/rasa/nlu/extractors/extractor.py +++ b/rasa/nlu/extractors/extractor.py @@ -28,17 +28,40 @@ SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE, SINGLE_ENTITY_ALLOWED_INTERLEAVING_CHARSET, ) +import rasa.utils.train_utils class EntityExtractor(Component): + """Entity extractors are components which extract entities. + + They can be placed in the pipeline like other components, and can extract + entities like a person's name, or a location. + """ + def add_extractor_name( self, entities: List[Dict[Text, Any]] ) -> List[Dict[Text, Any]]: + """Adds this extractor's name to a list of entities. + + Args: + entities: the extracted entities. + + Returns: + the modified entities. + """ for entity in entities: entity[EXTRACTOR] = self.name return entities def add_processor_name(self, entity: Dict[Text, Any]) -> Dict[Text, Any]: + """Adds this extractor's name to the list of processors for this entity. + + Args: + entity: the extracted entity and its metadata. + + Returns: + the modified entity. + """ if "processors" in entity: entity["processors"].append(self.name) else: @@ -46,18 +69,23 @@ def add_processor_name(self, entity: Dict[Text, Any]) -> Dict[Text, Any]: return entity - def init_split_entities(self): - """Initialise the behaviour for splitting entities by comma (or not).""" + def init_split_entities(self) -> Dict[Text, bool]: + """Initialises the behaviour for splitting entities by comma (or not). + + Returns: + Defines desired behaviour for splitting specific entity types and + default behaviour for splitting any entity types for which no + behaviour is defined. + """ split_entities_config = self.component_config.get( SPLIT_ENTITIES_BY_COMMA, SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE ) - if isinstance(split_entities_config, bool): - split_entities_config = {SPLIT_ENTITIES_BY_COMMA: split_entities_config} - else: - split_entities_config[SPLIT_ENTITIES_BY_COMMA] = self.defaults[ - SPLIT_ENTITIES_BY_COMMA - ] - return split_entities_config + default_value = self.defaults.get( + SPLIT_ENTITIES_BY_COMMA, SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE + ) + return rasa.utils.train_utils.init_split_entities( + split_entities_config, default_value + ) @staticmethod def filter_irrelevant_entities(extracted: list, requested_dimensions: set) -> list: diff --git a/rasa/utils/train_utils.py b/rasa/utils/train_utils.py index fb9ea1faf6ed..6fffa4589714 100644 --- a/rasa/utils/train_utils.py +++ b/rasa/utils/train_utils.py @@ -23,7 +23,12 @@ NUM_TRANSFORMER_LAYERS, DENSE_DIMENSION, ) -from rasa.shared.nlu.constants import ACTION_NAME, INTENT, ENTITIES +from rasa.shared.nlu.constants import ( + ACTION_NAME, + INTENT, + ENTITIES, + SPLIT_ENTITIES_BY_COMMA, +) from rasa.shared.core.constants import ACTIVE_LOOP, SLOTS from rasa.core.constants import DIALOGUE @@ -335,3 +340,23 @@ def override_defaults( config[key] = custom[key] return config + + +def init_split_entities( + split_entities_config, default_split_entity +) -> Dict[Text, bool]: + """Initialise the behaviour for splitting entities by comma (or not). + + Returns: + Defines desired behaviour for splitting specific entity types and + default behaviour for splitting any entity types for which no behaviour + is defined. + """ + if isinstance(split_entities_config, bool): + # All entities will be split according to `split_entities_config` + split_entities_config = {SPLIT_ENTITIES_BY_COMMA: split_entities_config} + else: + # All entities not named in split_entities_config will be split + # according to `split_entities_config` + split_entities_config[SPLIT_ENTITIES_BY_COMMA] = default_split_entity + return split_entities_config diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py index 8400a2be68e9..9ec906d9606e 100644 --- a/tests/utils/test_train_utils.py +++ b/tests/utils/test_train_utils.py @@ -1,8 +1,15 @@ +from typing import Any, Dict + import numpy as np +import pytest import rasa.utils.train_utils as train_utils from rasa.nlu.constants import NUMBER_OF_SUB_TOKENS from rasa.nlu.tokenizers.tokenizer import Token +from rasa.shared.nlu.constants import ( + SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE, + SPLIT_ENTITIES_BY_COMMA, +) def test_align_token_features(): @@ -26,3 +33,31 @@ def test_align_token_features(): assert np.all(actual_features[0][3] == np.mean(token_features[0][3:5], axis=0)) # embedding is split into 4 sub-tokens assert np.all(actual_features[0][4] == np.mean(token_features[0][5:10], axis=0)) + + +@pytest.mark.parametrize( + "split_entities_config, expected_initialized_config", + [ + ( + SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE, + {SPLIT_ENTITIES_BY_COMMA: SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE}, + ), + ( + {"address": False, "ingredients": True}, + { + "address": False, + "ingredients": True, + SPLIT_ENTITIES_BY_COMMA: SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE, + }, + ), + ], +) +def test_init_split_entities_config( + split_entities_config: Any, expected_initialized_config: Dict[(str, bool)], +): + assert ( + train_utils.init_split_entities( + split_entities_config, SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE + ) + == expected_initialized_config + )