Skip to content

Commit

Permalink
don't persist changed entities + autofill slots for policy entities (#…
Browse files Browse the repository at this point in the history
…7553)

* don't persist changed entities

* make tracker state return combined `UserUttered` event

* autofill slots for policy entities

* made if more explicit

* use constants

* rename `DefinePrevUserUtteredEntities` to `EntitiesAdded`

* rename and make `DefinePrevUserUttered` more general
  • Loading branch information
wochinge authored Dec 16, 2020
1 parent c9793d3 commit 3f47282
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 58 deletions.
4 changes: 2 additions & 2 deletions rasa/core/policies/ted_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
FEATURIZERS,
ENTITY_RECOGNITION,
)
from rasa.shared.core.events import DefinePrevUserUtteredEntities, Event
from rasa.shared.core.events import EntitiesAdded, Event
from rasa.shared.nlu.training_data.message import Message

if TYPE_CHECKING:
Expand Down Expand Up @@ -664,7 +664,7 @@ def _create_optional_event_for_entities(
for entity in entities:
entity[EXTRACTOR] = "TEDPolicy"

return [DefinePrevUserUtteredEntities(entities)]
return [EntitiesAdded(entities)]

def persist(self, path: Union[Text, Path]) -> None:
"""Persists the policy to a storage."""
Expand Down
81 changes: 34 additions & 47 deletions rasa/shared/core/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,18 @@ def __eq__(self, other: Any) -> bool:
return True


class SkipEventInMDStoryMixin(Event, ABC):
"""Skips the visualization of an event in Markdown stories."""

def as_story_string(self) -> None:
"""Returns the event as story string.
Returns:
None, as this event should not appear inside the story.
"""
return


class UserUttered(Event):
"""The user has said something to the bot.
Expand Down Expand Up @@ -400,9 +412,11 @@ def __init__(
self.use_text_for_featurization = False

self.parse_data = {
"intent": self.intent,
"entities": self.entities,
"text": self.text,
INTENT: self.intent,
# Copy entities so that changes to `self.entities` don't affect
# `self.parse_data` and hence don't get persisted
ENTITIES: self.entities.copy(),
TEXT: self.text,
"message_id": self.message_id,
"metadata": self.metadata,
}
Expand All @@ -421,8 +435,8 @@ def _from_parse_data(
):
return UserUttered(
text,
parse_data.get("intent"),
parse_data.get("entities", []),
parse_data.get(INTENT),
parse_data.get(ENTITIES, []),
parse_data,
timestamp,
input_channel,
Expand Down Expand Up @@ -606,19 +620,7 @@ def create_external(
)


class DefinePrevUserUttered(Event, ABC):
"""Defines the family of events that are used to update previous user utterance."""

def as_story_string(self) -> None:
"""Returns the event as story string.
Returns:
None, as this event should not appear inside the story.
"""
return


class DefinePrevUserUtteredFeaturization(DefinePrevUserUttered):
class DefinePrevUserUtteredFeaturization(SkipEventInMDStoryMixin):
"""Stores information whether action was predicted based on text or intent."""

type_name = "user_featurization"
Expand Down Expand Up @@ -686,21 +688,22 @@ def __eq__(self, other) -> bool:
return self.use_text_for_featurization == other.use_text_for_featurization


class DefinePrevUserUtteredEntities(DefinePrevUserUttered):
"""Event that is used to set entities on a previous user uttered event."""
class EntitiesAdded(SkipEventInMDStoryMixin):
"""Event that is used to add extracted entities to the tracker state."""

type_name = "user_entities"
type_name = "entities"

def __init__(
self,
entities: List[Dict[Text, Any]],
timestamp: Optional[float] = None,
metadata: Optional[Dict[Text, Any]] = None,
) -> None:
"""Initializes a DefinePrevUserUtteredEntities event.
"""Initializes event.
Args:
entities: the entities of a previous user uttered event
entities: Entities extracted from previous user message. This can either
be done by NLU components or end-to-end policy predictions.
timestamp: the timestamp
metadata: some optional metadata
"""
Expand All @@ -710,19 +713,19 @@ def __init__(
def __str__(self) -> Text:
"""Returns the string representation of the event."""
entity_str = [e[ENTITY_ATTRIBUTE_TYPE] for e in self.entities]
return f"DefinePrevUserUtteredEntities({entity_str})"
return f"{self.__class__.__name__}({entity_str})"

def __hash__(self) -> int:
"""Returns the hash value of the event."""
return hash(self.entities)

def __eq__(self, other) -> bool:
"""Compares this event with another event."""
return isinstance(other, DefinePrevUserUtteredEntities)
return isinstance(other, EntitiesAdded)

@classmethod
def _from_parameters(cls, parameters) -> "DefinePrevUserUtteredEntities":
return DefinePrevUserUtteredEntities(
def _from_parameters(cls, parameters) -> "EntitiesAdded":
return EntitiesAdded(
parameters.get(ENTITIES),
parameters.get("timestamp"),
parameters.get("metadata"),
Expand Down Expand Up @@ -754,7 +757,7 @@ def apply_to(self, tracker: "DialogueStateTracker") -> None:
tracker.latest_message.entities.append(entity)


class BotUttered(Event):
class BotUttered(Event, SkipEventInMDStoryMixin):
"""The bot has said something to the user.
This class is not used in the story training as it is contained in the
Expand Down Expand Up @@ -813,10 +816,6 @@ def apply_to(self, tracker: "DialogueStateTracker") -> None:
"""Applies event to current conversation state."""
tracker.latest_bot_utterance = self

def as_story_string(self) -> None:
"""Skips representing the event in stories."""
return None

def message(self) -> Dict[Text, Any]:
"""Return the complete message as a dictionary."""

Expand Down Expand Up @@ -1542,7 +1541,7 @@ def apply_to(self, tracker: "DialogueStateTracker") -> None:
tracker.clear_followup_action()


class AgentUttered(Event):
class AgentUttered(Event, SkipEventInMDStoryMixin):
"""The agent has said something to the user.
This class is not used in the story training as it is contained in the
Expand Down Expand Up @@ -1583,10 +1582,6 @@ def __str__(self) -> Text:
self.text, json.dumps(self.data)
)

def as_story_string(self) -> None:
"""Skips representing the event in stories."""
return None

def as_dict(self) -> Dict[Text, Any]:
"""Returns serialized event."""
d = super().as_dict()
Expand Down Expand Up @@ -1687,7 +1682,7 @@ def as_dict(self) -> Dict[Text, Any]:
return d


class LoopInterrupted(Event):
class LoopInterrupted(Event, SkipEventInMDStoryMixin):
"""Event added by FormPolicy and RulePolicy.
Notifies form action whether or not to validate the user input.
Expand Down Expand Up @@ -1730,10 +1725,6 @@ def __eq__(self, other) -> bool:

return self.is_interrupted == other.is_interrupted

def as_story_string(self) -> None:
"""Skips representing event in stories."""
return None

@classmethod
def _from_parameters(cls, parameters) -> "LoopInterrupted":
return LoopInterrupted(
Expand Down Expand Up @@ -1791,7 +1782,7 @@ def as_dict(self) -> Dict[Text, Any]:
return d


class ActionExecutionRejected(Event):
class ActionExecutionRejected(Event, SkipEventInMDStoryMixin):
"""Notify Core that the execution of the action has been rejected."""

type_name = "action_execution_rejected"
Expand Down Expand Up @@ -1847,10 +1838,6 @@ def _from_parameters(cls, parameters) -> "ActionExecutionRejected":
parameters.get("metadata"),
)

def as_story_string(self) -> None:
"""Skips representing this event in stories."""
return None

def as_dict(self) -> Dict[Text, Any]:
"""Returns serialized event."""
d = super().as_dict()
Expand Down
48 changes: 39 additions & 9 deletions rasa/shared/core/trackers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ENTITY_ATTRIBUTE_ROLE,
ACTION_TEXT,
ACTION_NAME,
ENTITIES,
)
from rasa.shared.core import events
from rasa.shared.core.constants import (
Expand Down Expand Up @@ -57,7 +58,7 @@
ActiveLoop,
SessionStarted,
ActionExecutionRejected,
DefinePrevUserUttered,
EntitiesAdded,
)
from rasa.shared.core.domain import Domain, State
from rasa.shared.core.slots import Slot
Expand Down Expand Up @@ -136,11 +137,26 @@ def from_events(
slots: Optional[Iterable[Slot]] = None,
max_event_history: Optional[int] = None,
sender_source: Optional[Text] = None,
):
domain: Optional[Domain] = None,
) -> "DialogueStateTracker":
"""Creates tracker from existing events.
Args:
sender_id: The ID of the conversation.
evts: Existing events which should be applied to the new tracker.
slots: Slots which can be set.
max_event_history: Maximum number of events which should be stored.
sender_source: File source of the messages.
domain: The current model domain.
Returns:
Instantiated tracker with its state updated according to the given
events.
"""
tracker = cls(sender_id, slots, max_event_history, sender_source)

for e in evts:
tracker.update(e)
tracker.update(e, domain)

return tracker

Expand Down Expand Up @@ -196,8 +212,7 @@ def __init__(
def current_state(
self, event_verbosity: EventVerbosity = EventVerbosity.NONE
) -> Dict[Text, Any]:
"""Return the current tracker state as an object."""

"""Returns the current tracker state as an object."""
_events = self._events_for_verbosity(event_verbosity)
if _events:
_events = [e.as_dict() for e in _events]
Expand All @@ -208,7 +223,7 @@ def current_state(
return {
"sender_id": self.sender_id,
"slots": self.current_slot_values(),
"latest_message": self.latest_message.parse_data,
"latest_message": self._latest_message_data(),
"latest_event_time": latest_event_time,
FOLLOWUP_ACTION: self.followup_action,
"paused": self.is_paused(),
Expand All @@ -231,6 +246,14 @@ def _events_for_verbosity(

return None

def _latest_message_data(self) -> Dict[Text, Any]:
parse_data_with_nlu_state = self.latest_message.parse_data.copy()
# Combine entities predicted by NLU with entities predicted by policies so that
# users can access them together via `latest_message` (e.g. in custom actions)
parse_data_with_nlu_state["entities"] = self.latest_message.entities

return parse_data_with_nlu_state

@staticmethod
def freeze_current_state(state: State) -> FrozenState:
"""Convert State dict into a hashable format FrozenState.
Expand Down Expand Up @@ -600,9 +623,16 @@ def update(self, event: Event, domain: Optional[Domain] = None) -> None:
self.events.append(event)
event.apply_to(self)

if domain and isinstance(event, UserUttered):
# store all entities as slots
for e in domain.slots_for_entities(event.parse_data["entities"]):
if domain and isinstance(event, (UserUttered, EntitiesAdded)):
if isinstance(event, UserUttered):
# Rather get entities from `parse_data` as
# `DefinePrevUserUtteredEntities` might have already affected the
# `UserUttered.entities` attribute
entities = event.parse_data[ENTITIES]
else:
entities = event.entities

for e in domain.slots_for_entities(entities):
self.update(e)

def update_with_events(
Expand Down
Loading

0 comments on commit 3f47282

Please sign in to comment.