diff --git a/changelog/6977.improvement.md b/changelog/6977.improvement.md new file mode 100644 index 000000000000..833e4b5e4da8 --- /dev/null +++ b/changelog/6977.improvement.md @@ -0,0 +1,11 @@ +[Forms](forms.mdx) no longer reject their execution before a potential custom +action for validating / extracting slots was executed. +Forms continue to reject in two cases automatically: +- A slot was requested to be filled, but no slot mapping applied to the latest user + message and there was no custom action for potentially extracting other slots. +- A slot was requested to be filled, but the custom action for validating / extracting + slots didn't return any slot event. + +Additionally you can also reject the form execution manually by returning a +`ActionExecutionRejected` event within your custom action for validating / extracting +slots. diff --git a/rasa/core/actions/forms.py b/rasa/core/actions/forms.py index 542cb7a0bb53..fce75169e7d4 100644 --- a/rasa/core/actions/forms.py +++ b/rasa/core/actions/forms.py @@ -6,7 +6,7 @@ from rasa.core.actions import action from rasa.core.actions.loops import LoopAction from rasa.core.channels import OutputChannel -from rasa.shared.core.domain import Domain +from rasa.shared.core.domain import Domain, InvalidDomain from rasa.core.actions.action import ActionExecutionRejection, RemoteAction from rasa.shared.core.constants import ( @@ -15,7 +15,13 @@ LOOP_INTERRUPTED, ) from rasa.shared.constants import UTTER_PREFIX -from rasa.shared.core.events import Event, SlotSet, ActionExecuted, ActiveLoop +from rasa.shared.core.events import ( + Event, + SlotSet, + ActionExecuted, + ActiveLoop, + ActionExecutionRejected, +) from rasa.core.nlg import NaturalLanguageGenerator from rasa.shared.core.trackers import DialogueStateTracker from rasa.utils.endpoints import EndpointConfig @@ -347,7 +353,7 @@ def extract_requested_slot( elif mapping_type == str(SlotMapping.FROM_TEXT): value = tracker.latest_message.text else: - raise ValueError("Provided slot mapping type is not supported") + raise InvalidDomain("Provided slot mapping type is not supported") if value is not None: logger.debug( @@ -361,7 +367,7 @@ def extract_requested_slot( async def validate_slots( self, - slot_dict: Dict[Text, Any], + slot_candidates: Dict[Text, Any], tracker: "DialogueStateTracker", domain: Domain, output_channel: OutputChannel, @@ -373,7 +379,7 @@ async def validate_slots( them. Otherwise there is no validation. Args: - slot_dict: Extracted slots which are candidates to fill the slots required + slot_candidates: Extracted slots which are candidates to fill the slots required by the form. tracker: The current conversation tracker. domain: The current model domain. @@ -385,8 +391,10 @@ async def validate_slots( The validation events including potential bot messages and `SlotSet` events for the validated slots. """ - - events = [SlotSet(slot_name, value) for slot_name, value in slot_dict.items()] + logger.debug(f"Validating extracted slots: {slot_candidates}") + events = [ + SlotSet(slot_name, value) for slot_name, value in slot_candidates.items() + ] validate_name = f"validate_{self.name()}" @@ -445,19 +453,30 @@ async def validate( if slot_to_fill: slot_values.update(self.extract_requested_slot(tracker, domain)) - if not slot_values: - # reject to execute the form action - # if some slot was requested but nothing was extracted - # it will allow other policies to predict another action - raise ActionExecutionRejection( - self.name(), - f"Failed to extract slot {slot_to_fill} with action {self.name()}", - ) - logger.debug(f"Validating extracted slots: {slot_values}") - return await self.validate_slots( + validation_events = await self.validate_slots( slot_values, tracker, domain, output_channel, nlg ) + some_slots_were_validated = any( + isinstance(event, SlotSet) for event in validation_events + ) + user_rejected_manually = any( + isinstance(event, ActionExecutionRejected) for event in validation_events + ) + if ( + slot_to_fill + and not some_slots_were_validated + and not user_rejected_manually + ): + # reject to execute the form action + # if some slot was requested but nothing was extracted + # it will allow other policies to predict another action + raise ActionExecutionRejection( + self.name(), + f"Failed to extract slot {slot_to_fill} with action {self.name()}", + ) + return validation_events + async def request_next_slot( self, tracker: "DialogueStateTracker", @@ -666,6 +685,9 @@ async def is_done( domain: "Domain", events_so_far: List[Event], ) -> bool: + if any(isinstance(event, ActionExecutionRejected) for event in events_so_far): + return False + # Custom validation actions can decide to terminate the loop early by # setting the requested slot to `None` or setting `ActiveLoop(None)`. # We explicitly check only the last occurrences for each possible termination diff --git a/tests/core/actions/test_forms.py b/tests/core/actions/test_forms.py index b661bf84afe2..57c33eee3a06 100644 --- a/tests/core/actions/test_forms.py +++ b/tests/core/actions/test_forms.py @@ -10,7 +10,7 @@ from rasa.shared.core.constants import ACTION_LISTEN_NAME, REQUESTED_SLOT from rasa.core.actions.forms import FormAction from rasa.core.channels import CollectingOutputChannel -from rasa.shared.core.domain import Domain +from rasa.shared.core.domain import Domain, InvalidDomain from rasa.shared.core.events import ( ActiveLoop, SlotSet, @@ -19,6 +19,7 @@ BotUttered, Restarted, Event, + ActionExecutionRejected, ) from rasa.core.nlg import TemplatedNaturalLanguageGenerator from rasa.shared.core.trackers import DialogueStateTracker @@ -291,6 +292,16 @@ async def test_action_rejection(): ActiveLoop(None), ], ), + # User rejected manually + ( + [{"event": "action_execution_rejected", "name": "my form"}], + [ + ActionExecutionRejected("my form"), + SlotSet("num_tables", 5), + SlotSet("num_people", "hi"), + SlotSet(REQUESTED_SLOT, None), + ], + ), ], ) async def test_validate_slots( @@ -341,6 +352,46 @@ async def test_validate_slots( assert events == expected_events +async def test_no_slots_extracted_with_custom_slot_mappings(): + form_name = "my form" + events = [ + ActiveLoop(form_name), + SlotSet(REQUESTED_SLOT, "num_tables"), + ActionExecuted(ACTION_LISTEN_NAME), + UserUttered("off topic"), + ] + tracker = DialogueStateTracker.from_events(sender_id="bla", evts=events) + + domain = f""" + slots: + num_tables: + type: any + forms: + {form_name}: + num_tables: + - type: from_entity + entity: num_tables + actions: + - validate_{form_name} + """ + domain = Domain.from_yaml(domain) + action_server_url = "http:/my-action-server:5055/webhook" + + with aioresponses() as mocked: + mocked.post(action_server_url, payload={"events": []}) + + action_server = EndpointConfig(action_server_url) + action = FormAction(form_name, action_server) + + with pytest.raises(ActionExecutionRejection): + await action.run( + CollectingOutputChannel(), + TemplatedNaturalLanguageGenerator(domain.templates), + tracker, + domain, + ) + + async def test_validate_slots_on_activation_with_other_action_after_user_utterance(): form_name = "my form" slot_name = "num_people" @@ -810,7 +861,7 @@ def test_invalid_slot_mapping(): {"forms": {form_name: {slot_name: [{"type": "invalid"}]}}} ) - with pytest.raises(ValueError): + with pytest.raises(InvalidDomain): form.extract_requested_slot(tracker, domain)