From 92476ca86499b7983ec069999ee98005087d4974 Mon Sep 17 00:00:00 2001 From: Tom Bocklisch Date: Fri, 27 Oct 2023 16:52:08 +0200 Subject: [PATCH] added a ton of tests --- rasa/core/policies/flow_policy.py | 8 +- rasa/core/policies/flows/flow_executor.py | 14 +- tests/cli/test_rasa_evaluate_markers.py | 1 - .../core/policies/flows/test_flow_executor.py | 306 ++++++++++++++++-- 4 files changed, 297 insertions(+), 32 deletions(-) diff --git a/rasa/core/policies/flow_policy.py b/rasa/core/policies/flow_policy.py index 9d23638236aa..b437a6c239a1 100644 --- a/rasa/core/policies/flow_policy.py +++ b/rasa/core/policies/flow_policy.py @@ -138,7 +138,9 @@ def predict_action_probabilities( # create executor and predict next action try: - prediction = flow_executor.advance_flows(tracker, domain, flows) + prediction = flow_executor.advance_flows( + tracker, domain.action_names_or_texts, flows + ) return self._create_prediction_result( prediction.action_name, domain, @@ -162,7 +164,9 @@ def predict_action_probabilities( # we retry, with the internal error frame on the stack event = updated_stack.persist_as_event() tracker.update(event) - prediction = flow_executor.advance_flows(tracker, domain, flows) + prediction = flow_executor.advance_flows( + tracker, domain.action_names_or_texts, flows + ) collected_events = [event] + (prediction.events or []) return self._create_prediction_result( prediction.action_name, diff --git a/rasa/core/policies/flows/flow_executor.py b/rasa/core/policies/flows/flow_executor.py index 2e66e2b8420f..9cec8d672145 100644 --- a/rasa/core/policies/flows/flow_executor.py +++ b/rasa/core/policies/flows/flow_executor.py @@ -61,7 +61,6 @@ CollectInformationFlowStep, StaticFlowLink, ) -from rasa.shared.core.domain import Domain from rasa.shared.core.trackers import ( DialogueStateTracker, ) @@ -299,7 +298,7 @@ def _reset_slot(slot_name: Text, dialogue_tracker: DialogueStateTracker) -> None def advance_flows( - tracker: DialogueStateTracker, domain: Domain, flows: FlowsList + tracker: DialogueStateTracker, available_actions: List[str], flows: FlowsList ) -> FlowActionPrediction: """Advance the flows. @@ -307,6 +306,8 @@ def advance_flows( Args: tracker: The tracker to get the next action for. + available_actions: The actions that are available in the domain. + flows: All flows. Returns: The predicted action and the events to run. @@ -317,7 +318,7 @@ def advance_flows( return FlowActionPrediction(None, 0.0) previous_stack = stack.as_dict() - prediction = select_next_action(stack, tracker, domain, flows) + prediction = select_next_action(stack, tracker, available_actions, flows) if previous_stack != stack.as_dict(): # we need to update dialogue stack to persist the state of the executor if not prediction.events: @@ -329,7 +330,7 @@ def advance_flows( def select_next_action( stack: DialogueStack, tracker: DialogueStateTracker, - domain: Domain, + available_actions: List[str], flows: FlowsList, ) -> FlowActionPrediction: """Select the next action to execute. @@ -340,7 +341,10 @@ def select_next_action( advanced. If there are no more flows, the action listen is predicted. Args: + stack: The stack to get the next action for. tracker: The tracker to get the next action for. + available_actions: The actions that are available in the domain. + flows: All flows. Returns: The next action to execute, the events that should be applied to the @@ -390,7 +394,7 @@ def select_next_action( current_flow, stack, tracker, - domain.action_names_or_texts, + available_actions, flows, ) tracker.update_with_events(step_result.events) diff --git a/tests/cli/test_rasa_evaluate_markers.py b/tests/cli/test_rasa_evaluate_markers.py index eefe88aa716d..804afe4e6a7f 100644 --- a/tests/cli/test_rasa_evaluate_markers.py +++ b/tests/cli/test_rasa_evaluate_markers.py @@ -10,7 +10,6 @@ from rasa.shared.core.events import ActionExecuted, SlotSet, UserUttered from rasa.shared.core.trackers import DialogueStateTracker from rasa.shared.core.constants import ACTION_SESSION_START_NAME -from rasa.shared.core.domain import Domain from rasa.core.tracker_store import SQLTrackerStore from rasa.cli.evaluate import STATS_SESSION_SUFFIX, STATS_OVERALL_SUFFIX from tests.conftest import write_endpoint_config_to_yaml diff --git a/tests/core/policies/flows/test_flow_executor.py b/tests/core/policies/flows/test_flow_executor.py index b7eb68426877..5a56fc8ff54d 100644 --- a/tests/core/policies/flows/test_flow_executor.py +++ b/tests/core/policies/flows/test_flow_executor.py @@ -1,8 +1,13 @@ import textwrap from typing import List, Optional, Tuple +from unittest.mock import patch import pytest from rasa.core.policies.flows import flow_executor from rasa.core.policies.flows.flow_exceptions import FlowCircuitBreakerTrippedException +from rasa.core.policies.flows.flow_step_result import ( + ContinueFlowWithNextStep, + PauseFlowReturnPrediction, +) from rasa.dialogue_understanding.patterns.collect_information import ( CollectInformationPatternFlowStackFrame, ) @@ -19,6 +24,7 @@ UserFlowStackFrame, ) from rasa.dialogue_understanding.stack.frames.search_frame import SearchStackFrame +from rasa.shared.core.constants import ACTION_SEND_TEXT_NAME from rasa.shared.core.domain import Domain from rasa.shared.core.events import ActionExecuted, Event, SlotSet from rasa.shared.core.flows.flow import ( @@ -472,7 +478,7 @@ def test_trigger_pattern_continue_interrupted_does_not_trigger_if_no_interrupt() assert len(stack.frames) == 1 -def test_trigger_pattern_continue_interrupted_does_not_trigger_if_frame_is_already_finished(): +def test_trigger_pattern_continue_interrupted_does_not_trigger_if_finished(): flows = flows_from_str( """ flows: @@ -509,7 +515,7 @@ def test_trigger_pattern_continue_interrupted_does_not_trigger_if_frame_is_alrea assert len(stack.frames) == 1 -def test_trigger_pattern_continue_interrupted_does_not_trigger_if_frame_is_not_user_frame(): +def test_trigger_pattern_continue_interrupted_does_not_trigger_if_not_user_frame(): flows = flows_from_str( """ flows: @@ -587,7 +593,7 @@ def test_trigger_pattern_completed_does_not_trigger_if_stack_not_empty(): assert len(stack.frames) == 1 -def test_trigger_pattern_completed_does_not_trigger_if_current_frame_is_not_user_frame(): +def test_trigger_pattern_completed_does_not_trigger_if_not_user_frame(): flows = flows_from_str( """ flows: @@ -692,7 +698,7 @@ def test_reset_scoped_slots_resets_set_slots(): foo_flow: name: foo flow steps: - - set_slots: + - set_slots: - foo: bar """ ) @@ -713,7 +719,7 @@ def test_reset_scoped_slots_does_not_reset_set_slots_if_collect_forbids_it(): steps: - collect: foo reset_after_flow_ends: false - - set_slots: + - set_slots: - foo: bar """ ) @@ -723,23 +729,14 @@ def test_reset_scoped_slots_does_not_reset_set_slots_if_collect_forbids_it(): assert events == [] -def test_run_step(): - all_flows = flows_from_str( +def test_run_step_collect(): + flows = flows_from_str( """ flows: my_flow: steps: - id: collect_foo collect: foo - next: - - if: foo is 'foobar' - then: collect_bar - - else: - - id: collect_baz - collect: baz - next: END - - id: collect_bar - collect: bar """ ) @@ -747,13 +744,200 @@ def test_run_step(): flow_id="my_flow", step_id="collect_foo", frame_id="some-frame-id" ) stack = DialogueStack(frames=[user_flow_frame]) - tracker = DialogueStateTracker.from_events( - "test", [stack.persist_as_event(), SlotSet("foo", "foobar")] + tracker = DialogueStateTracker.from_events("test", [stack.persist_as_event()]) + step = user_flow_frame.step(flows) + flow = user_flow_frame.flow(flows) + + available_actions = ["utter_ask_foo"] + + result = flow_executor.run_step( + step, flow, stack, tracker, available_actions, flows ) - step = user_flow_frame.step(all_flows) - flow = user_flow_frame.flow(all_flows) - result = flow_executor.run_step(step, flow, stack, tracker, domain, flows) + assert isinstance(result, ContinueFlowWithNextStep) + assert result.events == [] + + +def test_run_step_action(): + flows = flows_from_str( + """ + flows: + my_flow: + steps: + - id: action + action: utter_ask_foo + """ + ) + + user_flow_frame = UserFlowStackFrame( + flow_id="my_flow", step_id="action", frame_id="some-frame-id" + ) + stack = DialogueStack(frames=[user_flow_frame]) + tracker = DialogueStateTracker.from_events("test", [stack.persist_as_event()]) + step = user_flow_frame.step(flows) + flow = user_flow_frame.flow(flows) + + available_actions = ["utter_ask_foo"] + + result = flow_executor.run_step( + step, flow, stack, tracker, available_actions, flows + ) + + assert isinstance(result, PauseFlowReturnPrediction) + assert result.action_prediction.action_name == "utter_ask_foo" + + +def test_run_step_action_that_does_not_exist(): + flows = flows_from_str( + """ + flows: + my_flow: + steps: + - id: action + action: utter_ask_foo + """ + ) + + user_flow_frame = UserFlowStackFrame( + flow_id="my_flow", step_id="action", frame_id="some-frame-id" + ) + stack = DialogueStack(frames=[user_flow_frame]) + tracker = DialogueStateTracker.from_events("test", [stack.persist_as_event()]) + step = user_flow_frame.step(flows) + flow = user_flow_frame.flow(flows) + + available_actions = [] + + result = flow_executor.run_step( + step, flow, stack, tracker, available_actions, flows + ) + + assert isinstance(result, ContinueFlowWithNextStep) + + +def test_run_step_link(): + flows = flows_from_str( + """ + flows: + my_flow: + steps: + - id: link + link: bar_flow + """ + ) + + user_flow_frame = UserFlowStackFrame( + flow_id="my_flow", step_id="link", frame_id="some-frame-id" + ) + stack = DialogueStack(frames=[user_flow_frame]) + tracker = DialogueStateTracker.from_events("test", [stack.persist_as_event()]) + step = user_flow_frame.step(flows) + flow = user_flow_frame.flow(flows) + + available_actions = [] + + result = flow_executor.run_step( + step, flow, stack, tracker, available_actions, flows + ) + + assert isinstance(result, ContinueFlowWithNextStep) + top = stack.top() + assert isinstance(top, UserFlowStackFrame) + assert top.flow_id == "my_flow" + linked_flow = stack.frames[0] + assert isinstance(linked_flow, UserFlowStackFrame) + assert linked_flow.frame_type == FlowStackFrameType.LINK + assert linked_flow.flow_id == "bar_flow" + + +def test_run_step_set_slot(): + flows = flows_from_str( + """ + flows: + my_flow: + steps: + - id: set_slot + set_slots: + - bar: baz + """ + ) + + user_flow_frame = UserFlowStackFrame( + flow_id="my_flow", step_id="set_slot", frame_id="some-frame-id" + ) + stack = DialogueStack(frames=[user_flow_frame]) + tracker = DialogueStateTracker.from_events("test", [stack.persist_as_event()]) + step = user_flow_frame.step(flows) + flow = user_flow_frame.flow(flows) + + available_actions = [] + + result = flow_executor.run_step( + step, flow, stack, tracker, available_actions, flows + ) + + assert isinstance(result, ContinueFlowWithNextStep) + assert result.events == [SlotSet("bar", "baz")] + + +def test_run_step_generate_response(): + flows = flows_from_str( + """ + flows: + my_flow: + steps: + - id: generate + generation_prompt: Generate a message! + """ + ) + + user_flow_frame = UserFlowStackFrame( + flow_id="my_flow", step_id="generate", frame_id="some-frame-id" + ) + stack = DialogueStack(frames=[user_flow_frame]) + tracker = DialogueStateTracker.from_events("test", [stack.persist_as_event()]) + step = user_flow_frame.step(flows) + flow = user_flow_frame.flow(flows) + available_actions = [] + + # mock the steps `.generate` method to avoid an LLM call + with patch.object(step, "generate", return_value="generated"): + result = flow_executor.run_step( + step, flow, stack, tracker, available_actions, flows + ) + + assert isinstance(result, PauseFlowReturnPrediction) + assert result.action_prediction.action_name == ACTION_SEND_TEXT_NAME + assert result.action_prediction.metadata == {"message": {"text": "generated"}} + + +def test_run_step_end(): + flows = flows_from_str( + """ + flows: + my_flow: + steps: + - id: collect + collect: bar + """ + ) + + user_flow_frame = UserFlowStackFrame( + flow_id="my_flow", step_id="END", frame_id="some-frame-id" + ) + stack = DialogueStack(frames=[user_flow_frame]) + tracker = DialogueStateTracker.from_events("test", [stack.persist_as_event()]) + step = user_flow_frame.step(flows) + flow = user_flow_frame.flow(flows) + + available_actions = [] + + result = flow_executor.run_step( + step, flow, stack, tracker, available_actions, flows + ) + + assert isinstance(result, ContinueFlowWithNextStep) + assert result.events == [SlotSet("bar", None)] def test_executor_does_not_get_tripped_if_an_action_is_predicted_in_loop(): @@ -785,7 +969,11 @@ def test_executor_does_not_get_tripped_if_an_action_is_predicted_in_loop(): slots=domain.slots, ) - selection = flow_executor.select_next_action(stack, tracker, domain, flow_with_loop) + available_actions = ["action_listen"] + + selection = flow_executor.select_next_action( + stack, tracker, available_actions, flow_with_loop + ) assert selection.action_name == "action_listen" @@ -819,8 +1007,76 @@ def test_executor_trips_internal_circuit_breaker(): slots=domain.slots, ) + available_actions = [] + with pytest.raises(FlowCircuitBreakerTrippedException): - flow_executor.select_next_action(stack, tracker, domain, flow_with_loop) + flow_executor.select_next_action( + stack, tracker, available_actions, flow_with_loop + ) + + +def test_advance_flows_empty_stack(): + flows = flows_from_str( + """ + flows: + foo_flow: + steps: + - id: "1" + set_slots: + - foo: bar + next: "2" + - id: "2" + set_slots: + - foo: barbar + next: "1" + """ + ) + stack = DialogueStack(frames=[]) + tracker = DialogueStateTracker.from_events( + "test", + evts=[stack.persist_as_event()], + ) + available_actions = [] + prediction = flow_executor.advance_flows(tracker, available_actions, flows) + assert prediction.action_name is None + + +def test_advance_flows_selects_next_action(): + flows = flows_from_str( + """ + flows: + foo_flow: + steps: + - id: "1" + collect: foo + - id: "2" + action: utter_goodbye + """ + ) + stack = DialogueStack( + frames=[UserFlowStackFrame(flow_id="foo_flow", step_id="1", frame_id="some-id")] + ) + tracker = DialogueStateTracker.from_events( + "test", + evts=[stack.persist_as_event()], + ) + available_actions = ["utter_goodbye"] + prediction = flow_executor.advance_flows(tracker, available_actions, flows) + assert prediction.action_name == "utter_goodbye" + assert prediction.events == [ + SlotSet( + "dialogue_stack", + [ + { + "frame_id": "some-id", + "flow_id": "foo_flow", + "step_id": "2", + "frame_type": "regular", + "type": "flow", + } + ], + ) + ] def _run_flow_until_listen( @@ -831,7 +1087,9 @@ def _run_flow_until_listen( events = [] actions = [] while True: - action_prediction = flow_executor.advance_flows(tracker, domain, flows) + action_prediction = flow_executor.advance_flows( + tracker, domain.action_names_or_texts, flows + ) if not action_prediction: break