Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add split_entities_config to TEDPolicy #7716

Merged
merged 17 commits into from
Jan 25, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog/7707.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Add the option to configure whether extracted entities should be split by comma (`","`) or not to TEDPolicy. Fixes
crash when this parameter is accessed during extraction.
18 changes: 18 additions & 0 deletions docs/docs/policies.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,24 @@ If you want to fine-tune your model, start by modifying the following parameters
set `weight_sparsity` to 1 as this would result in all kernel weights being 0, i.e. the model is not able
to learn.

* `split_entities_by_comma`:
koernerfelicia marked this conversation as resolved.
Show resolved Hide resolved
This parameter defines whether adjacent entities separated by a comma should be treated as one, or split. For example,
whether the entities with type ingredients "apple, banana" should be split into "apple" and "banana". Can either be
`True`/`False` globally:
```yaml-rasa title="config.yml"
policies:
- name: TEDPolicy
split_entities_by_comma: True
```
or set per entity type, such as:
```yaml-rasa title="config.yml"
policies:
- name: TEDPolicy
split_entities_by_comma:
address: False
ingredients: True
```

The above configuration parameters are the ones you should configure to fit your model to your data.
However, additional parameters exist that can be adapted.

Expand Down
26 changes: 25 additions & 1 deletion rasa/core/policies/ted_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
ENTITY_ATTRIBUTE_TYPE,
ENTITY_TAGS,
EXTRACTOR,
SPLIT_ENTITIES_BY_COMMA,
SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
)
from rasa.shared.nlu.interpreter import NaturalLanguageInterpreter
from rasa.core.policies.policy import Policy, PolicyPrediction
Expand Down Expand Up @@ -272,6 +274,10 @@ class TEDPolicy(Policy):
FEATURIZERS: [],
# If set to true, entities are predicted in user utterances.
ENTITY_RECOGNITION: True,
# Split entities by comma, this makes sense e.g. for a list of
# ingredients in a recipe, but it doesn't make sense for the parts of
# an address
SPLIT_ENTITIES_BY_COMMA: SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
}

@staticmethod
Expand All @@ -280,6 +286,16 @@ def _standard_featurizer(max_history: Optional[int] = None) -> TrackerFeaturizer
SingleStateFeaturizer(), max_history=max_history
)

def init_split_entities(self, split_entities_config):
koernerfelicia marked this conversation as resolved.
Show resolved Hide resolved
"""Initialise the behaviour for splitting entities by comma (or not)."""
if isinstance(split_entities_config, bool):
split_entities_config = {SPLIT_ENTITIES_BY_COMMA: split_entities_config}
else:
split_entities_config[SPLIT_ENTITIES_BY_COMMA] = self.defaults[
SPLIT_ENTITIES_BY_COMMA
]
return split_entities_config

def __init__(
self,
featurizer: Optional[TrackerFeaturizer] = None,
Expand All @@ -292,6 +308,10 @@ def __init__(
**kwargs: Any,
) -> None:
"""Declare instance variables with default values."""
self.split_entities_config = self.init_split_entities(
kwargs.get(SPLIT_ENTITIES_BY_COMMA, SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE)
)

if not featurizer:
featurizer = self._standard_featurizer(max_history)

Expand Down Expand Up @@ -662,7 +682,11 @@ def _create_optional_event_for_entities(
parsed_message = interpreter.featurize_message(Message(data={TEXT: text}))
tokens = parsed_message.get(TOKENS_NAMES[TEXT])
entities = EntityExtractor.convert_predictions_into_entities(
text, tokens, predicted_tags, confidences=confidence_values
text,
tokens,
predicted_tags,
self.split_entities_config,
confidences=confidence_values,
)

# add the extractor name
Expand Down
36 changes: 34 additions & 2 deletions tests/core/test_policies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Type, List, Text, Tuple, Optional, Any
from typing import Type, List, Text, Tuple, Optional, Any, Dict
from unittest.mock import Mock, patch

import numpy as np
Expand All @@ -18,7 +18,12 @@
from rasa.shared.core.training_data.story_writer.markdown_story_writer import (
MarkdownStoryWriter,
)
from rasa.shared.nlu.constants import ACTION_NAME, INTENT_NAME_KEY
from rasa.shared.nlu.constants import (
ACTION_NAME,
INTENT_NAME_KEY,
SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
SPLIT_ENTITIES_BY_COMMA,
)
from rasa.shared.core.constants import (
USER_INTENT_RESTART,
USER_INTENT_BACK,
Expand Down Expand Up @@ -556,6 +561,33 @@ async def test_gen_batch(self, trained_policy: TEDPolicy, default_domain: Domain
or batch_slots_sentence_shape[1] == 0
)

@pytest.mark.parametrize(
"split_entities_config, expected_initialized_config",
[
(
SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
{SPLIT_ENTITIES_BY_COMMA: SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE},
),
(
{"address": False, "ingredients": True},
{
"address": False,
"ingredients": True,
SPLIT_ENTITIES_BY_COMMA: SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
},
),
],
)
def test_init_split_entities_config(
self,
trained_policy: TEDPolicy,
split_entities_config: Any,
expected_initialized_config: Dict[(str, bool)],
):
assert trained_policy.init_split_entities(
split_entities_config=split_entities_config
)


class TestTEDPolicyMargin(TestTEDPolicy):
def create_policy(
Expand Down