diff --git a/changelog/4088.feature.rst b/changelog/4088.feature.rst new file mode 100644 index 000000000000..654e3b72c23e --- /dev/null +++ b/changelog/4088.feature.rst @@ -0,0 +1 @@ +Add story structure validation functionality (e.g. `rasa data validate stories --max-history 5`). diff --git a/data/test_stories/stories_conflicting_1.md b/data/test_stories/stories_conflicting_1.md new file mode 100644 index 000000000000..d772f46ee33a --- /dev/null +++ b/data/test_stories/stories_conflicting_1.md @@ -0,0 +1,15 @@ +## story 1 +* greet + - utter_greet +* greet + - utter_greet +* greet + - utter_greet + +## story 2 +* default + - utter_greet +* greet + - utter_greet +* greet + - utter_default diff --git a/data/test_stories/stories_conflicting_2.md b/data/test_stories/stories_conflicting_2.md new file mode 100644 index 000000000000..001b7087c700 --- /dev/null +++ b/data/test_stories/stories_conflicting_2.md @@ -0,0 +1,14 @@ +## greetings +* greet + - utter_greet +> check_greet + +## happy path +> check_greet +* default + - utter_default + +## problem +> check_greet +* default + - utter_goodbye diff --git a/data/test_stories/stories_conflicting_3.md b/data/test_stories/stories_conflicting_3.md new file mode 100644 index 000000000000..2218f6cea164 --- /dev/null +++ b/data/test_stories/stories_conflicting_3.md @@ -0,0 +1,14 @@ +## greetings +* greet + - utter_greet +> check_greet + +## happy path +> check_greet +* default OR greet + - utter_default + +## problem +> check_greet +* greet + - utter_goodbye diff --git a/data/test_stories/stories_conflicting_4.md b/data/test_stories/stories_conflicting_4.md new file mode 100644 index 000000000000..372c38ff6d15 --- /dev/null +++ b/data/test_stories/stories_conflicting_4.md @@ -0,0 +1,17 @@ +## story 1 +* greet + - utter_greet +* greet + - slot{"cuisine": "German"} + - utter_greet +* greet + - utter_greet + +## story 2 +* greet + - utter_greet +* greet + - slot{"cuisine": "German"} + - utter_greet +* greet + - utter_default diff --git a/data/test_stories/stories_conflicting_5.md b/data/test_stories/stories_conflicting_5.md new file mode 100644 index 000000000000..6865c9db9b4f --- /dev/null +++ b/data/test_stories/stories_conflicting_5.md @@ -0,0 +1,16 @@ +## story 1 +* greet + - utter_greet +* greet + - utter_greet + - slot{"cuisine": "German"} +* greet + - utter_greet + +## story 2 +* greet + - utter_greet +* greet + - utter_greet +* greet + - utter_default diff --git a/data/test_stories/stories_conflicting_6.md b/data/test_stories/stories_conflicting_6.md new file mode 100644 index 000000000000..f58dc258078e --- /dev/null +++ b/data/test_stories/stories_conflicting_6.md @@ -0,0 +1,22 @@ +## story 1 +* greet + - utter_greet + +## story 2 +* greet + - utter_default + +## story 3 +* greet + - utter_default +* greet + +## story 4 +* greet + - utter_default +* default + +## story 5 +* greet + - utter_default +* goodbye diff --git a/docs/user-guide/validate-files.rst b/docs/user-guide/validate-files.rst index c4d8f590c362..2421eea49e05 100644 --- a/docs/user-guide/validate-files.rst +++ b/docs/user-guide/validate-files.rst @@ -18,7 +18,8 @@ You can run it with the following command: rasa data validate -The script above runs all the validations on your files. Here is the list of options to +The script above runs all the validations on your files, except for story structure validation, +which is omitted unless you provide the ``--max-history`` argument. Here is the list of options to the script: .. program-output:: rasa data validate --help @@ -65,3 +66,53 @@ To use these functions it is necessary to create a `Validator` object and initia stories='data/stories.md') validator.verify_all() + +Test Story Files for Conflicts +------------------------------ + +In addition to the default tests described above, you can also do a more in-depth structural test of your stories. +In particular, you can test if your stories are inconsistent, i.e. if different bot actions follow from the same dialogue history. +If this is not the case, then Rasa cannot learn the correct behaviour. + +Take, for example, the following two stories: + +.. code-block:: md + + ## Story 1 + * greet + - utter_greet + * inform_happy + - utter_happy + - utter_goodbye + + ## Story 2 + * greet + - utter_greet + * inform_happy + - utter_goodbye + +These two stories are inconsistent, because Rasa doesn't know if it should predict ``utter_happy`` or ``utter_goodbye`` +after ``inform_happy``, as there is nothing that would distinguish the dialogue states at ``inform_happy`` in the two +stories and the subsequent actions are different in Story 1 and Story 2. + +This conflict can be automatically identified with our story structure validation tool. +To do this, use ``rasa data validate`` in the command line, as follows: + +.. code-block:: bash + + rasa data validate stories --max-history 3 + > 2019-12-09 09:32:13 INFO rasa.core.validator - Story structure validation... + > 2019-12-09 09:32:13 INFO rasa.core.validator - Assuming max_history = 3 + > Processed Story Blocks: 100% 2/2 [00:00<00:00, 3237.59it/s, # trackers=1] + > 2019-12-09 09:32:13 WARNING rasa.core.validator - CONFLICT after intent 'inform_happy': + > utter_goodbye predicted in 'Story 2' + > utter_happy predicted in 'Story 1' + +Here we specify a ``max-history`` value of 3. +This means, that 3 events (user messages / bot actions) are taken into account for action predictions, but the particular setting does not matter for this example, because regardless of how long of a history you take into account, the conflict always exists. + +.. warning:: + + The ``rasa data validate stories`` script assumes that all your **story names are unique**. + If your stories are in the Markdown format, you may find duplicate names with a command like + ``grep -h "##" data/*.md | uniq -c | grep "^[^1]"``. diff --git a/rasa/cli/data.py b/rasa/cli/data.py index edb7bb28b6b2..f5ed5d0f1da4 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -1,21 +1,22 @@ +import logging import argparse import asyncio -import sys from typing import List from rasa import data from rasa.cli.arguments import data as arguments -from rasa.cli.utils import get_validated_path +import rasa.cli.utils from rasa.constants import DEFAULT_DATA_PATH -from typing import NoReturn +from rasa.validator import Validator +from rasa.importers.rasa import RasaFileImporter + +logger = logging.getLogger(__name__) # noinspection PyProtectedMember def add_subparser( subparsers: argparse._SubParsersAction, parents: List[argparse.ArgumentParser] ): - import rasa.nlu.convert as convert - data_parser = subparsers.add_parser( "data", conflict_handler="resolve", @@ -26,6 +27,17 @@ def add_subparser( data_parser.set_defaults(func=lambda _: data_parser.print_help(None)) data_subparsers = data_parser.add_subparsers() + + _add_data_convert_parsers(data_subparsers, parents) + _add_data_split_parsers(data_subparsers, parents) + _add_data_validate_parsers(data_subparsers, parents) + + +def _add_data_convert_parsers( + data_subparsers, parents: List[argparse.ArgumentParser] +) -> None: + from rasa.nlu import convert + convert_parser = data_subparsers.add_parser( "convert", formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -45,6 +57,10 @@ def add_subparser( arguments.set_convert_arguments(convert_nlu_parser) + +def _add_data_split_parsers( + data_subparsers, parents: List[argparse.ArgumentParser] +) -> None: split_parser = data_subparsers.add_parser( "split", formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -65,21 +81,46 @@ def add_subparser( arguments.set_split_arguments(nlu_split_parser) + +def _add_data_validate_parsers( + data_subparsers, parents: List[argparse.ArgumentParser] +) -> None: validate_parser = data_subparsers.add_parser( "validate", formatter_class=argparse.ArgumentDefaultsHelpFormatter, parents=parents, help="Validates domain and data files to check for possible mistakes.", ) + _append_story_structure_arguments(validate_parser) validate_parser.set_defaults(func=validate_files) arguments.set_validator_arguments(validate_parser) + validate_subparsers = validate_parser.add_subparsers() + story_structure_parser = validate_subparsers.add_parser( + "stories", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + parents=parents, + help="Checks for inconsistencies in the story files.", + ) + _append_story_structure_arguments(story_structure_parser) + story_structure_parser.set_defaults(func=validate_stories) + arguments.set_validator_arguments(story_structure_parser) + + +def _append_story_structure_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--max-history", + type=int, + default=None, + help="Number of turns taken into account for story structure validation.", + ) -def split_nlu_data(args) -> None: + +def split_nlu_data(args: argparse.Namespace) -> None: from rasa.nlu.training_data.loading import load_data from rasa.nlu.training_data.util import get_file_format - data_path = get_validated_path(args.nlu, "nlu", DEFAULT_DATA_PATH) + data_path = rasa.cli.utils.get_validated_path(args.nlu, "nlu", DEFAULT_DATA_PATH) data_path = data.get_nlu_directory(data_path) nlu_data = load_data(data_path) @@ -91,22 +132,53 @@ def split_nlu_data(args) -> None: test.persist(args.out, filename=f"test_data.{fformat}") -def validate_files(args) -> NoReturn: - """Validate all files needed for training a model. - - Fails with a non-zero exit code if there are any errors in the data.""" - from rasa.core.validator import Validator - from rasa.importers.rasa import RasaFileImporter +def validate_files(args: argparse.Namespace, stories_only: bool = False) -> None: + """ + Validates either the story structure or the entire project. + Args: + args: Commandline arguments + stories_only: If `True`, only the story structure is validated. + """ loop = asyncio.get_event_loop() file_importer = RasaFileImporter( domain_path=args.domain, training_data_paths=args.data ) validator = loop.run_until_complete(Validator.from_importer(file_importer)) - domain_is_valid = validator.verify_domain_validity() - if not domain_is_valid: - sys.exit(1) - everything_is_alright = validator.verify_all(not args.fail_on_warnings) - sys.exit(0) if everything_is_alright else sys.exit(1) + if stories_only: + all_good = _validate_story_structure(validator, args) + else: + all_good = ( + _validate_domain(validator) + and _validate_nlu(validator, args) + and _validate_story_structure(validator, args) + ) + + if not all_good: + rasa.cli.utils.print_error_and_exit("Project validation completed with errors.") + + +def validate_stories(args: argparse.Namespace) -> None: + validate_files(args, stories_only=True) + + +def _validate_domain(validator: Validator) -> bool: + return validator.verify_domain_validity() + + +def _validate_nlu(validator: Validator, args: argparse.Namespace) -> bool: + return validator.verify_nlu(not args.fail_on_warnings) + + +def _validate_story_structure(validator: Validator, args: argparse.Namespace) -> bool: + # Check if a valid setting for `max_history` was given + if isinstance(args.max_history, int) and args.max_history < 1: + raise argparse.ArgumentTypeError( + f"The value of `--max-history {args.max_history}` is not a positive integer.", + ) + + return validator.verify_story_structure( + not args.fail_on_warnings, max_history=args.max_history + ) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py new file mode 100644 index 000000000000..2510608a68b8 --- /dev/null +++ b/rasa/core/training/story_conflict.py @@ -0,0 +1,323 @@ +import logging +from collections import defaultdict, namedtuple +from typing import List, Optional, Dict, Text, Tuple, Generator, NamedTuple + +from rasa.core.actions.action import ACTION_LISTEN_NAME +from rasa.core.domain import PREV_PREFIX, Domain +from rasa.core.events import ActionExecuted, Event +from rasa.core.featurizers import MaxHistoryTrackerFeaturizer +from rasa.nlu.constants import INTENT_ATTRIBUTE +from rasa.core.training.generator import TrackerWithCachedStates + +logger = logging.getLogger(__name__) + + +class StoryConflict: + """ + Represents a conflict between two or more stories. + + Here, a conflict means that different actions are supposed to follow from + the same dialogue state, which most policies cannot learn. + + Attributes: + conflicting_actions: A list of actions that all follow from the same state. + conflict_has_prior_events: If `False`, then the conflict occurs without any + prior events (i.e. at the beginning of a dialogue). + """ + + def __init__(self, sliced_states: List[Optional[Dict[Text, float]]],) -> None: + """ + Creates a `StoryConflict` from a given state. + + Args: + sliced_states: The (sliced) dialogue state at which the conflict occurs. + """ + self._sliced_states = sliced_states + self._conflicting_actions = defaultdict( + list + ) # {"action": ["story_1", ...], ...} + + def __hash__(self) -> int: + return hash(str(list(self._sliced_states))) + + def add_conflicting_action(self, action: Text, story_name: Text) -> None: + """Adds another action that follows from the same state. + + Args: + action: Name of the action. + story_name: Name of the story where this action is chosen. + """ + self._conflicting_actions[action] += [story_name] + + @property + def conflicting_actions(self) -> List[Text]: + """List of conflicting actions. + + Returns: + List of conflicting actions. + + """ + return list(self._conflicting_actions.keys()) + + @property + def conflict_has_prior_events(self) -> bool: + """Checks if prior events exist. + + Returns: + `True` if anything has happened before this conflict, otherwise `False`. + """ + return _get_previous_event(self._sliced_states[-1])[0] is not None + + def __str__(self) -> Text: + # Describe where the conflict occurs in the stories + last_event_type, last_event_name = _get_previous_event(self._sliced_states[-1]) + if last_event_type: + conflict_message = f"Story structure conflict after {last_event_type} '{last_event_name}':\n" + else: + conflict_message = ( + f"Story structure conflict at the beginning of stories:\n" + ) + + # List which stories are in conflict with one another + for action, stories in self._conflicting_actions.items(): + conflict_message += ( + f" {self._summarize_conflicting_actions(action, stories)}" + ) + + return conflict_message + + @staticmethod + def _summarize_conflicting_actions(action: Text, stories: List[Text]) -> Text: + """Gives a summarized textual description of where one action occurs. + + Args: + action: The name of the action. + stories: The stories in which the action occurs. + + Returns: + A textural summary. + """ + if len(stories) > 3: + # Four or more stories are present + conflict_description = ( + f"'{stories[0]}', '{stories[1]}', and {len(stories) - 2} other trackers" + ) + elif len(stories) == 3: + conflict_description = f"'{stories[0]}', '{stories[1]}', and '{stories[2]}'" + elif len(stories) == 2: + conflict_description = f"'{stories[0]}' and '{stories[1]}'" + elif len(stories) == 1: + conflict_description = f"'{stories[0]}'" + else: + raise ValueError( + "An internal error occurred while trying to summarise a conflict without stories. " + "Please file a bug report at https://github.com/RasaHQ/rasa." + ) + + return f"{action} predicted in {conflict_description}\n" + + +class TrackerEventStateTuple(NamedTuple): + """Holds a tracker, an event, and sliced states associated with those.""" + + tracker: TrackerWithCachedStates + event: Event + sliced_states: List[Dict[Text, float]] + + @property + def sliced_states_hash(self) -> int: + return hash(str(list(self.sliced_states))) + + +def _get_length_of_longest_story( + trackers: List[TrackerWithCachedStates], domain: Domain +) -> int: + """Returns the longest story in the given trackers. + + Args: + trackers: Trackers to get stories from. + domain: The domain. + + Returns: + The maximal length of any story + """ + return max([len(tracker.past_states(domain)) for tracker in trackers]) + + +def find_story_conflicts( + trackers: List[TrackerWithCachedStates], + domain: Domain, + max_history: Optional[int] = None, +) -> List[StoryConflict]: + """Generates `StoryConflict` objects, describing conflicts in the given trackers. + + Args: + trackers: Trackers in which to search for conflicts. + domain: The domain. + max_history: The maximum history length to be taken into account. + + Returns: + StoryConflict objects. + """ + if not max_history: + max_history = _get_length_of_longest_story(trackers, domain) + + logger.info(f"Considering the preceding {max_history} turns for conflict analysis.") + + # We do this in two steps, to reduce memory consumption: + + # Create a 'state -> list of actions' dict, where the state is + # represented by its hash + conflicting_state_action_mapping = _find_conflicting_states( + trackers, domain, max_history + ) + + # Iterate once more over all states and note the (unhashed) state, + # for which a conflict occurs + conflicts = _build_conflicts_from_states( + trackers, domain, max_history, conflicting_state_action_mapping + ) + + return conflicts + + +def _find_conflicting_states( + trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int +) -> Dict[int, Optional[List[Text]]]: + """Identifies all states from which different actions follow. + + Args: + trackers: Trackers that contain the states. + domain: The domain object. + max_history: Number of turns to take into account for the state descriptions. + + Returns: + A dictionary mapping state-hashes to a list of actions that follow from each state. + """ + # Create a 'state -> list of actions' dict, where the state is + # represented by its hash + state_action_mapping = defaultdict(list) + for element in _sliced_states_iterator(trackers, domain, max_history): + hashed_state = element.sliced_states_hash + if element.event.as_story_string() not in state_action_mapping[hashed_state]: + state_action_mapping[hashed_state] += [element.event.as_story_string()] + + # Keep only conflicting `state_action_mapping`s + return { + state_hash: actions + for (state_hash, actions) in state_action_mapping.items() + if len(actions) > 1 + } + + +def _build_conflicts_from_states( + trackers: List[TrackerWithCachedStates], + domain: Domain, + max_history: int, + conflicting_state_action_mapping: Dict[int, Optional[List[Text]]], +) -> List["StoryConflict"]: + """Builds a list of `StoryConflict` objects for each given conflict. + + Args: + trackers: Trackers that contain the states. + domain: The domain object. + max_history: Number of turns to take into account for the state descriptions. + conflicting_state_action_mapping: A dictionary mapping state-hashes to a list of actions + that follow from each state. + + Returns: + A list of `StoryConflict` objects that describe inconsistencies in the story + structure. These objects also contain the history that leads up to the conflict. + """ + # Iterate once more over all states and note the (unhashed) state, + # for which a conflict occurs + conflicts = {} + for element in _sliced_states_iterator(trackers, domain, max_history): + hashed_state = element.sliced_states_hash + + if hashed_state in conflicting_state_action_mapping: + if hashed_state not in conflicts: + conflicts[hashed_state] = StoryConflict(element.sliced_states) + + conflicts[hashed_state].add_conflicting_action( + action=element.event.as_story_string(), + story_name=element.tracker.sender_id, + ) + + # Return list of conflicts that arise from unpredictable actions + # (actions that start the conversation) + return [ + conflict + for (hashed_state, conflict) in conflicts.items() + if conflict.conflict_has_prior_events + ] + + +def _sliced_states_iterator( + trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int +) -> Generator[TrackerEventStateTuple, None, None]: + """Creates an iterator over sliced states. + + Iterate over all given trackers and all sliced states within each tracker, + where the slicing is based on `max_history`. + + Args: + trackers: List of trackers. + domain: Domain (used for tracker.past_states). + max_history: Assumed `max_history` value for slicing. + + Yields: + A (tracker, event, sliced_states) triplet. + """ + for tracker in trackers: + states = tracker.past_states(domain) + states = [dict(state) for state in states] + + idx = 0 + for event in tracker.events: + if isinstance(event, ActionExecuted): + sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( + states[: idx + 1], max_history + ) + yield TrackerEventStateTuple(tracker, event, sliced_states) + idx += 1 + + +def _get_previous_event( + state: Optional[Dict[Text, float]] +) -> Tuple[Optional[Text], Optional[Text]]: + """Returns previous event type and name. + + Returns the type and name of the event (action or intent) previous to the + given state. + + Args: + state: Element of sliced states. + + Returns: + Tuple of (type, name) strings of the prior event. + """ + + previous_event_type = None + previous_event_name = None + + if not state: + return previous_event_type, previous_event_name + + # A typical state is, for example, + # `{'prev_action_listen': 1.0, 'intent_greet': 1.0, 'slot_cuisine_0': 1.0}`. + # We need to look out for `prev_` and `intent_` prefixes in the labels. + for turn_label in state: + if ( + turn_label.startswith(PREV_PREFIX) + and turn_label.replace(PREV_PREFIX, "") != ACTION_LISTEN_NAME + ): + # The `prev_...` was an action that was NOT `action_listen` + return "action", turn_label.replace(PREV_PREFIX, "") + elif turn_label.startswith(INTENT_ATTRIBUTE + "_"): + # We found an intent, but it is only the previous event if + # the `prev_...` was `prev_action_listen`, so we don't return. + previous_event_type = "intent" + previous_event_name = turn_label.replace(INTENT_ATTRIBUTE + "_", "") + + return previous_event_type, previous_event_name diff --git a/rasa/core/validator.py b/rasa/validator.py similarity index 80% rename from rasa/core/validator.py rename to rasa/validator.py index bb3674935d80..492e4366c21f 100644 --- a/rasa/core/validator.py +++ b/rasa/validator.py @@ -1,13 +1,16 @@ import logging from collections import defaultdict -from typing import List, Set, Text - -from rasa.constants import DOCS_URL_DOMAINS, DOCS_URL_ACTIONS -from rasa.core.constants import UTTER_PREFIX +from typing import Set, Text, Optional from rasa.core.domain import Domain -from rasa.core.training.dsl import ActionExecuted, StoryStep, UserUttered +from rasa.core.training.generator import TrainingDataGenerator from rasa.importers.importer import TrainingDataImporter from rasa.nlu.training_data import TrainingData +from rasa.core.training.structures import StoryGraph +from rasa.core.training.dsl import UserUttered +from rasa.core.training.dsl import ActionExecuted +from rasa.core.constants import UTTER_PREFIX +import rasa.core.training.story_conflict +from rasa.constants import DOCS_URL_DOMAINS, DOCS_URL_ACTIONS from rasa.utils.common import raise_warning logger = logging.getLogger(__name__) @@ -16,22 +19,24 @@ class Validator: """A class used to verify usage of intents and utterances.""" - def __init__(self, domain: Domain, intents: TrainingData, stories: List[StoryStep]): + def __init__( + self, domain: Domain, intents: TrainingData, story_graph: StoryGraph + ) -> None: """Initializes the Validator object. """ self.domain = domain self.intents = intents - self.stories = stories + self.story_graph = story_graph @classmethod async def from_importer(cls, importer: TrainingDataImporter) -> "Validator": """Create an instance from the domain, nlu and story files.""" domain = await importer.get_domain() - stories = await importer.get_stories() + story_graph = await importer.get_stories() intents = await importer.get_nlu_data() - return cls(domain, intents, stories.story_steps) + return cls(domain, intents, story_graph) def verify_intents(self, ignore_warnings: bool = True) -> bool: """Compares list of intents in domain with intents in NLU training data.""" @@ -95,7 +100,7 @@ def verify_intents_in_stories(self, ignore_warnings: bool = True) -> bool: stories_intents = { event.intent["name"] - for story in self.stories + for story in self.story_graph.story_steps for event in story.events if type(event) == UserUttered } @@ -165,7 +170,7 @@ def verify_utterances_in_stories(self, ignore_warnings: bool = True) -> bool: utterance_actions = self._gather_utterance_actions() stories_utterances = set() - for story in self.stories: + for story in self.story_graph.story_steps: for event in story.events: if not isinstance(event, ActionExecuted): continue @@ -195,13 +200,49 @@ def verify_utterances_in_stories(self, ignore_warnings: bool = True) -> bool: return everything_is_alright - def verify_all(self, ignore_warnings: bool = True) -> bool: + def verify_story_structure( + self, ignore_warnings: bool = True, max_history: Optional[int] = None + ) -> bool: + """Verifies that the bot behaviour in stories is deterministic. + + Args: + ignore_warnings: When `True`, return `True` even if conflicts were found. + max_history: Maximal number of events to take into account for conflict identification. + + Returns: + `False` is a conflict was found and `ignore_warnings` is `False`. + `True` otherwise. + """ + + logger.info("Story structure validation...") + + trackers = TrainingDataGenerator( + self.story_graph, + domain=self.domain, + remove_duplicates=False, + augmentation_factor=0, + ).generate() + + # Create a list of `StoryConflict` objects + conflicts = rasa.core.training.story_conflict.find_story_conflicts( + trackers, self.domain, max_history + ) + + if not conflicts: + logger.info("No story structure conflicts found.") + else: + for conflict in conflicts: + logger.warning(conflict) + + return ignore_warnings or not conflicts + + def verify_nlu(self, ignore_warnings: bool = True) -> bool: """Runs all the validations on intents and utterances.""" logger.info("Validating intents...") intents_are_valid = self.verify_intents_in_stories(ignore_warnings) - logger.info("Validating there is no duplications...") + logger.info("Validating uniqueness of intents and stories...") there_is_no_duplication = self.verify_example_repetition_in_intents( ignore_warnings ) diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py index 99c3acd4631b..2ac6e97f47e3 100644 --- a/tests/cli/conftest.py +++ b/tests/cli/conftest.py @@ -52,6 +52,8 @@ def run_in_default_project( def do_run(*args): args = ["rasa"] + list(args) - return testdir.run(*args) + result = testdir.run(*args) + os.environ["LOG_LEVEL"] = "INFO" + return result return do_run diff --git a/tests/cli/test_rasa_data.py b/tests/cli/test_rasa_data.py index 3021e9ab12e7..07e54e8c47ed 100644 --- a/tests/cli/test_rasa_data.py +++ b/tests/cli/test_rasa_data.py @@ -1,9 +1,15 @@ +import argparse import os +from unittest.mock import Mock import pytest from collections import namedtuple -from typing import Callable +from typing import Callable, Text + +from _pytest.monkeypatch import MonkeyPatch from _pytest.pytester import RunResult from rasa.cli import data +from rasa.importers.importer import TrainingDataImporter +from rasa.validator import Validator def test_data_split_nlu(run_in_default_project: Callable[..., RunResult]): @@ -60,8 +66,8 @@ def test_data_convert_help(run: Callable[..., RunResult]): def test_data_validate_help(run: Callable[..., RunResult]): output = run("data", "validate", "--help") - help_text = """usage: rasa data validate [-h] [-v] [-vv] [--quiet] [--fail-on-warnings] - [-d DOMAIN] [--data DATA]""" + help_text = """usage: rasa data validate [-h] [-v] [-vv] [--quiet] + [--max-history MAX_HISTORY] [--fail-on-warnings]""" lines = help_text.split("\n") @@ -69,9 +75,37 @@ def test_data_validate_help(run: Callable[..., RunResult]): assert output.outlines[i] == line +def _text_is_part_of_output_error(text: Text, output: RunResult) -> bool: + found_info_string = False + for line in output.errlines: + if text in line: + found_info_string = True + return found_info_string + + +def test_data_validate_stories_with_max_history_zero(monkeypatch: MonkeyPatch): + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(help="Rasa commands") + data.add_subparser(subparsers, parents=[]) + + args = parser.parse_args(["data", "validate", "stories", "--max-history", 0]) + + async def mock_from_importer(importer: TrainingDataImporter) -> Validator: + return Mock() + + monkeypatch.setattr("rasa.validator.Validator.from_importer", mock_from_importer) + + with pytest.raises(argparse.ArgumentTypeError): + data.validate_files(args) + + def test_validate_files_exit_early(): with pytest.raises(SystemExit) as pytest_e: - args = {"domain": "data/test_domains/duplicate_intents.yml", "data": None} + args = { + "domain": "data/test_domains/duplicate_intents.yml", + "data": None, + "max_history": None, + } data.validate_files(namedtuple("Args", args.keys())(*args.values())) assert pytest_e.type == SystemExit diff --git a/tests/core/test_story_conflict.py b/tests/core/test_story_conflict.py new file mode 100644 index 000000000000..1a426850a6b9 --- /dev/null +++ b/tests/core/test_story_conflict.py @@ -0,0 +1,160 @@ +from typing import Text, List, Tuple + +from rasa.core.domain import Domain +from rasa.core.training.story_conflict import ( + StoryConflict, + find_story_conflicts, + _get_previous_event, +) +from rasa.core.training.generator import TrainingDataGenerator, TrackerWithCachedStates +from rasa.validator import Validator +from rasa.importers.rasa import RasaFileImporter +from tests.core.conftest import DEFAULT_STORIES_FILE, DEFAULT_DOMAIN_PATH_WITH_SLOTS + + +async def _setup_trackers_for_testing( + domain_path: Text, training_data_file: Text +) -> Tuple[List[TrackerWithCachedStates], Domain]: + importer = RasaFileImporter( + domain_path=domain_path, training_data_paths=[training_data_file], + ) + validator = await Validator.from_importer(importer) + + trackers = TrainingDataGenerator( + validator.story_graph, + domain=validator.domain, + remove_duplicates=False, + augmentation_factor=0, + ).generate() + + return trackers, validator.domain + + +async def test_find_no_conflicts(): + trackers, domain = await _setup_trackers_for_testing( + DEFAULT_DOMAIN_PATH_WITH_SLOTS, DEFAULT_STORIES_FILE + ) + + # Create a list of `StoryConflict` objects + conflicts = find_story_conflicts(trackers, domain, 5) + + assert conflicts == [] + + +async def test_find_conflicts_in_short_history(): + trackers, domain = await _setup_trackers_for_testing( + "data/test_domains/default.yml", "data/test_stories/stories_conflicting_1.md" + ) + + # `max_history = 3` is too small, so a conflict must arise + conflicts = find_story_conflicts(trackers, domain, 3) + assert len(conflicts) == 1 + + # With `max_history = 4` the conflict should disappear + conflicts = find_story_conflicts(trackers, domain, 4) + assert len(conflicts) == 0 + + +async def test_find_conflicts_checkpoints(): + trackers, domain = await _setup_trackers_for_testing( + "data/test_domains/default.yml", "data/test_stories/stories_conflicting_2.md" + ) + + # Create a list of `StoryConflict` objects + conflicts = find_story_conflicts(trackers, domain, 5) + + assert len(conflicts) == 1 + assert conflicts[0].conflicting_actions == ["utter_goodbye", "utter_default"] + + +async def test_find_conflicts_or(): + trackers, domain = await _setup_trackers_for_testing( + "data/test_domains/default.yml", "data/test_stories/stories_conflicting_3.md" + ) + + # Create a list of `StoryConflict` objects + conflicts = find_story_conflicts(trackers, domain, 5) + + assert len(conflicts) == 1 + assert conflicts[0].conflicting_actions == ["utter_default", "utter_goodbye"] + + +async def test_find_conflicts_slots_that_break(): + trackers, domain = await _setup_trackers_for_testing( + "data/test_domains/default.yml", "data/test_stories/stories_conflicting_4.md" + ) + + # Create a list of `StoryConflict` objects + conflicts = find_story_conflicts(trackers, domain, 5) + + assert len(conflicts) == 1 + assert conflicts[0].conflicting_actions == ["utter_default", "utter_greet"] + + +async def test_find_conflicts_slots_that_dont_break(): + trackers, domain = await _setup_trackers_for_testing( + "data/test_domains/default.yml", "data/test_stories/stories_conflicting_5.md" + ) + + # Create a list of `StoryConflict` objects + conflicts = find_story_conflicts(trackers, domain, 5) + + assert len(conflicts) == 0 + + +async def test_find_conflicts_multiple_stories(): + trackers, domain = await _setup_trackers_for_testing( + "data/test_domains/default.yml", "data/test_stories/stories_conflicting_6.md" + ) + + # Create a list of `StoryConflict` objects + conflicts = find_story_conflicts(trackers, domain, 5) + + assert len(conflicts) == 1 + assert "and 2 other trackers" in str(conflicts[0]) + + +async def test_add_conflicting_action(): + sliced_states = [ + None, + {}, + {"intent_greet": 1.0, "prev_action_listen": 1.0}, + {"prev_utter_greet": 1.0, "intent_greet": 1.0}, + ] + conflict = StoryConflict(sliced_states) + + conflict.add_conflicting_action("utter_greet", "xyz") + conflict.add_conflicting_action("utter_default", "uvw") + assert conflict.conflicting_actions == ["utter_greet", "utter_default"] + + +async def test_has_prior_events(): + sliced_states = [ + None, + {}, + {"intent_greet": 1.0, "prev_action_listen": 1.0}, + {"prev_utter_greet": 1.0, "intent_greet": 1.0}, + ] + conflict = StoryConflict(sliced_states) + assert conflict.conflict_has_prior_events + + +async def test_get_previous_event(): + assert _get_previous_event({"prev_utter_greet": 1.0, "intent_greet": 1.0}) == ( + "action", + "utter_greet", + ) + assert _get_previous_event({"intent_greet": 1.0, "prev_utter_greet": 1.0}) == ( + "action", + "utter_greet", + ) + assert _get_previous_event({"intent_greet": 1.0, "prev_action_listen": 1.0}) == ( + "intent", + "greet", + ) + + +async def test_has_no_prior_events(): + sliced_states = [None] + conflict = StoryConflict(sliced_states) + assert not conflict.conflict_has_prior_events diff --git a/tests/core/test_validator.py b/tests/core/test_validator.py index 3395732547ee..c9a9665f18a8 100644 --- a/tests/core/test_validator.py +++ b/tests/core/test_validator.py @@ -1,14 +1,10 @@ import pytest -import logging -from rasa.core.validator import Validator +from rasa.validator import Validator from rasa.importers.rasa import RasaFileImporter from tests.core.conftest import ( - DEFAULT_DOMAIN_PATH_WITH_SLOTS, DEFAULT_STORIES_FILE, DEFAULT_NLU_DATA, ) -from rasa.core.domain import Domain -from rasa.nlu.training_data import TrainingData import rasa.utils.io as io_utils @@ -40,6 +36,33 @@ async def test_verify_valid_utterances(): assert validator.verify_utterances() +async def test_verify_story_structure(): + importer = RasaFileImporter( + domain_path="data/test_domains/default.yml", + training_data_paths=[DEFAULT_STORIES_FILE], + ) + validator = await Validator.from_importer(importer) + assert validator.verify_story_structure(ignore_warnings=False) + + +async def test_verify_bad_story_structure(): + importer = RasaFileImporter( + domain_path="data/test_domains/default.yml", + training_data_paths=["data/test_stories/stories_conflicting_2.md"], + ) + validator = await Validator.from_importer(importer) + assert not validator.verify_story_structure(ignore_warnings=False) + + +async def test_verify_bad_story_structure_ignore_warnings(): + importer = RasaFileImporter( + domain_path="data/test_domains/default.yml", + training_data_paths=["data/test_stories/stories_conflicting_2.md"], + ) + validator = await Validator.from_importer(importer) + assert validator.verify_story_structure(ignore_warnings=True) + + async def test_fail_on_invalid_utterances(tmpdir): # domain and stories are from different domain and should produce warnings invalid_domain = str(tmpdir / "invalid_domain.yml")