Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RulePolicy: Abstract loop interface + TwoStageFallbackPolicy #5933

Merged
merged 6 commits into from
Jun 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/rules/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ pipeline:
- name: EntitySynonymMapper
- name: FallbackClassifier
threshold: 0.5
fallback_intent_name: nlu_fallback

policies:
- name: RulePolicy
7 changes: 4 additions & 3 deletions examples/rules/data/stories.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@
- utter_greet


>> fallback story
>> Implementation of the TwoStageFallbackPolicy
- ...
* nlu_fallback
- action_default_fallback
* nlu_fallback <!-- like request_restaurant -->
- two_stage_fallback <!-- Activate and run form -->
- form{"name": "two_stage_fallback"}
2 changes: 1 addition & 1 deletion examples/rules/domain.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,4 @@ responses:
utter_revert_fallback_and_reapply_last_intent:
- text: "utter_revert_fallback_and_reapply_last_intent"
utter_default:
- text: "please rephrase"
- text: "I give up."
20 changes: 17 additions & 3 deletions rasa/core/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import aiohttp

import rasa.core
from rasa.constants import DOCS_BASE_URL
from rasa.constants import DOCS_BASE_URL, DEFAULT_NLU_FALLBACK_INTENT_NAME
from rasa.core import events
from rasa.core.constants import (
DEFAULT_REQUEST_TIMEOUT,
Expand All @@ -21,6 +21,7 @@
DEFAULT_OPEN_UTTERANCE_TYPE,
OPEN_UTTERANCE_PREDICTION_KEY,
RESPONSE_SELECTOR_PROPERTY_NAME,
INTENT_RANKING_KEY,
)

from rasa.core.events import (
Expand Down Expand Up @@ -61,8 +62,10 @@
ACTION_BACK_NAME = "action_back"


def default_actions() -> List["Action"]:
def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["Action"]:
"""List default actions."""
from rasa.core.actions.two_stage_fallback import TwoStageFallbackAction

return [
ActionListen(),
ActionRestart(),
Expand All @@ -72,6 +75,7 @@ def default_actions() -> List["Action"]:
ActionRevertFallbackEvents(),
ActionDefaultAskAffirmation(),
ActionDefaultAskRephrase(),
TwoStageFallbackAction(action_endpoint),
ActionBack(),
]

Expand Down Expand Up @@ -109,7 +113,8 @@ def action_from_name(
) -> "Action":
"""Return an action instance for the name."""

defaults = {a.name(): a for a in default_actions()}
# TODO: Why do we need to create instances of everything if just need one thing?!
defaults = {a.name(): a for a in default_actions(action_endpoint)}

if name in defaults and name not in user_actions:
return defaults[name]
Expand Down Expand Up @@ -721,6 +726,15 @@ async def run(
domain: "Domain",
) -> List[Event]:
intent_to_affirm = tracker.latest_message.intent.get("name")

# TODO: Simplify once the RulePolicy is out of prototype stage
intent_ranking = tracker.latest_message.intent.get(INTENT_RANKING_KEY, [])
if (
intent_to_affirm == DEFAULT_NLU_FALLBACK_INTENT_NAME
and len(intent_ranking) > 1
):
intent_to_affirm = intent_ranking[1]["name"]

affirmation_message = f"Did you mean '{intent_to_affirm}'?"

message = {
Expand Down
133 changes: 62 additions & 71 deletions rasa/core/actions/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import logging

from rasa.core.actions import action
from rasa.core.actions.loops import LoopAction
from rasa.core.channels import OutputChannel
from rasa.core.domain import Domain

from rasa.core.actions import Action
from rasa.core.actions.action import ActionExecutionRejection, RemoteAction
from rasa.core.events import Event, SlotSet, Form
from rasa.core.nlg import NaturalLanguageGenerator
Expand All @@ -24,7 +24,7 @@
# - add proper docstrings


class FormAction(Action):
class FormAction(LoopAction):
def __init__(
self, form_name: Text, action_endpoint: Optional[EndpointConfig]
) -> None:
Expand All @@ -35,13 +35,13 @@ def __init__(
def name(self) -> Text:
return self._form_name

def required_slots(self) -> List[Text]:
def required_slots(self, domain: Domain) -> List[Text]:
"""A list of required slots that the form has to fill.
Returns:
A list of slot names.
"""
return list(self.slot_mappings().keys())
return list(self.slot_mappings(domain).keys())

def from_entity(
self,
Expand Down Expand Up @@ -146,7 +146,9 @@ def from_text(
return {"type": "from_text", "intent": intent, "not_intent": not_intent}

# noinspection PyMethodMayBeStatic
def slot_mappings(self) -> Dict[Text, Union[Dict, List[Dict[Text, Any]]]]:
def slot_mappings(
self, domain: Domain
) -> Dict[Text, Union[Dict, List[Dict[Text, Any]]]]:
"""A dictionary to map required slots.
Options:
Expand All @@ -160,26 +162,21 @@ def slot_mappings(self) -> Dict[Text, Union[Dict, List[Dict[Text, Any]]]]:
the slot to the extracted entity with the same name
"""

if not self._domain:
return {}

return next(
(
form[self._form_name]
for form in self._domain.forms
if self._form_name in form.keys()
),
(form[self.name()] for form in domain.forms if self.name() in form.keys()),
{},
)

def get_mappings_for_slot(self, slot_to_fill: Text) -> List[Dict[Text, Any]]:
def get_mappings_for_slot(
self, slot_to_fill: Text, domain: Domain
) -> List[Dict[Text, Any]]:
"""Get mappings for requested slot.
If None, map requested slot to an entity with the same name
"""

requested_slot_mappings = self._to_list(
self.slot_mappings().get(slot_to_fill, self.from_entity(slot_to_fill))
self.slot_mappings(domain).get(slot_to_fill, self.from_entity(slot_to_fill))
)
# check provided slot mappings
for requested_slot_mapping in requested_slot_mappings:
Expand Down Expand Up @@ -280,11 +277,11 @@ def extract_other_slots(
slot_to_fill = tracker.get_slot(REQUESTED_SLOT)

slot_values = {}
for slot in self.required_slots():
for slot in self.required_slots(domain):
# look for other slots
if slot != slot_to_fill:
# list is used to cover the case of list slot type
other_slot_mappings = self.get_mappings_for_slot(slot)
other_slot_mappings = self.get_mappings_for_slot(slot, domain)

for other_slot_mapping in other_slot_mappings:
# check whether the slot should be filled by an entity in the input
Expand Down Expand Up @@ -331,7 +328,7 @@ def extract_requested_slot(
logger.debug(f"Trying to extract requested slot '{slot_to_fill}' ...")

# get mapping for requested slot
requested_slot_mappings = self.get_mappings_for_slot(slot_to_fill)
requested_slot_mappings = self.get_mappings_for_slot(slot_to_fill, domain)

for requested_slot_mapping in requested_slot_mappings:
logger.debug(f"Got mapping '{requested_slot_mapping}'")
Expand Down Expand Up @@ -449,11 +446,11 @@ async def request_next_slot(
domain: Domain,
output_channel: OutputChannel,
nlg: NaturalLanguageGenerator,
) -> Optional[List[Event]]:
) -> List[Event]:
"""Request the next slot and utter template if needed,
else return None"""

for slot in self.required_slots():
for slot in self.required_slots(domain):
if self._should_request_slot(tracker, slot):
logger.debug(f"Request next slot '{slot}'")

Expand All @@ -463,7 +460,7 @@ async def request_next_slot(
return [SlotSet(REQUESTED_SLOT, slot), *bot_message_events]

# no more required slots to fill
return None
return [SlotSet(REQUESTED_SLOT, None)]

@staticmethod
async def _ask_for_slot(
Expand All @@ -482,13 +479,6 @@ async def _ask_for_slot(
)
return events_to_ask_for_next_slot

def deactivate(self) -> List[Event]:
"""Return `Form` event with `None` as name to deactivate the form
and reset the requested slot"""

logger.debug(f"Deactivating the form '{self.name()}'")
return [Form(None), SlotSet(REQUESTED_SLOT, None)]

# helpers
@staticmethod
def _to_list(x: Optional[Any]) -> List[Any]:
Expand Down Expand Up @@ -516,15 +506,6 @@ def _list_intents(

return self._to_list(intent), self._to_list(not_intent)

def _log_form_slots(self, tracker: "DialogueStateTracker") -> None:
"""Logs the values of all required slots before submitting the form."""
slot_values = "\n".join(
[f"\t{slot}: {tracker.get_slot(slot)}" for slot in self.required_slots()]
)
logger.debug(
f"No slots left to request, all required slots are filled:\n{slot_values}"
)

async def _activate_if_required(
self,
tracker: "DialogueStateTracker",
Expand Down Expand Up @@ -553,7 +534,7 @@ async def _activate_if_required(
# collect values of required slots filled before activation
prefilled_slots = {}

for slot_name in self.required_slots():
for slot_name in self.required_slots(domain):
if not self._should_request_slot(tracker, slot_name):
prefilled_slots[slot_name] = tracker.get_slot(slot_name)

Expand Down Expand Up @@ -597,49 +578,59 @@ def _should_request_slot(tracker: "DialogueStateTracker", slot_name: Text) -> bo

return tracker.get_slot(slot_name) is None

async def run(
def __str__(self) -> Text:
return f"FormAction('{self.name()}')"

async def activate(
self,
output_channel: "OutputChannel",
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
) -> List[Event]:
"""Execute the side effects of this form.
Steps:
- activate if needed
- validate user input if needed
- set validated slots
- utter_ask_{slot} template with the next required slot
- submit the form if all required slots are set
- deactivate the form
"""
# collect values of required slots filled before activation
prefilled_slots = {}

self._domain = domain
for slot_name in self.required_slots(domain):
if not self._should_request_slot(tracker, slot_name):
prefilled_slots[slot_name] = tracker.get_slot(slot_name)

# activate the form (we don't return these events in case form immediately
# finishes)
events = await self._activate_if_required(tracker, domain, output_channel, nlg)
# validate user input
events += await self._validate_if_required(tracker, domain, output_channel, nlg)
# check that the form wasn't deactivated in validation
if Form(None) not in events:
temp_tracker = self._temporary_tracker(tracker, events, domain)
if not prefilled_slots:
logger.debug("No pre-filled required slots to validate.")
return []

next_slot_events = await self.request_next_slot(
temp_tracker, domain, output_channel, nlg
)
logger.debug(f"Validating pre-filled required slots: {prefilled_slots}")
return await self.validate_slots(
prefilled_slots, tracker, domain, output_channel, nlg
)

if next_slot_events is not None:
# request next slot
events += next_slot_events
else:
# there is nothing more to request, so we can submit
self._log_form_slots(temp_tracker)
# deactivate the form after submission
events.extend(self.deactivate()) # type: ignore
async def do(
self,
output_channel: "OutputChannel",
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
events_so_far: List[Event],
) -> List[Event]:
events = await self._validate_if_required(tracker, domain, output_channel, nlg)

temp_tracker = self._temporary_tracker(tracker, events_so_far + events, domain)
events += await self.request_next_slot(
temp_tracker, domain, output_channel, nlg
)

return events

def __str__(self) -> Text:
return f"FormAction('{self.name()}')"
async def is_done(
self,
output_channel: "OutputChannel",
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
events_so_far: List[Event],
) -> bool:
return SlotSet(REQUESTED_SLOT, None) in events_so_far

async def deactivate(self, *args: Any, **kwargs: Any) -> List[Event]:
logger.debug(f"Deactivating the form '{self.name()}'")
return []
wochinge marked this conversation as resolved.
Show resolved Hide resolved
Loading