Skip to content

Commit

Permalink
added a ton of tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbo committed Oct 27, 2023
1 parent 8f64d6a commit 92476ca
Show file tree
Hide file tree
Showing 4 changed files with 297 additions and 32 deletions.
8 changes: 6 additions & 2 deletions rasa/core/policies/flow_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ def predict_action_probabilities(

# create executor and predict next action
try:
prediction = flow_executor.advance_flows(tracker, domain, flows)
prediction = flow_executor.advance_flows(
tracker, domain.action_names_or_texts, flows
)
return self._create_prediction_result(
prediction.action_name,
domain,
Expand All @@ -162,7 +164,9 @@ def predict_action_probabilities(
# we retry, with the internal error frame on the stack
event = updated_stack.persist_as_event()
tracker.update(event)
prediction = flow_executor.advance_flows(tracker, domain, flows)
prediction = flow_executor.advance_flows(
tracker, domain.action_names_or_texts, flows
)
collected_events = [event] + (prediction.events or [])
return self._create_prediction_result(
prediction.action_name,
Expand Down
14 changes: 9 additions & 5 deletions rasa/core/policies/flows/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
CollectInformationFlowStep,
StaticFlowLink,
)
from rasa.shared.core.domain import Domain
from rasa.shared.core.trackers import (
DialogueStateTracker,
)
Expand Down Expand Up @@ -299,14 +298,16 @@ def _reset_slot(slot_name: Text, dialogue_tracker: DialogueStateTracker) -> None


def advance_flows(
tracker: DialogueStateTracker, domain: Domain, flows: FlowsList
tracker: DialogueStateTracker, available_actions: List[str], flows: FlowsList
) -> FlowActionPrediction:
"""Advance the flows.
Either start a new flow or advance the current flow.
Args:
tracker: The tracker to get the next action for.
available_actions: The actions that are available in the domain.
flows: All flows.
Returns:
The predicted action and the events to run.
Expand All @@ -317,7 +318,7 @@ def advance_flows(
return FlowActionPrediction(None, 0.0)

previous_stack = stack.as_dict()
prediction = select_next_action(stack, tracker, domain, flows)
prediction = select_next_action(stack, tracker, available_actions, flows)
if previous_stack != stack.as_dict():
# we need to update dialogue stack to persist the state of the executor
if not prediction.events:
Expand All @@ -329,7 +330,7 @@ def advance_flows(
def select_next_action(
stack: DialogueStack,
tracker: DialogueStateTracker,
domain: Domain,
available_actions: List[str],
flows: FlowsList,
) -> FlowActionPrediction:
"""Select the next action to execute.
Expand All @@ -340,7 +341,10 @@ def select_next_action(
advanced. If there are no more flows, the action listen is predicted.
Args:
stack: The stack to get the next action for.
tracker: The tracker to get the next action for.
available_actions: The actions that are available in the domain.
flows: All flows.
Returns:
The next action to execute, the events that should be applied to the
Expand Down Expand Up @@ -390,7 +394,7 @@ def select_next_action(
current_flow,
stack,
tracker,
domain.action_names_or_texts,
available_actions,
flows,
)
tracker.update_with_events(step_result.events)
Expand Down
1 change: 0 additions & 1 deletion tests/cli/test_rasa_evaluate_markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from rasa.shared.core.events import ActionExecuted, SlotSet, UserUttered
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.core.constants import ACTION_SESSION_START_NAME
from rasa.shared.core.domain import Domain
from rasa.core.tracker_store import SQLTrackerStore
from rasa.cli.evaluate import STATS_SESSION_SUFFIX, STATS_OVERALL_SUFFIX
from tests.conftest import write_endpoint_config_to_yaml
Expand Down
Loading

0 comments on commit 92476ca

Please sign in to comment.