diff --git a/rasa/core/actions/action.py b/rasa/core/actions/action.py index 232b977d2db8..e03ff75df311 100644 --- a/rasa/core/actions/action.py +++ b/rasa/core/actions/action.py @@ -225,9 +225,9 @@ def action_for_name_or_text( return FormAction(action_name_or_text, action_endpoint) if action_name_or_text.startswith(FLOW_PREFIX): - from rasa.core.actions.flows import FlowTriggerAction + from rasa.core.actions.action_trigger_flow import ActionTriggerFlow - return FlowTriggerAction(action_name_or_text) + return ActionTriggerFlow(action_name_or_text) return RemoteAction(action_name_or_text, action_endpoint) diff --git a/rasa/core/actions/flow_trigger_action.py b/rasa/core/actions/action_trigger_flow.py similarity index 53% rename from rasa/core/actions/flow_trigger_action.py rename to rasa/core/actions/action_trigger_flow.py index 7b4271e27c50..a89481adf516 100644 --- a/rasa/core/actions/flow_trigger_action.py +++ b/rasa/core/actions/action_trigger_flow.py @@ -22,16 +22,23 @@ structlogger = structlog.get_logger(__name__) -class FlowTriggerAction(action.Action): - """Action which implements and executes the form logic.""" +class ActionTriggerFlow(action.Action): + """Action which triggers a flow by putting it on the dialogue stack.""" def __init__(self, flow_action_name: Text) -> None: - """Creates a `FlowTriggerAction`. + """Creates a `ActionTriggerFlow`. Args: flow_action_name: Name of the flow. """ super().__init__() + + if not flow_action_name.startswith(FLOW_PREFIX): + raise ValueError( + f"Flow action name '{flow_action_name}' needs to start with " + f"'{FLOW_PREFIX}'." + ) + self._flow_name = flow_action_name[len(FLOW_PREFIX) :] self._flow_action_name = flow_action_name @@ -39,20 +46,20 @@ def name(self) -> Text: """Return the flow name.""" return self._flow_action_name - async def run( - self, - output_channel: "OutputChannel", - nlg: "NaturalLanguageGenerator", - tracker: "DialogueStateTracker", - domain: "Domain", - metadata: Optional[Dict[Text, Any]] = None, - ) -> List[Event]: - """Trigger the flow.""" + def create_event_to_start_flow(self, tracker: DialogueStateTracker) -> Event: + """Create an event to start the flow. + + Args: + tracker: The tracker to start the flow on. + + Returns: + The event to start the flow.""" stack = DialogueStack.from_tracker(tracker) - if not stack.is_empty(): - frame_type = FlowStackFrameType.INTERRUPT - else: - frame_type = FlowStackFrameType.REGULAR + frame_type = ( + FlowStackFrameType.REGULAR + if stack.is_empty() + else FlowStackFrameType.INTERRUPT + ) stack.push( UserFlowStackFrame( @@ -60,16 +67,36 @@ async def run( frame_type=frame_type, ) ) + return stack.persist_as_event() + def create_events_to_set_flow_slots(self, metadata: Dict[str, Any]) -> List[Event]: + """Create events to set the flow slots. + + Set additional slots to prefill information for the flow. + + Args: + metadata: The metadata to set the slots from. + + Returns: + The events to set the flow slots. + """ slots_to_be_set = metadata.get("slots", {}) if metadata else {} - slot_set_events: List[Event] = [ - SlotSet(key, value) for key, value in slots_to_be_set.items() - ] + return [SlotSet(key, value) for key, value in slots_to_be_set.items()] + + async def run( + self, + output_channel: "OutputChannel", + nlg: "NaturalLanguageGenerator", + tracker: "DialogueStateTracker", + domain: "Domain", + metadata: Optional[Dict[Text, Any]] = None, + ) -> List[Event]: + """Trigger the flow.""" + events: List[Event] = [self.create_event_to_start_flow(tracker)] + events.extend(self.create_events_to_set_flow_slots(metadata)) - events: List[Event] = [ - stack.persist_as_event(), - ] + slot_set_events if tracker.active_loop_name: + # end any active loop to ensure we are progressing the started flow events.append(ActiveLoop(None)) return events diff --git a/tests/core/actions/test_action_trigger_flow.py b/tests/core/actions/test_action_trigger_flow.py new file mode 100644 index 000000000000..652007f7fd38 --- /dev/null +++ b/tests/core/actions/test_action_trigger_flow.py @@ -0,0 +1,92 @@ +import pytest +from rasa.core.actions.action_trigger_flow import ActionTriggerFlow +from rasa.core.channels import CollectingOutputChannel +from rasa.core.nlg import TemplatedNaturalLanguageGenerator +from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack +from rasa.dialogue_understanding.stack.frames.flow_stack_frame import ( + FlowStackFrameType, + UserFlowStackFrame, +) +from rasa.shared.core.constants import DIALOGUE_STACK_SLOT +from rasa.shared.core.domain import Domain +from rasa.shared.core.events import ActiveLoop, SlotSet +from rasa.shared.core.trackers import DialogueStateTracker + + +async def test_action_trigger_flow(): + tracker = DialogueStateTracker.from_events("test", []) + action = ActionTriggerFlow("flow_foo") + channel = CollectingOutputChannel() + nlg = TemplatedNaturalLanguageGenerator({}) + events = await action.run(channel, nlg, tracker, Domain.empty()) + assert len(events) == 1 + event = events[0] + assert isinstance(event, SlotSet) + assert event.key == DIALOGUE_STACK_SLOT + assert len(event.value) == 1 + assert event.value[0]["type"] == UserFlowStackFrame.type() + assert event.value[0]["flow_id"] == "foo" + assert event.value[0]["frame_type"] == FlowStackFrameType.REGULAR.value + + +async def test_action_trigger_flow_with_slots(): + tracker = DialogueStateTracker.from_events("test", []) + action = ActionTriggerFlow("flow_foo") + channel = CollectingOutputChannel() + nlg = TemplatedNaturalLanguageGenerator({}) + events = await action.run( + channel, nlg, tracker, Domain.empty(), metadata={"slots": {"foo": "bar"}} + ) + + event = events[0] + assert isinstance(event, SlotSet) + assert event.key == DIALOGUE_STACK_SLOT + assert len(event.value) == 1 + assert event.value[0]["type"] == UserFlowStackFrame.type() + assert event.value[0]["flow_id"] == "foo" + + assert len(events) == 2 + event = events[1] + assert isinstance(event, SlotSet) + assert event.key == "foo" + assert event.value == "bar" + + +async def test_action_trigger_fails_if_name_is_invalid(): + with pytest.raises(ValueError): + ActionTriggerFlow("foo") + + +async def test_action_trigger_ends_an_active_loop_on_the_tracker(): + tracker = DialogueStateTracker.from_events("test", [ActiveLoop("loop_foo")]) + action = ActionTriggerFlow("flow_foo") + channel = CollectingOutputChannel() + nlg = TemplatedNaturalLanguageGenerator({}) + events = await action.run(channel, nlg, tracker, Domain.empty()) + + assert len(events) == 2 + assert isinstance(events[1], ActiveLoop) + assert events[1].name is None + + +async def test_action_trigger_uses_interrupt_flow_type_if_stack_already_contains_flow(): + user_frame = UserFlowStackFrame( + flow_id="my_flow", step_id="collect_bar", frame_id="some-frame-id" + ) + stack = DialogueStack(frames=[user_frame]) + tracker = DialogueStateTracker.from_events("test", [stack.persist_as_event()]) + + action = ActionTriggerFlow("flow_foo") + channel = CollectingOutputChannel() + nlg = TemplatedNaturalLanguageGenerator({}) + + events = await action.run(channel, nlg, tracker, Domain.empty()) + + assert len(events) == 1 + event = events[0] + assert isinstance(event, SlotSet) + assert event.key == DIALOGUE_STACK_SLOT + assert len(event.value) == 2 + assert event.value[1]["type"] == UserFlowStackFrame.type() + assert event.value[1]["flow_id"] == "foo" + assert event.value[1]["frame_type"] == FlowStackFrameType.INTERRUPT.value