Skip to content

Commit

Permalink
check if ActionExecutionRejected should be thrown after custom acti…
Browse files Browse the repository at this point in the history
…on for slot validations was called
  • Loading branch information
wochinge committed Oct 16, 2020
1 parent a2dc866 commit b224fb7
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 19 deletions.
11 changes: 11 additions & 0 deletions changelog/6977.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[Forms](forms.mdx) no longer reject their execution before a potential custom
action for validating / extracting slots was executed.
Forms continue to reject in two cases:
- A slot was requested to be filled, but no slot mapping applied to the latest user
message and there was no custom action for potentially extracting other slots.
- A slot was requested to be filled, but the custom action for validating / extracting
slots didn't return any slot event.

Additionally you can also reject the form execution manually by returning a
`ActionExecutionRejected` event within your custom action for validating / extracting
slots.
56 changes: 39 additions & 17 deletions rasa/core/actions/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from rasa.core.actions import action
from rasa.core.actions.loops import LoopAction
from rasa.core.channels import OutputChannel
from rasa.shared.core.domain import Domain
from rasa.shared.core.domain import Domain, InvalidDomain

from rasa.core.actions.action import ActionExecutionRejection, RemoteAction
from rasa.shared.core.constants import (
Expand All @@ -15,7 +15,13 @@
LOOP_INTERRUPTED,
)
from rasa.shared.constants import UTTER_PREFIX
from rasa.shared.core.events import Event, SlotSet, ActionExecuted, ActiveLoop
from rasa.shared.core.events import (
Event,
SlotSet,
ActionExecuted,
ActiveLoop,
ActionExecutionRejected,
)
from rasa.core.nlg import NaturalLanguageGenerator
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.utils.endpoints import EndpointConfig
Expand Down Expand Up @@ -347,7 +353,7 @@ def extract_requested_slot(
elif mapping_type == str(SlotMapping.FROM_TEXT):
value = tracker.latest_message.text
else:
raise ValueError("Provided slot mapping type is not supported")
raise InvalidDomain("Provided slot mapping type is not supported")

if value is not None:
logger.debug(
Expand All @@ -361,7 +367,7 @@ def extract_requested_slot(

async def validate_slots(
self,
slot_dict: Dict[Text, Any],
slot_candidates: Dict[Text, Any],
tracker: "DialogueStateTracker",
domain: Domain,
output_channel: OutputChannel,
Expand All @@ -373,7 +379,7 @@ async def validate_slots(
them. Otherwise there is no validation.
Args:
slot_dict: Extracted slots which are candidates to fill the slots required
slot_candidates: Extracted slots which are candidates to fill the slots required
by the form.
tracker: The current conversation tracker.
domain: The current model domain.
Expand All @@ -385,8 +391,10 @@ async def validate_slots(
The validation events including potential bot messages and `SlotSet` events
for the validated slots.
"""

events = [SlotSet(slot_name, value) for slot_name, value in slot_dict.items()]
logger.debug(f"Validating extracted slots: {slot_candidates}")
events = [
SlotSet(slot_name, value) for slot_name, value in slot_candidates.items()
]

validate_name = f"validate_{self.name()}"

Expand Down Expand Up @@ -445,19 +453,30 @@ async def validate(
if slot_to_fill:
slot_values.update(self.extract_requested_slot(tracker, domain))

if not slot_values:
# reject to execute the form action
# if some slot was requested but nothing was extracted
# it will allow other policies to predict another action
raise ActionExecutionRejection(
self.name(),
f"Failed to extract slot {slot_to_fill} with action {self.name()}",
)
logger.debug(f"Validating extracted slots: {slot_values}")
return await self.validate_slots(
validation_events = await self.validate_slots(
slot_values, tracker, domain, output_channel, nlg
)

some_slots_were_validated = any(
isinstance(event, SlotSet) for event in validation_events
)
user_rejected_manually = any(
isinstance(event, ActionExecutionRejected) for event in validation_events
)
if (
slot_to_fill
and not some_slots_were_validated
and not user_rejected_manually
):
# reject to execute the form action
# if some slot was requested but nothing was extracted
# it will allow other policies to predict another action
raise ActionExecutionRejection(
self.name(),
f"Failed to extract slot {slot_to_fill} with action {self.name()}",
)
return validation_events

async def request_next_slot(
self,
tracker: "DialogueStateTracker",
Expand Down Expand Up @@ -666,6 +685,9 @@ async def is_done(
domain: "Domain",
events_so_far: List[Event],
) -> bool:
if any(isinstance(event, ActionExecutionRejected) for event in events_so_far):
return False

# Custom validation actions can decide to terminate the loop early by
# setting the requested slot to `None` or setting `ActiveLoop(None)`.
# We explicitly check only the last occurrences for each possible termination
Expand Down
55 changes: 53 additions & 2 deletions tests/core/actions/test_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from rasa.shared.core.constants import ACTION_LISTEN_NAME, REQUESTED_SLOT
from rasa.core.actions.forms import FormAction
from rasa.core.channels import CollectingOutputChannel
from rasa.shared.core.domain import Domain
from rasa.shared.core.domain import Domain, InvalidDomain
from rasa.shared.core.events import (
ActiveLoop,
SlotSet,
Expand All @@ -19,6 +19,7 @@
BotUttered,
Restarted,
Event,
ActionExecutionRejected,
)
from rasa.core.nlg import TemplatedNaturalLanguageGenerator
from rasa.shared.core.trackers import DialogueStateTracker
Expand Down Expand Up @@ -291,6 +292,16 @@ async def test_action_rejection():
ActiveLoop(None),
],
),
# User rejected manually
(
[{"event": "action_execution_rejected", "name": "my form"}],
[
ActionExecutionRejected("my form"),
SlotSet("num_tables", 5),
SlotSet("num_people", "hi"),
SlotSet(REQUESTED_SLOT, None),
],
),
],
)
async def test_validate_slots(
Expand Down Expand Up @@ -341,6 +352,46 @@ async def test_validate_slots(
assert events == expected_events


async def test_no_slots_extracted_with_custom_slot_mappings():
form_name = "my form"
events = [
ActiveLoop(form_name),
SlotSet(REQUESTED_SLOT, "num_tables"),
ActionExecuted(ACTION_LISTEN_NAME),
UserUttered("off topic"),
]
tracker = DialogueStateTracker.from_events(sender_id="bla", evts=events)

domain = f"""
slots:
num_tables:
type: any
forms:
{form_name}:
num_tables:
- type: from_entity
entity: num_tables
actions:
- validate_{form_name}
"""
domain = Domain.from_yaml(domain)
action_server_url = "http:/my-action-server:5055/webhook"

with aioresponses() as mocked:
mocked.post(action_server_url, payload={"events": []})

action_server = EndpointConfig(action_server_url)
action = FormAction(form_name, action_server)

with pytest.raises(ActionExecutionRejection):
await action.run(
CollectingOutputChannel(),
TemplatedNaturalLanguageGenerator(domain.templates),
tracker,
domain,
)


async def test_validate_slots_on_activation_with_other_action_after_user_utterance():
form_name = "my form"
slot_name = "num_people"
Expand Down Expand Up @@ -810,7 +861,7 @@ def test_invalid_slot_mapping():
{"forms": {form_name: {slot_name: [{"type": "invalid"}]}}}
)

with pytest.raises(ValueError):
with pytest.raises(InvalidDomain):
form.extract_requested_slot(tracker, domain)


Expand Down

0 comments on commit b224fb7

Please sign in to comment.