diff --git a/examples/rules/config.yml b/examples/rules/config.yml index b30f3b4f6615..be06b1a15a7b 100644 --- a/examples/rules/config.yml +++ b/examples/rules/config.yml @@ -14,7 +14,6 @@ pipeline: - name: EntitySynonymMapper - name: FallbackClassifier threshold: 0.5 - fallback_intent_name: nlu_fallback policies: - name: RulePolicy diff --git a/examples/rules/data/stories.md b/examples/rules/data/stories.md index 923b193a5425..9a6e5fef9d3b 100644 --- a/examples/rules/data/stories.md +++ b/examples/rules/data/stories.md @@ -102,7 +102,8 @@ - utter_greet ->> fallback story +>> Implementation of the TwoStageFallbackPolicy - ... -* nlu_fallback - - action_default_fallback +* nlu_fallback + - two_stage_fallback + - form{"name": "two_stage_fallback"} \ No newline at end of file diff --git a/examples/rules/domain.yml b/examples/rules/domain.yml index b66089819243..9518e84651c1 100644 --- a/examples/rules/domain.yml +++ b/examples/rules/domain.yml @@ -77,4 +77,4 @@ responses: utter_revert_fallback_and_reapply_last_intent: - text: "utter_revert_fallback_and_reapply_last_intent" utter_default: - - text: "please rephrase" + - text: "I give up." diff --git a/rasa/core/actions/action.py b/rasa/core/actions/action.py index ef49ae0901e0..3c738e159567 100644 --- a/rasa/core/actions/action.py +++ b/rasa/core/actions/action.py @@ -7,7 +7,7 @@ import aiohttp import rasa.core -from rasa.constants import DOCS_BASE_URL +from rasa.constants import DOCS_BASE_URL, DEFAULT_NLU_FALLBACK_INTENT_NAME from rasa.core import events from rasa.core.constants import ( DEFAULT_REQUEST_TIMEOUT, @@ -21,6 +21,7 @@ DEFAULT_OPEN_UTTERANCE_TYPE, OPEN_UTTERANCE_PREDICTION_KEY, RESPONSE_SELECTOR_PROPERTY_NAME, + INTENT_RANKING_KEY, ) from rasa.core.events import ( @@ -61,8 +62,10 @@ ACTION_BACK_NAME = "action_back" -def default_actions() -> List["Action"]: +def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["Action"]: """List default actions.""" + from rasa.core.actions.two_stage_fallback import TwoStageFallbackAction + return [ ActionListen(), ActionRestart(), @@ -72,6 +75,7 @@ def default_actions() -> List["Action"]: ActionRevertFallbackEvents(), ActionDefaultAskAffirmation(), ActionDefaultAskRephrase(), + TwoStageFallbackAction(action_endpoint), ActionBack(), ] @@ -109,7 +113,8 @@ def action_from_name( ) -> "Action": """Return an action instance for the name.""" - defaults = {a.name(): a for a in default_actions()} + # TODO: Why do we need to create instances of everything if just need one thing?! + defaults = {a.name(): a for a in default_actions(action_endpoint)} if name in defaults and name not in user_actions: return defaults[name] @@ -721,6 +726,15 @@ async def run( domain: "Domain", ) -> List[Event]: intent_to_affirm = tracker.latest_message.intent.get("name") + + # TODO: Simplify once the RulePolicy is out of prototype stage + intent_ranking = tracker.latest_message.intent.get(INTENT_RANKING_KEY, []) + if ( + intent_to_affirm == DEFAULT_NLU_FALLBACK_INTENT_NAME + and len(intent_ranking) > 1 + ): + intent_to_affirm = intent_ranking[1]["name"] + affirmation_message = f"Did you mean '{intent_to_affirm}'?" message = { diff --git a/rasa/core/actions/forms.py b/rasa/core/actions/forms.py index 0b2cd5caaff8..2d21744d0882 100644 --- a/rasa/core/actions/forms.py +++ b/rasa/core/actions/forms.py @@ -2,10 +2,10 @@ import logging from rasa.core.actions import action +from rasa.core.actions.loops import LoopAction from rasa.core.channels import OutputChannel from rasa.core.domain import Domain -from rasa.core.actions import Action from rasa.core.actions.action import ActionExecutionRejection, RemoteAction from rasa.core.events import Event, SlotSet, Form from rasa.core.nlg import NaturalLanguageGenerator @@ -24,7 +24,7 @@ # - add proper docstrings -class FormAction(Action): +class FormAction(LoopAction): def __init__( self, form_name: Text, action_endpoint: Optional[EndpointConfig] ) -> None: @@ -35,13 +35,13 @@ def __init__( def name(self) -> Text: return self._form_name - def required_slots(self) -> List[Text]: + def required_slots(self, domain: Domain) -> List[Text]: """A list of required slots that the form has to fill. Returns: A list of slot names. """ - return list(self.slot_mappings().keys()) + return list(self.slot_mappings(domain).keys()) def from_entity( self, @@ -146,7 +146,9 @@ def from_text( return {"type": "from_text", "intent": intent, "not_intent": not_intent} # noinspection PyMethodMayBeStatic - def slot_mappings(self) -> Dict[Text, Union[Dict, List[Dict[Text, Any]]]]: + def slot_mappings( + self, domain: Domain + ) -> Dict[Text, Union[Dict, List[Dict[Text, Any]]]]: """A dictionary to map required slots. Options: @@ -160,26 +162,21 @@ def slot_mappings(self) -> Dict[Text, Union[Dict, List[Dict[Text, Any]]]]: the slot to the extracted entity with the same name """ - if not self._domain: - return {} - return next( - ( - form[self._form_name] - for form in self._domain.forms - if self._form_name in form.keys() - ), + (form[self.name()] for form in domain.forms if self.name() in form.keys()), {}, ) - def get_mappings_for_slot(self, slot_to_fill: Text) -> List[Dict[Text, Any]]: + def get_mappings_for_slot( + self, slot_to_fill: Text, domain: Domain + ) -> List[Dict[Text, Any]]: """Get mappings for requested slot. If None, map requested slot to an entity with the same name """ requested_slot_mappings = self._to_list( - self.slot_mappings().get(slot_to_fill, self.from_entity(slot_to_fill)) + self.slot_mappings(domain).get(slot_to_fill, self.from_entity(slot_to_fill)) ) # check provided slot mappings for requested_slot_mapping in requested_slot_mappings: @@ -280,11 +277,11 @@ def extract_other_slots( slot_to_fill = tracker.get_slot(REQUESTED_SLOT) slot_values = {} - for slot in self.required_slots(): + for slot in self.required_slots(domain): # look for other slots if slot != slot_to_fill: # list is used to cover the case of list slot type - other_slot_mappings = self.get_mappings_for_slot(slot) + other_slot_mappings = self.get_mappings_for_slot(slot, domain) for other_slot_mapping in other_slot_mappings: # check whether the slot should be filled by an entity in the input @@ -331,7 +328,7 @@ def extract_requested_slot( logger.debug(f"Trying to extract requested slot '{slot_to_fill}' ...") # get mapping for requested slot - requested_slot_mappings = self.get_mappings_for_slot(slot_to_fill) + requested_slot_mappings = self.get_mappings_for_slot(slot_to_fill, domain) for requested_slot_mapping in requested_slot_mappings: logger.debug(f"Got mapping '{requested_slot_mapping}'") @@ -449,11 +446,11 @@ async def request_next_slot( domain: Domain, output_channel: OutputChannel, nlg: NaturalLanguageGenerator, - ) -> Optional[List[Event]]: + ) -> List[Event]: """Request the next slot and utter template if needed, else return None""" - for slot in self.required_slots(): + for slot in self.required_slots(domain): if self._should_request_slot(tracker, slot): logger.debug(f"Request next slot '{slot}'") @@ -463,7 +460,7 @@ async def request_next_slot( return [SlotSet(REQUESTED_SLOT, slot), *bot_message_events] # no more required slots to fill - return None + return [SlotSet(REQUESTED_SLOT, None)] @staticmethod async def _ask_for_slot( @@ -482,13 +479,6 @@ async def _ask_for_slot( ) return events_to_ask_for_next_slot - def deactivate(self) -> List[Event]: - """Return `Form` event with `None` as name to deactivate the form - and reset the requested slot""" - - logger.debug(f"Deactivating the form '{self.name()}'") - return [Form(None), SlotSet(REQUESTED_SLOT, None)] - # helpers @staticmethod def _to_list(x: Optional[Any]) -> List[Any]: @@ -516,15 +506,6 @@ def _list_intents( return self._to_list(intent), self._to_list(not_intent) - def _log_form_slots(self, tracker: "DialogueStateTracker") -> None: - """Logs the values of all required slots before submitting the form.""" - slot_values = "\n".join( - [f"\t{slot}: {tracker.get_slot(slot)}" for slot in self.required_slots()] - ) - logger.debug( - f"No slots left to request, all required slots are filled:\n{slot_values}" - ) - async def _activate_if_required( self, tracker: "DialogueStateTracker", @@ -553,7 +534,7 @@ async def _activate_if_required( # collect values of required slots filled before activation prefilled_slots = {} - for slot_name in self.required_slots(): + for slot_name in self.required_slots(domain): if not self._should_request_slot(tracker, slot_name): prefilled_slots[slot_name] = tracker.get_slot(slot_name) @@ -597,49 +578,59 @@ def _should_request_slot(tracker: "DialogueStateTracker", slot_name: Text) -> bo return tracker.get_slot(slot_name) is None - async def run( + def __str__(self) -> Text: + return f"FormAction('{self.name()}')" + + async def activate( self, output_channel: "OutputChannel", nlg: "NaturalLanguageGenerator", tracker: "DialogueStateTracker", domain: "Domain", ) -> List[Event]: - """Execute the side effects of this form. - - Steps: - - activate if needed - - validate user input if needed - - set validated slots - - utter_ask_{slot} template with the next required slot - - submit the form if all required slots are set - - deactivate the form - """ + # collect values of required slots filled before activation + prefilled_slots = {} - self._domain = domain + for slot_name in self.required_slots(domain): + if not self._should_request_slot(tracker, slot_name): + prefilled_slots[slot_name] = tracker.get_slot(slot_name) - # activate the form (we don't return these events in case form immediately - # finishes) - events = await self._activate_if_required(tracker, domain, output_channel, nlg) - # validate user input - events += await self._validate_if_required(tracker, domain, output_channel, nlg) - # check that the form wasn't deactivated in validation - if Form(None) not in events: - temp_tracker = self._temporary_tracker(tracker, events, domain) + if not prefilled_slots: + logger.debug("No pre-filled required slots to validate.") + return [] - next_slot_events = await self.request_next_slot( - temp_tracker, domain, output_channel, nlg - ) + logger.debug(f"Validating pre-filled required slots: {prefilled_slots}") + return await self.validate_slots( + prefilled_slots, tracker, domain, output_channel, nlg + ) - if next_slot_events is not None: - # request next slot - events += next_slot_events - else: - # there is nothing more to request, so we can submit - self._log_form_slots(temp_tracker) - # deactivate the form after submission - events.extend(self.deactivate()) # type: ignore + async def do( + self, + output_channel: "OutputChannel", + nlg: "NaturalLanguageGenerator", + tracker: "DialogueStateTracker", + domain: "Domain", + events_so_far: List[Event], + ) -> List[Event]: + events = await self._validate_if_required(tracker, domain, output_channel, nlg) + + temp_tracker = self._temporary_tracker(tracker, events_so_far + events, domain) + events += await self.request_next_slot( + temp_tracker, domain, output_channel, nlg + ) return events - def __str__(self) -> Text: - return f"FormAction('{self.name()}')" + async def is_done( + self, + output_channel: "OutputChannel", + nlg: "NaturalLanguageGenerator", + tracker: "DialogueStateTracker", + domain: "Domain", + events_so_far: List[Event], + ) -> bool: + return SlotSet(REQUESTED_SLOT, None) in events_so_far + + async def deactivate(self, *args: Any, **kwargs: Any) -> List[Event]: + logger.debug(f"Deactivating the form '{self.name()}'") + return [] diff --git a/rasa/core/actions/loops.py b/rasa/core/actions/loops.py new file mode 100644 index 000000000000..5308499d3ca3 --- /dev/null +++ b/rasa/core/actions/loops.py @@ -0,0 +1,90 @@ +from typing import List + +from rasa.core.actions import Action +from rasa.core.channels import OutputChannel +from rasa.core.domain import Domain +from rasa.core.events import Event, Form +from rasa.core.nlg import NaturalLanguageGenerator +from rasa.core.trackers import DialogueStateTracker + + +class LoopAction(Action): + async def run( + self, + output_channel: "OutputChannel", + nlg: "NaturalLanguageGenerator", + tracker: "DialogueStateTracker", + domain: "Domain", + ) -> List[Event]: + events = [] + if not await self.is_activated(output_channel, nlg, tracker, domain): + events += self._default_activation_events() + events += await self.activate(output_channel, nlg, tracker, domain) + + if not await self.is_done(output_channel, nlg, tracker, domain, events): + events += await self.do(output_channel, nlg, tracker, domain, events) + + if await self.is_done(output_channel, nlg, tracker, domain, events): + events += self._default_deactivation_events() + events += await self.deactivate( + output_channel, nlg, tracker, domain, events + ) + + return events + + async def is_activated( + self, + output_channel: "OutputChannel", + nlg: "NaturalLanguageGenerator", + tracker: "DialogueStateTracker", + domain: "Domain", + ) -> bool: + return tracker.active_form.get("name") == self.name() + + # default implementation checks if form active + def _default_activation_events(self) -> List[Event]: + return [Form(self.name())] + + async def activate( + self, + output_channel: "OutputChannel", + nlg: "NaturalLanguageGenerator", + tracker: "DialogueStateTracker", + domain: "Domain", + ) -> List[Event]: + # can be overwritten + return [] + + async def do( + self, + output_channel: "OutputChannel", + nlg: "NaturalLanguageGenerator", + tracker: "DialogueStateTracker", + domain: "Domain", + events_so_far: List[Event], + ) -> List[Event]: + raise NotImplementedError() + + async def is_done( + self, + output_channel: "OutputChannel", + nlg: "NaturalLanguageGenerator", + tracker: "DialogueStateTracker", + domain: "Domain", + events_so_far: List[Event], + ) -> bool: + raise NotImplementedError() + + def _default_deactivation_events(self) -> List[Event]: + return [Form(None)] + + async def deactivate( + self, + output_channel: "OutputChannel", + nlg: "NaturalLanguageGenerator", + tracker: "DialogueStateTracker", + domain: "Domain", + events_so_far: List[Event], + ) -> List[Event]: + # can be overwritten + return [] diff --git a/rasa/core/actions/two_stage_fallback.py b/rasa/core/actions/two_stage_fallback.py new file mode 100644 index 000000000000..3b93d3b69674 --- /dev/null +++ b/rasa/core/actions/two_stage_fallback.py @@ -0,0 +1,197 @@ +import copy +import time +from typing import List, Text, Optional + +from rasa.constants import DEFAULT_NLU_FALLBACK_INTENT_NAME +from rasa.core.actions import action +from rasa.core.actions.action import ( + ACTION_DEFAULT_ASK_AFFIRMATION_NAME, + ACTION_LISTEN_NAME, + ACTION_DEFAULT_FALLBACK_NAME, + ACTION_DEFAULT_ASK_REPHRASE_NAME, +) +from rasa.core.actions.loops import LoopAction +from rasa.core.channels import OutputChannel +from rasa.core.constants import USER_INTENT_OUT_OF_SCOPE +from rasa.core.domain import Domain +from rasa.core.events import ( + Event, + UserUtteranceReverted, + ActionExecuted, + UserUttered, + Form, +) +from rasa.core.nlg import NaturalLanguageGenerator +from rasa.core.trackers import DialogueStateTracker +from rasa.utils.endpoints import EndpointConfig + +ACTION_TWO_STAGE_FALLBACK_NAME = "two_stage_fallback" + + +class TwoStageFallbackAction(LoopAction): + def __init__(self, action_endpoint: Optional[EndpointConfig] = None) -> None: + self._action_endpoint = action_endpoint + + def name(self) -> Text: + return ACTION_TWO_STAGE_FALLBACK_NAME + + async def do( + self, + output_channel: "OutputChannel", + nlg: "NaturalLanguageGenerator", + tracker: "DialogueStateTracker", + domain: "Domain", + events_so_far: List[Event], + ) -> List[Event]: + if _user_should_affirm(tracker, events_so_far): + return await self._ask_affirm(output_channel, nlg, tracker, domain) + + return await self._ask_rephrase(output_channel, nlg, tracker, domain) + + async def _ask_affirm( + self, + output_channel: OutputChannel, + nlg: NaturalLanguageGenerator, + tracker: DialogueStateTracker, + domain: Domain, + ) -> List[Event]: + affirm_action = action.action_from_name( + ACTION_DEFAULT_ASK_AFFIRMATION_NAME, + self._action_endpoint, + domain.user_actions, + ) + + return await affirm_action.run(output_channel, nlg, tracker, domain) + + async def _ask_rephrase( + self, + output_channel: OutputChannel, + nlg: NaturalLanguageGenerator, + tracker: DialogueStateTracker, + domain: Domain, + ) -> List[Event]: + rephrase = action.action_from_name( + ACTION_DEFAULT_ASK_REPHRASE_NAME, self._action_endpoint, domain.user_actions + ) + + return await rephrase.run(output_channel, nlg, tracker, domain) + + async def is_done( + self, + output_channel: "OutputChannel", + nlg: "NaturalLanguageGenerator", + tracker: "DialogueStateTracker", + domain: "Domain", + events_so_far: List[Event], + ) -> bool: + _user_clarified = _last_intent_name(tracker) not in [ + DEFAULT_NLU_FALLBACK_INTENT_NAME, + USER_INTENT_OUT_OF_SCOPE, + ] + return ( + _user_clarified + or _two_fallbacks_in_a_row(tracker) + or _second_affirmation_failed(tracker) + ) + + async def deactivate( + self, + output_channel: "OutputChannel", + nlg: "NaturalLanguageGenerator", + tracker: "DialogueStateTracker", + domain: "Domain", + events_so_far: List[Event], + ) -> List[Event]: + if _two_fallbacks_in_a_row(tracker) or _second_affirmation_failed(tracker): + return await self._give_up(output_channel, nlg, tracker, domain) + + return await self._revert_fallback_events( + output_channel, nlg, tracker, domain, events_so_far + ) + _message_clarification(tracker) + + async def _revert_fallback_events( + self, + output_channel: OutputChannel, + nlg: NaturalLanguageGenerator, + tracker: DialogueStateTracker, + domain: Domain, + events_so_far: List[Event], + ) -> List[Event]: + revert_events = [UserUtteranceReverted(), UserUtteranceReverted()] + + temp_tracker = DialogueStateTracker.from_events( + tracker.sender_id, tracker.applied_events() + events_so_far + revert_events + ) + + while temp_tracker.latest_message and not await self.is_done( + output_channel, nlg, temp_tracker, domain, [] + ): + temp_tracker.update(revert_events[-1]) + revert_events.append(UserUtteranceReverted()) + + return revert_events + + async def _give_up( + self, + output_channel: OutputChannel, + nlg: NaturalLanguageGenerator, + tracker: DialogueStateTracker, + domain: Domain, + ) -> List[Event]: + fallback = action.action_from_name( + ACTION_DEFAULT_FALLBACK_NAME, self._action_endpoint, domain.user_actions + ) + + return await fallback.run(output_channel, nlg, tracker, domain) + + +def _last_intent_name(tracker: DialogueStateTracker) -> Optional[Text]: + last_message = tracker.latest_message + if not last_message: + return + + return last_message.intent.get("name") + + +def _two_fallbacks_in_a_row(tracker: DialogueStateTracker) -> bool: + return _last_n_intent_names(tracker, 2) == [ + DEFAULT_NLU_FALLBACK_INTENT_NAME, + DEFAULT_NLU_FALLBACK_INTENT_NAME, + ] + + +def _last_n_intent_names( + tracker: DialogueStateTracker, number_of_last_intent_names: int +) -> List[Text]: + intent_names = [] + for i in range(number_of_last_intent_names): + message = tracker.get_last_event_for(UserUttered, skip=i) + if isinstance(message, UserUttered): + intent_names.append(message.intent.get("name")) + + return intent_names + + +def _user_should_affirm( + tracker: DialogueStateTracker, events_so_far: List[Event] +) -> bool: + form_was_just_activated = any(isinstance(event, Form) for event in events_so_far) + if form_was_just_activated: + return True + + return _last_intent_name(tracker) == DEFAULT_NLU_FALLBACK_INTENT_NAME + + +def _second_affirmation_failed(tracker: DialogueStateTracker) -> bool: + return _last_n_intent_names(tracker, 3) == [ + USER_INTENT_OUT_OF_SCOPE, + DEFAULT_NLU_FALLBACK_INTENT_NAME, + USER_INTENT_OUT_OF_SCOPE, + ] + + +def _message_clarification(tracker: DialogueStateTracker) -> List[Event]: + clarification = copy.deepcopy(tracker.latest_message) + clarification.parse_data["intent"]["confidence"] = 1.0 + clarification.timestamp = time.time() + return [ActionExecuted(ACTION_LISTEN_NAME), clarification] diff --git a/rasa/core/trackers.py b/rasa/core/trackers.py index b3998b695338..eb4dcc3f4456 100644 --- a/rasa/core/trackers.py +++ b/rasa/core/trackers.py @@ -576,6 +576,7 @@ def _reset(self) -> None: self.latest_message = UserUttered.empty() self.latest_bot_utterance = BotUttered.empty() self.followup_action = ACTION_LISTEN_NAME + # TODO: Rename to `active_loop` once the `RulePolicy` is finalized self.active_form = {} def _reset_slots(self) -> None: diff --git a/rasa/nlu/classifiers/fallback_classifier.py b/rasa/nlu/classifiers/fallback_classifier.py index 80dd685f064a..c995befeed54 100644 --- a/rasa/nlu/classifiers/fallback_classifier.py +++ b/rasa/nlu/classifiers/fallback_classifier.py @@ -8,7 +8,6 @@ from rasa.nlu.constants import INTENT_RANKING_KEY, INTENT, INTENT_CONFIDENCE_KEY THRESHOLD_KEY = "threshold" -FALLBACK_INTENT_NAME_KEY = "fallback_intent_name" class FallbackClassifier(Component): @@ -18,10 +17,7 @@ class FallbackClassifier(Component): # ## Architecture of the used neural network # If all intent confidence scores are beyond this threshold, set the current # intent to `FALLBACK_INTENT_NAME` - THRESHOLD_KEY: DEFAULT_NLU_FALLBACK_THRESHOLD, - # The intent which is used to signal that the NLU confidence was below the - # threshold. - FALLBACK_INTENT_NAME_KEY: DEFAULT_NLU_FALLBACK_INTENT_NAME, + THRESHOLD_KEY: DEFAULT_NLU_FALLBACK_THRESHOLD } @classmethod @@ -59,7 +55,7 @@ def _should_fallback(self, message: Message) -> bool: def _fallback_intent(self) -> Dict[Text, Union[Text, float]]: return { - "name": self.component_config[FALLBACK_INTENT_NAME_KEY], + "name": DEFAULT_NLU_FALLBACK_INTENT_NAME, # TODO: Re-consider how we represent the confidence here INTENT_CONFIDENCE_KEY: 1.0, } diff --git a/tests/core/actions/test_forms.py b/tests/core/actions/test_forms.py index e9c8b62726ff..4e05c35329bc 100644 --- a/tests/core/actions/test_forms.py +++ b/tests/core/actions/test_forms.py @@ -81,8 +81,8 @@ async def test_activate_and_immediate_deactivate(): assert events == [ Form(form_name), SlotSet(slot_name, slot_value), - Form(None), SlotSet(REQUESTED_SLOT, None), + Form(None), ] @@ -118,8 +118,8 @@ async def test_set_slot_and_deactivate(): ) assert events == [ SlotSet(slot_name, slot_value), - Form(None), SlotSet(REQUESTED_SLOT, None), + Form(None), ] @@ -171,8 +171,8 @@ async def test_validate_slots(): ) assert events == [ SlotSet(slot_name, validated_slot_value), - Form(None), SlotSet(REQUESTED_SLOT, None), + Form(None), ] diff --git a/tests/core/actions/test_loops.py b/tests/core/actions/test_loops.py new file mode 100644 index 000000000000..08677ed5400e --- /dev/null +++ b/tests/core/actions/test_loops.py @@ -0,0 +1,163 @@ +from typing import List, Any, Text + +import pytest +from rasa.core.actions.loops import LoopAction +from rasa.core.channels import CollectingOutputChannel +from rasa.core.domain import Domain +from rasa.core.events import ( + Event, + ActionExecutionRejected, + ActionExecuted, + Form, + SlotSet, +) +from rasa.core.nlg import TemplatedNaturalLanguageGenerator +from rasa.core.trackers import DialogueStateTracker + + +async def test_whole_loop(): + expected_activation_events = [ + ActionExecutionRejected("tada"), + ActionExecuted("test"), + ] + + expected_do_events = [ActionExecuted("do")] + expected_deactivation_events = [SlotSet("deactivated")] + + form_name = "my form" + + class MyLoop(LoopAction): + def name(self) -> Text: + return form_name + + async def activate(self, *args: Any) -> List[Event]: + return expected_activation_events + + async def do(self, *args: Any) -> List[Event]: + events_so_far = args[-1] + assert events_so_far == [Form(form_name), *expected_activation_events] + + return expected_do_events + + async def deactivate(self, *args) -> List[Event]: + events_so_far = args[-1] + assert events_so_far == [ + Form(form_name), + *expected_activation_events, + *expected_do_events, + Form(None), + ] + + return expected_deactivation_events + + async def is_done(self, *args) -> bool: + events_so_far = args[-1] + return events_so_far == [ + Form(form_name), + *expected_activation_events, + *expected_do_events, + ] + + tracker = DialogueStateTracker.from_events("some sender", []) + domain = Domain.empty() + + action = MyLoop() + actual = await action.run( + CollectingOutputChannel(), + TemplatedNaturalLanguageGenerator(domain.templates), + tracker, + domain, + ) + + assert actual == [ + Form(form_name), + *expected_activation_events, + *expected_do_events, + Form(None), + *expected_deactivation_events, + ] + + +async def test_loop_without_deactivate(): + expected_activation_events = [ + ActionExecutionRejected("tada"), + ActionExecuted("test"), + ] + + expected_do_events = [ActionExecuted("do")] + form_name = "my form" + + class MyLoop(LoopAction): + def name(self) -> Text: + return form_name + + async def activate(self, *args: Any) -> List[Event]: + return expected_activation_events + + async def do(self, *args: Any) -> List[Event]: + return expected_do_events + + async def deactivate(self, *args) -> List[Event]: + raise ValueError("this shouldn't be called") + + async def is_done(self, *args) -> bool: + return False + + tracker = DialogueStateTracker.from_events("some sender", []) + domain = Domain.empty() + + action = MyLoop() + actual = await action.run( + CollectingOutputChannel(), + TemplatedNaturalLanguageGenerator(domain.templates), + tracker, + domain, + ) + + assert actual == [Form(form_name), *expected_activation_events, *expected_do_events] + + +async def test_loop_without_activate_and_without_deactivate(): + expected_do_events = [ActionExecuted("do")] + form_name = "my form" + + class MyLoop(LoopAction): + def name(self) -> Text: + return form_name + + async def activate(self, *args: Any) -> List[Event]: + raise ValueError("this shouldn't be called") + + async def do(self, *args: Any) -> List[Event]: + return expected_do_events + + async def deactivate(self, *args) -> List[Event]: + return [SlotSet("deactivated")] + + async def is_activated(self, *args: Any) -> bool: + return True + + async def is_done(self, *args) -> bool: + return False + + tracker = DialogueStateTracker.from_events("some sender", []) + domain = Domain.empty() + + action = MyLoop() + actual = await action.run( + CollectingOutputChannel(), + TemplatedNaturalLanguageGenerator(domain.templates), + tracker, + domain, + ) + + assert actual == [*expected_do_events] + + +async def test_raise_not_implemented_error(): + loop = LoopAction() + with pytest.raises(NotImplementedError): + await loop.do(None, None, None, None, []) + + with pytest.raises(NotImplementedError): + await loop.is_done(None, None, None, None, []) diff --git a/tests/core/actions/test_two_stage_fallback.py b/tests/core/actions/test_two_stage_fallback.py new file mode 100644 index 000000000000..3bb8a23b9dff --- /dev/null +++ b/tests/core/actions/test_two_stage_fallback.py @@ -0,0 +1,320 @@ +from typing import List, Text + +import pytest + +from rasa.constants import DEFAULT_NLU_FALLBACK_INTENT_NAME +from rasa.core.actions.action import ACTION_LISTEN_NAME +from rasa.core.actions.two_stage_fallback import ( + TwoStageFallbackAction, + ACTION_TWO_STAGE_FALLBACK_NAME, +) +from rasa.core.channels import CollectingOutputChannel +from rasa.core.constants import USER_INTENT_OUT_OF_SCOPE +from rasa.core.domain import Domain +from rasa.core.events import ( + ActionExecuted, + UserUttered, + Form, + BotUttered, + UserUtteranceReverted, + Event, +) +from rasa.core.nlg import TemplatedNaturalLanguageGenerator +from rasa.core.trackers import DialogueStateTracker +from rasa.nlu.constants import INTENT_RANKING_KEY + + +def _message_requiring_fallback() -> List[Event]: + return [ + ActionExecuted(ACTION_LISTEN_NAME), + UserUttered( + "hi", + {"name": DEFAULT_NLU_FALLBACK_INTENT_NAME}, + parse_data={ + INTENT_RANKING_KEY: [ + {"name": DEFAULT_NLU_FALLBACK_INTENT_NAME}, + {"name": "greet"}, + {"name": "bye"}, + ] + }, + ), + ] + + +def _two_stage_clarification_request() -> List[Event]: + return [ActionExecuted(ACTION_TWO_STAGE_FALLBACK_NAME), BotUttered("please affirm")] + + +async def test_ask_affirmation(): + tracker = DialogueStateTracker.from_events( + "some-sender", evts=_message_requiring_fallback() + ) + domain = Domain.empty() + action = TwoStageFallbackAction() + + events = await action.run( + CollectingOutputChannel(), + TemplatedNaturalLanguageGenerator(domain.templates), + tracker, + domain, + ) + + assert len(events) == 2 + assert events[0] == Form(ACTION_TWO_STAGE_FALLBACK_NAME) + assert isinstance(events[1], BotUttered) + + +async def test_1st_affirmation_is_successful(): + tracker = DialogueStateTracker.from_events( + "some-sender", + evts=[ + # User sends message with low NLU confidence + *_message_requiring_fallback(), + Form(ACTION_TWO_STAGE_FALLBACK_NAME), + # Action asks user to affirm + *_two_stage_clarification_request(), + ActionExecuted(ACTION_LISTEN_NAME), + # User affirms + UserUttered("hi", {"name": "greet", "confidence": 1.0}), + ], + ) + domain = Domain.empty() + action = TwoStageFallbackAction() + + events = await action.run( + CollectingOutputChannel(), + TemplatedNaturalLanguageGenerator(domain.templates), + tracker, + domain, + ) + + for events in events: + tracker.update(events, domain) + + applied_events = tracker.applied_events() + assert applied_events == [ + ActionExecuted(ACTION_LISTEN_NAME), + UserUttered("hi", {"name": "greet", "confidence": 1.0}), + ] + + +async def test_give_it_up_after_low_confidence_after_affirm_request(): + tracker = DialogueStateTracker.from_events( + "some-sender", + evts=[ + # User sends message with low NLU confidence + *_message_requiring_fallback(), + Form(ACTION_TWO_STAGE_FALLBACK_NAME), + # Action asks user to affirm + *_two_stage_clarification_request(), + # User's affirms with low NLU confidence again + *_message_requiring_fallback(), + ], + ) + domain = Domain.empty() + action = TwoStageFallbackAction() + + events = await action.run( + CollectingOutputChannel(), + TemplatedNaturalLanguageGenerator(domain.templates), + tracker, + domain, + ) + + assert events == [Form(None), UserUtteranceReverted()] + + +async def test_ask_rephrase_after_failed_affirmation(): + rephrase_text = "please rephrase" + tracker = DialogueStateTracker.from_events( + "some-sender", + evts=[ + # User sends message with low NLU confidence + *_message_requiring_fallback(), + Form(ACTION_TWO_STAGE_FALLBACK_NAME), + # Action asks user to affirm + *_two_stage_clarification_request(), + ActionExecuted(ACTION_LISTEN_NAME), + # User denies suggested intents + UserUttered("hi", {"name": USER_INTENT_OUT_OF_SCOPE}), + ], + ) + + domain = Domain.from_yaml( + f""" + responses: + utter_ask_rephrase: + - {rephrase_text} + """ + ) + action = TwoStageFallbackAction() + + events = await action.run( + CollectingOutputChannel(), + TemplatedNaturalLanguageGenerator(domain.templates), + tracker, + domain, + ) + + assert len(events) == 1 + assert isinstance(events[0], BotUttered) + + bot_utterance = events[0] + assert isinstance(bot_utterance, BotUttered) + assert bot_utterance.text == rephrase_text + + +async def test_ask_rephrasing_successful(): + tracker = DialogueStateTracker.from_events( + "some-sender", + evts=[ + # User sends message with low NLU confidence + *_message_requiring_fallback(), + Form(ACTION_TWO_STAGE_FALLBACK_NAME), + # Action asks user to affirm + *_two_stage_clarification_request(), + ActionExecuted(ACTION_LISTEN_NAME), + # User denies suggested intents + UserUttered("hi", {"name": USER_INTENT_OUT_OF_SCOPE}), + *_two_stage_clarification_request(), + # Action asks user to rephrase + ActionExecuted(ACTION_LISTEN_NAME), + # User rephrases successfully + UserUttered("hi", {"name": "greet"}), + ], + ) + domain = Domain.empty() + action = TwoStageFallbackAction() + + events = await action.run( + CollectingOutputChannel(), + TemplatedNaturalLanguageGenerator(domain.templates), + tracker, + domain, + ) + + for event in events: + tracker.update(event) + + applied_events = tracker.applied_events() + assert applied_events == [ + ActionExecuted(ACTION_LISTEN_NAME), + UserUttered("hi", {"name": "greet"}), + ] + + +async def test_ask_affirm_after_rephrasing(): + tracker = DialogueStateTracker.from_events( + "some-sender", + evts=[ + # User sends message with low NLU confidence + *_message_requiring_fallback(), + Form(ACTION_TWO_STAGE_FALLBACK_NAME), + # Action asks user to affirm + *_two_stage_clarification_request(), + ActionExecuted(ACTION_LISTEN_NAME), + # User denies suggested intents + UserUttered("hi", {"name": USER_INTENT_OUT_OF_SCOPE}), + # Action asks user to rephrase + ActionExecuted(ACTION_TWO_STAGE_FALLBACK_NAME), + BotUttered("please rephrase"), + # User rephrased with low confidence + *_message_requiring_fallback(), + ], + ) + domain = Domain.empty() + action = TwoStageFallbackAction() + + events = await action.run( + CollectingOutputChannel(), + TemplatedNaturalLanguageGenerator(domain.templates), + tracker, + domain, + ) + + assert len(events) == 1 + assert isinstance(events[0], BotUttered) + + +async def test_2nd_affirm_successful(): + tracker = DialogueStateTracker.from_events( + "some-sender", + evts=[ + # User sends message with low NLU confidence + *_message_requiring_fallback(), + Form(ACTION_TWO_STAGE_FALLBACK_NAME), + # Action asks user to affirm + *_two_stage_clarification_request(), + ActionExecuted(ACTION_LISTEN_NAME), + # User denies suggested intents + UserUttered("hi", {"name": USER_INTENT_OUT_OF_SCOPE}), + # Action asks user to rephrase + *_two_stage_clarification_request(), + # User rephrased with low confidence + *_message_requiring_fallback(), + *_two_stage_clarification_request(), + # Actions asks user to affirm for the last time + ActionExecuted(ACTION_LISTEN_NAME), + # User affirms successfully + UserUttered("hi", {"name": "greet"}), + ], + ) + domain = Domain.empty() + action = TwoStageFallbackAction() + + events = await action.run( + CollectingOutputChannel(), + TemplatedNaturalLanguageGenerator(domain.templates), + tracker, + domain, + ) + + for event in events: + tracker.update(event) + + applied_events = tracker.applied_events() + + assert applied_events == [ + ActionExecuted(ACTION_LISTEN_NAME), + UserUttered("hi", {"name": "greet"}), + ] + + +@pytest.mark.parametrize( + "intent_which_lets_action_give_up", + [USER_INTENT_OUT_OF_SCOPE, DEFAULT_NLU_FALLBACK_INTENT_NAME], +) +async def test_2nd_affirmation_failed(intent_which_lets_action_give_up: Text): + tracker = DialogueStateTracker.from_events( + "some-sender", + evts=[ + # User sends message with low NLU confidence + *_message_requiring_fallback(), + Form(ACTION_TWO_STAGE_FALLBACK_NAME), + # Action asks user to affirm + *_two_stage_clarification_request(), + ActionExecuted(ACTION_LISTEN_NAME), + # User denies suggested intents + UserUttered("hi", {"name": USER_INTENT_OUT_OF_SCOPE}), + # Action asks user to rephrase + *_two_stage_clarification_request(), + # User rephrased with low confidence + *_message_requiring_fallback(), + # Actions asks user to affirm for the last time + *_two_stage_clarification_request(), + ActionExecuted(ACTION_LISTEN_NAME), + # User denies suggested intents for the second time + UserUttered("hi", {"name": intent_which_lets_action_give_up}), + ], + ) + domain = Domain.empty() + action = TwoStageFallbackAction() + + events = await action.run( + CollectingOutputChannel(), + TemplatedNaturalLanguageGenerator(domain.templates), + tracker, + domain, + ) + + assert events == [Form(None), UserUtteranceReverted()] diff --git a/tests/core/test_actions.py b/tests/core/test_actions.py index 1e1d1154710e..f3c724b62363 100644 --- a/tests/core/test_actions.py +++ b/tests/core/test_actions.py @@ -28,6 +28,7 @@ ActionSessionStart, ) from rasa.core.actions.forms import FormAction +from rasa.core.actions.two_stage_fallback import ACTION_TWO_STAGE_FALLBACK_NAME from rasa.core.channels import CollectingOutputChannel from rasa.core.domain import Domain, SessionConfig from rasa.core.events import ( @@ -114,7 +115,7 @@ def test_domain_action_instantiation(): instantiated_actions = domain.actions(None) - assert len(instantiated_actions) == 13 + assert len(instantiated_actions) == 14 assert instantiated_actions[0].name() == ACTION_LISTEN_NAME assert instantiated_actions[1].name() == ACTION_RESTART_NAME assert instantiated_actions[2].name() == ACTION_SESSION_START_NAME @@ -123,11 +124,12 @@ def test_domain_action_instantiation(): assert instantiated_actions[5].name() == ACTION_REVERT_FALLBACK_EVENTS_NAME assert instantiated_actions[6].name() == ACTION_DEFAULT_ASK_AFFIRMATION_NAME assert instantiated_actions[7].name() == ACTION_DEFAULT_ASK_REPHRASE_NAME - assert instantiated_actions[8].name() == ACTION_BACK_NAME - assert instantiated_actions[9].name() == RULE_SNIPPET_ACTION_NAME - assert instantiated_actions[10].name() == "my_module.ActionTest" - assert instantiated_actions[11].name() == "utter_test" - assert instantiated_actions[12].name() == "respond_test" + assert instantiated_actions[8].name() == ACTION_TWO_STAGE_FALLBACK_NAME + assert instantiated_actions[9].name() == ACTION_BACK_NAME + assert instantiated_actions[10].name() == RULE_SNIPPET_ACTION_NAME + assert instantiated_actions[11].name() == "my_module.ActionTest" + assert instantiated_actions[12].name() == "utter_test" + assert instantiated_actions[13].name() == "respond_test" async def test_remote_action_runs( diff --git a/tests/core/test_domain.py b/tests/core/test_domain.py index 621eb501a79f..a2f3ae04fa21 100644 --- a/tests/core/test_domain.py +++ b/tests/core/test_domain.py @@ -163,7 +163,7 @@ def test_domain_from_template(): assert not domain.is_empty() assert len(domain.intents) == 10 - assert len(domain.action_names) == 14 + assert len(domain.action_names) == 15 def test_avoid_action_repetition(): diff --git a/tests/core/test_dsl.py b/tests/core/test_dsl.py index 9f3b92b45933..1c718bda4c26 100644 --- a/tests/core/test_dsl.py +++ b/tests/core/test_dsl.py @@ -238,7 +238,7 @@ async def test_generate_training_data_with_cycles(default_domain): num_tens = len(training_trackers) - 1 # if new default actions are added the keys of the actions will be changed - assert Counter(y) == {0: 6, 11: num_tens, 13: 1, 1: 2, 12: 3} + assert Counter(y) == {0: 6, 12: num_tens, 14: 1, 1: 2, 13: 3} async def test_generate_training_data_with_unused_checkpoints(tmpdir, default_domain): diff --git a/tests/importers/test_rasa.py b/tests/importers/test_rasa.py index 30efffaf3865..d004fca991cb 100644 --- a/tests/importers/test_rasa.py +++ b/tests/importers/test_rasa.py @@ -22,7 +22,7 @@ async def test_rasa_file_importer(project: Text): assert len(domain.intents) == 7 assert domain.slots == [] assert domain.entities == [] - assert len(domain.action_names) == 16 + assert len(domain.action_names) == 17 assert len(domain.templates) == 6 stories = await importer.get_stories() diff --git a/tests/nlu/classifiers/test_fallback_classifier.py b/tests/nlu/classifiers/test_fallback_classifier.py index 920681b386a8..7e684d672bf2 100644 --- a/tests/nlu/classifiers/test_fallback_classifier.py +++ b/tests/nlu/classifiers/test_fallback_classifier.py @@ -2,11 +2,7 @@ from rasa.constants import DEFAULT_NLU_FALLBACK_INTENT_NAME from rasa.core.constants import DEFAULT_NLU_FALLBACK_THRESHOLD -from rasa.nlu.classifiers.fallback_classifier import ( - FallbackClassifier, - FALLBACK_INTENT_NAME_KEY, - THRESHOLD_KEY, -) +from rasa.nlu.classifiers.fallback_classifier import FallbackClassifier, THRESHOLD_KEY from rasa.nlu.training_data import Message from rasa.nlu.constants import INTENT_RANKING_KEY, INTENT, INTENT_CONFIDENCE_KEY @@ -70,8 +66,5 @@ def test_not_predict_fallback_intent(): def test_default_threshold(): classifier = FallbackClassifier({}) - assert ( - classifier.component_config[FALLBACK_INTENT_NAME_KEY] - == DEFAULT_NLU_FALLBACK_INTENT_NAME - ) + assert classifier.component_config[THRESHOLD_KEY] == DEFAULT_NLU_FALLBACK_THRESHOLD diff --git a/tests/nlu/test_train.py b/tests/nlu/test_train.py index a84f1282443a..6a136e26a86c 100644 --- a/tests/nlu/test_train.py +++ b/tests/nlu/test_train.py @@ -72,7 +72,7 @@ def pipelines_for_tests(): "MitieNLP", "JiebaTokenizer", "MitieFeaturizer", "MitieEntityExtractor" ), ), - ("fallback", as_pipeline("FallbackClassifier")), + ("fallback", as_pipeline("KeywordIntentClassifier", "FallbackClassifier")), ]