From cb3bc373dd6e36405e2a225fbead764f0fb35b43 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 9 Oct 2019 11:02:01 +0200 Subject: [PATCH 001/209] Copy story_tree.py script --- rasa/utils/story_tree.py | 684 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 684 insertions(+) create mode 100644 rasa/utils/story_tree.py diff --git a/rasa/utils/story_tree.py b/rasa/utils/story_tree.py new file mode 100644 index 000000000000..35bf5c414ce2 --- /dev/null +++ b/rasa/utils/story_tree.py @@ -0,0 +1,684 @@ +"""This script generates a tree diagram from a story file.""" + +import argparse +import os +from termcolor import colored as term_colored +import pydoc +from math import log2 +import re +import ast + +################################################# +# Command line argument specification +################################################# + +arg_parser = argparse.ArgumentParser(description="Represent story tree.") +arg_parser.add_argument("input", type=str, nargs="+", help="Input story file name(s)") +arg_parser.add_argument("--ambiguities", "-a", default="false", choices=["true", "false", "wizard"], + const="true", nargs="?", help="Display ambiguous branches only") +arg_parser.add_argument("--max-depth", "-m", type=int, default=None, help="Maximum depth") +arg_parser.add_argument("--coloring", "-c", default="role", choices=["role", "r", "ambiguities", "a", "depth", "none", + "n"]) +arg_parser.add_argument("--color-code", default="terminal", choices=["terminal", "markdown"]) +arg_parser.add_argument("--branch", "-b", default="", + help="Restrict output to the given branch of the tree, separated by '/'") +arg_parser.add_argument("--stories", "-s", nargs="+", default=None, + help="Restrict output to the given stories.") +arg_parser.add_argument("--labels", "-l", action="store_true", default=False, + help="Show the names of stories at each node that has no siblings") +arg_parser.add_argument("--page", "-p", action="store_true", default=False, + help="Use pagination for output if necessary") +arg_parser.add_argument("--prune", default="most-visited", choices=["first", "last", "most-visited"], + const="First", nargs="?", help="Selection criterion for kept branches during pruning") +arg_parser.add_argument("--output", "-o", default="tree", choices=["tree", "stats", "pruned", "ok"], + nargs="+", help="What to display") +arg_parser.add_argument("--merge", nargs=1, help="Merge given story file into main file, avoiding ambiguities") + +################################################# +# Helper functions +################################################# + +color_code = "terminal" + + +def colored(string, color): + # noinspection PyUnresolvedReferences + global color_code + if color_code == "terminal": + return term_colored(string, color) + elif color_code == "markdown": + return f"" + string + "" + elif color_code == "none": + return string + else: + raise ValueError("Invalid color_code. Must be one of \"terminal\", \"markdown\", or \"none\"") + + +def slot_to_dict(string): + regex_slot = r"- slot\{([^\}]*)\}$" # Groups: slots + regex_intent = r"\* ([^_\{]+)(?:\{([^\}]*)})?$" # Groups: name | slots + + match = re.search(regex_slot, string) + if match: + return ast.literal_eval("{" + match.group(1) + "}") + else: + match = re.search(regex_intent, string) + if match: + if match.group(2): + return ast.literal_eval("{" + match.group(2) + "}") + return {} + + +################################################# +# Classes +################################################# + +# =============================================== +# Node +# Represents a node in a tree graph +# =============================================== + +class Node: + + def __init__(self, name="root", parent=None, story=""): + self.count = 1 + self.name = name + self.parent = parent + self.children = [] + self.labels = [story] + self._is_pruned = False + + def add_child(self, child): + """ + Add the child node `child`, unless it exists already, in which case an error is raised. + :type child: Node + :param child: Child node + """ + if self.get_child(child.name) is None: + self.children.append(child) + else: + raise ValueError(f"A child with the name {child.name} already exists!") + pass + + def get_child(self, name: str): + """ + Get the child with the given name + :param name: Name of the sought child + :return: Child node + """ + for child in self.children: + if child.name == name: + return child + return None + + def print_string(self, branch="", stories=None, max_depth=None, only_ambiguous=False, show_labels=False, + coloring="role", include_users=True, _depth=0, _has_siblings=False): + """ + Recursively generate a string representation of the tree with this node as root. + :param stories: Restrict output to given stories + :param branch: Restrict output to given branch (overwrites `stories`) + :param max_depth: Go no deeper than this + :param show_labels: Indicate branch labels on non-branch points + :param only_ambiguous: Only output ambiguous branches + :param coloring: Coloring rule ('role', 'depth', 'ambiguities') + :param _depth: Recursion depth - only set in recursion step! + :param include_users: When `only_ambiguous` is `True`, include ambiguous user responses + :param _has_siblings: True, iff this node has siblings - only set in recursion step! + :return: The generated string + """ + + # Abort recursion if max depth is reached + if max_depth and _depth >= max_depth: + return "" + + # Decide how to color the present node + if coloring[0] == "r": # "role" + color = {"S": "yellow", "U": "blue", "W": "green"}.get(self.name[0]) + elif coloring[0] == "d": # "depth" + color = {1: "green", 2: "magenta", 3: "yellow", 4: "cyan", 5: "blue", 6: "grey"}.get(_depth, "grey") + elif coloring[0] == "n": # "none" + color = "none" + elif coloring[0] == "a": # "ambiguities" + if _has_siblings: + color = "red" + elif self.has_descendants_with_siblings: + color = "yellow" + else: + color = "white" + else: + raise ValueError(f"Invalid coloring \"{coloring}\". Must be one of 'roles', 'depth', " + f"'ambiguities' or 'none'.") + + # If only ambiguous nodes should be printed, then print only if there are siblings or descendants with siblings + if (not only_ambiguous) or _has_siblings or self.has_descendants_with_siblings(include_users): + + # Visit count indicator for non-root nodes only + count_str = f" ({self.count})" if _depth > 0 else "" + + # Show branch labels iff visit count is 1 + if show_labels and self.count == 1: + result = "+" + "-" * (2 * _depth) + " " + colored(self.name + count_str, color) \ + + f" <{self.labels[0]}>" + os.linesep + else: + result = "+" + "-" * (2 * _depth) + " " + colored(self.name + count_str, color) \ + + os.linesep + else: + # We show only ambiguous branches, and this node is not root of an ambiguous branch + result = "" + + # Prepare _has_siblings for recursion step + has_siblings = (len(self.children) > 1) + if has_siblings and not include_users: + all_children_are_users = all(child.name.startswith("U:") or child.name.startswith("S:") for child in self.children) + has_siblings = not all_children_are_users + + # Recursion step into all child nodes + if branch: + # Output should be restricted to `branch` + path = branch.split("/") # Split the branch spec into one for each level + sought_name = path[0] # First entry is where we should step into now + remain_branch = "/".join(path[1:]) # Remaining entries have to be passed on to recursive call + + for child in self.children: + if child.name == sought_name or sought_name == "*": + result += child.print_string(branch=remain_branch, + stories=stories, + max_depth=max_depth, + only_ambiguous=only_ambiguous, + show_labels=show_labels, + include_users=include_users, + coloring=coloring, + _depth=_depth + 1, + _has_siblings=has_siblings) + else: + # No branch restriction -> step into all child branches unless stories are restricted + for child in self.children: + if stories is None or not set(child.labels).isdisjoint(stories): + result += child.print_string(branch=branch, + stories=stories, + max_depth=max_depth, + only_ambiguous=only_ambiguous, + show_labels=show_labels, + include_users=include_users, + coloring=coloring, + _depth=_depth + 1, + _has_siblings=has_siblings) + return result + + def has_descendants_with_siblings(self, include_users): + """ + Boolean that indicates if there are any descendants that have siblings. + :return: True, iff a descendant node has siblings + """ + if len(self.children) > 1: + if include_users: + return True + else: + all_children_are_users = all(child.name.startswith("U:") or child.name.startswith("S:") for child in self.children) + if all_children_are_users: + return any(child.has_descendants_with_siblings(include_users) for child in self.children) + else: + return True + elif len(self.children) == 1: + return list(self.children)[0].has_descendants_with_siblings(include_users) + else: + return False + + def __str__(self): + return self.print_string() + + def prune(self, keep: str): + """Removes all ambiguous branches""" + if self._is_pruned: + return + if len(self.children) > 0: + if len(self.children) > 1: + if any([child.name.startswith("W:") for child in self.children]): + if keep == "first": + del self.children[1:] + elif keep == "last": + del self.children[:-1] + elif keep == "most-visited": + visit_counts = [len(c.labels) for c in self.children] + keep_idx = visit_counts.index(max(visit_counts)) + # Delete all but the one at `keep_idx` + del self.children[:keep_idx] + if len(self.children) > 1: + del self.children[1:] + else: + raise ValueError("Invalid prune keep criterion.") + for child in self.children: + child.prune(keep) + self._is_pruned = True + + def remove(self, story) -> bool: + """Remove the given story from this node and recursively from all + descendants. """ + if story in self.labels: + # Remove the story from internal stats + self.labels = [label for label in self.labels if label != story] + # Recurse through all children + new_children = [] + for child in self.children: + if not child.remove(story): + # Only retain children that did not self-delete + new_children.append(child) + else: + # Delete this child node + self.count -= 1 + del child + + self.children = new_children + + assert (len(self.labels) == 0) == (self.count == 0) + + # If this node had no other stories than the one we just + # deleted, then return True, and False otherwise + return len(self.labels) == 0 + + @property + def leafs(self): + leafs = set() + + # noinspection PyUnusedLocal + def callback_discover_leaf(node, *args): + assert len(node.labels) >= 1, f"Leaf has no story assigned!" + # Leafs may have multiple stories assigned, iff stories have duplicates + # Ignore duplicates iff the tree was pruned + if self._is_pruned: + leafs.add(node.labels[0]) + else: + for story in node.labels: + leafs.add(story) + return {} + + self._depth_first_search({"discover_leaf": callback_discover_leaf}, {}) + + return leafs + + @property + def duplicates(self): + duplicates = [] + + # noinspection PyUnusedLocal + def callback_discover_leaf(node, *args): + assert len(node.labels) >= 1, f"Leaf has no story assigned!" + # Leafs have multiple stories assigned, iff stories have duplicates + if len(node.labels) > 1: + duplicates.append(node.labels) + return {} + + self._depth_first_search({"discover_leaf": callback_discover_leaf}, {}) + + return duplicates + + def stats(self): + """ + Collects statistics about the tree that has this node as a root. + :return: Dict with statistical information + """ + statistics = { + "num_nodes": 0, # Total number of nodes in the tree (steps in all dialogues) + "num_nodes_with_multiple_children": 0, # Number of nodes that have multiple children + "num_leaves": 0, # How many stories are present? + "depth": 0, # How deep is the graph + "ambiguity_depth": 0, # How deep is the deepest branch point? + "ambiguity_chain_length": 0, # How many branch points follow each other (max)? + "ambiguity_level": 0, # How many leaves are connected to root via branch points? + "story_stats": {} # Stats about individual stories + } + + def callback_discover_node(node, depth, flags): + statistics["num_nodes"] += 1 + if len(node.children) > 1: + if any([child.name.startswith("W:") for child in node.children]): + statistics["num_nodes_with_multiple_children"] += 1 + statistics["ambiguity_depth"] = max(statistics["ambiguity_depth"], depth) + statistics["ambiguity_chain_length"] = max(statistics["ambiguity_chain_length"], + flags["ambiguity_chain_length"] + 1) + if flags["linear_so_far"]: + statistics["ambiguity_level"] += node.count + for story in node.labels: + if story in statistics["story_stats"]: + statistics["story_stats"][story]["ambiguity_length"] += 1 + statistics["story_stats"][story]["related_to"].update(node.labels) + else: + statistics["story_stats"][story] = { + "length": depth, + "ambiguity_length": 1, + "related_to": set(node.labels) + } + return { + "linear_so_far": False, + "ambiguity_chain_length": flags["ambiguity_chain_length"] + 1 + } + return {} + + # noinspection PyUnusedLocal + def callback_discover_leaf(node, depth, flags): + statistics["num_leaves"] += 1 + statistics["depth"] = max(statistics["depth"], depth) + story = node.labels[0] + if story in statistics["story_stats"]: + statistics["story_stats"][story]["length"] = depth + return {} + + self._depth_first_search({ + "discover_node": callback_discover_node, + "discover_leaf": callback_discover_leaf + }, flags={"linear_so_far": True, "ambiguity_chain_length": 0}) + + return statistics + + def _depth_first_search(self, callbacks, flags, _depth=0): + new_flags = flags.copy() + if len(self.children) == 0 and "discover_leaf" in callbacks: + new_flags.update(callbacks["discover_leaf"](self, _depth, flags)) + return + for child in self.children: + if "discover_node" in callbacks: + new_flags.update(callbacks["discover_node"](child, _depth + 1, flags)) + # noinspection PyProtectedMember + child._depth_first_search(callbacks, new_flags, _depth + 1) + + +# =============================================== +# Tree +# Represents a tree graph +# =============================================== + +class Tree: + + def __init__(self): + self.root = Node() # Root node, should never change + self.pointer = self.root # Pointer to the currently active node + self.label = "" # Label for active branch + + def add_or_goto(self, name): + """ + If a branch with name `name` is a child of the currently active node, then move `self.pointer` + to that branch and update visit counts and branch name lists. Otherwise, create a new child + branch with this name and move the pointer to it. + :param name: Name of the (new) branch to go to + :return: True, iff a new branch was created + """ + # Check if branch with name `name` exists + for branch in self.pointer.children: + if branch.name == name: + branch.count += 1 # Increase visit count + branch.labels += [self.label] # Append new branch label + self.pointer = branch # Move pointer to this branch + return False + + # Add a new branch + new_branch = Node(name, parent=self.pointer, story=self.label) + self.pointer.add_child(new_branch) + self.pointer = new_branch + return True + + def adding_creates_ambiguity(self, name: str): + """ + Returns True iff adding a branch with this name would result in an ambiguity in this tree, + i.e. another child node exists, which is as Wizard node. + :param name: Name of the branch (user/wizard action) + :return: True iff ambiguous + """ + return name.startswith("W") and any(c.name.startswith("W") for c in self.pointer.children) + + def up(self): + """ + Move the active branch pointer one step towards root + :return: True, iff active branch is not already on root + """ + if self.pointer != self.root: + self.pointer = self.pointer.parent + return True + else: + return False + + def reset(self, story): + """ + Reset the active branch pointer to root and specify a new story label to use in `self.add_or_goto`. + :param story: New story label + """ + self.pointer = self.root + self.label = story + + def remove(self, story=None): + """ + Remove the given story, or the story with the name stored in self.label + :param story: Name of the story + """ + if story: + self.root.remove(story) + else: + self.root.remove(self.label) + + def to_string(self, branch="", max_depth=None, show_labels=False, only_ambiguous=False, coloring="role", + include_users=True, stories=None): + """ + Create a string representation of the tree. + :param stories: Restrict output to given stories + :param branch: Restrict output to given branch (overwrites `stories`) + :param max_depth: Go no deeper than this + :param show_labels: Indicate branch labels on non-branch points + :param only_ambiguous: Only output ambiguous branches + :param coloring: Coloring rule ('role', 'depth', 'ambiguities') + :param include_users: When `only_ambiguous` is `True`, include ambiguous user responses + :return: The generated string + """ + return self.root.print_string(branch=branch, max_depth=max_depth, include_users=include_users, + show_labels=show_labels, only_ambiguous=only_ambiguous, coloring=coloring, + stories=stories) + + def __str__(self): + return self.root.print_string() + + def prune(self, *args): + self.root.prune(*args) + + @property + def leafs(self): + return self.root.leafs + + @property + def duplicates(self): + return self.root.duplicates + + def stats(self) -> dict: + """ + Compute statistics about this tree. + :return: The generated dict with statistical information + """ + return self.root.stats() + + +################################################# +# Main +################################################# + +if __name__ == '__main__': + + def main(): + + stats = None + + # Read command line arguments + args = arg_parser.parse_args() + story_file_names = args.input # Input file name + global color_code + color_code = args.color_code # "terminal" / "markdown" + + # Generate the story tree + n = 0 + tree = Tree() + slots = {} + for story_file_name in story_file_names: + with open(story_file_name, "r") as story_file: + for line in story_file: + if line.startswith("##"): + n += 1 + tree.reset(story=line[2:].strip()) + slots.clear() + else: + if line.lstrip().startswith("*"): + name = "U: " + elif line.lstrip().startswith("- slot"): + name = "S: " + else: + name = "W: " + + if name in ["U: ", "S: "]: + # Slots might have been updated -> keep track of it + new_slots = slot_to_dict(line) + copy_slots = slots.copy() + copy_slots.update(new_slots) + if copy_slots.items() == slots.items(): + # Setting this slot does not change anything + if name == "S: ": + # Ignore redundant slot lines + name = "" + line = "" + else: + slots.update(new_slots) + + name += line.strip()[2:] + if line.strip(): + tree.add_or_goto(name) + + # Merge other story file (only take in stories that don't create ambiguities) + if args.merge: + successful_merge = [] # Stories that got merged in successfully + total_num_merge = 0 # Total number of stories that should have been merged + for story_file_name in args.merge: + with open(story_file_name, "r") as story_file: + active_story = "" + for line in story_file: + if line.startswith("##"): + if active_story: + # The previous story was merged all the way and + # thus `active_story` was not set to `""`. In this + # case, we remember the name of the story that + # merged successfully + successful_merge.append(active_story) + total_num_merge += 1 + n += 1 + active_story = line[2:].strip() + tree.reset(story=active_story) + else: + if active_story: + if line.lstrip().startswith("*"): + name = "U: " + elif line.lstrip().startswith("- slot"): + name = "S: " + else: + name = "W: " + name += line.strip()[2:] + if line.strip(): + if tree.adding_creates_ambiguity(name): + # Merging `active_story` would create ambiguity + tree.remove() + active_story = "" + else: + tree.add_or_goto(name) + + # Display the tree if required + if "tree" in args.output: + _print = pydoc.pager if args.page else print + _print(tree.to_string(only_ambiguous=(args.ambiguities in ["true", "wizard"]), + include_users=(args.ambiguities != "wizard"), + max_depth=args.max_depth, + show_labels=args.labels, + coloring=args.coloring, + branch=args.branch, + stories=args.stories)) + + # Display statistics if required + if "stats" in args.output: + stats = tree.stats() + duplicates = tree.duplicates + + print() + print(colored("Text summary:", "cyan")) + if duplicates: + print(f"The input contains {stats['num_leaves']} stories, but there are some duplicates (see below).") + else: + print(f"The input contains {stats['num_leaves']} unique stories.") + print(f"The longest story is {stats['depth']} nodes deep.") + print( + f"{stats['num_nodes_with_multiple_children']} / {stats['num_nodes']} = " + f"{100.0 * stats['num_nodes_with_multiple_children'] / stats['num_nodes']:.2f}% of all nodes " + f"have multiple children.") + print(f"The deepest branch point occurs after {stats['ambiguity_depth']} steps.") + print(f"We encounter up to {stats['ambiguity_chain_length']} branch points in a single story.") + print(f"{stats['ambiguity_level']} / {stats['num_leaves']} = " + f"{100.0 * stats['ambiguity_level'] / stats['num_leaves']:.2f}% of all stories are ambiguous.") + if args.merge: + # noinspection PyUnboundLocalVariable + print(f"Successfully merged {len(successful_merge)} out of {total_num_merge} stories.") + + if duplicates: + print() + print(colored("Duplicate stories:", "cyan")) + for d in duplicates: + print(d) + + print() + print(colored("Statistics table:", "cyan")) + print(f"num stories: {stats['num_leaves']}") + print(f"max turns: {stats['depth']}") + print(f"num nodes: {stats['num_nodes']}") + print(f"branch-points: {stats['num_nodes_with_multiple_children']} " + f"({100.0 * stats['num_nodes_with_multiple_children'] / stats['num_nodes']:.2f}%)") + print(f"ambiguity depth: {stats['ambiguity_depth']}") + print(f"ambiguity length: {stats['ambiguity_chain_length']}") + print(f"ambiguity level: {stats['ambiguity_level']} " + f"({100.0 * stats['ambiguity_level'] / stats['num_leaves']:.2f}%)") + if stats['ambiguity_level'] > 0.0: + print(f"ambiguity log: log2({stats['ambiguity_level']}) = {log2(stats['ambiguity_level']):.2f} ") + + tree.prune(args.prune) + pruned_stats = tree.stats() + print() + print(colored("After pruning:", "cyan")) + print(f"num stories: {pruned_stats['num_leaves']} " + f"({stats['num_leaves'] - pruned_stats['num_leaves']} fewer)") + print(f"max turns: {pruned_stats['depth']} " + f"({stats['depth'] - pruned_stats['depth']} fewer)") + print(f"num nodes: {pruned_stats['num_nodes']} " + f"({stats['num_nodes'] - pruned_stats['num_nodes']} fewer)") + + if len(stats["story_stats"]) > 0: + print() + print(colored("Most ambiguous stories:", "cyan")) + print(f"{'Story':>15}", f"{'# relations':>14}", f"{'# branchings':>14}", f"{'# turns':>14}") + for story, values in sorted(stats["story_stats"].items(), + key=lambda kv: [ + len(kv[1]["related_to"]), + kv[1]["ambiguity_length"], + kv[1]["length"]], + reverse=True)[:12]: + print(f"{story:>15} {len(values['related_to']):>14} {values['ambiguity_length']:>14} " + f"{values['length']:>14}") + + print() + + # Display remaining stories after pruning, if required + if "pruned" in args.output: + if args.merge: + for story in sorted(successful_merge): + print(story) + else: + tree.prune(args.prune) + for story in sorted(tree.leafs): + print(story) + + if "ok" in args.output: + if not stats: + stats = tree.stats() + if stats['num_nodes_with_multiple_children'] > 0: + print("False") + else: + print("True") + + main() From f0fe7b276758698bc17c62ca21e19f3c493d806e Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 9 Oct 2019 14:57:07 +0200 Subject: [PATCH 002/209] Hack in simple story-tree validation (not featurized) --- rasa/core/validator.py | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 1feb3834d85e..51336f80ada1 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -7,6 +7,7 @@ from rasa.core.training.dsl import StoryStep from rasa.core.training.dsl import UserUttered from rasa.core.training.dsl import ActionExecuted +from rasa.core.training.dsl import SlotSet from rasa.core.constants import UTTER_PREFIX logger = logging.getLogger(__name__) @@ -162,6 +163,33 @@ def verify_utterances_in_stories(self, ignore_warnings: bool = True) -> bool: return everything_is_alright + def verify_story_structure(self, ignore_warnings: bool = True) -> bool: + """Verifies that bot behaviour in stories is deterministic.""" + + from rasa.utils.story_tree import Tree + # Generate the story tree + n = 0 + tree = Tree() + slots = {} + for story in self.stories: + n += 1 + tree.reset(story=story.block_name) + for event in story.events: + print(event) + if isinstance(event, ActionExecuted): + tree.add_or_goto("W: " + event.as_story_string()) + elif isinstance(event, UserUttered): + tree.add_or_goto("U: " + event.as_story_string()) + elif isinstance(event, SlotSet): + tree.add_or_goto("S: " + event.as_story_string()) + else: + logger.error("JJJ: event is neither action, nor a slot, nor a user utterance") + + stats = tree.stats() + logger.info(tree.to_string(show_labels=True)) + + return True + def verify_all(self, ignore_warnings: bool = True) -> bool: """Runs all the validations on intents and utterances.""" @@ -169,5 +197,9 @@ def verify_all(self, ignore_warnings: bool = True) -> bool: intents_are_valid = self.verify_intents_in_stories(ignore_warnings) logger.info("Validating utterances...") - stories_are_valid = self.verify_utterances_in_stories(ignore_warnings) - return intents_are_valid and stories_are_valid + utterances_are_valid = self.verify_utterances_in_stories(ignore_warnings) + + logger.info("Validating story-structure...") + stories_are_valid = self.verify_story_structure(ignore_warnings) + + return intents_are_valid and utterances_are_valid and stories_are_valid From 00de88ba6f0543f847a476227ebce2364b697ad3 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Thu, 10 Oct 2019 16:29:44 +0200 Subject: [PATCH 003/209] Begin tracker use (draft) --- rasa/core/validator.py | 41 +++++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 51336f80ada1..85791d3dc2d4 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -2,9 +2,10 @@ import asyncio from typing import List, Set, Text from rasa.core.domain import Domain +from rasa.core.training.generator import TrainingDataGenerator from rasa.importers.importer import TrainingDataImporter from rasa.nlu.training_data import TrainingData -from rasa.core.training.dsl import StoryStep +from rasa.core.training.structures import StoryGraph from rasa.core.training.dsl import UserUttered from rasa.core.training.dsl import ActionExecuted from rasa.core.training.dsl import SlotSet @@ -16,22 +17,22 @@ class Validator(object): """A class used to verify usage of intents and utterances.""" - def __init__(self, domain: Domain, intents: TrainingData, stories: List[StoryStep]): + def __init__(self, domain: Domain, intents: TrainingData, story_graph: StoryGraph): """Initializes the Validator object. """ self.domain = domain self.intents = intents - self.stories = stories + self.story_graph = story_graph @classmethod async def from_importer(cls, importer: TrainingDataImporter) -> "Validator": """Create an instance from the domain, nlu and story files.""" domain = await importer.get_domain() - stories = await importer.get_stories() + story_graph = await importer.get_stories() intents = await importer.get_nlu_data() - return cls(domain, intents, stories.story_steps) + return cls(domain, intents, story_graph) def verify_intents(self, ignore_warnings: bool = True) -> bool: """Compares list of intents in domain with intents in NLU training data.""" @@ -68,7 +69,7 @@ def verify_intents_in_stories(self, ignore_warnings: bool = True) -> bool: stories_intents = { event.intent["name"] - for story in self.stories + for story in self.story_graph.story_steps for event in story.events if type(event) == UserUttered } @@ -134,7 +135,7 @@ def verify_utterances_in_stories(self, ignore_warnings: bool = True) -> bool: utterance_actions = self._gather_utterance_actions() stories_utterances = set() - for story in self.stories: + for story in self.story_graph.story_steps: for event in story.events: if not isinstance(event, ActionExecuted): continue @@ -166,18 +167,33 @@ def verify_utterances_in_stories(self, ignore_warnings: bool = True) -> bool: def verify_story_structure(self, ignore_warnings: bool = True) -> bool: """Verifies that bot behaviour in stories is deterministic.""" + return True + from rasa.utils.story_tree import Tree # Generate the story tree - n = 0 tree = Tree() - slots = {} - for story in self.stories: - n += 1 + trackers = TrainingDataGenerator( + self.story_graph, + domain=self.domain, + remove_duplicates=False, + augmentation_factor=0).generate() + for story, tracker in zip(self.story_graph.story_steps, trackers): + if story.block_name not in tracker.sender_id.split(" > "): + # tracker = search(trackers) + logger.error(f"Story <{story.block_name}> not in tracker with id <{tracker.sender_id}>") tree.reset(story=story.block_name) + states = tracker.past_states(self.domain) + idx = 0 + print("\nTracker:") + print(tracker) + print("\nState list:") + print(states) for event in story.events: print(event) if isinstance(event, ActionExecuted): tree.add_or_goto("W: " + event.as_story_string()) + idx += 1 + states[: idx + 1] elif isinstance(event, UserUttered): tree.add_or_goto("U: " + event.as_story_string()) elif isinstance(event, SlotSet): @@ -186,7 +202,8 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: logger.error("JJJ: event is neither action, nor a slot, nor a user utterance") stats = tree.stats() - logger.info(tree.to_string(show_labels=True)) + logger.info(tree.to_string(show_labels=True, only_ambiguous=False)) + logger.info(stats) return True From a548667b5b6758ca66a10fbd61feb1a8da3379ed Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Thu, 10 Oct 2019 16:59:46 +0200 Subject: [PATCH 004/209] Begin tracker use (draft) --- rasa/core/validator.py | 52 ++++++++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 85791d3dc2d4..4e045400748f 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -6,6 +6,7 @@ from rasa.importers.importer import TrainingDataImporter from rasa.nlu.training_data import TrainingData from rasa.core.training.structures import StoryGraph +from rasa.core.featurizers import MaxHistoryTrackerFeaturizer from rasa.core.training.dsl import UserUttered from rasa.core.training.dsl import ActionExecuted from rasa.core.training.dsl import SlotSet @@ -167,39 +168,40 @@ def verify_utterances_in_stories(self, ignore_warnings: bool = True) -> bool: def verify_story_structure(self, ignore_warnings: bool = True) -> bool: """Verifies that bot behaviour in stories is deterministic.""" - return True + max_history = 2 - from rasa.utils.story_tree import Tree # Generate the story tree + from rasa.utils.story_tree import Tree tree = Tree() trackers = TrainingDataGenerator( self.story_graph, domain=self.domain, remove_duplicates=False, augmentation_factor=0).generate() - for story, tracker in zip(self.story_graph.story_steps, trackers): - if story.block_name not in tracker.sender_id.split(" > "): - # tracker = search(trackers) - logger.error(f"Story <{story.block_name}> not in tracker with id <{tracker.sender_id}>") - tree.reset(story=story.block_name) - states = tracker.past_states(self.domain) - idx = 0 - print("\nTracker:") - print(tracker) - print("\nState list:") - print(states) - for event in story.events: - print(event) - if isinstance(event, ActionExecuted): - tree.add_or_goto("W: " + event.as_story_string()) - idx += 1 - states[: idx + 1] - elif isinstance(event, UserUttered): - tree.add_or_goto("U: " + event.as_story_string()) - elif isinstance(event, SlotSet): - tree.add_or_goto("S: " + event.as_story_string()) - else: - logger.error("JJJ: event is neither action, nor a slot, nor a user utterance") + for story in self.story_graph.story_steps: + for tracker in trackers: + if story.block_name in tracker.sender_id.split(" > "): + tree.reset(story=tracker.sender_id) + states = tracker.past_states(self.domain) + states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 + idx = 0 + for event in story.events: + print(event) + if isinstance(event, ActionExecuted): + # only actions which can be + # predicted at a story's start + sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( + states[: idx + 1], max_history + ) + print(sliced_states) + idx += 1 + tree.add_or_goto("W: " + event.as_story_string()) + elif isinstance(event, UserUttered): + tree.add_or_goto("U: " + event.as_story_string()) + elif isinstance(event, SlotSet): + tree.add_or_goto("S: " + event.as_story_string()) + else: + logger.error("JJJ: event is neither action, nor a slot, nor a user utterance") stats = tree.stats() logger.info(tree.to_string(show_labels=True, only_ambiguous=False)) From bc4465e3662fd362d8eb7734c9b89964ec884186 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 1 Nov 2019 11:24:53 +0100 Subject: [PATCH 005/209] Use dialogue state for tree --- rasa/core/validator.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 4e045400748f..a958d5dbfb69 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -176,7 +176,7 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: trackers = TrainingDataGenerator( self.story_graph, domain=self.domain, - remove_duplicates=False, + remove_duplicates=False, # ToDo: Q&A: Why don't we deduplicate the graph here? augmentation_factor=0).generate() for story in self.story_graph.story_steps: for tracker in trackers: @@ -186,25 +186,29 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 idx = 0 for event in story.events: - print(event) + # print(event) if isinstance(event, ActionExecuted): - # only actions which can be - # predicted at a story's start sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( states[: idx + 1], max_history ) - print(sliced_states) + # print(sliced_states) idx += 1 - tree.add_or_goto("W: " + event.as_story_string()) + tree.add_or_goto("W: " + str(sliced_states) + f" [{event.as_story_string()}]") elif isinstance(event, UserUttered): - tree.add_or_goto("U: " + event.as_story_string()) + sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( + states[: idx + 1], max_history + ) + # print(sliced_states) + idx += 1 + tree.add_or_goto("U: " + str(sliced_states) + f" [{event.as_story_string()}]") + # tree.add_or_goto("U: " + event.as_story_string()) elif isinstance(event, SlotSet): tree.add_or_goto("S: " + event.as_story_string()) else: logger.error("JJJ: event is neither action, nor a slot, nor a user utterance") stats = tree.stats() - logger.info(tree.to_string(show_labels=True, only_ambiguous=False)) + logger.info(tree.to_string(show_labels=True, only_ambiguous=True)) logger.info(stats) return True From e8743c0cd413bbd9164575475f6024d788fffd2f Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 1 Nov 2019 11:34:00 +0100 Subject: [PATCH 006/209] Distinguish name and state in story_tree --- rasa/core/validator.py | 10 ++++++++-- rasa/utils/story_tree.py | 29 +++++++++++++++-------------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index a958d5dbfb69..3f01bda52324 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -193,14 +193,20 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: ) # print(sliced_states) idx += 1 - tree.add_or_goto("W: " + str(sliced_states) + f" [{event.as_story_string()}]") + tree.add_or_goto( + "W: " + str(sliced_states) + f" [{event.as_story_string()}]", + event.as_story_string() + ) elif isinstance(event, UserUttered): sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( states[: idx + 1], max_history ) # print(sliced_states) idx += 1 - tree.add_or_goto("U: " + str(sliced_states) + f" [{event.as_story_string()}]") + tree.add_or_goto( + "U: " + str(sliced_states) + f" [{event.as_story_string()}]", + event.as_story_string() + ) # tree.add_or_goto("U: " + event.as_story_string()) elif isinstance(event, SlotSet): tree.add_or_goto("S: " + event.as_story_string()) diff --git a/rasa/utils/story_tree.py b/rasa/utils/story_tree.py index 35bf5c414ce2..ef4a959d85ab 100644 --- a/rasa/utils/story_tree.py +++ b/rasa/utils/story_tree.py @@ -80,8 +80,9 @@ def slot_to_dict(string): class Node: - def __init__(self, name="root", parent=None, story=""): + def __init__(self, state="root", name="root", parent=None, story=""): self.count = 1 + self.state = state self.name = name self.parent = parent self.children = [] @@ -94,20 +95,20 @@ def add_child(self, child): :type child: Node :param child: Child node """ - if self.get_child(child.name) is None: + if self.get_child(child.state) is None: self.children.append(child) else: - raise ValueError(f"A child with the name {child.name} already exists!") + raise ValueError(f"A child with the name {child.state} already exists!") pass - def get_child(self, name: str): + def get_child(self, state: str): """ Get the child with the given name - :param name: Name of the sought child + :param state: Name of the sought child :return: Child node """ for child in self.children: - if child.name == name: + if child.state == state: return child return None @@ -133,7 +134,7 @@ def print_string(self, branch="", stories=None, max_depth=None, only_ambiguous=F # Decide how to color the present node if coloring[0] == "r": # "role" - color = {"S": "yellow", "U": "blue", "W": "green"}.get(self.name[0]) + color = {"S": "yellow", "U": "blue", "W": "green"}.get(self.state[0]) elif coloring[0] == "d": # "depth" color = {1: "green", 2: "magenta", 3: "yellow", 4: "cyan", 5: "blue", 6: "grey"}.get(_depth, "grey") elif coloring[0] == "n": # "none" @@ -394,36 +395,36 @@ def __init__(self): self.pointer = self.root # Pointer to the currently active node self.label = "" # Label for active branch - def add_or_goto(self, name): + def add_or_goto(self, state, name): """ If a branch with name `name` is a child of the currently active node, then move `self.pointer` to that branch and update visit counts and branch name lists. Otherwise, create a new child branch with this name and move the pointer to it. - :param name: Name of the (new) branch to go to + :param state: Name of the (new) branch to go to :return: True, iff a new branch was created """ # Check if branch with name `name` exists for branch in self.pointer.children: - if branch.name == name: + if branch.state == state: branch.count += 1 # Increase visit count branch.labels += [self.label] # Append new branch label self.pointer = branch # Move pointer to this branch return False # Add a new branch - new_branch = Node(name, parent=self.pointer, story=self.label) + new_branch = Node(state, name, parent=self.pointer, story=self.label) self.pointer.add_child(new_branch) self.pointer = new_branch return True - def adding_creates_ambiguity(self, name: str): + def adding_creates_ambiguity(self, state: str): """ Returns True iff adding a branch with this name would result in an ambiguity in this tree, i.e. another child node exists, which is as Wizard node. - :param name: Name of the branch (user/wizard action) + :param state: State of the branch (user/wizard action) :return: True iff ambiguous """ - return name.startswith("W") and any(c.name.startswith("W") for c in self.pointer.children) + return state.startswith("W") and any(c.state.startswith("W") for c in self.pointer.children) def up(self): """ From 33225bdcddfb11d34c2777755c2c0fc1f09335ca Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 1 Nov 2019 11:40:34 +0100 Subject: [PATCH 007/209] Add separate 'kind' property to tree --- rasa/core/validator.py | 12 +++++++++--- rasa/utils/story_tree.py | 17 ++++++++++------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 3f01bda52324..2475b286e8e3 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -194,7 +194,8 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: # print(sliced_states) idx += 1 tree.add_or_goto( - "W: " + str(sliced_states) + f" [{event.as_story_string()}]", + "W", + str(sliced_states) + f" [{event.as_story_string()}]", event.as_story_string() ) elif isinstance(event, UserUttered): @@ -204,12 +205,17 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: # print(sliced_states) idx += 1 tree.add_or_goto( - "U: " + str(sliced_states) + f" [{event.as_story_string()}]", + "U", + str(sliced_states) + f" [{event.as_story_string()}]", event.as_story_string() ) # tree.add_or_goto("U: " + event.as_story_string()) elif isinstance(event, SlotSet): - tree.add_or_goto("S: " + event.as_story_string()) + tree.add_or_goto( + "S", + event.as_story_string(), + event.as_story_string() + ) else: logger.error("JJJ: event is neither action, nor a slot, nor a user utterance") diff --git a/rasa/utils/story_tree.py b/rasa/utils/story_tree.py index ef4a959d85ab..859ad9e7612c 100644 --- a/rasa/utils/story_tree.py +++ b/rasa/utils/story_tree.py @@ -80,8 +80,9 @@ def slot_to_dict(string): class Node: - def __init__(self, state="root", name="root", parent=None, story=""): + def __init__(self, kind="R", state="root", name="root", parent=None, story=""): self.count = 1 + self.kind = kind self.state = state self.name = name self.parent = parent @@ -134,7 +135,7 @@ def print_string(self, branch="", stories=None, max_depth=None, only_ambiguous=F # Decide how to color the present node if coloring[0] == "r": # "role" - color = {"S": "yellow", "U": "blue", "W": "green"}.get(self.state[0]) + color = {"S": "yellow", "U": "blue", "W": "green"}.get(self.kind) elif coloring[0] == "d": # "depth" color = {1: "green", 2: "magenta", 3: "yellow", 4: "cyan", 5: "blue", 6: "grey"}.get(_depth, "grey") elif coloring[0] == "n": # "none" @@ -170,7 +171,7 @@ def print_string(self, branch="", stories=None, max_depth=None, only_ambiguous=F # Prepare _has_siblings for recursion step has_siblings = (len(self.children) > 1) if has_siblings and not include_users: - all_children_are_users = all(child.name.startswith("U:") or child.name.startswith("S:") for child in self.children) + all_children_are_users = all(child.kind == "U" or child.kind == "S" for child in self.children) has_siblings = not all_children_are_users # Recursion step into all child nodes @@ -215,7 +216,7 @@ def has_descendants_with_siblings(self, include_users): if include_users: return True else: - all_children_are_users = all(child.name.startswith("U:") or child.name.startswith("S:") for child in self.children) + all_children_are_users = all(child.kind == "U" or child.kind == "S" for child in self.children) if all_children_are_users: return any(child.has_descendants_with_siblings(include_users) for child in self.children) else: @@ -395,12 +396,14 @@ def __init__(self): self.pointer = self.root # Pointer to the currently active node self.label = "" # Label for active branch - def add_or_goto(self, state, name): + def add_or_goto(self, kind, state, name): """ If a branch with name `name` is a child of the currently active node, then move `self.pointer` to that branch and update visit counts and branch name lists. Otherwise, create a new child branch with this name and move the pointer to it. - :param state: Name of the (new) branch to go to + :param kind: U/S/W for user/slot/wizard + :param state: State string of the (new) branch to go to + :param name: Name of the (new) branch to go to :return: True, iff a new branch was created """ # Check if branch with name `name` exists @@ -412,7 +415,7 @@ def add_or_goto(self, state, name): return False # Add a new branch - new_branch = Node(state, name, parent=self.pointer, story=self.label) + new_branch = Node(kind, state, name, parent=self.pointer, story=self.label) self.pointer.add_child(new_branch) self.pointer = new_branch return True From 59c467289b8c7a78676153692c4eb13756aacdfe Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 1 Nov 2019 12:59:24 +0100 Subject: [PATCH 008/209] Add sensible output --- rasa/core/validator.py | 18 +++++++++++++--- rasa/utils/story_tree.py | 45 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 2475b286e8e3..3dd23548a85f 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -219,9 +219,21 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: else: logger.error("JJJ: event is neither action, nor a slot, nor a user utterance") - stats = tree.stats() - logger.info(tree.to_string(show_labels=True, only_ambiguous=True)) - logger.info(stats) + conflicts = tree.conflicts + if conflicts: + output = "Ambiguous stories:\n" + for conflict in conflicts: + stories = [] + for a in conflict["ambiguity"]: + stories += a["stories"] + lead = conflict["leading_steps"][1:] + output += f"The stories {stories} all start with {lead}, but then \n" + for a in conflict["ambiguity"]: + if len(a["stories"]) > 1: + output += f"* stories {a['stories']} continue with {a['action']}\n" + else: + output += f"* story {a['stories']} continues with {a['action']}\n" + logger.warning(output) return True diff --git a/rasa/utils/story_tree.py b/rasa/utils/story_tree.py index 859ad9e7612c..4ba4ba160ceb 100644 --- a/rasa/utils/story_tree.py +++ b/rasa/utils/story_tree.py @@ -235,7 +235,7 @@ def prune(self, keep: str): return if len(self.children) > 0: if len(self.children) > 1: - if any([child.name.startswith("W:") for child in self.children]): + if any([child.kind == "W" for child in self.children]): if keep == "first": del self.children[1:] elif keep == "last": @@ -253,6 +253,43 @@ def prune(self, keep: str): child.prune(keep) self._is_pruned = True + def conflicts(self, keep: str, _depth=0, _leading_steps=[]): + """Return list of conflict points""" + if self._is_pruned: + return None + + conflicts = [] + if len(self.children) > 0: + if len(self.children) > 1: + if any([child.kind == "W" for child in self.children]): + conflicts.append({ + "ambiguity": [{"stories": c.labels, "action": c.name} for c in self.children if c.kind == "W"], + "conflict_step": _depth, + "leading_steps": _leading_steps + [self.name], + }) + if keep == "first": + del self.children[1:] + elif keep == "last": + del self.children[:-1] + elif keep == "most-visited": + visit_counts = [len(c.labels) for c in self.children] + keep_idx = visit_counts.index(max(visit_counts)) + # Delete all but the one at `keep_idx` + del self.children[:keep_idx] + if len(self.children) > 1: + del self.children[1:] + else: + raise ValueError("Invalid prune keep criterion.") + + for child in self.children: + child_conflicts = child.conflicts(keep, _depth+1, _leading_steps + [self.name]) + if child_conflicts: + conflicts += child_conflicts + + self._is_pruned = True + + return conflicts + def remove(self, story) -> bool: """Remove the given story from this node and recursively from all descendants. """ @@ -333,7 +370,7 @@ def stats(self): def callback_discover_node(node, depth, flags): statistics["num_nodes"] += 1 if len(node.children) > 1: - if any([child.name.startswith("W:") for child in node.children]): + if any([child.kind == "W" for child in node.children]): statistics["num_nodes_with_multiple_children"] += 1 statistics["ambiguity_depth"] = max(statistics["ambiguity_depth"], depth) statistics["ambiguity_chain_length"] = max(statistics["ambiguity_chain_length"], @@ -489,6 +526,10 @@ def leafs(self): def duplicates(self): return self.root.duplicates + @property + def conflicts(self): + return self.root.conflicts("most-visited") + def stats(self) -> dict: """ Compute statistics about this tree. From e771f21b1587a74e7d7c3cb1564399317a8c760e Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 1 Nov 2019 14:59:53 +0100 Subject: [PATCH 009/209] Discard the tree --- rasa/core/validator.py | 104 +++++++++++++++++++++++------------------ 1 file changed, 59 insertions(+), 45 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 3dd23548a85f..7475fc01f8af 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -168,7 +168,7 @@ def verify_utterances_in_stories(self, ignore_warnings: bool = True) -> bool: def verify_story_structure(self, ignore_warnings: bool = True) -> bool: """Verifies that bot behaviour in stories is deterministic.""" - max_history = 2 + max_history = 1 # Generate the story tree from rasa.utils.story_tree import Tree @@ -178,62 +178,76 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: domain=self.domain, remove_duplicates=False, # ToDo: Q&A: Why don't we deduplicate the graph here? augmentation_factor=0).generate() + rules = {} + conflicts = {} for story in self.story_graph.story_steps: for tracker in trackers: if story.block_name in tracker.sender_id.split(" > "): - tree.reset(story=tracker.sender_id) + states = tracker.past_states(self.domain) states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 idx = 0 for event in story.events: - # print(event) if isinstance(event, ActionExecuted): sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( states[: idx + 1], max_history ) - # print(sliced_states) - idx += 1 - tree.add_or_goto( - "W", - str(sliced_states) + f" [{event.as_story_string()}]", - event.as_story_string() - ) - elif isinstance(event, UserUttered): - sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( - states[: idx + 1], max_history - ) - # print(sliced_states) + h = hash(str(sliced_states)) + if h in rules and rules[h]['action'] != event.as_story_string(): + print(f"CONFLICT in between " + f"story '{tracker.sender_id}' with action '{event.as_story_string()}' " + f"and story '{rules[h]['tracker']}' with action '{rules[h]['action']}'.") + if h not in conflicts: + conflicts[h] = {tracker.sender_id, rules[h]['tracker']} + else: + conflicts[h] += {tracker.sender_id, rules[h]['tracker']} + else: + rules[h] = { + "tracker": tracker.sender_id, + "action": event.as_story_string() + } + idx += 1 + + print(conflicts) + + for state_hash, tracker_ids in conflicts.items(): + print(f" -- CONFLICT -- ") + if len(tracker_ids) == 1: + tracker_id = tracker_ids.pop() + print(f"The tracker '{tracker_id}' is inconsistent with itself:") + + # find the right tracker + tracker = None + for t in trackers: + if t.sender_id == tracker_id: + tracker = t + break + + assert tracker + + description = "" + for story in self.story_graph.story_steps: + if story.block_name in tracker.sender_id.split(" > "): + description += f"Story '{story.block_name}':\n" + states = tracker.past_states(self.domain) + states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 + idx = 0 + for event in story.events: + if isinstance(event, UserUttered): + description += f"* {event.as_story_string()}" + elif isinstance(event, ActionExecuted): + description += f" - {event.as_story_string()}" + sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( + states[: idx + 1], max_history + ) + h = hash(str(sliced_states)) + if h == state_hash: + description += " <-- CONFLICT" idx += 1 - tree.add_or_goto( - "U", - str(sliced_states) + f" [{event.as_story_string()}]", - event.as_story_string() - ) - # tree.add_or_goto("U: " + event.as_story_string()) - elif isinstance(event, SlotSet): - tree.add_or_goto( - "S", - event.as_story_string(), - event.as_story_string() - ) - else: - logger.error("JJJ: event is neither action, nor a slot, nor a user utterance") - - conflicts = tree.conflicts - if conflicts: - output = "Ambiguous stories:\n" - for conflict in conflicts: - stories = [] - for a in conflict["ambiguity"]: - stories += a["stories"] - lead = conflict["leading_steps"][1:] - output += f"The stories {stories} all start with {lead}, but then \n" - for a in conflict["ambiguity"]: - if len(a["stories"]) > 1: - output += f"* stories {a['stories']} continue with {a['action']}\n" - else: - output += f"* story {a['stories']} continues with {a['action']}\n" - logger.warning(output) + description += "\n" + print(description) + elif len(tracker_ids) == 2: + print(f"The trackers {tracker_ids} contain inconsistent states:") return True From ba49d542612f5e4e306a0502615dbd739745fb7c Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 1 Nov 2019 15:22:05 +0100 Subject: [PATCH 010/209] Enable single-story-per-tracker output --- rasa/core/validator.py | 62 ++++++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 24 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 7475fc01f8af..ee6488e685b7 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -194,41 +194,30 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: ) h = hash(str(sliced_states)) if h in rules and rules[h]['action'] != event.as_story_string(): - print(f"CONFLICT in between " - f"story '{tracker.sender_id}' with action '{event.as_story_string()}' " - f"and story '{rules[h]['tracker']}' with action '{rules[h]['action']}'.") + # print(f"CONFLICT in between " + # f"story '{tracker.sender_id}' with action '{event.as_story_string()}' " + # f"and story '{rules[h]['tracker']}' with action '{rules[h]['action']}'.") if h not in conflicts: - conflicts[h] = {tracker.sender_id, rules[h]['tracker']} + conflicts[h] = {tracker.sender_id: tracker, rules[h]['tracker'].sender_id: rules[h]['tracker']} else: - conflicts[h] += {tracker.sender_id, rules[h]['tracker']} + conflicts[h] += {tracker.sender_id: tracker, rules[h]['tracker'].sender_id: rules[h]['tracker']} else: rules[h] = { - "tracker": tracker.sender_id, + "tracker": tracker, "action": event.as_story_string() } idx += 1 - print(conflicts) - - for state_hash, tracker_ids in conflicts.items(): + for state_hash, tracker_dict in conflicts.items(): print(f" -- CONFLICT -- ") - if len(tracker_ids) == 1: - tracker_id = tracker_ids.pop() - print(f"The tracker '{tracker_id}' is inconsistent with itself:") - - # find the right tracker - tracker = None - for t in trackers: - if t.sender_id == tracker_id: - tracker = t - break - - assert tracker + if len(tracker_dict) == 1: + tracker = list(tracker_dict.values())[0] + print(f"The tracker '{tracker.sender_id}' is inconsistent with itself:") description = "" for story in self.story_graph.story_steps: if story.block_name in tracker.sender_id.split(" > "): - description += f"Story '{story.block_name}':\n" + description += f"\nStory '{story.block_name}':\n" states = tracker.past_states(self.domain) states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 idx = 0 @@ -246,8 +235,33 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: idx += 1 description += "\n" print(description) - elif len(tracker_ids) == 2: - print(f"The trackers {tracker_ids} contain inconsistent states:") + elif len(tracker_dict) == 2: + print(f"The trackers {set(tracker_dict.keys())} contain inconsistent states:") + trackers = list(tracker_dict.values()) + description = "" + for story in self.story_graph.story_steps: + for tracker in trackers: + if story.block_name in tracker.sender_id.split(" > "): + description += f"\nStory '{story.block_name}':\n" + states = tracker.past_states(self.domain) + states = [dict(state) for state in + states] # ToDo: Check against rasa/core/featurizers.py:318 + idx = 0 + for event in story.events: + if isinstance(event, UserUttered): + description += f"* {event.as_story_string()}" + elif isinstance(event, ActionExecuted): + description += f" - {event.as_story_string()}" + sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( + states[: idx + 1], max_history + ) + h = hash(str(sliced_states)) + if h == state_hash: + description += " <-- CONFLICT" + idx += 1 + description += "\n" + + print(description) return True From 408775a9825f931127f9c26dc76ad7e2a265dd5c Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 1 Nov 2019 15:43:05 +0100 Subject: [PATCH 011/209] Fix invalid += for dict --- rasa/core/validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index ee6488e685b7..dfca7179e569 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -200,7 +200,7 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: if h not in conflicts: conflicts[h] = {tracker.sender_id: tracker, rules[h]['tracker'].sender_id: rules[h]['tracker']} else: - conflicts[h] += {tracker.sender_id: tracker, rules[h]['tracker'].sender_id: rules[h]['tracker']} + conflicts[h].update({tracker.sender_id: tracker, rules[h]['tracker'].sender_id: rules[h]['tracker']}) else: rules[h] = { "tracker": tracker, From 9d0485f28bb1e8fb803c8eff21524626689423cc Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 1 Nov 2019 16:07:01 +0100 Subject: [PATCH 012/209] Start using trackers only for conflict finding --- rasa/core/validator.py | 80 +++++++++++++++++++++++++----------------- 1 file changed, 47 insertions(+), 33 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index dfca7179e569..9ffe00bebceb 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -180,33 +180,36 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: augmentation_factor=0).generate() rules = {} conflicts = {} - for story in self.story_graph.story_steps: - for tracker in trackers: - if story.block_name in tracker.sender_id.split(" > "): - - states = tracker.past_states(self.domain) - states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 - idx = 0 - for event in story.events: - if isinstance(event, ActionExecuted): - sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( - states[: idx + 1], max_history - ) - h = hash(str(sliced_states)) - if h in rules and rules[h]['action'] != event.as_story_string(): - # print(f"CONFLICT in between " - # f"story '{tracker.sender_id}' with action '{event.as_story_string()}' " - # f"and story '{rules[h]['tracker']}' with action '{rules[h]['action']}'.") - if h not in conflicts: - conflicts[h] = {tracker.sender_id: tracker, rules[h]['tracker'].sender_id: rules[h]['tracker']} - else: - conflicts[h].update({tracker.sender_id: tracker, rules[h]['tracker'].sender_id: rules[h]['tracker']}) - else: - rules[h] = { - "tracker": tracker, - "action": event.as_story_string() - } - idx += 1 + for tracker in trackers: + print(tracker.sender_id) + states = tracker.past_states(self.domain) + states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 + idx = 0 + for event in tracker.events: + if isinstance(event, ActionExecuted): + sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( + states[: idx + 1], max_history + ) + h = hash(str(sliced_states)) + if h in rules and rules[h]['action'] != event.as_story_string(): + # print(f"CONFLICT in between " + # f"story '{tracker.sender_id}' with action '{event.as_story_string()}' " + # f"and story '{rules[h]['tracker']}' with action '{rules[h]['action']}'.") + if h not in conflicts: + conflicts[h] = {tracker.sender_id: tracker, rules[h]['tracker'].sender_id: rules[h]['tracker']} + else: + conflicts[h].update({tracker.sender_id: tracker, rules[h]['tracker'].sender_id: rules[h]['tracker']}) + else: + rules[h] = { + "tracker": tracker, + "action": event.as_story_string() + } + elif isinstance(event, UserUttered): + pass + else: + raise ValueError(f"Event has type {type(event)}") + idx += 1 + print() for state_hash, tracker_dict in conflicts.items(): print(f" -- CONFLICT -- ") @@ -217,7 +220,7 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: description = "" for story in self.story_graph.story_steps: if story.block_name in tracker.sender_id.split(" > "): - description += f"\nStory '{story.block_name}':\n" + description += f" ~~ '{story.block_name}' ~~\n" states = tracker.past_states(self.domain) states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 idx = 0 @@ -238,11 +241,18 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: elif len(tracker_dict) == 2: print(f"The trackers {set(tracker_dict.keys())} contain inconsistent states:") trackers = list(tracker_dict.values()) - description = "" - for story in self.story_graph.story_steps: - for tracker in trackers: + story_blocks = {} + for tracker in trackers: + print(tracker.sender_id) + for story in self.story_graph.story_steps: if story.block_name in tracker.sender_id.split(" > "): - description += f"\nStory '{story.block_name}':\n" + block_id = 0 + for i, s in enumerate(tracker.sender_id.split(" > ")): + if story.block_name == s: + block_id = i + break + + description = f"~~ '{story.block_name}' ~~\n" states = tracker.past_states(self.domain) states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 @@ -261,7 +271,11 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: idx += 1 description += "\n" - print(description) + story_blocks[block_id] = description + + for _, block in story_blocks.items(): + # print(i) + print(block, end="") return True From 2a3e140c568c9d2deb3430c666bb1323d1c49a9f Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 1 Nov 2019 16:23:45 +0100 Subject: [PATCH 013/209] Fix print problem --- rasa/core/validator.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 9ffe00bebceb..c64d924eb77b 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -212,7 +212,7 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: print() for state_hash, tracker_dict in conflicts.items(): - print(f" -- CONFLICT -- ") + print(f" -- CONFLICT [{state_hash}] -- ") if len(tracker_dict) == 1: tracker = list(tracker_dict.values())[0] print(f"The tracker '{tracker.sender_id}' is inconsistent with itself:") @@ -243,6 +243,7 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: trackers = list(tracker_dict.values()) story_blocks = {} for tracker in trackers: + print() print(tracker.sender_id) for story in self.story_graph.story_steps: if story.block_name in tracker.sender_id.split(" > "): @@ -273,9 +274,9 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: story_blocks[block_id] = description - for _, block in story_blocks.items(): - # print(i) - print(block, end="") + for _, block in story_blocks.items(): + # print(i) + print(block, end="") return True From fb2a5522ea94b12b410eac1b6980a40b2be7733a Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 1 Nov 2019 17:24:51 +0100 Subject: [PATCH 014/209] Fix bad position of idx --- rasa/core/validator.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index c64d924eb77b..8905a63d3d5c 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -190,7 +190,7 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( states[: idx + 1], max_history ) - h = hash(str(sliced_states)) + h = hash(str(sorted(list(sliced_states)))) if h in rules and rules[h]['action'] != event.as_story_string(): # print(f"CONFLICT in between " # f"story '{tracker.sender_id}' with action '{event.as_story_string()}' " @@ -218,12 +218,13 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: print(f"The tracker '{tracker.sender_id}' is inconsistent with itself:") description = "" + idx = 0 for story in self.story_graph.story_steps: if story.block_name in tracker.sender_id.split(" > "): description += f" ~~ '{story.block_name}' ~~\n" states = tracker.past_states(self.domain) states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 - idx = 0 + for event in story.events: if isinstance(event, UserUttered): description += f"* {event.as_story_string()}" @@ -232,7 +233,7 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( states[: idx + 1], max_history ) - h = hash(str(sliced_states)) + h = hash(str(sorted(list(sliced_states)))) if h == state_hash: description += " <-- CONFLICT" idx += 1 @@ -245,6 +246,7 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: for tracker in trackers: print() print(tracker.sender_id) + idx = 0 for story in self.story_graph.story_steps: if story.block_name in tracker.sender_id.split(" > "): block_id = 0 @@ -257,7 +259,6 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: states = tracker.past_states(self.domain) states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 - idx = 0 for event in story.events: if isinstance(event, UserUttered): description += f"* {event.as_story_string()}" @@ -266,7 +267,7 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( states[: idx + 1], max_history ) - h = hash(str(sliced_states)) + h = hash(str(sorted(list(sliced_states)))) if h == state_hash: description += " <-- CONFLICT" idx += 1 From 343d98a26ffc6ffe98e21d991351e158b03fbcf3 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 4 Nov 2019 11:26:06 +0100 Subject: [PATCH 015/209] Fix rule collection --- rasa/core/validator.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 8905a63d3d5c..3a142b97a46f 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -179,37 +179,38 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: remove_duplicates=False, # ToDo: Q&A: Why don't we deduplicate the graph here? augmentation_factor=0).generate() rules = {} - conflicts = {} for tracker in trackers: print(tracker.sender_id) states = tracker.past_states(self.domain) states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 + idx = 0 for event in tracker.events: if isinstance(event, ActionExecuted): sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( states[: idx + 1], max_history ) - h = hash(str(sorted(list(sliced_states)))) - if h in rules and rules[h]['action'] != event.as_story_string(): - # print(f"CONFLICT in between " - # f"story '{tracker.sender_id}' with action '{event.as_story_string()}' " - # f"and story '{rules[h]['tracker']}' with action '{rules[h]['action']}'.") - if h not in conflicts: - conflicts[h] = {tracker.sender_id: tracker, rules[h]['tracker'].sender_id: rules[h]['tracker']} - else: - conflicts[h].update({tracker.sender_id: tracker, rules[h]['tracker'].sender_id: rules[h]['tracker']}) + h = str(sorted(list(sliced_states))) + if h in rules: + known_actions = [info["action"] for info in rules[h]] + if event.as_story_string() not in known_actions: + # print(f"{h} >> {event.as_story_string()} and {known_actions}") + rules[h] += [{ + "tracker": tracker, + "action": event.as_story_string() + }] else: - rules[h] = { + # print(f"{h} >> {event.as_story_string()}") + rules[h] = [{ "tracker": tracker, "action": event.as_story_string() - } - elif isinstance(event, UserUttered): - pass - else: - raise ValueError(f"Event has type {type(event)}") - idx += 1 + }] + idx += 1 print() + for state_hash, info in rules.items(): + if len(info) > 1: + print(f"{state_hash}: {[i['action'] for i in info]}") + exit() for state_hash, tracker_dict in conflicts.items(): print(f" -- CONFLICT [{state_hash}] -- ") From 30252877d374ac792f458bd4c671e461ba15e75b Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 4 Nov 2019 11:47:33 +0100 Subject: [PATCH 016/209] Recreate output --- rasa/core/validator.py | 114 ++++++++++++++++++++--------------------- 1 file changed, 57 insertions(+), 57 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 3a142b97a46f..8a064fbb2c44 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -180,7 +180,7 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: augmentation_factor=0).generate() rules = {} for tracker in trackers: - print(tracker.sender_id) + # print(tracker.sender_id) states = tracker.past_states(self.domain) states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 @@ -190,7 +190,7 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( states[: idx + 1], max_history ) - h = str(sorted(list(sliced_states))) + h = hash(str(sorted(list(sliced_states)))) if h in rules: known_actions = [info["action"] for info in rules[h]] if event.as_story_string() not in known_actions: @@ -207,59 +207,26 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: }] idx += 1 print() + result = True for state_hash, info in rules.items(): if len(info) > 1: - print(f"{state_hash}: {[i['action'] for i in info]}") - exit() - - for state_hash, tracker_dict in conflicts.items(): - print(f" -- CONFLICT [{state_hash}] -- ") - if len(tracker_dict) == 1: - tracker = list(tracker_dict.values())[0] - print(f"The tracker '{tracker.sender_id}' is inconsistent with itself:") - - description = "" - idx = 0 - for story in self.story_graph.story_steps: - if story.block_name in tracker.sender_id.split(" > "): - description += f" ~~ '{story.block_name}' ~~\n" - states = tracker.past_states(self.domain) - states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 - - for event in story.events: - if isinstance(event, UserUttered): - description += f"* {event.as_story_string()}" - elif isinstance(event, ActionExecuted): - description += f" - {event.as_story_string()}" - sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( - states[: idx + 1], max_history - ) - h = hash(str(sorted(list(sliced_states)))) - if h == state_hash: - description += " <-- CONFLICT" - idx += 1 - description += "\n" - print(description) - elif len(tracker_dict) == 2: - print(f"The trackers {set(tracker_dict.keys())} contain inconsistent states:") - trackers = list(tracker_dict.values()) - story_blocks = {} - for tracker in trackers: - print() - print(tracker.sender_id) + result = False + conflicting_trackers = {i['tracker'].sender_id: i['tracker'] for i in info} + # print(f"CONFLICT {state_hash}: Ambiguity of choice between actions {[i['action'] for i in info]} " + # f"when learning from trackers {conflicting_tracker_names}.") + + if len(conflicting_trackers) == 1: + tracker = list(conflicting_trackers.values())[0] + print(f"The tracker '{tracker.sender_id}' is inconsistent with itself:") + description = "" idx = 0 for story in self.story_graph.story_steps: if story.block_name in tracker.sender_id.split(" > "): - block_id = 0 - for i, s in enumerate(tracker.sender_id.split(" > ")): - if story.block_name == s: - block_id = i - break - - description = f"~~ '{story.block_name}' ~~\n" + description += f" ~~ '{story.block_name}' ~~\n" states = tracker.past_states(self.domain) - states = [dict(state) for state in - states] # ToDo: Check against rasa/core/featurizers.py:318 + # ToDo: Check against rasa/core/featurizers.py:318 + states = [dict(state) for state in states] + for event in story.events: if isinstance(event, UserUttered): description += f"* {event.as_story_string()}" @@ -273,14 +240,47 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: description += " <-- CONFLICT" idx += 1 description += "\n" - - story_blocks[block_id] = description - - for _, block in story_blocks.items(): - # print(i) - print(block, end="") - - return True + print(description) + elif len(conflicting_trackers) == 2: + print(f"The trackers '{list(conflicting_trackers.keys())}' are inconsistent:") + story_blocks = {} + for tracker in list(conflicting_trackers.values()): + print() + print(tracker.sender_id) + idx = 0 + for story in self.story_graph.story_steps: + if story.block_name in tracker.sender_id.split(" > "): + block_id = 0 + for i, s in enumerate(tracker.sender_id.split(" > ")): + if story.block_name == s: + block_id = i + break + + description = f"~~ '{story.block_name}' ~~\n" + states = tracker.past_states(self.domain) + states = [dict(state) for state in + states] # ToDo: Check against rasa/core/featurizers.py:318 + for event in story.events: + if isinstance(event, UserUttered): + description += f"* {event.as_story_string()}" + elif isinstance(event, ActionExecuted): + description += f" - {event.as_story_string()}" + sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( + states[: idx + 1], max_history + ) + h = hash(str(sorted(list(sliced_states)))) + if h == state_hash: + description += " <-- CONFLICT" + idx += 1 + description += "\n" + + story_blocks[block_id] = description + + for _, block in story_blocks.items(): + # print(i) + print(block, end="") + + return result def verify_all(self, ignore_warnings: bool = True) -> bool: """Runs all the validations on intents and utterances.""" From f32806d5010b9ac1f951ac0dd0abf184175d4da5 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 4 Nov 2019 11:55:09 +0100 Subject: [PATCH 017/209] Dont sort states, since not sortable --- rasa/core/validator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 8a064fbb2c44..262efce41c3b 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -190,7 +190,7 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( states[: idx + 1], max_history ) - h = hash(str(sorted(list(sliced_states)))) + h = hash(str(list(sliced_states))) if h in rules: known_actions = [info["action"] for info in rules[h]] if event.as_story_string() not in known_actions: @@ -235,7 +235,7 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( states[: idx + 1], max_history ) - h = hash(str(sorted(list(sliced_states)))) + h = hash(str(list(sliced_states))) if h == state_hash: description += " <-- CONFLICT" idx += 1 @@ -268,7 +268,7 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( states[: idx + 1], max_history ) - h = hash(str(sorted(list(sliced_states)))) + h = hash(str(list(sliced_states))) if h == state_hash: description += " <-- CONFLICT" idx += 1 From 7605845ee182fb0b6b45aaa8a2f54f919a5dd415 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 4 Nov 2019 16:54:00 +0100 Subject: [PATCH 018/209] Fix forms and slots handling --- rasa/core/validator.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 262efce41c3b..8c66260337fd 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -168,7 +168,7 @@ def verify_utterances_in_stories(self, ignore_warnings: bool = True) -> bool: def verify_story_structure(self, ignore_warnings: bool = True) -> bool: """Verifies that bot behaviour in stories is deterministic.""" - max_history = 1 + max_history = 3 # Generate the story tree from rasa.utils.story_tree import Tree @@ -217,7 +217,7 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: if len(conflicting_trackers) == 1: tracker = list(conflicting_trackers.values())[0] - print(f"The tracker '{tracker.sender_id}' is inconsistent with itself:") + print(f"\nThe tracker '{tracker.sender_id}' is inconsistent with itself:") description = "" idx = 0 for story in self.story_graph.story_steps: @@ -242,40 +242,48 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: description += "\n" print(description) elif len(conflicting_trackers) == 2: - print(f"The trackers '{list(conflicting_trackers.keys())}' are inconsistent:") + print(f"\nThe trackers '{list(conflicting_trackers.keys())}' are inconsistent:") story_blocks = {} for tracker in list(conflicting_trackers.values()): print() - print(tracker.sender_id) + # print(tracker.sender_id) idx = 0 for story in self.story_graph.story_steps: if story.block_name in tracker.sender_id.split(" > "): - block_id = 0 + block_id = -1 for i, s in enumerate(tracker.sender_id.split(" > ")): if story.block_name == s: + if i in story_blocks: + print(f"DUPLICATE BLOCK {story}") # ToDo: Resolve this OR-problem block_id = i break + assert block_id >= 0 + description = f"~~ '{story.block_name}' ~~\n" states = tracker.past_states(self.domain) states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 for event in story.events: if isinstance(event, UserUttered): - description += f"* {event.as_story_string()}" + description += f"* {event.as_story_string().rstrip()}" elif isinstance(event, ActionExecuted): - description += f" - {event.as_story_string()}" + description += f" - {event.as_story_string().rstrip()}" sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( states[: idx + 1], max_history ) h = hash(str(list(sliced_states))) if h == state_hash: description += " <-- CONFLICT" + else: + # Slots and Forms + description += f" - {event.as_story_string().rstrip()}" idx += 1 description += "\n" story_blocks[block_id] = description + print(tracker.sender_id) for _, block in story_blocks.items(): # print(i) print(block, end="") From 5c55573dc1c479a75d5180b78a560af46f1b2543 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Tue, 5 Nov 2019 07:20:19 +0100 Subject: [PATCH 019/209] Minimal output --- rasa/core/validator.py | 109 +++++++++++------------------------------ 1 file changed, 28 insertions(+), 81 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 8c66260337fd..e82aa0af2d00 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -1,7 +1,8 @@ import logging -import asyncio -from typing import List, Set, Text -from rasa.core.domain import Domain +from typing import Set, Text + +from rasa.core.actions.action import ACTION_LISTEN_NAME +from rasa.core.domain import Domain, PREV_PREFIX from rasa.core.training.generator import TrainingDataGenerator from rasa.importers.importer import TrainingDataImporter from rasa.nlu.training_data import TrainingData @@ -9,7 +10,6 @@ from rasa.core.featurizers import MaxHistoryTrackerFeaturizer from rasa.core.training.dsl import UserUttered from rasa.core.training.dsl import ActionExecuted -from rasa.core.training.dsl import SlotSet from rasa.core.constants import UTTER_PREFIX logger = logging.getLogger(__name__) @@ -168,7 +168,7 @@ def verify_utterances_in_stories(self, ignore_warnings: bool = True) -> bool: def verify_story_structure(self, ignore_warnings: bool = True) -> bool: """Verifies that bot behaviour in stories is deterministic.""" - max_history = 3 + max_history = 5 # Generate the story tree from rasa.utils.story_tree import Tree @@ -211,82 +211,29 @@ def verify_story_structure(self, ignore_warnings: bool = True) -> bool: for state_hash, info in rules.items(): if len(info) > 1: result = False - conflicting_trackers = {i['tracker'].sender_id: i['tracker'] for i in info} - # print(f"CONFLICT {state_hash}: Ambiguity of choice between actions {[i['action'] for i in info]} " - # f"when learning from trackers {conflicting_tracker_names}.") - - if len(conflicting_trackers) == 1: - tracker = list(conflicting_trackers.values())[0] - print(f"\nThe tracker '{tracker.sender_id}' is inconsistent with itself:") - description = "" - idx = 0 - for story in self.story_graph.story_steps: - if story.block_name in tracker.sender_id.split(" > "): - description += f" ~~ '{story.block_name}' ~~\n" - states = tracker.past_states(self.domain) - # ToDo: Check against rasa/core/featurizers.py:318 - states = [dict(state) for state in states] - - for event in story.events: - if isinstance(event, UserUttered): - description += f"* {event.as_story_string()}" - elif isinstance(event, ActionExecuted): - description += f" - {event.as_story_string()}" - sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( - states[: idx + 1], max_history - ) - h = hash(str(list(sliced_states))) - if h == state_hash: - description += " <-- CONFLICT" - idx += 1 - description += "\n" - print(description) - elif len(conflicting_trackers) == 2: - print(f"\nThe trackers '{list(conflicting_trackers.keys())}' are inconsistent:") - story_blocks = {} - for tracker in list(conflicting_trackers.values()): - print() - # print(tracker.sender_id) - idx = 0 - for story in self.story_graph.story_steps: - if story.block_name in tracker.sender_id.split(" > "): - block_id = -1 - for i, s in enumerate(tracker.sender_id.split(" > ")): - if story.block_name == s: - if i in story_blocks: - print(f"DUPLICATE BLOCK {story}") # ToDo: Resolve this OR-problem - block_id = i - break - - assert block_id >= 0 - - description = f"~~ '{story.block_name}' ~~\n" - states = tracker.past_states(self.domain) - states = [dict(state) for state in - states] # ToDo: Check against rasa/core/featurizers.py:318 - for event in story.events: - if isinstance(event, UserUttered): - description += f"* {event.as_story_string().rstrip()}" - elif isinstance(event, ActionExecuted): - description += f" - {event.as_story_string().rstrip()}" - sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( - states[: idx + 1], max_history - ) - h = hash(str(list(sliced_states))) - if h == state_hash: - description += " <-- CONFLICT" - else: - # Slots and Forms - description += f" - {event.as_story_string().rstrip()}" - idx += 1 - description += "\n" - - story_blocks[block_id] = description - - print(tracker.sender_id) - for _, block in story_blocks.items(): - # print(i) - print(block, end="") + tracker = info[0]["tracker"] + states = tracker.past_states(self.domain) + states = [dict(state) for state in states] + last_event_string = None + idx = 0 + for event in tracker.events: + if isinstance(event, ActionExecuted): + sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( + states[: idx + 1], max_history + ) + h = hash(str(list(sliced_states))) + if h == state_hash: + for k, v in sliced_states[-1].items(): + if k.startswith(PREV_PREFIX): + if k[len(PREV_PREFIX):] != ACTION_LISTEN_NAME: + last_event_string = f"action '{k[len(PREV_PREFIX):]}'" + elif k.startswith("intent_") and not last_event_string: + last_event_string = f"intent '{k[len('intent_'):]}'" + break + idx += 1 + print(f"CONFLICT after {last_event_string}: ") + for i in info: + print(f" '{i['action']}' predicted in '{i['tracker'].sender_id}'") return result From bfea1ce37b085a13ea85c756161ff795062532a3 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Tue, 5 Nov 2019 09:34:06 +0100 Subject: [PATCH 020/209] Introduce rasa data validate stories --- rasa/cli/data.py | 27 +++++++++++++++++++++++++++ rasa/core/validator.py | 9 ++------- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 0f8009e035fe..1bea12f42011 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -73,6 +73,16 @@ def add_subparser( validate_parser.set_defaults(func=validate_files) arguments.set_validator_arguments(validate_parser) + validate_subparsers = validate_parser.add_subparsers() + story_structure_parser = validate_subparsers.add_parser( + "stories", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + parents=parents, + help="Checks for inconsistencies in the story files.", + ) + story_structure_parser.set_defaults(func=validate_stories) + arguments.set_validator_arguments(story_structure_parser) + def split_nlu_data(args): from rasa.nlu.training_data.loading import load_data @@ -105,3 +115,20 @@ def validate_files(args): validator = loop.run_until_complete(Validator.from_importer(file_importer)) everything_is_alright = validator.verify_all(not args.fail_on_warnings) sys.exit(0) if everything_is_alright else sys.exit(1) + + +def validate_stories(args): + """Validate all files needed for training a model. + + Fails with a non-zero exit code if there are any errors in the data.""" + from rasa.core.validator import Validator + from rasa.importers.rasa import RasaFileImporter + + loop = asyncio.get_event_loop() + file_importer = RasaFileImporter( + domain_path=args.domain, training_data_paths=args.data + ) + + validator = loop.run_until_complete(Validator.from_importer(file_importer)) + everything_is_alright = validator.verify_story_structure(not args.fail_on_warnings, max_history=5) + sys.exit(0) if everything_is_alright else sys.exit(1) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index e82aa0af2d00..9e755bee8ce5 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -165,11 +165,9 @@ def verify_utterances_in_stories(self, ignore_warnings: bool = True) -> bool: return everything_is_alright - def verify_story_structure(self, ignore_warnings: bool = True) -> bool: + def verify_story_structure(self, ignore_warnings: bool = True, max_history: int = 5) -> bool: """Verifies that bot behaviour in stories is deterministic.""" - max_history = 5 - # Generate the story tree from rasa.utils.story_tree import Tree tree = Tree() @@ -246,7 +244,4 @@ def verify_all(self, ignore_warnings: bool = True) -> bool: logger.info("Validating utterances...") utterances_are_valid = self.verify_utterances_in_stories(ignore_warnings) - logger.info("Validating story-structure...") - stories_are_valid = self.verify_story_structure(ignore_warnings) - - return intents_are_valid and utterances_are_valid and stories_are_valid + return intents_are_valid and utterances_are_valid From 178c67303d8cd1109eadf58369a0acefb1d48b85 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Tue, 5 Nov 2019 09:39:03 +0100 Subject: [PATCH 021/209] Inform user about max_history --- rasa/core/validator.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 9e755bee8ce5..47f0765851fb 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -168,6 +168,8 @@ def verify_utterances_in_stories(self, ignore_warnings: bool = True) -> bool: def verify_story_structure(self, ignore_warnings: bool = True, max_history: int = 5) -> bool: """Verifies that bot behaviour in stories is deterministic.""" + print(f"Assuming max_history = {max_history}") + # Generate the story tree from rasa.utils.story_tree import Tree tree = Tree() @@ -233,6 +235,9 @@ def verify_story_structure(self, ignore_warnings: bool = True, max_history: int for i in info: print(f" '{i['action']}' predicted in '{i['tracker'].sender_id}'") + if result: + print("No conflicts found.") + return result def verify_all(self, ignore_warnings: bool = True) -> bool: From 4488439946af5507f3bf1965f9108fdead2bfddd Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Tue, 5 Nov 2019 09:46:28 +0100 Subject: [PATCH 022/209] Add max_history parameter --- rasa/cli/data.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 1bea12f42011..9f730f7d022d 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -80,6 +80,8 @@ def add_subparser( parents=parents, help="Checks for inconsistencies in the story files.", ) + story_structure_parser.add_argument("--max-history", type=int, default=5, + help="Assume this max_history setting for validation.") story_structure_parser.set_defaults(func=validate_stories) arguments.set_validator_arguments(story_structure_parser) @@ -130,5 +132,5 @@ def validate_stories(args): ) validator = loop.run_until_complete(Validator.from_importer(file_importer)) - everything_is_alright = validator.verify_story_structure(not args.fail_on_warnings, max_history=5) + everything_is_alright = validator.verify_story_structure(not args.fail_on_warnings, max_history=args.max_history) sys.exit(0) if everything_is_alright else sys.exit(1) From 9e8036f1cb4d3de87f090c434dcaa193e7ccea41 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Tue, 5 Nov 2019 16:48:33 +0100 Subject: [PATCH 023/209] Enable `rasa data validate --stories` --- rasa/cli/data.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 9f730f7d022d..7de3fbb80fc3 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -11,7 +11,7 @@ # noinspection PyProtectedMember def add_subparser( - subparsers: argparse._SubParsersAction, parents: List[argparse.ArgumentParser] + subparsers: argparse._SubParsersAction, parents: List[argparse.ArgumentParser] ): import rasa.nlu.convert as convert @@ -58,7 +58,7 @@ def add_subparser( parents=parents, formatter_class=argparse.ArgumentDefaultsHelpFormatter, help="Performs a split of your NLU data into training and test data " - "according to the specified percentages.", + "according to the specified percentages.", ) nlu_split_parser.set_defaults(func=split_nlu_data) @@ -70,6 +70,10 @@ def add_subparser( parents=parents, help="Validates domain and data files to check for possible mistakes.", ) + validate_parser.add_argument("--stories", action="store_true", default=False, + help="Also validate that stories are consistent.") + validate_parser.add_argument("--max-history", type=int, default=5, + help="Assume this max_history setting for story structure validation.") validate_parser.set_defaults(func=validate_files) arguments.set_validator_arguments(validate_parser) @@ -116,6 +120,10 @@ def validate_files(args): validator = loop.run_until_complete(Validator.from_importer(file_importer)) everything_is_alright = validator.verify_all(not args.fail_on_warnings) + if args.stories: + everything_is_alright = everything_is_alright and \ + validator.verify_story_structure(not args.fail_on_warnings, + max_history=args.max_history) sys.exit(0) if everything_is_alright else sys.exit(1) From 3f1d689a353995866ab94439fa7fd27858199a9b Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 6 Nov 2019 17:03:28 +0100 Subject: [PATCH 024/209] Implement verify_story_names --- rasa/cli/data.py | 5 ++++- rasa/core/validator.py | 15 ++++++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 7de3fbb80fc3..068b89a26892 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -140,5 +140,8 @@ def validate_stories(args): ) validator = loop.run_until_complete(Validator.from_importer(file_importer)) - everything_is_alright = validator.verify_story_structure(not args.fail_on_warnings, max_history=args.max_history) + everything_is_alright = ( + validator.verify_story_names(not args.fail_on_warnings) and + validator.verify_story_structure(not args.fail_on_warnings, max_history=args.max_history) + ) sys.exit(0) if everything_is_alright else sys.exit(1) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index c53295ce91f2..53fa4c63e2e9 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -188,6 +188,17 @@ def verify_utterances_in_stories(self, ignore_warnings: bool = True) -> bool: return everything_is_alright + def verify_story_names(self, ignore_warnings: bool = True): + """Verify that story names are unique.""" + names = set() + for step in self.story_graph.story_steps: + if step.block_name in names: + logger.warning("Found duplicate story names") + return ignore_warnings + names.add(step.block_name) + logger.info("All story names are unique") + return True + def verify_story_structure(self, ignore_warnings: bool = True, max_history: int = 5) -> bool: """Verifies that bot behaviour in stories is deterministic.""" @@ -273,7 +284,9 @@ def verify_all(self, ignore_warnings: bool = True) -> bool: there_is_no_duplication = self.verify_example_repetition_in_intents( ignore_warnings ) + all_story_names_unique = self.verify_story_names(ignore_warnings) logger.info("Validating utterances...") stories_are_valid = self.verify_utterances_in_stories(ignore_warnings) - return intents_are_valid and stories_are_valid and there_is_no_duplication + return (intents_are_valid and stories_are_valid and + there_is_no_duplication and all_story_names_unique) From b62c8e9ee22b0fb690f662dad85c0e13e2e0cbe3 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 6 Nov 2019 17:57:59 +0100 Subject: [PATCH 025/209] Implement deduplicate_story_names for rasa data clean --- rasa/cli/data.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 068b89a26892..1472d157453f 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -89,6 +89,15 @@ def add_subparser( story_structure_parser.set_defaults(func=validate_stories) arguments.set_validator_arguments(story_structure_parser) + split_parser = data_subparsers.add_parser( + "clean", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + parents=parents, + help="[Experimental] Ensures that story names are unique.", + ) + split_parser.set_defaults(func=deduplicate_story_names) + arguments.set_validator_arguments(split_parser) + def split_nlu_data(args): from rasa.nlu.training_data.loading import load_data @@ -145,3 +154,42 @@ def validate_stories(args): validator.verify_story_structure(not args.fail_on_warnings, max_history=args.max_history) ) sys.exit(0) if everything_is_alright else sys.exit(1) + + +def deduplicate_story_names(args): + """Changes story names so as to make them unique. + --EXPERIMENTAL-- """ + + from rasa.importers.rasa import RasaFileImporter + + loop = asyncio.get_event_loop() + file_importer = RasaFileImporter( + domain_path=args.domain, training_data_paths=args.data + ) + + story_file_names, _ = data.get_core_nlu_files(args.data) + names = set() + for file_name in story_file_names: + if file_name.endswith(".new"): + continue + with open(file_name, "r") as in_file: + with open(file_name + ".new", "w+") as out_file: + for line in in_file: + if line.startswith("## "): + new_name = line[3:].rstrip() + if new_name in names: + first = new_name + k = 1 + while new_name in names: + new_name = first + f" ({k})" + k += 1 + print(f"- replacing {first} with {new_name}") + names.add(new_name) + out_file.write(f"## {new_name}\n") + else: + out_file.write(line.rstrip() + "\n") + + # story_files, _ = data.get_core_nlu_files(args.data) + # story_steps = loop.run_until_complete(file_importer.get_story_steps()) + # for step in story_steps: + # print(step.block_name) From 0a193e7f80f35ff14a2e049f2a0dae34880cd24a Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Thu, 7 Nov 2019 09:47:30 +0100 Subject: [PATCH 026/209] Remove story_tree.py script since its not used --- rasa/core/validator.py | 3 - rasa/utils/story_tree.py | 729 --------------------------------------- 2 files changed, 732 deletions(-) delete mode 100644 rasa/utils/story_tree.py diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 53fa4c63e2e9..9a0956e01028 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -204,9 +204,6 @@ def verify_story_structure(self, ignore_warnings: bool = True, max_history: int print(f"Assuming max_history = {max_history}") - # Generate the story tree - from rasa.utils.story_tree import Tree - tree = Tree() trackers = TrainingDataGenerator( self.story_graph, domain=self.domain, diff --git a/rasa/utils/story_tree.py b/rasa/utils/story_tree.py deleted file mode 100644 index 4ba4ba160ceb..000000000000 --- a/rasa/utils/story_tree.py +++ /dev/null @@ -1,729 +0,0 @@ -"""This script generates a tree diagram from a story file.""" - -import argparse -import os -from termcolor import colored as term_colored -import pydoc -from math import log2 -import re -import ast - -################################################# -# Command line argument specification -################################################# - -arg_parser = argparse.ArgumentParser(description="Represent story tree.") -arg_parser.add_argument("input", type=str, nargs="+", help="Input story file name(s)") -arg_parser.add_argument("--ambiguities", "-a", default="false", choices=["true", "false", "wizard"], - const="true", nargs="?", help="Display ambiguous branches only") -arg_parser.add_argument("--max-depth", "-m", type=int, default=None, help="Maximum depth") -arg_parser.add_argument("--coloring", "-c", default="role", choices=["role", "r", "ambiguities", "a", "depth", "none", - "n"]) -arg_parser.add_argument("--color-code", default="terminal", choices=["terminal", "markdown"]) -arg_parser.add_argument("--branch", "-b", default="", - help="Restrict output to the given branch of the tree, separated by '/'") -arg_parser.add_argument("--stories", "-s", nargs="+", default=None, - help="Restrict output to the given stories.") -arg_parser.add_argument("--labels", "-l", action="store_true", default=False, - help="Show the names of stories at each node that has no siblings") -arg_parser.add_argument("--page", "-p", action="store_true", default=False, - help="Use pagination for output if necessary") -arg_parser.add_argument("--prune", default="most-visited", choices=["first", "last", "most-visited"], - const="First", nargs="?", help="Selection criterion for kept branches during pruning") -arg_parser.add_argument("--output", "-o", default="tree", choices=["tree", "stats", "pruned", "ok"], - nargs="+", help="What to display") -arg_parser.add_argument("--merge", nargs=1, help="Merge given story file into main file, avoiding ambiguities") - -################################################# -# Helper functions -################################################# - -color_code = "terminal" - - -def colored(string, color): - # noinspection PyUnresolvedReferences - global color_code - if color_code == "terminal": - return term_colored(string, color) - elif color_code == "markdown": - return f"" + string + "" - elif color_code == "none": - return string - else: - raise ValueError("Invalid color_code. Must be one of \"terminal\", \"markdown\", or \"none\"") - - -def slot_to_dict(string): - regex_slot = r"- slot\{([^\}]*)\}$" # Groups: slots - regex_intent = r"\* ([^_\{]+)(?:\{([^\}]*)})?$" # Groups: name | slots - - match = re.search(regex_slot, string) - if match: - return ast.literal_eval("{" + match.group(1) + "}") - else: - match = re.search(regex_intent, string) - if match: - if match.group(2): - return ast.literal_eval("{" + match.group(2) + "}") - return {} - - -################################################# -# Classes -################################################# - -# =============================================== -# Node -# Represents a node in a tree graph -# =============================================== - -class Node: - - def __init__(self, kind="R", state="root", name="root", parent=None, story=""): - self.count = 1 - self.kind = kind - self.state = state - self.name = name - self.parent = parent - self.children = [] - self.labels = [story] - self._is_pruned = False - - def add_child(self, child): - """ - Add the child node `child`, unless it exists already, in which case an error is raised. - :type child: Node - :param child: Child node - """ - if self.get_child(child.state) is None: - self.children.append(child) - else: - raise ValueError(f"A child with the name {child.state} already exists!") - pass - - def get_child(self, state: str): - """ - Get the child with the given name - :param state: Name of the sought child - :return: Child node - """ - for child in self.children: - if child.state == state: - return child - return None - - def print_string(self, branch="", stories=None, max_depth=None, only_ambiguous=False, show_labels=False, - coloring="role", include_users=True, _depth=0, _has_siblings=False): - """ - Recursively generate a string representation of the tree with this node as root. - :param stories: Restrict output to given stories - :param branch: Restrict output to given branch (overwrites `stories`) - :param max_depth: Go no deeper than this - :param show_labels: Indicate branch labels on non-branch points - :param only_ambiguous: Only output ambiguous branches - :param coloring: Coloring rule ('role', 'depth', 'ambiguities') - :param _depth: Recursion depth - only set in recursion step! - :param include_users: When `only_ambiguous` is `True`, include ambiguous user responses - :param _has_siblings: True, iff this node has siblings - only set in recursion step! - :return: The generated string - """ - - # Abort recursion if max depth is reached - if max_depth and _depth >= max_depth: - return "" - - # Decide how to color the present node - if coloring[0] == "r": # "role" - color = {"S": "yellow", "U": "blue", "W": "green"}.get(self.kind) - elif coloring[0] == "d": # "depth" - color = {1: "green", 2: "magenta", 3: "yellow", 4: "cyan", 5: "blue", 6: "grey"}.get(_depth, "grey") - elif coloring[0] == "n": # "none" - color = "none" - elif coloring[0] == "a": # "ambiguities" - if _has_siblings: - color = "red" - elif self.has_descendants_with_siblings: - color = "yellow" - else: - color = "white" - else: - raise ValueError(f"Invalid coloring \"{coloring}\". Must be one of 'roles', 'depth', " - f"'ambiguities' or 'none'.") - - # If only ambiguous nodes should be printed, then print only if there are siblings or descendants with siblings - if (not only_ambiguous) or _has_siblings or self.has_descendants_with_siblings(include_users): - - # Visit count indicator for non-root nodes only - count_str = f" ({self.count})" if _depth > 0 else "" - - # Show branch labels iff visit count is 1 - if show_labels and self.count == 1: - result = "+" + "-" * (2 * _depth) + " " + colored(self.name + count_str, color) \ - + f" <{self.labels[0]}>" + os.linesep - else: - result = "+" + "-" * (2 * _depth) + " " + colored(self.name + count_str, color) \ - + os.linesep - else: - # We show only ambiguous branches, and this node is not root of an ambiguous branch - result = "" - - # Prepare _has_siblings for recursion step - has_siblings = (len(self.children) > 1) - if has_siblings and not include_users: - all_children_are_users = all(child.kind == "U" or child.kind == "S" for child in self.children) - has_siblings = not all_children_are_users - - # Recursion step into all child nodes - if branch: - # Output should be restricted to `branch` - path = branch.split("/") # Split the branch spec into one for each level - sought_name = path[0] # First entry is where we should step into now - remain_branch = "/".join(path[1:]) # Remaining entries have to be passed on to recursive call - - for child in self.children: - if child.name == sought_name or sought_name == "*": - result += child.print_string(branch=remain_branch, - stories=stories, - max_depth=max_depth, - only_ambiguous=only_ambiguous, - show_labels=show_labels, - include_users=include_users, - coloring=coloring, - _depth=_depth + 1, - _has_siblings=has_siblings) - else: - # No branch restriction -> step into all child branches unless stories are restricted - for child in self.children: - if stories is None or not set(child.labels).isdisjoint(stories): - result += child.print_string(branch=branch, - stories=stories, - max_depth=max_depth, - only_ambiguous=only_ambiguous, - show_labels=show_labels, - include_users=include_users, - coloring=coloring, - _depth=_depth + 1, - _has_siblings=has_siblings) - return result - - def has_descendants_with_siblings(self, include_users): - """ - Boolean that indicates if there are any descendants that have siblings. - :return: True, iff a descendant node has siblings - """ - if len(self.children) > 1: - if include_users: - return True - else: - all_children_are_users = all(child.kind == "U" or child.kind == "S" for child in self.children) - if all_children_are_users: - return any(child.has_descendants_with_siblings(include_users) for child in self.children) - else: - return True - elif len(self.children) == 1: - return list(self.children)[0].has_descendants_with_siblings(include_users) - else: - return False - - def __str__(self): - return self.print_string() - - def prune(self, keep: str): - """Removes all ambiguous branches""" - if self._is_pruned: - return - if len(self.children) > 0: - if len(self.children) > 1: - if any([child.kind == "W" for child in self.children]): - if keep == "first": - del self.children[1:] - elif keep == "last": - del self.children[:-1] - elif keep == "most-visited": - visit_counts = [len(c.labels) for c in self.children] - keep_idx = visit_counts.index(max(visit_counts)) - # Delete all but the one at `keep_idx` - del self.children[:keep_idx] - if len(self.children) > 1: - del self.children[1:] - else: - raise ValueError("Invalid prune keep criterion.") - for child in self.children: - child.prune(keep) - self._is_pruned = True - - def conflicts(self, keep: str, _depth=0, _leading_steps=[]): - """Return list of conflict points""" - if self._is_pruned: - return None - - conflicts = [] - if len(self.children) > 0: - if len(self.children) > 1: - if any([child.kind == "W" for child in self.children]): - conflicts.append({ - "ambiguity": [{"stories": c.labels, "action": c.name} for c in self.children if c.kind == "W"], - "conflict_step": _depth, - "leading_steps": _leading_steps + [self.name], - }) - if keep == "first": - del self.children[1:] - elif keep == "last": - del self.children[:-1] - elif keep == "most-visited": - visit_counts = [len(c.labels) for c in self.children] - keep_idx = visit_counts.index(max(visit_counts)) - # Delete all but the one at `keep_idx` - del self.children[:keep_idx] - if len(self.children) > 1: - del self.children[1:] - else: - raise ValueError("Invalid prune keep criterion.") - - for child in self.children: - child_conflicts = child.conflicts(keep, _depth+1, _leading_steps + [self.name]) - if child_conflicts: - conflicts += child_conflicts - - self._is_pruned = True - - return conflicts - - def remove(self, story) -> bool: - """Remove the given story from this node and recursively from all - descendants. """ - if story in self.labels: - # Remove the story from internal stats - self.labels = [label for label in self.labels if label != story] - # Recurse through all children - new_children = [] - for child in self.children: - if not child.remove(story): - # Only retain children that did not self-delete - new_children.append(child) - else: - # Delete this child node - self.count -= 1 - del child - - self.children = new_children - - assert (len(self.labels) == 0) == (self.count == 0) - - # If this node had no other stories than the one we just - # deleted, then return True, and False otherwise - return len(self.labels) == 0 - - @property - def leafs(self): - leafs = set() - - # noinspection PyUnusedLocal - def callback_discover_leaf(node, *args): - assert len(node.labels) >= 1, f"Leaf has no story assigned!" - # Leafs may have multiple stories assigned, iff stories have duplicates - # Ignore duplicates iff the tree was pruned - if self._is_pruned: - leafs.add(node.labels[0]) - else: - for story in node.labels: - leafs.add(story) - return {} - - self._depth_first_search({"discover_leaf": callback_discover_leaf}, {}) - - return leafs - - @property - def duplicates(self): - duplicates = [] - - # noinspection PyUnusedLocal - def callback_discover_leaf(node, *args): - assert len(node.labels) >= 1, f"Leaf has no story assigned!" - # Leafs have multiple stories assigned, iff stories have duplicates - if len(node.labels) > 1: - duplicates.append(node.labels) - return {} - - self._depth_first_search({"discover_leaf": callback_discover_leaf}, {}) - - return duplicates - - def stats(self): - """ - Collects statistics about the tree that has this node as a root. - :return: Dict with statistical information - """ - statistics = { - "num_nodes": 0, # Total number of nodes in the tree (steps in all dialogues) - "num_nodes_with_multiple_children": 0, # Number of nodes that have multiple children - "num_leaves": 0, # How many stories are present? - "depth": 0, # How deep is the graph - "ambiguity_depth": 0, # How deep is the deepest branch point? - "ambiguity_chain_length": 0, # How many branch points follow each other (max)? - "ambiguity_level": 0, # How many leaves are connected to root via branch points? - "story_stats": {} # Stats about individual stories - } - - def callback_discover_node(node, depth, flags): - statistics["num_nodes"] += 1 - if len(node.children) > 1: - if any([child.kind == "W" for child in node.children]): - statistics["num_nodes_with_multiple_children"] += 1 - statistics["ambiguity_depth"] = max(statistics["ambiguity_depth"], depth) - statistics["ambiguity_chain_length"] = max(statistics["ambiguity_chain_length"], - flags["ambiguity_chain_length"] + 1) - if flags["linear_so_far"]: - statistics["ambiguity_level"] += node.count - for story in node.labels: - if story in statistics["story_stats"]: - statistics["story_stats"][story]["ambiguity_length"] += 1 - statistics["story_stats"][story]["related_to"].update(node.labels) - else: - statistics["story_stats"][story] = { - "length": depth, - "ambiguity_length": 1, - "related_to": set(node.labels) - } - return { - "linear_so_far": False, - "ambiguity_chain_length": flags["ambiguity_chain_length"] + 1 - } - return {} - - # noinspection PyUnusedLocal - def callback_discover_leaf(node, depth, flags): - statistics["num_leaves"] += 1 - statistics["depth"] = max(statistics["depth"], depth) - story = node.labels[0] - if story in statistics["story_stats"]: - statistics["story_stats"][story]["length"] = depth - return {} - - self._depth_first_search({ - "discover_node": callback_discover_node, - "discover_leaf": callback_discover_leaf - }, flags={"linear_so_far": True, "ambiguity_chain_length": 0}) - - return statistics - - def _depth_first_search(self, callbacks, flags, _depth=0): - new_flags = flags.copy() - if len(self.children) == 0 and "discover_leaf" in callbacks: - new_flags.update(callbacks["discover_leaf"](self, _depth, flags)) - return - for child in self.children: - if "discover_node" in callbacks: - new_flags.update(callbacks["discover_node"](child, _depth + 1, flags)) - # noinspection PyProtectedMember - child._depth_first_search(callbacks, new_flags, _depth + 1) - - -# =============================================== -# Tree -# Represents a tree graph -# =============================================== - -class Tree: - - def __init__(self): - self.root = Node() # Root node, should never change - self.pointer = self.root # Pointer to the currently active node - self.label = "" # Label for active branch - - def add_or_goto(self, kind, state, name): - """ - If a branch with name `name` is a child of the currently active node, then move `self.pointer` - to that branch and update visit counts and branch name lists. Otherwise, create a new child - branch with this name and move the pointer to it. - :param kind: U/S/W for user/slot/wizard - :param state: State string of the (new) branch to go to - :param name: Name of the (new) branch to go to - :return: True, iff a new branch was created - """ - # Check if branch with name `name` exists - for branch in self.pointer.children: - if branch.state == state: - branch.count += 1 # Increase visit count - branch.labels += [self.label] # Append new branch label - self.pointer = branch # Move pointer to this branch - return False - - # Add a new branch - new_branch = Node(kind, state, name, parent=self.pointer, story=self.label) - self.pointer.add_child(new_branch) - self.pointer = new_branch - return True - - def adding_creates_ambiguity(self, state: str): - """ - Returns True iff adding a branch with this name would result in an ambiguity in this tree, - i.e. another child node exists, which is as Wizard node. - :param state: State of the branch (user/wizard action) - :return: True iff ambiguous - """ - return state.startswith("W") and any(c.state.startswith("W") for c in self.pointer.children) - - def up(self): - """ - Move the active branch pointer one step towards root - :return: True, iff active branch is not already on root - """ - if self.pointer != self.root: - self.pointer = self.pointer.parent - return True - else: - return False - - def reset(self, story): - """ - Reset the active branch pointer to root and specify a new story label to use in `self.add_or_goto`. - :param story: New story label - """ - self.pointer = self.root - self.label = story - - def remove(self, story=None): - """ - Remove the given story, or the story with the name stored in self.label - :param story: Name of the story - """ - if story: - self.root.remove(story) - else: - self.root.remove(self.label) - - def to_string(self, branch="", max_depth=None, show_labels=False, only_ambiguous=False, coloring="role", - include_users=True, stories=None): - """ - Create a string representation of the tree. - :param stories: Restrict output to given stories - :param branch: Restrict output to given branch (overwrites `stories`) - :param max_depth: Go no deeper than this - :param show_labels: Indicate branch labels on non-branch points - :param only_ambiguous: Only output ambiguous branches - :param coloring: Coloring rule ('role', 'depth', 'ambiguities') - :param include_users: When `only_ambiguous` is `True`, include ambiguous user responses - :return: The generated string - """ - return self.root.print_string(branch=branch, max_depth=max_depth, include_users=include_users, - show_labels=show_labels, only_ambiguous=only_ambiguous, coloring=coloring, - stories=stories) - - def __str__(self): - return self.root.print_string() - - def prune(self, *args): - self.root.prune(*args) - - @property - def leafs(self): - return self.root.leafs - - @property - def duplicates(self): - return self.root.duplicates - - @property - def conflicts(self): - return self.root.conflicts("most-visited") - - def stats(self) -> dict: - """ - Compute statistics about this tree. - :return: The generated dict with statistical information - """ - return self.root.stats() - - -################################################# -# Main -################################################# - -if __name__ == '__main__': - - def main(): - - stats = None - - # Read command line arguments - args = arg_parser.parse_args() - story_file_names = args.input # Input file name - global color_code - color_code = args.color_code # "terminal" / "markdown" - - # Generate the story tree - n = 0 - tree = Tree() - slots = {} - for story_file_name in story_file_names: - with open(story_file_name, "r") as story_file: - for line in story_file: - if line.startswith("##"): - n += 1 - tree.reset(story=line[2:].strip()) - slots.clear() - else: - if line.lstrip().startswith("*"): - name = "U: " - elif line.lstrip().startswith("- slot"): - name = "S: " - else: - name = "W: " - - if name in ["U: ", "S: "]: - # Slots might have been updated -> keep track of it - new_slots = slot_to_dict(line) - copy_slots = slots.copy() - copy_slots.update(new_slots) - if copy_slots.items() == slots.items(): - # Setting this slot does not change anything - if name == "S: ": - # Ignore redundant slot lines - name = "" - line = "" - else: - slots.update(new_slots) - - name += line.strip()[2:] - if line.strip(): - tree.add_or_goto(name) - - # Merge other story file (only take in stories that don't create ambiguities) - if args.merge: - successful_merge = [] # Stories that got merged in successfully - total_num_merge = 0 # Total number of stories that should have been merged - for story_file_name in args.merge: - with open(story_file_name, "r") as story_file: - active_story = "" - for line in story_file: - if line.startswith("##"): - if active_story: - # The previous story was merged all the way and - # thus `active_story` was not set to `""`. In this - # case, we remember the name of the story that - # merged successfully - successful_merge.append(active_story) - total_num_merge += 1 - n += 1 - active_story = line[2:].strip() - tree.reset(story=active_story) - else: - if active_story: - if line.lstrip().startswith("*"): - name = "U: " - elif line.lstrip().startswith("- slot"): - name = "S: " - else: - name = "W: " - name += line.strip()[2:] - if line.strip(): - if tree.adding_creates_ambiguity(name): - # Merging `active_story` would create ambiguity - tree.remove() - active_story = "" - else: - tree.add_or_goto(name) - - # Display the tree if required - if "tree" in args.output: - _print = pydoc.pager if args.page else print - _print(tree.to_string(only_ambiguous=(args.ambiguities in ["true", "wizard"]), - include_users=(args.ambiguities != "wizard"), - max_depth=args.max_depth, - show_labels=args.labels, - coloring=args.coloring, - branch=args.branch, - stories=args.stories)) - - # Display statistics if required - if "stats" in args.output: - stats = tree.stats() - duplicates = tree.duplicates - - print() - print(colored("Text summary:", "cyan")) - if duplicates: - print(f"The input contains {stats['num_leaves']} stories, but there are some duplicates (see below).") - else: - print(f"The input contains {stats['num_leaves']} unique stories.") - print(f"The longest story is {stats['depth']} nodes deep.") - print( - f"{stats['num_nodes_with_multiple_children']} / {stats['num_nodes']} = " - f"{100.0 * stats['num_nodes_with_multiple_children'] / stats['num_nodes']:.2f}% of all nodes " - f"have multiple children.") - print(f"The deepest branch point occurs after {stats['ambiguity_depth']} steps.") - print(f"We encounter up to {stats['ambiguity_chain_length']} branch points in a single story.") - print(f"{stats['ambiguity_level']} / {stats['num_leaves']} = " - f"{100.0 * stats['ambiguity_level'] / stats['num_leaves']:.2f}% of all stories are ambiguous.") - if args.merge: - # noinspection PyUnboundLocalVariable - print(f"Successfully merged {len(successful_merge)} out of {total_num_merge} stories.") - - if duplicates: - print() - print(colored("Duplicate stories:", "cyan")) - for d in duplicates: - print(d) - - print() - print(colored("Statistics table:", "cyan")) - print(f"num stories: {stats['num_leaves']}") - print(f"max turns: {stats['depth']}") - print(f"num nodes: {stats['num_nodes']}") - print(f"branch-points: {stats['num_nodes_with_multiple_children']} " - f"({100.0 * stats['num_nodes_with_multiple_children'] / stats['num_nodes']:.2f}%)") - print(f"ambiguity depth: {stats['ambiguity_depth']}") - print(f"ambiguity length: {stats['ambiguity_chain_length']}") - print(f"ambiguity level: {stats['ambiguity_level']} " - f"({100.0 * stats['ambiguity_level'] / stats['num_leaves']:.2f}%)") - if stats['ambiguity_level'] > 0.0: - print(f"ambiguity log: log2({stats['ambiguity_level']}) = {log2(stats['ambiguity_level']):.2f} ") - - tree.prune(args.prune) - pruned_stats = tree.stats() - print() - print(colored("After pruning:", "cyan")) - print(f"num stories: {pruned_stats['num_leaves']} " - f"({stats['num_leaves'] - pruned_stats['num_leaves']} fewer)") - print(f"max turns: {pruned_stats['depth']} " - f"({stats['depth'] - pruned_stats['depth']} fewer)") - print(f"num nodes: {pruned_stats['num_nodes']} " - f"({stats['num_nodes'] - pruned_stats['num_nodes']} fewer)") - - if len(stats["story_stats"]) > 0: - print() - print(colored("Most ambiguous stories:", "cyan")) - print(f"{'Story':>15}", f"{'# relations':>14}", f"{'# branchings':>14}", f"{'# turns':>14}") - for story, values in sorted(stats["story_stats"].items(), - key=lambda kv: [ - len(kv[1]["related_to"]), - kv[1]["ambiguity_length"], - kv[1]["length"]], - reverse=True)[:12]: - print(f"{story:>15} {len(values['related_to']):>14} {values['ambiguity_length']:>14} " - f"{values['length']:>14}") - - print() - - # Display remaining stories after pruning, if required - if "pruned" in args.output: - if args.merge: - for story in sorted(successful_merge): - print(story) - else: - tree.prune(args.prune) - for story in sorted(tree.leafs): - print(story) - - if "ok" in args.output: - if not stats: - stats = tree.stats() - if stats['num_nodes_with_multiple_children'] > 0: - print("False") - else: - print("True") - - main() From a0c8adce217d5479cfdd13877ecadf558b52552a Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Thu, 7 Nov 2019 10:40:07 +0100 Subject: [PATCH 027/209] Use logger instead of print --- rasa/core/validator.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 9a0956e01028..823140e4eb28 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -202,7 +202,8 @@ def verify_story_names(self, ignore_warnings: bool = True): def verify_story_structure(self, ignore_warnings: bool = True, max_history: int = 5) -> bool: """Verifies that bot behaviour in stories is deterministic.""" - print(f"Assuming max_history = {max_history}") + logger.info("Story structure validation...") + logger.info(f"Assuming max_history = {max_history}") trackers = TrainingDataGenerator( self.story_graph, @@ -211,7 +212,6 @@ def verify_story_structure(self, ignore_warnings: bool = True, max_history: int augmentation_factor=0).generate() rules = {} for tracker in trackers: - # print(tracker.sender_id) states = tracker.past_states(self.domain) states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 @@ -237,7 +237,6 @@ def verify_story_structure(self, ignore_warnings: bool = True, max_history: int "action": event.as_story_string() }] idx += 1 - print() result = True for state_hash, info in rules.items(): if len(info) > 1: @@ -262,12 +261,13 @@ def verify_story_structure(self, ignore_warnings: bool = True, max_history: int last_event_string = f"intent '{k[len('intent_'):]}'" break idx += 1 - print(f"CONFLICT after {last_event_string}: ") + conflict_string = f"CONFLICT after {last_event_string}:\n" for i in info: - print(f" '{i['action']}' predicted in '{i['tracker'].sender_id}'") + conflict_string += f" '{i['action']}' predicted in '{i['tracker'].sender_id}'\n" + logger.warning(conflict_string) if result: - print("No conflicts found.") + logger.info("No story structure conflicts found.") return result From 9f124e9e45d4929e1fd918e8756b295db7f03726 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Thu, 7 Nov 2019 10:50:40 +0100 Subject: [PATCH 028/209] Reduce cognitive complexity --- rasa/cli/data.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 1472d157453f..7333e5d3a204 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -172,22 +172,23 @@ def deduplicate_story_names(args): for file_name in story_file_names: if file_name.endswith(".new"): continue - with open(file_name, "r") as in_file: - with open(file_name + ".new", "w+") as out_file: - for line in in_file: - if line.startswith("## "): - new_name = line[3:].rstrip() - if new_name in names: - first = new_name - k = 1 - while new_name in names: - new_name = first + f" ({k})" - k += 1 - print(f"- replacing {first} with {new_name}") - names.add(new_name) - out_file.write(f"## {new_name}\n") - else: - out_file.write(line.rstrip() + "\n") + with open(file_name, "r") as in_file, \ + open(file_name + ".new", "w+") as out_file: + for line in in_file: + line = line.rstrip() + if line.startswith("## "): + new_name = line[3:] + if new_name in names: + first = new_name + k = 1 + while new_name in names: + new_name = first + f" ({k})" + k += 1 + print(f"- replacing {first} with {new_name}") + names.add(new_name) + out_file.write(f"## {new_name}\n") + else: + out_file.write(line + "\n") # story_files, _ = data.get_core_nlu_files(args.data) # story_steps = loop.run_until_complete(file_importer.get_story_steps()) From 5454a3bf130b67951e94ac46c9a4ac8b857ef468 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Thu, 7 Nov 2019 10:59:17 +0100 Subject: [PATCH 029/209] Split add_subparser to reduce lines of code --- rasa/cli/data.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 7333e5d3a204..e2c315d7772d 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -13,8 +13,6 @@ def add_subparser( subparsers: argparse._SubParsersAction, parents: List[argparse.ArgumentParser] ): - import rasa.nlu.convert as convert - data_parser = subparsers.add_parser( "data", conflict_handler="resolve", @@ -25,6 +23,16 @@ def add_subparser( data_parser.set_defaults(func=lambda _: data_parser.print_help(None)) data_subparsers = data_parser.add_subparsers() + + _add_data_convert_parsers(data_subparsers, parents) + _add_data_split_parsers(data_subparsers, parents) + _add_data_validate_parsers(data_subparsers, parents) + + +def _add_data_convert_parsers( + data_subparsers, parents: List[argparse.ArgumentParser] +): + import rasa.nlu.convert as convert convert_parser = data_subparsers.add_parser( "convert", formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -44,6 +52,10 @@ def add_subparser( arguments.set_convert_arguments(convert_nlu_parser) + +def _add_data_split_parsers( + data_subparsers, parents: List[argparse.ArgumentParser] +): split_parser = data_subparsers.add_parser( "split", formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -64,6 +76,10 @@ def add_subparser( arguments.set_split_arguments(nlu_split_parser) + +def _add_data_validate_parsers( + data_subparsers, parents: List[argparse.ArgumentParser] +): validate_parser = data_subparsers.add_parser( "validate", formatter_class=argparse.ArgumentDefaultsHelpFormatter, From 762b14f0cc2d67aeb711fcef272913dfa81f8380 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 11 Nov 2019 13:29:30 +0100 Subject: [PATCH 030/209] Improve logging text --- rasa/core/validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 823140e4eb28..31de44d32d28 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -277,7 +277,7 @@ def verify_all(self, ignore_warnings: bool = True) -> bool: logger.info("Validating intents...") intents_are_valid = self.verify_intents_in_stories(ignore_warnings) - logger.info("Validating there is no duplications...") + logger.info("Validating uniqueness of intents and stories...") there_is_no_duplication = self.verify_example_repetition_in_intents( ignore_warnings ) From 4cc621a6adedf28dbe4b63c80cd49a0815db77b8 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 11 Nov 2019 13:59:55 +0100 Subject: [PATCH 031/209] Use MESSAGE_INTENT_ATTRIBUTE --- rasa/core/validator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 31de44d32d28..ac95b5a6f5ef 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -5,6 +5,7 @@ from rasa.core.actions.action import ACTION_LISTEN_NAME from rasa.core.training.generator import TrainingDataGenerator from rasa.importers.importer import TrainingDataImporter +from rasa.nlu.constants import MESSAGE_INTENT_ATTRIBUTE from rasa.nlu.training_data import TrainingData from rasa.core.training.structures import StoryGraph from rasa.core.featurizers import MaxHistoryTrackerFeaturizer @@ -257,8 +258,8 @@ def verify_story_structure(self, ignore_warnings: bool = True, max_history: int if k.startswith(PREV_PREFIX): if k[len(PREV_PREFIX):] != ACTION_LISTEN_NAME: last_event_string = f"action '{k[len(PREV_PREFIX):]}'" - elif k.startswith("intent_") and not last_event_string: - last_event_string = f"intent '{k[len('intent_'):]}'" + elif k.startswith(MESSAGE_INTENT_ATTRIBUTE + "_") and not last_event_string: + last_event_string = f"intent '{k[len(MESSAGE_INTENT_ATTRIBUTE + '_'):]}'" break idx += 1 conflict_string = f"CONFLICT after {last_event_string}:\n" From 1ff31faf42ea03a5a2fa8a9f54f4fb978ebc516b Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 11 Nov 2019 15:07:22 +0100 Subject: [PATCH 032/209] Output all conflicting stories --- rasa/core/validator.py | 104 ++++++++++++++++++++++++----------------- 1 file changed, 60 insertions(+), 44 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index ac95b5a6f5ef..646f1f8a2e17 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -200,6 +200,18 @@ def verify_story_names(self, ignore_warnings: bool = True): logger.info("All story names are unique") return True + @staticmethod + def _last_event_string(sliced_states): + last_event_string = None + for k, v in sliced_states[-1].items(): + if k.startswith(PREV_PREFIX): + if k[len(PREV_PREFIX):] != ACTION_LISTEN_NAME: + last_event_string = f"action '{k[len(PREV_PREFIX):]}'" + elif k.startswith(MESSAGE_INTENT_ATTRIBUTE + "_") and not last_event_string: + last_event_string = f"intent '{k[len(MESSAGE_INTENT_ATTRIBUTE + '_'):]}'" + + return last_event_string + def verify_story_structure(self, ignore_warnings: bool = True, max_history: int = 5) -> bool: """Verifies that bot behaviour in stories is deterministic.""" @@ -224,53 +236,57 @@ def verify_story_structure(self, ignore_warnings: bool = True, max_history: int ) h = hash(str(list(sliced_states))) if h in rules: - known_actions = [info["action"] for info in rules[h]] - if event.as_story_string() not in known_actions: - # print(f"{h} >> {event.as_story_string()} and {known_actions}") - rules[h] += [{ - "tracker": tracker, - "action": event.as_story_string() - }] + if event.as_story_string() not in rules[h]: + rules[h] += [event.as_story_string()] else: - # print(f"{h} >> {event.as_story_string()}") - rules[h] = [{ - "tracker": tracker, - "action": event.as_story_string() - }] + rules[h] = [event.as_story_string()] idx += 1 - result = True - for state_hash, info in rules.items(): - if len(info) > 1: - result = False - tracker = info[0]["tracker"] - states = tracker.past_states(self.domain) - states = [dict(state) for state in states] - last_event_string = None - idx = 0 - for event in tracker.events: - if isinstance(event, ActionExecuted): - sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( - states[: idx + 1], max_history - ) - h = hash(str(list(sliced_states))) - if h == state_hash: - for k, v in sliced_states[-1].items(): - if k.startswith(PREV_PREFIX): - if k[len(PREV_PREFIX):] != ACTION_LISTEN_NAME: - last_event_string = f"action '{k[len(PREV_PREFIX):]}'" - elif k.startswith(MESSAGE_INTENT_ATTRIBUTE + "_") and not last_event_string: - last_event_string = f"intent '{k[len(MESSAGE_INTENT_ATTRIBUTE + '_'):]}'" - break - idx += 1 - conflict_string = f"CONFLICT after {last_event_string}:\n" - for i in info: - conflict_string += f" '{i['action']}' predicted in '{i['tracker'].sender_id}'\n" - logger.warning(conflict_string) - - if result: - logger.info("No story structure conflicts found.") - return result + # Keep only conflicting rules + rules = {state: actions for (state, actions) in rules.items() if len(actions) > 1} + + conflicts = {} + + for tracker in trackers: + states = tracker.past_states(self.domain) + states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 + + idx = 0 + for event in tracker.events: + if isinstance(event, ActionExecuted): + sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( + states[: idx + 1], max_history + ) + h = hash(str(list(sliced_states))) + if h in rules: + # Get the last event + last_event_string = self._last_event_string(sliced_states) + + # Fill `conflicts` dict + # {hash: {last_event: {action_1: [stories...], action_2: [stories...], ...}, ...}, ...} + if h not in conflicts: + conflicts[h] = {last_event_string: {event.as_story_string(): [tracker.sender_id]}} + else: + if last_event_string not in conflicts[h]: + conflicts[h][last_event_string] = {event.as_story_string(): [tracker.sender_id]} + else: + if event.as_story_string() not in conflicts[h][last_event_string]: + conflicts[h][last_event_string][event.as_story_string()] = [tracker.sender_id] + else: + conflicts[h][last_event_string][event.as_story_string()].append(tracker.sender_id) + idx += 1 + + if len(conflicts) == 0: + logger.info("No story structure conflicts found.") + else: + for conflict in list(conflicts.values()): + for state, actions_and_stories in conflict.items(): + conflict_string = f"CONFLICT after {state}:\n" + for action, stories in actions_and_stories.items(): + conflict_string += f" {action} predicted in {stories}\n" + logger.warning(conflict_string) + + return len(conflicts) > 0 def verify_all(self, ignore_warnings: bool = True) -> bool: """Runs all the validations on intents and utterances.""" From 6426cc18c677ceebfad2369688a6baf02752970a Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 13 Nov 2019 11:20:29 +0100 Subject: [PATCH 033/209] Check story structure on `data validate` if error-free otherwise --- rasa/cli/data.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index cc3a47287938..9c85c87d4949 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -27,6 +27,7 @@ def add_subparser( _add_data_convert_parsers(data_subparsers, parents) _add_data_split_parsers(data_subparsers, parents) _add_data_validate_parsers(data_subparsers, parents) + _add_data_clean_parsers(data_subparsers, parents) def _add_data_convert_parsers( @@ -86,8 +87,6 @@ def _add_data_validate_parsers( parents=parents, help="Validates domain and data files to check for possible mistakes.", ) - validate_parser.add_argument("--stories", action="store_true", default=False, - help="Also validate that stories are consistent.") validate_parser.add_argument("--max-history", type=int, default=5, help="Assume this max_history setting for story structure validation.") validate_parser.set_defaults(func=validate_files) @@ -105,14 +104,19 @@ def _add_data_validate_parsers( story_structure_parser.set_defaults(func=validate_stories) arguments.set_validator_arguments(story_structure_parser) - split_parser = data_subparsers.add_parser( + +def _add_data_clean_parsers( + data_subparsers, parents: List[argparse.ArgumentParser] +): + + clean_parser = data_subparsers.add_parser( "clean", formatter_class=argparse.ArgumentDefaultsHelpFormatter, parents=parents, help="[Experimental] Ensures that story names are unique.", ) - split_parser.set_defaults(func=deduplicate_story_names) - arguments.set_validator_arguments(split_parser) + clean_parser.set_defaults(func=deduplicate_story_names) + arguments.set_validator_arguments(clean_parser) def split_nlu_data(args): @@ -149,9 +153,10 @@ def validate_files(args): sys.exit(1) everything_is_alright = validator.verify_all(not args.fail_on_warnings) - if args.stories: - everything_is_alright = everything_is_alright and \ - validator.verify_story_structure(not args.fail_on_warnings, + if everything_is_alright: + # Only run story structure validation if everything else is fine + # since this might take a while + everything_is_alright = validator.verify_story_structure(not args.fail_on_warnings, max_history=args.max_history) sys.exit(0) if everything_is_alright else sys.exit(1) From 379165bec7d9aac8e393606169f261563879b400 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 13 Nov 2019 11:32:45 +0100 Subject: [PATCH 034/209] Setup --prompt flag for `rasa data validate stories` --- rasa/cli/data.py | 8 +++++++- rasa/core/validator.py | 9 ++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 9c85c87d4949..e88b14ce18aa 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -101,6 +101,8 @@ def _add_data_validate_parsers( ) story_structure_parser.add_argument("--max-history", type=int, default=5, help="Assume this max_history setting for validation.") + story_structure_parser.add_argument("--prompt", action="store_true", default=False, + help="Ask how conflicts should be fixed") story_structure_parser.set_defaults(func=validate_stories) arguments.set_validator_arguments(story_structure_parser) @@ -176,7 +178,11 @@ def validate_stories(args): validator = loop.run_until_complete(Validator.from_importer(file_importer)) everything_is_alright = ( validator.verify_story_names(not args.fail_on_warnings) and - validator.verify_story_structure(not args.fail_on_warnings, max_history=args.max_history) + validator.verify_story_structure( + not args.fail_on_warnings, + max_history=args.max_history, + prompt=args.prompt + ) ) sys.exit(0) if everything_is_alright else sys.exit(1) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index b87ba7f30e08..3f9d306a035c 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -206,7 +206,10 @@ def _last_event_string(sliced_states): return last_event_string - def verify_story_structure(self, ignore_warnings: bool = True, max_history: int = 5) -> bool: + def verify_story_structure(self, + ignore_warnings: bool = True, + max_history: int = 5, + prompt: bool = False) -> bool: """Verifies that bot behaviour in stories is deterministic.""" logger.info("Story structure validation...") @@ -280,6 +283,10 @@ def verify_story_structure(self, ignore_warnings: bool = True, max_history: int conflict_string += f" {action} predicted in {stories}\n" logger.warning(conflict_string) + # Fix the conflict if required + if prompt: + raise NotImplementedError + return len(conflicts) > 0 def verify_all(self, ignore_warnings: bool = True) -> bool: From 0aa8c772b68c7399f51602dcd3dc2524cbd48fca Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 13 Nov 2019 11:46:06 +0100 Subject: [PATCH 035/209] List duplicate story names --- rasa/core/validator.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 3f9d306a035c..8c27da80d95b 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -185,14 +185,31 @@ def verify_utterances_in_stories(self, ignore_warnings: bool = True) -> bool: def verify_story_names(self, ignore_warnings: bool = True): """Verify that story names are unique.""" - names = set() + + # Tally story names, e.g. {"story_1": 3, "story_2": 1, ...} + name_tally = {} for step in self.story_graph.story_steps: - if step.block_name in names: - logger.warning("Found duplicate story names") - return ignore_warnings - names.add(step.block_name) - logger.info("All story names are unique") - return True + if step.block_name in name_tally: + name_tally[step.block_name] += 1 + else: + name_tally[step.block_name] = 1 + + # Find story names that appear more than once + # and construct a warning message + result = True + message = "" + for name, count in name_tally.items(): + if count > 1: + if result: + message = f"Found duplicate story names:\n" + result = False + message += f" '{name}' appears {count}x\n" + + if result: + logger.info("All story names are unique") + else: + logger.error(message) + return result @staticmethod def _last_event_string(sliced_states): From b5591fe0207808c0f6ce2f6b6f6702039efd036b Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 13 Nov 2019 11:52:32 +0100 Subject: [PATCH 036/209] Make output of story names more user friendly --- rasa/core/validator.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 8c27da80d95b..0f17812fd456 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -297,6 +297,14 @@ def verify_story_structure(self, for state, actions_and_stories in conflict.items(): conflict_string = f"CONFLICT after {state}:\n" for action, stories in actions_and_stories.items(): + if len(stories) == 1: + stories = f"'{stories[0]}'" + elif len(stories) == 2: + stories = f"'{stories[0]}' and '{stories[1]}'" + elif len(stories) == 3: + stories = f"'{stories[0]}', '{stories[1]}', and '{stories[2]}'" + elif len(stories) >= 4: + stories = f"'{stories[0]}' and {len(stories) - 1} other stories" conflict_string += f" {action} predicted in {stories}\n" logger.warning(conflict_string) From fb6c162eb3896a854ee0b8c158b9424bc2ee63ed Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 13 Nov 2019 14:54:32 +0100 Subject: [PATCH 037/209] Introduce StoryConflict --- rasa/core/story_conflict.py | 72 +++++++++++++++++++++++++++++++++++++ rasa/core/validator.py | 34 ++++-------------- 2 files changed, 79 insertions(+), 27 deletions(-) create mode 100644 rasa/core/story_conflict.py diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py new file mode 100644 index 000000000000..3619c8feac7e --- /dev/null +++ b/rasa/core/story_conflict.py @@ -0,0 +1,72 @@ + +from typing import List, Optional, Dict, Text + +from rasa.core.actions.action import ACTION_LISTEN_NAME +from rasa.core.domain import PREV_PREFIX +from rasa.core.events import Event +from rasa.nlu.constants import MESSAGE_INTENT_ATTRIBUTE +from rasa.core.training.generator import TrackerWithCachedStates + + +class StoryConflict: + + def __init__( + self, + sliced_states: List[Optional[Dict[Text, float]]], + tracker: TrackerWithCachedStates, + event + ): + self.sliced_states = sliced_states + self.hash = hash(str(list(sliced_states))) + self.tracker = tracker, + self.event = event + self._conflicting_actions = {} # {"action": ["story_1", ...], ...} + + def events_prior_to_conflict(self): + raise NotImplementedError + + @staticmethod + def _get_prev_event(state) -> [Event, None]: + if not state: + return None + result = None + for k in state: + if k.startswith(PREV_PREFIX): + if k[len(PREV_PREFIX):] != ACTION_LISTEN_NAME: + result = ("action", k[len(PREV_PREFIX):]) + elif k.startswith(MESSAGE_INTENT_ATTRIBUTE + "_") and not result: + result = ("intent", k[len(MESSAGE_INTENT_ATTRIBUTE + '_'):]) + return result + + def add_conflicting_action(self, action: Text, story_name: Text): + if action not in self._conflicting_actions: + self._conflicting_actions[action] = [story_name] + else: + self._conflicting_actions[action] += [story_name] + + @property + def conflicting_actions(self): + return list(self._conflicting_actions.keys()) + + def __str__(self): + last_event_type, last_event_name = self._get_prev_event(self.sliced_states[-1]) + conflict_string = f"CONFLICT after {last_event_type} '{last_event_name}':\n" + # for state in self.sliced_states: + # if state: + # event_type, event_name = self._get_prev_event(state) + # if event_type == "intent": + # conflict_string += f"* {event_name}\n" + # else: + # conflict_string += f" - {event_name}\n" + for action, stories in self._conflicting_actions.items(): + if len(stories) == 1: + stories = f"'{stories[0]}'" + elif len(stories) == 2: + stories = f"'{stories[0]}' and '{stories[1]}'" + elif len(stories) == 3: + stories = f"'{stories[0]}', '{stories[1]}', and '{stories[2]}'" + elif len(stories) >= 4: + stories = f"'{stories[0]}' and {len(stories) - 1} other stories" + conflict_string += f" {action} predicted in {stories}\n" + + return conflict_string diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 0f17812fd456..fe3af280a4bf 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -12,6 +12,7 @@ from rasa.core.training.dsl import UserUttered from rasa.core.training.dsl import ActionExecuted from rasa.core.constants import UTTER_PREFIX +from rasa.core.story_conflict import StoryConflict logger = logging.getLogger(__name__) @@ -273,40 +274,19 @@ def verify_story_structure(self, ) h = hash(str(list(sliced_states))) if h in rules: - # Get the last event - last_event_string = self._last_event_string(sliced_states) - - # Fill `conflicts` dict - # {hash: {last_event: {action_1: [stories...], action_2: [stories...], ...}, ...}, ...} if h not in conflicts: - conflicts[h] = {last_event_string: {event.as_story_string(): [tracker.sender_id]}} - else: - if last_event_string not in conflicts[h]: - conflicts[h][last_event_string] = {event.as_story_string(): [tracker.sender_id]} - else: - if event.as_story_string() not in conflicts[h][last_event_string]: - conflicts[h][last_event_string][event.as_story_string()] = [tracker.sender_id] - else: - conflicts[h][last_event_string][event.as_story_string()].append(tracker.sender_id) + conflicts[h] = StoryConflict(sliced_states, tracker, event) + conflicts[h].add_conflicting_action( + action=event.as_story_string(), + story_name=tracker.sender_id + ) idx += 1 if len(conflicts) == 0: logger.info("No story structure conflicts found.") else: for conflict in list(conflicts.values()): - for state, actions_and_stories in conflict.items(): - conflict_string = f"CONFLICT after {state}:\n" - for action, stories in actions_and_stories.items(): - if len(stories) == 1: - stories = f"'{stories[0]}'" - elif len(stories) == 2: - stories = f"'{stories[0]}' and '{stories[1]}'" - elif len(stories) == 3: - stories = f"'{stories[0]}', '{stories[1]}', and '{stories[2]}'" - elif len(stories) >= 4: - stories = f"'{stories[0]}' and {len(stories) - 1} other stories" - conflict_string += f" {action} predicted in {stories}\n" - logger.warning(conflict_string) + logger.warning(conflict) # Fix the conflict if required if prompt: From 51570c7c6bb45694c8be9a5d6f6320a1b3e52c31 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 13 Nov 2019 16:58:11 +0100 Subject: [PATCH 038/209] Enable --prompt flag (dummy) --- rasa/core/story_conflict.py | 11 +++++++++++ rasa/core/validator.py | 8 +++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index 3619c8feac7e..27b7ef9619e2 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -48,6 +48,17 @@ def add_conflicting_action(self, action: Text, story_name: Text): def conflicting_actions(self): return list(self._conflicting_actions.keys()) + def story_prior_to_conflict(self): + result = "" + for state in self.sliced_states: + if state: + event_type, event_name = self._get_prev_event(state) + if event_type == "intent": + result += f"* {event_name}\n" + else: + result += f" - {event_name}\n" + return result + def __str__(self): last_event_type, last_event_name = self._get_prev_event(self.sliced_states[-1]) conflict_string = f"CONFLICT after {last_event_type} '{last_event_name}':\n" diff --git a/rasa/core/validator.py b/rasa/core/validator.py index fe3af280a4bf..43a626d87fe8 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -1,6 +1,7 @@ import logging from collections import defaultdict from typing import Set, Text +import questionary from rasa.core.domain import Domain, PREV_PREFIX from rasa.core.actions.action import ACTION_LISTEN_NAME from rasa.core.training.generator import TrainingDataGenerator @@ -290,7 +291,12 @@ def verify_story_structure(self, # Fix the conflict if required if prompt: - raise NotImplementedError + print(conflict.story_prior_to_conflict()) + correct_action = questionary.select( + message="How should your bot respond at this point?", + choices=conflict.conflicting_actions + ).ask() + print(correct_action) return len(conflicts) > 0 From 0224ab372b8936cf42a7d3b279bbb8eeebd98bfc Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 13 Nov 2019 16:58:31 +0100 Subject: [PATCH 039/209] Let deduplicate_story_names replace the files --- rasa/cli/data.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index e88b14ce18aa..29f3a546733d 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -191,6 +191,8 @@ def deduplicate_story_names(args): """Changes story names so as to make them unique. --EXPERIMENTAL-- """ + # ToDo: Make this work with multiple story files + from rasa.importers.rasa import RasaFileImporter loop = asyncio.get_event_loop() @@ -198,11 +200,16 @@ def deduplicate_story_names(args): domain_path=args.domain, training_data_paths=args.data ) + import shutil + story_file_names, _ = data.get_core_nlu_files(args.data) names = set() for file_name in story_file_names: - if file_name.endswith(".new"): + if file_name.endswith(".bak"): continue + + shutil.copy2(file_name, file_name + ".bak") + with open(file_name, "r") as in_file, \ open(file_name + ".new", "w+") as out_file: for line in in_file: @@ -221,6 +228,8 @@ def deduplicate_story_names(args): else: out_file.write(line + "\n") + shutil.move(file_name + ".new", file_name) + # story_files, _ = data.get_core_nlu_files(args.data) # story_steps = loop.run_until_complete(file_importer.get_story_steps()) # for step in story_steps: From 0c380d035f6b18a8fb78bf35a86bfe62d0269c8f Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 13 Nov 2019 17:01:16 +0100 Subject: [PATCH 040/209] Remove unused code --- rasa/core/story_conflict.py | 7 ------- rasa/core/validator.py | 16 +--------------- 2 files changed, 1 insertion(+), 22 deletions(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index 27b7ef9619e2..6213f28f0a00 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -62,13 +62,6 @@ def story_prior_to_conflict(self): def __str__(self): last_event_type, last_event_name = self._get_prev_event(self.sliced_states[-1]) conflict_string = f"CONFLICT after {last_event_type} '{last_event_name}':\n" - # for state in self.sliced_states: - # if state: - # event_type, event_name = self._get_prev_event(state) - # if event_type == "intent": - # conflict_string += f"* {event_name}\n" - # else: - # conflict_string += f" - {event_name}\n" for action, stories in self._conflicting_actions.items(): if len(stories) == 1: stories = f"'{stories[0]}'" diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 43a626d87fe8..e503eec5b0a4 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -2,11 +2,9 @@ from collections import defaultdict from typing import Set, Text import questionary -from rasa.core.domain import Domain, PREV_PREFIX -from rasa.core.actions.action import ACTION_LISTEN_NAME +from rasa.core.domain import Domain from rasa.core.training.generator import TrainingDataGenerator from rasa.importers.importer import TrainingDataImporter -from rasa.nlu.constants import MESSAGE_INTENT_ATTRIBUTE from rasa.nlu.training_data import TrainingData from rasa.core.training.structures import StoryGraph from rasa.core.featurizers import MaxHistoryTrackerFeaturizer @@ -213,18 +211,6 @@ def verify_story_names(self, ignore_warnings: bool = True): logger.error(message) return result - @staticmethod - def _last_event_string(sliced_states): - last_event_string = None - for k, v in sliced_states[-1].items(): - if k.startswith(PREV_PREFIX): - if k[len(PREV_PREFIX):] != ACTION_LISTEN_NAME: - last_event_string = f"action '{k[len(PREV_PREFIX):]}'" - elif k.startswith(MESSAGE_INTENT_ATTRIBUTE + "_") and not last_event_string: - last_event_string = f"intent '{k[len(MESSAGE_INTENT_ATTRIBUTE + '_'):]}'" - - return last_event_string - def verify_story_structure(self, ignore_warnings: bool = True, max_history: int = 5, From 26e292741261da8a73e35d7f228bcaeeae78e02f Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 13 Nov 2019 17:29:55 +0100 Subject: [PATCH 041/209] Improve prompt --- rasa/core/story_conflict.py | 5 +++++ rasa/core/validator.py | 11 ++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index 6213f28f0a00..fa383dd2bbf4 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -21,6 +21,7 @@ def __init__( self.tracker = tracker, self.event = event self._conflicting_actions = {} # {"action": ["story_1", ...], ...} + self.correct_response = None def events_prior_to_conflict(self): raise NotImplementedError @@ -48,6 +49,10 @@ def add_conflicting_action(self, action: Text, story_name: Text): def conflicting_actions(self): return list(self._conflicting_actions.keys()) + @property + def conflicting_actions_with_counts(self): + return [f"{a} [{len(s)}x]" for (a, s) in self._conflicting_actions.items()] + def story_prior_to_conflict(self): result = "" for state in self.sliced_states: diff --git a/rasa/core/validator.py b/rasa/core/validator.py index e503eec5b0a4..cef38364027b 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -277,12 +277,17 @@ def verify_story_structure(self, # Fix the conflict if required if prompt: + print("A conflict occurs after the following sequence of events:") print(conflict.story_prior_to_conflict()) - correct_action = questionary.select( + keep = "KEEP AS IS" + correct_response = questionary.select( message="How should your bot respond at this point?", - choices=conflict.conflicting_actions + choices=[keep] + conflict.conflicting_actions_with_counts ).ask() - print(correct_action) + if correct_response != keep: + # Remove the story count ending, e.g. " [42x]" + conflict.correct_response = correct_response.rsplit(" ", 1)[0] + print(conflict.correct_response) return len(conflicts) > 0 From 3e69cda19d45578547aab1933bfd8eb7d3172a2b Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 13 Nov 2019 17:53:09 +0100 Subject: [PATCH 042/209] Add stories_to_correct --- rasa/core/story_conflict.py | 11 +++++++++++ rasa/core/validator.py | 5 ++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index fa383dd2bbf4..0a825d083c11 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -53,6 +53,17 @@ def conflicting_actions(self): def conflicting_actions_with_counts(self): return [f"{a} [{len(s)}x]" for (a, s) in self._conflicting_actions.items()] + @property + def incorrect_stories(self): + if self.correct_response: + incorrect_stories = [] + for stories in [s for (a, s) in self._conflicting_actions.items() if a != self.correct_response]: + for story in stories: + incorrect_stories.append(story) + return incorrect_stories + else: + return [] + def story_prior_to_conflict(self): result = "" for state in self.sliced_states: diff --git a/rasa/core/validator.py b/rasa/core/validator.py index cef38364027b..54b22c3fcbdf 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -287,7 +287,10 @@ def verify_story_structure(self, if correct_response != keep: # Remove the story count ending, e.g. " [42x]" conflict.correct_response = correct_response.rsplit(" ", 1)[0] - print(conflict.correct_response) + + for conflict in list(conflicts.values()): + if conflict.correct_response: + print(f"Fixing {conflict.incorrect_stories} with {conflict.correct_response}...") return len(conflicts) > 0 From e80892837e4f7d52600c3ff2bdc543d07f216327 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Thu, 14 Nov 2019 16:49:23 +0100 Subject: [PATCH 043/209] Fix return value --- rasa/core/validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 54b22c3fcbdf..2a8faaccd8b2 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -292,7 +292,7 @@ def verify_story_structure(self, if conflict.correct_response: print(f"Fixing {conflict.incorrect_stories} with {conflict.correct_response}...") - return len(conflicts) > 0 + return len(conflicts) == 0 def verify_all(self, ignore_warnings: bool = True) -> bool: """Runs all the validations on intents and utterances.""" From ecf76ff496e4e09bf11aa6f90e282dfd790dacac Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 15 Nov 2019 10:41:52 +0100 Subject: [PATCH 044/209] Make --max-history a necessary parameter --- rasa/cli/data.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 29f3a546733d..ea74be440cbf 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -1,3 +1,4 @@ +import logging import argparse import asyncio import sys @@ -8,6 +9,7 @@ from rasa.cli.utils import get_validated_path from rasa.constants import DEFAULT_DATA_PATH +logger = logging.getLogger(__name__) # noinspection PyProtectedMember def add_subparser( @@ -87,7 +89,7 @@ def _add_data_validate_parsers( parents=parents, help="Validates domain and data files to check for possible mistakes.", ) - validate_parser.add_argument("--max-history", type=int, default=5, + validate_parser.add_argument("--max-history", type=int, default=None, help="Assume this max_history setting for story structure validation.") validate_parser.set_defaults(func=validate_files) arguments.set_validator_arguments(validate_parser) @@ -99,7 +101,7 @@ def _add_data_validate_parsers( parents=parents, help="Checks for inconsistencies in the story files.", ) - story_structure_parser.add_argument("--max-history", type=int, default=5, + story_structure_parser.add_argument("--max-history", type=int, help="Assume this max_history setting for validation.") story_structure_parser.add_argument("--prompt", action="store_true", default=False, help="Ask how conflicts should be fixed") @@ -155,7 +157,10 @@ def validate_files(args): sys.exit(1) everything_is_alright = validator.verify_all(not args.fail_on_warnings) - if everything_is_alright: + if not args.max_history: + logger.info("Will not test for inconsistencies in stories since " + "you did not provide --max-history.") + if everything_is_alright and args.max_history: # Only run story structure validation if everything else is fine # since this might take a while everything_is_alright = validator.verify_story_structure(not args.fail_on_warnings, @@ -170,20 +175,27 @@ def validate_stories(args): from rasa.core.validator import Validator from rasa.importers.rasa import RasaFileImporter + if not isinstance(args.max_history, int) or args.max_history < 1: + logger.error("You have to provide a positive integer for --max-history.") + sys.exit(1) + loop = asyncio.get_event_loop() file_importer = RasaFileImporter( domain_path=args.domain, training_data_paths=args.data ) validator = loop.run_until_complete(Validator.from_importer(file_importer)) - everything_is_alright = ( - validator.verify_story_names(not args.fail_on_warnings) and - validator.verify_story_structure( - not args.fail_on_warnings, - max_history=args.max_history, - prompt=args.prompt - ) - ) + + # First check for duplicate story names + story_names_unique = validator.verify_story_names(not args.fail_on_warnings) + + # If names are unique, look for inconsistencies + everything_is_alright = validator.verify_story_structure( + not args.fail_on_warnings, + max_history=args.max_history, + prompt=args.prompt + ) if story_names_unique else False + sys.exit(0) if everything_is_alright else sys.exit(1) From a1d1f5a23d37736900f78fcb56ac2739a5c3006c Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 15 Nov 2019 10:45:53 +0100 Subject: [PATCH 045/209] Refer to trackers as trackers, not stories --- rasa/core/story_conflict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index 0a825d083c11..36904dd79274 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -86,7 +86,7 @@ def __str__(self): elif len(stories) == 3: stories = f"'{stories[0]}', '{stories[1]}', and '{stories[2]}'" elif len(stories) >= 4: - stories = f"'{stories[0]}' and {len(stories) - 1} other stories" + stories = f"'{stories[0]}' and {len(stories) - 1} other trackers" conflict_string += f" {action} predicted in {stories}\n" return conflict_string From 8e329d26daf796245f778df13afd55527fb17be9 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Thu, 21 Nov 2019 13:18:13 +0100 Subject: [PATCH 046/209] Fix check for story name duplicates --- rasa/cli/data.py | 11 ++++++++++- rasa/core/story_conflict.py | 9 ++++++--- rasa/core/training/dsl.py | 7 +++++++ rasa/core/training/structures.py | 3 +++ rasa/core/validator.py | 4 +++- 5 files changed, 29 insertions(+), 5 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index ea74be440cbf..1b432d1e6af1 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -11,6 +11,7 @@ logger = logging.getLogger(__name__) + # noinspection PyProtectedMember def add_subparser( subparsers: argparse._SubParsersAction, parents: List[argparse.ArgumentParser] @@ -184,10 +185,18 @@ def validate_stories(args): domain_path=args.domain, training_data_paths=args.data ) + # This loads the stories and thus fills `STORY_NAME_TALLY` (see next code block) validator = loop.run_until_complete(Validator.from_importer(file_importer)) # First check for duplicate story names - story_names_unique = validator.verify_story_names(not args.fail_on_warnings) + from rasa.core.training.structures import STORY_NAME_TALLY # ToDo: Avoid global variable + duplicate_story_names = {name: count for (name, count) in STORY_NAME_TALLY.items() if count > 1} + story_names_unique = len(duplicate_story_names) == 0 + if not story_names_unique: + msg = "Found duplicate story names:\n" + for (name, count) in duplicate_story_names.items(): + msg += f" Story name '{name}' appears {count}x\n" + logger.error(msg) # If names are unique, look for inconsistencies everything_is_alright = validator.verify_story_structure( diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index 36904dd79274..008fd7e9d99a 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -29,8 +29,8 @@ def events_prior_to_conflict(self): @staticmethod def _get_prev_event(state) -> [Event, None]: if not state: - return None - result = None + return None, None + result = (None, None) for k in state: if k.startswith(PREV_PREFIX): if k[len(PREV_PREFIX):] != ACTION_LISTEN_NAME: @@ -77,7 +77,10 @@ def story_prior_to_conflict(self): def __str__(self): last_event_type, last_event_name = self._get_prev_event(self.sliced_states[-1]) - conflict_string = f"CONFLICT after {last_event_type} '{last_event_name}':\n" + if last_event_type: + conflict_string = f"CONFLICT after {last_event_type} '{last_event_name}':\n" + else: + conflict_string = f"CONFLICT at the beginning of stories:\n" for action, stories in self._conflicting_actions.items(): if len(stories) == 1: stories = f"'{stories[0]}'" diff --git a/rasa/core/training/dsl.py b/rasa/core/training/dsl.py index bd02b61683c6..8c2d37e94c58 100644 --- a/rasa/core/training/dsl.py +++ b/rasa/core/training/dsl.py @@ -20,6 +20,7 @@ GENERATED_CHECKPOINT_PREFIX, GENERATED_HASH_LENGTH, FORM_PREFIX, + STORY_NAME_TALLY, ) from rasa.nlu.training_data.formats import MarkdownReader from rasa.core.domain import Domain @@ -374,6 +375,12 @@ def new_story_part(self, name): self._add_current_stories_to_result() self.current_step_builder = StoryStepBuilder(name) + # Tally names of stories, so we can identify duplicate names + if name not in STORY_NAME_TALLY: + STORY_NAME_TALLY[name] = 1 + else: + STORY_NAME_TALLY[name] += 1 + def add_checkpoint(self, name: Text, conditions: Optional[Dict[Text, Any]]) -> None: # Ensure story part already has a name diff --git a/rasa/core/training/structures.py b/rasa/core/training/structures.py index 685c3899dabe..6eaf838a961c 100644 --- a/rasa/core/training/structures.py +++ b/rasa/core/training/structures.py @@ -38,6 +38,9 @@ # will get increased with each new instance STEP_COUNT = 1 +# Tally over story names +STORY_NAME_TALLY = {} + class StoryStringHelper: """A helper class to mark story steps that are inside a form with `form: ` diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 2a8faaccd8b2..02f7d9503fed 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -188,7 +188,9 @@ def verify_story_names(self, ignore_warnings: bool = True): # Tally story names, e.g. {"story_1": 3, "story_2": 1, ...} name_tally = {} + print(self.story_graph.as_story_string()) for step in self.story_graph.story_steps: + # print(step.block_name) if step.block_name in name_tally: name_tally[step.block_name] += 1 else: @@ -223,7 +225,7 @@ def verify_story_structure(self, trackers = TrainingDataGenerator( self.story_graph, domain=self.domain, - remove_duplicates=False, # ToDo: Q&A: Why don't we deduplicate the graph here? + remove_duplicates=False, # ToDo: Q&A: Why not remove_duplicates=True? augmentation_factor=0).generate() rules = {} for tracker in trackers: From b51bdb7178ba63ea75bd620d9c4d75c94819858b Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Thu, 21 Nov 2019 14:30:17 +0100 Subject: [PATCH 047/209] Remove conflicts that arise from unpredictable actions --- rasa/cli/data.py | 2 +- rasa/core/story_conflict.py | 4 ++++ rasa/core/validator.py | 3 +++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 1b432d1e6af1..31887ba9ff85 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -195,7 +195,7 @@ def validate_stories(args): if not story_names_unique: msg = "Found duplicate story names:\n" for (name, count) in duplicate_story_names.items(): - msg += f" Story name '{name}' appears {count}x\n" + msg += f" '{name}' appears {count}x\n" logger.error(msg) # If names are unique, look for inconsistencies diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index 008fd7e9d99a..cd218e538721 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -64,6 +64,10 @@ def incorrect_stories(self): else: return [] + @property + def has_prior_events(self): + return self._get_prev_event(self.sliced_states[-1])[0] is not None + def story_prior_to_conflict(self): result = "" for state in self.sliced_states: diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 02f7d9503fed..52f99c5f45fe 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -271,6 +271,9 @@ def verify_story_structure(self, ) idx += 1 + # Remove conflicts that arise from unpredictable actions + conflicts = {h: c for (h, c) in conflicts.items() if c.has_prior_events} + if len(conflicts) == 0: logger.info("No story structure conflicts found.") else: From c5e0b66984b9fc9a935335e3602c60cfdb011403 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Thu, 21 Nov 2019 14:32:29 +0100 Subject: [PATCH 048/209] Remove debug print statements --- rasa/core/validator.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 52f99c5f45fe..6b31c1280279 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -188,9 +188,7 @@ def verify_story_names(self, ignore_warnings: bool = True): # Tally story names, e.g. {"story_1": 3, "story_2": 1, ...} name_tally = {} - print(self.story_graph.as_story_string()) for step in self.story_graph.story_steps: - # print(step.block_name) if step.block_name in name_tally: name_tally[step.block_name] += 1 else: From bec21a8332d33c43d7002172558df9f59b19dca2 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Thu, 21 Nov 2019 14:44:59 +0100 Subject: [PATCH 049/209] Fix missing [0] --- rasa/core/story_conflict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index cd218e538721..9dc9cfd4fb95 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -35,7 +35,7 @@ def _get_prev_event(state) -> [Event, None]: if k.startswith(PREV_PREFIX): if k[len(PREV_PREFIX):] != ACTION_LISTEN_NAME: result = ("action", k[len(PREV_PREFIX):]) - elif k.startswith(MESSAGE_INTENT_ATTRIBUTE + "_") and not result: + elif k.startswith(MESSAGE_INTENT_ATTRIBUTE + "_") and not result[0]: result = ("intent", k[len(MESSAGE_INTENT_ATTRIBUTE + '_'):]) return result From fa4bcc1a55138bb60919a4324f83dec06ff3d5bc Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 11:01:29 +0100 Subject: [PATCH 050/209] Move finding of conflicts to StoryConflict class --- rasa/core/story_conflict.py | 49 +++++++++++++++++++++++++- rasa/core/validator.py | 69 ++++--------------------------------- 2 files changed, 54 insertions(+), 64 deletions(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index 9dc9cfd4fb95..ca2ac45f7393 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -3,7 +3,8 @@ from rasa.core.actions.action import ACTION_LISTEN_NAME from rasa.core.domain import PREV_PREFIX -from rasa.core.events import Event +from rasa.core.events import Event, ActionExecuted +from rasa.core.featurizers import MaxHistoryTrackerFeaturizer from rasa.nlu.constants import MESSAGE_INTENT_ATTRIBUTE from rasa.core.training.generator import TrackerWithCachedStates @@ -23,6 +24,52 @@ def __init__( self._conflicting_actions = {} # {"action": ["story_1", ...], ...} self.correct_response = None + @staticmethod + def find_conflicts(trackers, domain, max_history: int): + + # Create a 'state -> list of actions' dict, where the state is represented by its hash + rules = {} + for tracker, event, sliced_states in StoryConflict._sliced_states_stream(trackers, domain, max_history): + h = hash(str(list(sliced_states))) + if h in rules: + if event.as_story_string() not in rules[h]: + rules[h] += [event.as_story_string()] + else: + rules[h] = [event.as_story_string()] + + # Keep only conflicting rules + rules = {state: actions for (state, actions) in rules.items() if len(actions) > 1} + + # Iterate once more over all states and note the (unhashed) state, tracker, and event for which a conflict occurs + conflicts = {} + for tracker, event, sliced_states in StoryConflict._sliced_states_stream(trackers, domain, max_history): + h = hash(str(list(sliced_states))) + if h in rules: + if h not in conflicts: + conflicts[h] = StoryConflict(sliced_states, tracker, event) + conflicts[h].add_conflicting_action( + action=event.as_story_string(), + story_name=tracker.sender_id + ) + + # Remove conflicts that arise from unpredictable actions + return [c for (h, c) in conflicts.items() if c.has_prior_events] + + @staticmethod + def _sliced_states_stream(trackers, domain, max_history): + for tracker in trackers: + states = tracker.past_states(domain) + states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 + + idx = 0 + for event in tracker.events: + if isinstance(event, ActionExecuted): + sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( + states[: idx + 1], max_history + ) + yield tracker, event, sliced_states + idx += 1 + def events_prior_to_conflict(self): raise NotImplementedError diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 7ec34a99aba4..45e0ea6845ec 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -226,75 +226,18 @@ def verify_story_structure(self, domain=self.domain, remove_duplicates=False, # ToDo: Q&A: Why not remove_duplicates=True? augmentation_factor=0).generate() - rules = {} - for tracker in trackers: - states = tracker.past_states(self.domain) - states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 - - idx = 0 - for event in tracker.events: - if isinstance(event, ActionExecuted): - sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( - states[: idx + 1], max_history - ) - h = hash(str(list(sliced_states))) - if h in rules: - if event.as_story_string() not in rules[h]: - rules[h] += [event.as_story_string()] - else: - rules[h] = [event.as_story_string()] - idx += 1 - - # Keep only conflicting rules - rules = {state: actions for (state, actions) in rules.items() if len(actions) > 1} - - conflicts = {} - - for tracker in trackers: - states = tracker.past_states(self.domain) - states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 - - idx = 0 - for event in tracker.events: - if isinstance(event, ActionExecuted): - sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( - states[: idx + 1], max_history - ) - h = hash(str(list(sliced_states))) - if h in rules: - if h not in conflicts: - conflicts[h] = StoryConflict(sliced_states, tracker, event) - conflicts[h].add_conflicting_action( - action=event.as_story_string(), - story_name=tracker.sender_id - ) - idx += 1 - - # Remove conflicts that arise from unpredictable actions - conflicts = {h: c for (h, c) in conflicts.items() if c.has_prior_events} + + # Create a list of `StoryConflict` objects + conflicts = StoryConflict.find_conflicts(trackers, self.domain, max_history) if len(conflicts) == 0: logger.info("No story structure conflicts found.") else: - for conflict in list(conflicts.values()): + for conflict in conflicts: logger.warning(conflict) - # Fix the conflict if required - if prompt: - print("A conflict occurs after the following sequence of events:") - print(conflict.story_prior_to_conflict()) - keep = "KEEP AS IS" - correct_response = questionary.select( - message="How should your bot respond at this point?", - choices=[keep] + conflict.conflicting_actions_with_counts - ).ask() - if correct_response != keep: - # Remove the story count ending, e.g. " [42x]" - conflict.correct_response = correct_response.rsplit(" ", 1)[0] - - for conflict in list(conflicts.values()): - if conflict.correct_response: - print(f"Fixing {conflict.incorrect_stories} with {conflict.correct_response}...") + # For code stub to fix the conflict in the command line, + # see commit 3fdc08a030dbd85c15b4f5d7e8b5ad6a254eefb4 return len(conflicts) == 0 From ce86ede10792c52e61a58df6590019a2c3934511 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 11:04:21 +0100 Subject: [PATCH 051/209] Respect ignore_warnings --- rasa/cli/data.py | 3 +-- rasa/core/validator.py | 5 ++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 31887ba9ff85..f2b8f41a4557 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -201,8 +201,7 @@ def validate_stories(args): # If names are unique, look for inconsistencies everything_is_alright = validator.verify_story_structure( not args.fail_on_warnings, - max_history=args.max_history, - prompt=args.prompt + max_history=args.max_history ) if story_names_unique else False sys.exit(0) if everything_is_alright else sys.exit(1) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 45e0ea6845ec..bad785b0e839 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -214,8 +214,7 @@ def verify_story_names(self, ignore_warnings: bool = True): def verify_story_structure(self, ignore_warnings: bool = True, - max_history: int = 5, - prompt: bool = False) -> bool: + max_history: int = 5) -> bool: """Verifies that bot behaviour in stories is deterministic.""" logger.info("Story structure validation...") @@ -239,7 +238,7 @@ def verify_story_structure(self, # For code stub to fix the conflict in the command line, # see commit 3fdc08a030dbd85c15b4f5d7e8b5ad6a254eefb4 - return len(conflicts) == 0 + return ignore_warnings or len(conflicts) == 0 def verify_all(self, ignore_warnings: bool = True) -> bool: """Runs all the validations on intents and utterances.""" From fe46fb9c0187677d3d6d84fc4ef3f73c6824e573 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 11:21:15 +0100 Subject: [PATCH 052/209] Apply BLACK formatting --- rasa/core/story_conflict.py | 49 +++++++++++++++++++++++-------------- rasa/core/validator.py | 22 +++++++++-------- 2 files changed, 42 insertions(+), 29 deletions(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index ca2ac45f7393..bf2a7928fc02 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -1,4 +1,3 @@ - from typing import List, Optional, Dict, Text from rasa.core.actions.action import ACTION_LISTEN_NAME @@ -10,16 +9,15 @@ class StoryConflict: - def __init__( - self, - sliced_states: List[Optional[Dict[Text, float]]], - tracker: TrackerWithCachedStates, - event + self, + sliced_states: List[Optional[Dict[Text, float]]], + tracker: TrackerWithCachedStates, + event, ): self.sliced_states = sliced_states self.hash = hash(str(list(sliced_states))) - self.tracker = tracker, + self.tracker = (tracker,) self.event = event self._conflicting_actions = {} # {"action": ["story_1", ...], ...} self.correct_response = None @@ -27,9 +25,12 @@ def __init__( @staticmethod def find_conflicts(trackers, domain, max_history: int): - # Create a 'state -> list of actions' dict, where the state is represented by its hash + # Create a 'state -> list of actions' dict, where the state is + # represented by its hash rules = {} - for tracker, event, sliced_states in StoryConflict._sliced_states_stream(trackers, domain, max_history): + for tracker, event, sliced_states in StoryConflict._sliced_states_stream( + trackers, domain, max_history + ): h = hash(str(list(sliced_states))) if h in rules: if event.as_story_string() not in rules[h]: @@ -38,18 +39,22 @@ def find_conflicts(trackers, domain, max_history: int): rules[h] = [event.as_story_string()] # Keep only conflicting rules - rules = {state: actions for (state, actions) in rules.items() if len(actions) > 1} + rules = { + state: actions for (state, actions) in rules.items() if len(actions) > 1 + } - # Iterate once more over all states and note the (unhashed) state, tracker, and event for which a conflict occurs + # Iterate once more over all states and note the (unhashed) state, + # tracker, and event for which a conflict occurs conflicts = {} - for tracker, event, sliced_states in StoryConflict._sliced_states_stream(trackers, domain, max_history): + for tracker, event, sliced_states in StoryConflict._sliced_states_stream( + trackers, domain, max_history + ): h = hash(str(list(sliced_states))) if h in rules: if h not in conflicts: conflicts[h] = StoryConflict(sliced_states, tracker, event) conflicts[h].add_conflicting_action( - action=event.as_story_string(), - story_name=tracker.sender_id + action=event.as_story_string(), story_name=tracker.sender_id ) # Remove conflicts that arise from unpredictable actions @@ -59,7 +64,9 @@ def find_conflicts(trackers, domain, max_history: int): def _sliced_states_stream(trackers, domain, max_history): for tracker in trackers: states = tracker.past_states(domain) - states = [dict(state) for state in states] # ToDo: Check against rasa/core/featurizers.py:318 + states = [ + dict(state) for state in states + ] # ToDo: Check against rasa/core/featurizers.py:318 idx = 0 for event in tracker.events: @@ -80,10 +87,10 @@ def _get_prev_event(state) -> [Event, None]: result = (None, None) for k in state: if k.startswith(PREV_PREFIX): - if k[len(PREV_PREFIX):] != ACTION_LISTEN_NAME: - result = ("action", k[len(PREV_PREFIX):]) + if k[len(PREV_PREFIX) :] != ACTION_LISTEN_NAME: + result = ("action", k[len(PREV_PREFIX) :]) elif k.startswith(MESSAGE_INTENT_ATTRIBUTE + "_") and not result[0]: - result = ("intent", k[len(MESSAGE_INTENT_ATTRIBUTE + '_'):]) + result = ("intent", k[len(MESSAGE_INTENT_ATTRIBUTE + "_") :]) return result def add_conflicting_action(self, action: Text, story_name: Text): @@ -104,7 +111,11 @@ def conflicting_actions_with_counts(self): def incorrect_stories(self): if self.correct_response: incorrect_stories = [] - for stories in [s for (a, s) in self._conflicting_actions.items() if a != self.correct_response]: + for stories in [ + s + for (a, s) in self._conflicting_actions.items() + if a != self.correct_response + ]: for story in stories: incorrect_stories.append(story) return incorrect_stories diff --git a/rasa/core/validator.py b/rasa/core/validator.py index bad785b0e839..69080ce64d17 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -1,15 +1,12 @@ import logging import warnings -import asyncio from collections import defaultdict from typing import Set, Text -import questionary from rasa.core.domain import Domain from rasa.core.training.generator import TrainingDataGenerator from rasa.importers.importer import TrainingDataImporter from rasa.nlu.training_data import TrainingData from rasa.core.training.structures import StoryGraph -from rasa.core.featurizers import MaxHistoryTrackerFeaturizer from rasa.core.training.dsl import UserUttered from rasa.core.training.dsl import ActionExecuted from rasa.core.constants import UTTER_PREFIX @@ -212,9 +209,9 @@ def verify_story_names(self, ignore_warnings: bool = True): logger.error(message) return result - def verify_story_structure(self, - ignore_warnings: bool = True, - max_history: int = 5) -> bool: + def verify_story_structure( + self, ignore_warnings: bool = True, max_history: int = 5 + ) -> bool: """Verifies that bot behaviour in stories is deterministic.""" logger.info("Story structure validation...") @@ -223,8 +220,9 @@ def verify_story_structure(self, trackers = TrainingDataGenerator( self.story_graph, domain=self.domain, - remove_duplicates=False, # ToDo: Q&A: Why not remove_duplicates=True? - augmentation_factor=0).generate() + remove_duplicates=False, # ToDo: Q&A: Why not remove_duplicates=True? + augmentation_factor=0, + ).generate() # Create a list of `StoryConflict` objects conflicts = StoryConflict.find_conflicts(trackers, self.domain, max_history) @@ -254,8 +252,12 @@ def verify_all(self, ignore_warnings: bool = True) -> bool: logger.info("Validating utterances...") stories_are_valid = self.verify_utterances_in_stories(ignore_warnings) - return (intents_are_valid and stories_are_valid and - there_is_no_duplication and all_story_names_unique) + return ( + intents_are_valid + and stories_are_valid + and there_is_no_duplication + and all_story_names_unique + ) def verify_domain_validity(self) -> bool: """Checks whether the domain returned by the importer is empty, indicating an invalid domain.""" From 93acaa0c525ea969d133ce850ec68db83a150529 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 12:35:39 +0100 Subject: [PATCH 053/209] Add doc strings --- rasa/core/story_conflict.py | 67 ++++++++++++++++++++++++++++++++----- rasa/core/validator.py | 2 +- 2 files changed, 60 insertions(+), 9 deletions(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index bf2a7928fc02..199dc2ea6bd1 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -2,7 +2,7 @@ from rasa.core.actions.action import ACTION_LISTEN_NAME from rasa.core.domain import PREV_PREFIX -from rasa.core.events import Event, ActionExecuted +from rasa.core.events import ActionExecuted from rasa.core.featurizers import MaxHistoryTrackerFeaturizer from rasa.nlu.constants import MESSAGE_INTENT_ATTRIBUTE from rasa.core.training.generator import TrackerWithCachedStates @@ -24,11 +24,20 @@ def __init__( @staticmethod def find_conflicts(trackers, domain, max_history: int): + """ + Generate a list of StoryConflict objects, describing + conflicts in the given trackers. + :param trackers: Trackers in which to search for conflicts + :param domain: The domain + :param max_history: The maximum history length to be + taken into account + :return: List of conflicts + """ # Create a 'state -> list of actions' dict, where the state is # represented by its hash rules = {} - for tracker, event, sliced_states in StoryConflict._sliced_states_stream( + for tracker, event, sliced_states in StoryConflict._sliced_states_iterator( trackers, domain, max_history ): h = hash(str(list(sliced_states))) @@ -46,7 +55,7 @@ def find_conflicts(trackers, domain, max_history: int): # Iterate once more over all states and note the (unhashed) state, # tracker, and event for which a conflict occurs conflicts = {} - for tracker, event, sliced_states in StoryConflict._sliced_states_stream( + for tracker, event, sliced_states in StoryConflict._sliced_states_iterator( trackers, domain, max_history ): h = hash(str(list(sliced_states))) @@ -61,7 +70,15 @@ def find_conflicts(trackers, domain, max_history: int): return [c for (h, c) in conflicts.items() if c.has_prior_events] @staticmethod - def _sliced_states_stream(trackers, domain, max_history): + def _sliced_states_iterator(trackers, domain, max_history): + """ + Iterate over all given trackers and all sliced states within + each tracker, where the slicing is based on `max_history` + :param trackers: List of trackers + :param domain: Domain (used for tracker.past_states) + :param max_history: Assumed `max_history` value for slicing + :return: Yields (tracker, event, sliced_states) triplet + """ for tracker in trackers: states = tracker.past_states(domain) states = [ @@ -77,11 +94,16 @@ def _sliced_states_stream(trackers, domain, max_history): yield tracker, event, sliced_states idx += 1 - def events_prior_to_conflict(self): - raise NotImplementedError - @staticmethod - def _get_prev_event(state) -> [Event, None]: + def _get_prev_event( + state: Optional[Dict[Text, float]] + ) -> [Optional[Text], Optional[Text]]: + """ + Returns the type and name of the event (action or intent) previous to the + given state + :param state: Element of sliced states + :return: (type, name) strings of the prior event + """ if not state: return None, None result = (None, None) @@ -94,6 +116,12 @@ def _get_prev_event(state) -> [Event, None]: return result def add_conflicting_action(self, action: Text, story_name: Text): + """ + Add another action that follows from the same state + :param action: Name of the action + :param story_name: Name of the story where this action + is chosen + """ if action not in self._conflicting_actions: self._conflicting_actions[action] = [story_name] else: @@ -101,14 +129,25 @@ def add_conflicting_action(self, action: Text, story_name: Text): @property def conflicting_actions(self): + """ + Returns the list of conflicting actions + """ return list(self._conflicting_actions.keys()) @property def conflicting_actions_with_counts(self): + """ + Returns a list of strings, describing what action + occurs how often + """ return [f"{a} [{len(s)}x]" for (a, s) in self._conflicting_actions.items()] @property def incorrect_stories(self): + """ + Returns a list of stories that have not yet been + corrected. + """ if self.correct_response: incorrect_stories = [] for stories in [ @@ -124,9 +163,17 @@ def incorrect_stories(self): @property def has_prior_events(self): + """ + Returns True iff anything has happened before this + conflict. + """ return self._get_prev_event(self.sliced_states[-1])[0] is not None def story_prior_to_conflict(self): + """ + Generates a story string, describing the events that + lead up to the conflict. + """ result = "" for state in self.sliced_states: if state: @@ -138,12 +185,16 @@ def story_prior_to_conflict(self): return result def __str__(self): + # Describe where the conflict occurs in the stories last_event_type, last_event_name = self._get_prev_event(self.sliced_states[-1]) if last_event_type: conflict_string = f"CONFLICT after {last_event_type} '{last_event_name}':\n" else: conflict_string = f"CONFLICT at the beginning of stories:\n" + + # List which stories are in conflict with one another for action, stories in self._conflicting_actions.items(): + # Summarize if necessary if len(stories) == 1: stories = f"'{stories[0]}'" elif len(stories) == 2: diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 69080ce64d17..aec98d95bbd7 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -228,7 +228,7 @@ def verify_story_structure( conflicts = StoryConflict.find_conflicts(trackers, self.domain, max_history) if len(conflicts) == 0: - logger.info("No story structure conflicts found.") + logger.info("No story structure conflicts found") else: for conflict in conflicts: logger.warning(conflict) From 44649d41f9a766af4d79068744fc6bd628a6cbfc Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 12:48:13 +0100 Subject: [PATCH 054/209] Declare types --- rasa/core/story_conflict.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index 199dc2ea6bd1..ec32f3956bda 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -1,8 +1,8 @@ from typing import List, Optional, Dict, Text from rasa.core.actions.action import ACTION_LISTEN_NAME -from rasa.core.domain import PREV_PREFIX -from rasa.core.events import ActionExecuted +from rasa.core.domain import PREV_PREFIX, Domain +from rasa.core.events import ActionExecuted, Event from rasa.core.featurizers import MaxHistoryTrackerFeaturizer from rasa.nlu.constants import MESSAGE_INTENT_ATTRIBUTE from rasa.core.training.generator import TrackerWithCachedStates @@ -23,7 +23,9 @@ def __init__( self.correct_response = None @staticmethod - def find_conflicts(trackers, domain, max_history: int): + def find_conflicts( + trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int + ) -> List: """ Generate a list of StoryConflict objects, describing conflicts in the given trackers. @@ -70,7 +72,9 @@ def find_conflicts(trackers, domain, max_history: int): return [c for (h, c) in conflicts.items() if c.has_prior_events] @staticmethod - def _sliced_states_iterator(trackers, domain, max_history): + def _sliced_states_iterator( + trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int + ) -> (TrackerWithCachedStates, Event, List[Optional[Dict[Text, float]]]): """ Iterate over all given trackers and all sliced states within each tracker, where the slicing is based on `max_history` @@ -128,14 +132,14 @@ def add_conflicting_action(self, action: Text, story_name: Text): self._conflicting_actions[action] += [story_name] @property - def conflicting_actions(self): + def conflicting_actions(self) -> List[Text]: """ Returns the list of conflicting actions """ return list(self._conflicting_actions.keys()) @property - def conflicting_actions_with_counts(self): + def conflicting_actions_with_counts(self) -> List[Text]: """ Returns a list of strings, describing what action occurs how often @@ -143,9 +147,9 @@ def conflicting_actions_with_counts(self): return [f"{a} [{len(s)}x]" for (a, s) in self._conflicting_actions.items()] @property - def incorrect_stories(self): + def incorrect_stories(self) -> List[Text]: """ - Returns a list of stories that have not yet been + Returns a list of story names that have not yet been corrected. """ if self.correct_response: @@ -162,14 +166,14 @@ def incorrect_stories(self): return [] @property - def has_prior_events(self): + def has_prior_events(self) -> bool: """ Returns True iff anything has happened before this conflict. """ return self._get_prev_event(self.sliced_states[-1])[0] is not None - def story_prior_to_conflict(self): + def story_prior_to_conflict(self) -> Text: """ Generates a story string, describing the events that lead up to the conflict. From 5aa3b5394087deb1428e27a2f58b4bcb7cac4039 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 12:53:56 +0100 Subject: [PATCH 055/209] Remove verify_story_names, as it does not work --- rasa/core/validator.py | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index aec98d95bbd7..2a1625a32060 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -181,34 +181,6 @@ def verify_utterances_in_stories(self, ignore_warnings: bool = True) -> bool: return everything_is_alright - def verify_story_names(self, ignore_warnings: bool = True): - """Verify that story names are unique.""" - - # Tally story names, e.g. {"story_1": 3, "story_2": 1, ...} - name_tally = {} - for step in self.story_graph.story_steps: - if step.block_name in name_tally: - name_tally[step.block_name] += 1 - else: - name_tally[step.block_name] = 1 - - # Find story names that appear more than once - # and construct a warning message - result = True - message = "" - for name, count in name_tally.items(): - if count > 1: - if result: - message = f"Found duplicate story names:\n" - result = False - message += f" '{name}' appears {count}x\n" - - if result: - logger.info("All story names are unique") - else: - logger.error(message) - return result - def verify_story_structure( self, ignore_warnings: bool = True, max_history: int = 5 ) -> bool: @@ -248,7 +220,6 @@ def verify_all(self, ignore_warnings: bool = True) -> bool: there_is_no_duplication = self.verify_example_repetition_in_intents( ignore_warnings ) - all_story_names_unique = self.verify_story_names(ignore_warnings) logger.info("Validating utterances...") stories_are_valid = self.verify_utterances_in_stories(ignore_warnings) @@ -256,7 +227,6 @@ def verify_all(self, ignore_warnings: bool = True) -> bool: intents_are_valid and stories_are_valid and there_is_no_duplication - and all_story_names_unique ) def verify_domain_validity(self) -> bool: From 48814db7599aa8ebc6e50e3703ca45b9206c3e9d Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 14:28:25 +0100 Subject: [PATCH 056/209] Drop writing bak files when deduplicating --- rasa/cli/data.py | 98 +++++++++++++++----------------- rasa/core/training/structures.py | 2 +- 2 files changed, 47 insertions(+), 53 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index f2b8f41a4557..2250eff4ba28 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -14,7 +14,7 @@ # noinspection PyProtectedMember def add_subparser( - subparsers: argparse._SubParsersAction, parents: List[argparse.ArgumentParser] + subparsers: argparse._SubParsersAction, parents: List[argparse.ArgumentParser] ): data_parser = subparsers.add_parser( "data", @@ -33,10 +33,9 @@ def add_subparser( _add_data_clean_parsers(data_subparsers, parents) -def _add_data_convert_parsers( - data_subparsers, parents: List[argparse.ArgumentParser] -): +def _add_data_convert_parsers(data_subparsers, parents: List[argparse.ArgumentParser]): import rasa.nlu.convert as convert + convert_parser = data_subparsers.add_parser( "convert", formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -57,9 +56,7 @@ def _add_data_convert_parsers( arguments.set_convert_arguments(convert_nlu_parser) -def _add_data_split_parsers( - data_subparsers, parents: List[argparse.ArgumentParser] -): +def _add_data_split_parsers(data_subparsers, parents: List[argparse.ArgumentParser]): split_parser = data_subparsers.add_parser( "split", formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -74,24 +71,26 @@ def _add_data_split_parsers( parents=parents, formatter_class=argparse.ArgumentDefaultsHelpFormatter, help="Performs a split of your NLU data into training and test data " - "according to the specified percentages.", + "according to the specified percentages.", ) nlu_split_parser.set_defaults(func=split_nlu_data) arguments.set_split_arguments(nlu_split_parser) -def _add_data_validate_parsers( - data_subparsers, parents: List[argparse.ArgumentParser] -): +def _add_data_validate_parsers(data_subparsers, parents: List[argparse.ArgumentParser]): validate_parser = data_subparsers.add_parser( "validate", formatter_class=argparse.ArgumentDefaultsHelpFormatter, parents=parents, help="Validates domain and data files to check for possible mistakes.", ) - validate_parser.add_argument("--max-history", type=int, default=None, - help="Assume this max_history setting for story structure validation.") + validate_parser.add_argument( + "--max-history", + type=int, + default=None, + help="Assume this max_history setting for story structure validation.", + ) validate_parser.set_defaults(func=validate_files) arguments.set_validator_arguments(validate_parser) @@ -102,17 +101,22 @@ def _add_data_validate_parsers( parents=parents, help="Checks for inconsistencies in the story files.", ) - story_structure_parser.add_argument("--max-history", type=int, - help="Assume this max_history setting for validation.") - story_structure_parser.add_argument("--prompt", action="store_true", default=False, - help="Ask how conflicts should be fixed") + story_structure_parser.add_argument( + "--max-history", + type=int, + help="Assume this max_history setting for validation.", + ) + story_structure_parser.add_argument( + "--prompt", + action="store_true", + default=False, + help="Ask how conflicts should be fixed", + ) story_structure_parser.set_defaults(func=validate_stories) arguments.set_validator_arguments(story_structure_parser) -def _add_data_clean_parsers( - data_subparsers, parents: List[argparse.ArgumentParser] -): +def _add_data_clean_parsers(data_subparsers, parents: List[argparse.ArgumentParser]): clean_parser = data_subparsers.add_parser( "clean", @@ -159,13 +163,16 @@ def validate_files(args): everything_is_alright = validator.verify_all(not args.fail_on_warnings) if not args.max_history: - logger.info("Will not test for inconsistencies in stories since " - "you did not provide --max-history.") + logger.info( + "Will not test for inconsistencies in stories since " + "you did not provide --max-history." + ) if everything_is_alright and args.max_history: # Only run story structure validation if everything else is fine # since this might take a while - everything_is_alright = validator.verify_story_structure(not args.fail_on_warnings, - max_history=args.max_history) + everything_is_alright = validator.verify_story_structure( + not args.fail_on_warnings, max_history=args.max_history + ) sys.exit(0) if everything_is_alright else sys.exit(1) @@ -188,9 +195,12 @@ def validate_stories(args): # This loads the stories and thus fills `STORY_NAME_TALLY` (see next code block) validator = loop.run_until_complete(Validator.from_importer(file_importer)) - # First check for duplicate story names - from rasa.core.training.structures import STORY_NAME_TALLY # ToDo: Avoid global variable - duplicate_story_names = {name: count for (name, count) in STORY_NAME_TALLY.items() if count > 1} + # Check for duplicate story names + from rasa.core.training.structures import STORY_NAME_TALLY + + duplicate_story_names = { + name: count for (name, count) in STORY_NAME_TALLY.items() if count > 1 + } story_names_unique = len(duplicate_story_names) == 0 if not story_names_unique: msg = "Found duplicate story names:\n" @@ -199,10 +209,12 @@ def validate_stories(args): logger.error(msg) # If names are unique, look for inconsistencies - everything_is_alright = validator.verify_story_structure( - not args.fail_on_warnings, - max_history=args.max_history - ) if story_names_unique else False + if story_names_unique: + everything_is_alright = validator.verify_story_structure( + not args.fail_on_warnings, max_history=args.max_history + ) + else: + everything_is_alright = False sys.exit(0) if everything_is_alright else sys.exit(1) @@ -211,27 +223,14 @@ def deduplicate_story_names(args): """Changes story names so as to make them unique. --EXPERIMENTAL-- """ - # ToDo: Make this work with multiple story files - - from rasa.importers.rasa import RasaFileImporter - - loop = asyncio.get_event_loop() - file_importer = RasaFileImporter( - domain_path=args.domain, training_data_paths=args.data - ) - import shutil story_file_names, _ = data.get_core_nlu_files(args.data) names = set() for file_name in story_file_names: - if file_name.endswith(".bak"): - continue - - shutil.copy2(file_name, file_name + ".bak") - - with open(file_name, "r") as in_file, \ - open(file_name + ".new", "w+") as out_file: + with open(file_name, "r") as in_file, open( + file_name + ".new", "w+" + ) as out_file: for line in in_file: line = line.rstrip() if line.startswith("## "): @@ -249,8 +248,3 @@ def deduplicate_story_names(args): out_file.write(line + "\n") shutil.move(file_name + ".new", file_name) - - # story_files, _ = data.get_core_nlu_files(args.data) - # story_steps = loop.run_until_complete(file_importer.get_story_steps()) - # for step in story_steps: - # print(step.block_name) diff --git a/rasa/core/training/structures.py b/rasa/core/training/structures.py index 6eaf838a961c..0066218c6b75 100644 --- a/rasa/core/training/structures.py +++ b/rasa/core/training/structures.py @@ -38,7 +38,7 @@ # will get increased with each new instance STEP_COUNT = 1 -# Tally over story names +# Tally over story names, filled when reading stories STORY_NAME_TALLY = {} From 4a36d1eb912427f5c849b786c5f5d0f074b9ae53 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 14:35:13 +0100 Subject: [PATCH 057/209] Clean up deduplicate_story_names --- rasa/cli/data.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 2250eff4ba28..3f529332d60b 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -220,31 +220,33 @@ def validate_stories(args): def deduplicate_story_names(args): - """Changes story names so as to make them unique. - --EXPERIMENTAL-- """ + """ + Changes story names so as to make them unique. + """ import shutil story_file_names, _ = data.get_core_nlu_files(args.data) - names = set() - for file_name in story_file_names: - with open(file_name, "r") as in_file, open( - file_name + ".new", "w+" - ) as out_file: + names = set() # Set of names we have already encountered + for in_file_name in story_file_names: + out_file_name = in_file_name + ".new" + with open(in_file_name, "r") as in_file, open(out_file_name, "w+") as out_file: for line in in_file: line = line.rstrip() if line.startswith("## "): - new_name = line[3:] - if new_name in names: - first = new_name + name = line[3:] + # Check if we have already encountered a story with this name + if name in names: + # Find a unique name + old_name = name k = 1 - while new_name in names: - new_name = first + f" ({k})" + while name in names: + name = old_name + f" ({k})" k += 1 - print(f"- replacing {first} with {new_name}") - names.add(new_name) - out_file.write(f"## {new_name}\n") + print(f"- replacing {old_name} with {name}") + names.add(name) + out_file.write(f"## {name}\n") else: out_file.write(line + "\n") - shutil.move(file_name + ".new", file_name) + shutil.move(in_file_name + ".new", in_file_name) From 4d984e6f64b15457af478fd29ac5a39f1ba0d9c7 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 14:37:23 +0100 Subject: [PATCH 058/209] Add some comments --- rasa/cli/data.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 3f529332d60b..94fe449f38bb 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -183,10 +183,12 @@ def validate_stories(args): from rasa.core.validator import Validator from rasa.importers.rasa import RasaFileImporter + # Check if a valid setting for `max_history` was given if not isinstance(args.max_history, int) or args.max_history < 1: logger.error("You have to provide a positive integer for --max-history.") sys.exit(1) + # Prepare story and domain file import loop = asyncio.get_event_loop() file_importer = RasaFileImporter( domain_path=args.domain, training_data_paths=args.data @@ -208,7 +210,7 @@ def validate_stories(args): msg += f" '{name}' appears {count}x\n" logger.error(msg) - # If names are unique, look for inconsistencies + # If names are unique, look for story conflicts if story_names_unique: everything_is_alright = validator.verify_story_structure( not args.fail_on_warnings, max_history=args.max_history From add4a2b3b4b0e5a8483e89f30818f9f9eb71091b Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 15:03:56 +0100 Subject: [PATCH 059/209] Write first test for StoryConflict class --- tests/core/test_storyconflict.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 tests/core/test_storyconflict.py diff --git a/tests/core/test_storyconflict.py b/tests/core/test_storyconflict.py new file mode 100644 index 000000000000..5019a34656be --- /dev/null +++ b/tests/core/test_storyconflict.py @@ -0,0 +1,25 @@ +from rasa.core.story_conflict import StoryConflict +from rasa.core.training.generator import TrainingDataGenerator +from rasa.core.validator import Validator +from rasa.importers.rasa import RasaFileImporter +from tests.core.conftest import DEFAULT_STORIES_FILE, DEFAULT_DOMAIN_PATH_WITH_SLOTS + + +async def test_find_no_conflicts(): + importer = RasaFileImporter( + domain_path=DEFAULT_DOMAIN_PATH_WITH_SLOTS, + training_data_paths=[DEFAULT_STORIES_FILE], + ) + validator = await Validator.from_importer(importer) + + trackers = TrainingDataGenerator( + validator.story_graph, + domain=validator.domain, + remove_duplicates=False, + augmentation_factor=0, + ).generate() + + # Create a list of `StoryConflict` objects + conflicts = StoryConflict.find_conflicts(trackers, validator.domain, 5) + + assert conflicts == [] From f9b6d16fa3f0277712181c6f5c6da7e3cf5f4d46 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 15:35:39 +0100 Subject: [PATCH 060/209] Add warning about non-markdown file cleaning --- rasa/cli/data.py | 7 +++++++ tests/core/test_storyconflict.py | 20 ++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 94fe449f38bb..d8b53fee49da 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -224,6 +224,8 @@ def validate_stories(args): def deduplicate_story_names(args): """ Changes story names so as to make them unique. + + WARNING: Only works for markdown files at the moment """ import shutil @@ -231,6 +233,11 @@ def deduplicate_story_names(args): story_file_names, _ = data.get_core_nlu_files(args.data) names = set() # Set of names we have already encountered for in_file_name in story_file_names: + if not in_file_name.endswith(".md"): + logger.warning( + f"Support for cleaning non-markdown file '{in_file_name}' is not yet implemented" + ) + continue out_file_name = in_file_name + ".new" with open(in_file_name, "r") as in_file, open(out_file_name, "w+") as out_file: for line in in_file: diff --git a/tests/core/test_storyconflict.py b/tests/core/test_storyconflict.py index 5019a34656be..078a5cf447f4 100644 --- a/tests/core/test_storyconflict.py +++ b/tests/core/test_storyconflict.py @@ -23,3 +23,23 @@ async def test_find_no_conflicts(): conflicts = StoryConflict.find_conflicts(trackers, validator.domain, 5) assert conflicts == [] + + +async def test_find_conflicts(): + importer = RasaFileImporter( + domain_path=DEFAULT_DOMAIN_PATH_WITH_SLOTS, + training_data_paths=[DEFAULT_STORIES_FILE], + ) + validator = await Validator.from_importer(importer) + + trackers = TrainingDataGenerator( + validator.story_graph, + domain=validator.domain, + remove_duplicates=False, + augmentation_factor=0, + ).generate() + + # Create a list of `StoryConflict` objects + conflicts = StoryConflict.find_conflicts(trackers, validator.domain, 1) + + assert conflicts == [] From bbb10b700d84dbbc50166daa38656836409db09a Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 15:52:29 +0100 Subject: [PATCH 061/209] Fix test to yield actual conflicts --- data/test_stories/stories_conflicting_1.md | 14 ++++++++++++++ tests/core/test_storyconflict.py | 9 +++++---- 2 files changed, 19 insertions(+), 4 deletions(-) create mode 100644 data/test_stories/stories_conflicting_1.md diff --git a/data/test_stories/stories_conflicting_1.md b/data/test_stories/stories_conflicting_1.md new file mode 100644 index 000000000000..001b7087c700 --- /dev/null +++ b/data/test_stories/stories_conflicting_1.md @@ -0,0 +1,14 @@ +## greetings +* greet + - utter_greet +> check_greet + +## happy path +> check_greet +* default + - utter_default + +## problem +> check_greet +* default + - utter_goodbye diff --git a/tests/core/test_storyconflict.py b/tests/core/test_storyconflict.py index 078a5cf447f4..24ce4ba29415 100644 --- a/tests/core/test_storyconflict.py +++ b/tests/core/test_storyconflict.py @@ -27,8 +27,8 @@ async def test_find_no_conflicts(): async def test_find_conflicts(): importer = RasaFileImporter( - domain_path=DEFAULT_DOMAIN_PATH_WITH_SLOTS, - training_data_paths=[DEFAULT_STORIES_FILE], + domain_path="data/test_domains/default.yml", + training_data_paths=["data/test_stories/stories_conflicting_1.md"], ) validator = await Validator.from_importer(importer) @@ -40,6 +40,7 @@ async def test_find_conflicts(): ).generate() # Create a list of `StoryConflict` objects - conflicts = StoryConflict.find_conflicts(trackers, validator.domain, 1) + conflicts = StoryConflict.find_conflicts(trackers, validator.domain, 5) - assert conflicts == [] + assert len(conflicts) == 1 + assert conflicts[0].conflicting_actions == ["utter_goodbye", "utter_default"] From 3398dcbbc61ed522fc459201314a8dba5c1bec22 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 16:04:03 +0100 Subject: [PATCH 062/209] Remove unnecessary arguments --- rasa/core/story_conflict.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index ec32f3956bda..c8a931ef69fb 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -12,13 +12,9 @@ class StoryConflict: def __init__( self, sliced_states: List[Optional[Dict[Text, float]]], - tracker: TrackerWithCachedStates, - event, ): self.sliced_states = sliced_states self.hash = hash(str(list(sliced_states))) - self.tracker = (tracker,) - self.event = event self._conflicting_actions = {} # {"action": ["story_1", ...], ...} self.correct_response = None @@ -55,7 +51,7 @@ def find_conflicts( } # Iterate once more over all states and note the (unhashed) state, - # tracker, and event for which a conflict occurs + # for which a conflict occurs conflicts = {} for tracker, event, sliced_states in StoryConflict._sliced_states_iterator( trackers, domain, max_history @@ -63,7 +59,7 @@ def find_conflicts( h = hash(str(list(sliced_states))) if h in rules: if h not in conflicts: - conflicts[h] = StoryConflict(sliced_states, tracker, event) + conflicts[h] = StoryConflict(sliced_states) conflicts[h].add_conflicting_action( action=event.as_story_string(), story_name=tracker.sender_id ) From ebab3a37607f8e0cfb1d07eacf68985fafdf83f1 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 16:15:04 +0100 Subject: [PATCH 063/209] Add more tests --- rasa/core/story_conflict.py | 3 ++- tests/core/test_storyconflict.py | 35 ++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index c8a931ef69fb..4e723fcc6fa4 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -159,7 +159,8 @@ def incorrect_stories(self) -> List[Text]: incorrect_stories.append(story) return incorrect_stories else: - return [] + # Return all stories + return [v[0] for v in self._conflicting_actions.values()] @property def has_prior_events(self) -> bool: diff --git a/tests/core/test_storyconflict.py b/tests/core/test_storyconflict.py index 24ce4ba29415..4f86cca2d0d1 100644 --- a/tests/core/test_storyconflict.py +++ b/tests/core/test_storyconflict.py @@ -44,3 +44,38 @@ async def test_find_conflicts(): assert len(conflicts) == 1 assert conflicts[0].conflicting_actions == ["utter_goodbye", "utter_default"] + + +async def test_add_conflicting_action(): + + sliced_states = [ + None, + {}, + {'intent_greet': 1.0, 'prev_action_listen': 1.0}, + {'prev_utter_greet': 1.0, 'intent_greet': 1.0} + ] + conflict = StoryConflict(sliced_states) + + conflict.add_conflicting_action("utter_greet", "xyz") + conflict.add_conflicting_action("utter_default", "uvw") + assert conflict.conflicting_actions == ["utter_greet", "utter_default"] + assert conflict.incorrect_stories == ["xyz", "uvw"] + + +async def test_has_prior_events(): + + sliced_states = [ + None, + {}, + {'intent_greet': 1.0, 'prev_action_listen': 1.0}, + {'prev_utter_greet': 1.0, 'intent_greet': 1.0} + ] + conflict = StoryConflict(sliced_states) + assert conflict.has_prior_events + + +async def test_has_no_prior_events(): + + sliced_states = [None] + conflict = StoryConflict(sliced_states) + assert not conflict.has_prior_events From f631ac465f27f1f471c6c3a2793c1a38d0daccb2 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 16:18:02 +0100 Subject: [PATCH 064/209] Add more tests --- tests/core/test_storyconflict.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/core/test_storyconflict.py b/tests/core/test_storyconflict.py index 4f86cca2d0d1..35f25527a57e 100644 --- a/tests/core/test_storyconflict.py +++ b/tests/core/test_storyconflict.py @@ -79,3 +79,16 @@ async def test_has_no_prior_events(): sliced_states = [None] conflict = StoryConflict(sliced_states) assert not conflict.has_prior_events + + +async def test_story_prior_to_conflict(): + + story = "* greet\n - utter_greet\n" + sliced_states = [ + None, + {}, + {'intent_greet': 1.0, 'prev_action_listen': 1.0}, + {'prev_utter_greet': 1.0, 'intent_greet': 1.0} + ] + conflict = StoryConflict(sliced_states) + assert conflict.story_prior_to_conflict() == story From d44b0f25381cb3aa1852861281ba1f074dc5d1da Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 16:22:39 +0100 Subject: [PATCH 065/209] Reformat with BLACK --- tests/core/test_storyconflict.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/core/test_storyconflict.py b/tests/core/test_storyconflict.py index 35f25527a57e..0b625bdc7911 100644 --- a/tests/core/test_storyconflict.py +++ b/tests/core/test_storyconflict.py @@ -51,8 +51,8 @@ async def test_add_conflicting_action(): sliced_states = [ None, {}, - {'intent_greet': 1.0, 'prev_action_listen': 1.0}, - {'prev_utter_greet': 1.0, 'intent_greet': 1.0} + {"intent_greet": 1.0, "prev_action_listen": 1.0}, + {"prev_utter_greet": 1.0, "intent_greet": 1.0}, ] conflict = StoryConflict(sliced_states) @@ -67,8 +67,8 @@ async def test_has_prior_events(): sliced_states = [ None, {}, - {'intent_greet': 1.0, 'prev_action_listen': 1.0}, - {'prev_utter_greet': 1.0, 'intent_greet': 1.0} + {"intent_greet": 1.0, "prev_action_listen": 1.0}, + {"prev_utter_greet": 1.0, "intent_greet": 1.0}, ] conflict = StoryConflict(sliced_states) assert conflict.has_prior_events @@ -87,8 +87,8 @@ async def test_story_prior_to_conflict(): sliced_states = [ None, {}, - {'intent_greet': 1.0, 'prev_action_listen': 1.0}, - {'prev_utter_greet': 1.0, 'intent_greet': 1.0} + {"intent_greet": 1.0, "prev_action_listen": 1.0}, + {"prev_utter_greet": 1.0, "intent_greet": 1.0}, ] conflict = StoryConflict(sliced_states) assert conflict.story_prior_to_conflict() == story From 68055ce819d4bb51ac40ccbe80ed47b6bdec7489 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 17:28:33 +0100 Subject: [PATCH 066/209] Add more tests --- tests/core/test_validator.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/core/test_validator.py b/tests/core/test_validator.py index 4e098fbdfd13..2ca0318c5f2c 100644 --- a/tests/core/test_validator.py +++ b/tests/core/test_validator.py @@ -40,6 +40,24 @@ async def test_verify_valid_utterances(): assert validator.verify_utterances() +async def test_verify_story_structure(): + importer = RasaFileImporter( + domain_path="data/test_domains/default.yml", + training_data_paths=[DEFAULT_STORIES_FILE], + ) + validator = await Validator.from_importer(importer) + assert validator.verify_story_structure(ignore_warnings=False) + + +async def test_verify_bad_story_structure(): + importer = RasaFileImporter( + domain_path="data/test_domains/default.yml", + training_data_paths=["data/test_stories/stories_conflicting_1.md"], + ) + validator = await Validator.from_importer(importer) + assert not validator.verify_story_structure(ignore_warnings=False) + + async def test_fail_on_invalid_utterances(tmpdir): # domain and stories are from different domain and should produce warnings invalid_domain = str(tmpdir / "invalid_domain.yml") From 9062f4435540733363b69a1f310e0107501c0215 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 17:29:52 +0100 Subject: [PATCH 067/209] Optimize imports --- tests/core/test_validator.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/core/test_validator.py b/tests/core/test_validator.py index 2ca0318c5f2c..feb90a851b78 100644 --- a/tests/core/test_validator.py +++ b/tests/core/test_validator.py @@ -1,14 +1,10 @@ import pytest -import logging from rasa.core.validator import Validator from rasa.importers.rasa import RasaFileImporter from tests.core.conftest import ( - DEFAULT_DOMAIN_PATH_WITH_SLOTS, DEFAULT_STORIES_FILE, DEFAULT_NLU_DATA, ) -from rasa.core.domain import Domain -from rasa.nlu.training_data import TrainingData import rasa.utils.io as io_utils From b3f5a13c1b01e4d3f73d70d75de5ca681ac251a7 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 17:43:19 +0100 Subject: [PATCH 068/209] Add more tests --- tests/cli/test_rasa_data.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/cli/test_rasa_data.py b/tests/cli/test_rasa_data.py index 6c89717940b1..644a02ec70c5 100644 --- a/tests/cli/test_rasa_data.py +++ b/tests/cli/test_rasa_data.py @@ -75,3 +75,9 @@ def test_validate_files_exit_early(): assert pytest_e.type == SystemExit assert pytest_e.value.code == 1 + + +def test_clean(run_in_default_project: Callable[..., RunResult]): + # Nothing to be cleaned in the init project + output = run_in_default_project("data", "clean") + assert output.outlines == [] From cc8d850c736d514ddf0ec9d9d1349eddccb9b650 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 18:15:54 +0100 Subject: [PATCH 069/209] Add more tests --- rasa/cli/data.py | 6 ++++-- tests/cli/test_rasa_data.py | 36 +++++++++++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index d8b53fee49da..10e45bdee207 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -228,6 +228,8 @@ def deduplicate_story_names(args): WARNING: Only works for markdown files at the moment """ + logger.info("Replacing duplicate story names...") + import shutil story_file_names, _ = data.get_core_nlu_files(args.data) @@ -241,7 +243,7 @@ def deduplicate_story_names(args): out_file_name = in_file_name + ".new" with open(in_file_name, "r") as in_file, open(out_file_name, "w+") as out_file: for line in in_file: - line = line.rstrip() + line = line.strip() if line.startswith("## "): name = line[3:] # Check if we have already encountered a story with this name @@ -252,7 +254,7 @@ def deduplicate_story_names(args): while name in names: name = old_name + f" ({k})" k += 1 - print(f"- replacing {old_name} with {name}") + logger.info(f"- replacing {old_name} with {name}") names.add(name) out_file.write(f"## {name}\n") else: diff --git a/tests/cli/test_rasa_data.py b/tests/cli/test_rasa_data.py index 644a02ec70c5..d50c751796c7 100644 --- a/tests/cli/test_rasa_data.py +++ b/tests/cli/test_rasa_data.py @@ -77,7 +77,41 @@ def test_validate_files_exit_early(): assert pytest_e.value.code == 1 -def test_clean(run_in_default_project: Callable[..., RunResult]): +def test_clean_init(run_in_default_project: Callable[..., RunResult]): # Nothing to be cleaned in the init project output = run_in_default_project("data", "clean") assert output.outlines == [] + assert output.errlines == [] + + +def test_clean(run: Callable[..., RunResult]): + os.mkdir("data") + with open("data/stories.md", "w+") as file: + file.write("## story\n" + "* greet\n" + " - utter_greet\n" + "\n" + "## story\n" + "* bye\n" + " - utter_bye\n") + + with open("domain.yml", "w+") as file: + file.write("actions:\n" + "- utter_greet\n" + "- utter_bye\n" + "intents:\n" + "- greet\n" + "- bye\n" + "templates:\n" + " utter_greet:\n" + " - text: \"hi\"\n" + " utter_bye:\n" + " - text: \"bye\"\n") + + output = run("data", "clean") + # One replacement + headline + assert len(output.errlines) == 2 + + output = run("data", "validate", "stories", "--max-history", "3") + # No errors: + assert "No story structure conflicts found" in output.errlines[-1] From 0c66536a474a04e44926f0ae5d2676cd24611060 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 18:23:57 +0100 Subject: [PATCH 070/209] Apply BLACK formatting --- tests/cli/test_rasa_data.py | 40 ++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/tests/cli/test_rasa_data.py b/tests/cli/test_rasa_data.py index d50c751796c7..fee2c62313db 100644 --- a/tests/cli/test_rasa_data.py +++ b/tests/cli/test_rasa_data.py @@ -87,26 +87,30 @@ def test_clean_init(run_in_default_project: Callable[..., RunResult]): def test_clean(run: Callable[..., RunResult]): os.mkdir("data") with open("data/stories.md", "w+") as file: - file.write("## story\n" - "* greet\n" - " - utter_greet\n" - "\n" - "## story\n" - "* bye\n" - " - utter_bye\n") + file.write( + "## story\n" + "* greet\n" + " - utter_greet\n" + "\n" + "## story\n" + "* bye\n" + " - utter_bye\n" + ) with open("domain.yml", "w+") as file: - file.write("actions:\n" - "- utter_greet\n" - "- utter_bye\n" - "intents:\n" - "- greet\n" - "- bye\n" - "templates:\n" - " utter_greet:\n" - " - text: \"hi\"\n" - " utter_bye:\n" - " - text: \"bye\"\n") + file.write( + "actions:\n" + "- utter_greet\n" + "- utter_bye\n" + "intents:\n" + "- greet\n" + "- bye\n" + "templates:\n" + " utter_greet:\n" + ' - text: "hi"\n' + " utter_bye:\n" + ' - text: "bye"\n' + ) output = run("data", "clean") # One replacement + headline From c2ec80d1ad6d1a3d17eb33c46ce9f211ed8b7148 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 11 Dec 2019 19:20:54 +0100 Subject: [PATCH 071/209] Explain rasa data validate stories in docs --- docs/user-guide/validate-files.rst | 50 +++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/docs/user-guide/validate-files.rst b/docs/user-guide/validate-files.rst index 90a07c03a882..209564fd0214 100644 --- a/docs/user-guide/validate-files.rst +++ b/docs/user-guide/validate-files.rst @@ -18,7 +18,7 @@ You can run it with the following command: rasa data validate -The script above runs all the validations on your files. Here is the list of options to +The script above runs most of the validations on your files. Here is the list of options to the script: .. program-output:: rasa data validate --help @@ -60,3 +60,51 @@ To use these functions it is necessary to create a `Validator` object and initia stories='data/stories.md') validator.verify_all() + + +Test Story Files for Conflicts +------------------------------ + +In addition to the default tests described above, you can also do a more in-depth structural test of your stories. +In particular, you can test if your stories are inconsistent, i.e. if different bot actions follow after the same dialogue history. +Here is a more detailed explanation. + +The purpose of Rasa Core is to predict the correct next bot action, given the dialogue state, that is the history of intents, entities, slots, and actions. +Crucially, Rasa Core assumes that for any given dialogue state, exactly one next action is the correct one. +If your stories don’t reflect that, Rasa Core cannot learn the correct behaviour. + +Take, for example, the following two stories: + +.. code-block:: markdown + + ## Story 1 + * greet + - utter_greet + * inform_happy + - utter_happy + - utter_goodbye + + ## Story 2 + * greet + - utter_greet + * inform_happy + - utter_goodbye + +These two stories are inconsistent, because Rasa Core cannot know if it should predict `utter_happy` or `utter_goodbye` after `inform_happy`, as there is nothing that would distinguish the dialogue states at `inform_happy` in the two stories and the subsequent actions are different in Story 1 and Story 2. + +This conflict can now be automatically identified with our new story structure tool. +Just use `rasa data validate` in the command line, as follows: + +.. code-block:: bash + + rasa data validate stories --max-history 3 + > 2019-12-09 09:32:13 INFO rasa.core.validator - Story structure validation... + > 2019-12-09 09:32:13 INFO rasa.core.validator - Assuming max_history = 3 + > Processed Story Blocks: 100% 2/2 [00:00<00:00, 3237.59it/s, # trackers=1] + > 2019-12-09 09:32:13 WARNING rasa.core.validator - CONFLICT after intent 'inform_happy': + > utter_goodbye predicted in 'Story 2' + > utter_happy predicted in 'Story 1' + +Here we specify a `max-history` value of 3. +This means, that 3 events (user / bot actions) are taken into account for action prediction, but the particular setting does not matter for this example, because regardless of how long of a history you take into account, the conflict always exists. + From 9c26332ab7c66243b2c9ddcfb7ff3174b1075ba8 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 13 Dec 2019 10:07:23 +0100 Subject: [PATCH 072/209] Revert check for duplicates and data clean --- rasa/cli/data.py | 79 ++------------------------------ rasa/core/training/dsl.py | 7 --- rasa/core/training/structures.py | 3 -- tests/cli/test_rasa_data.py | 44 ------------------ 4 files changed, 4 insertions(+), 129 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index e18b4455d3ff..930b6f8a412f 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -31,7 +31,6 @@ def add_subparser( _add_data_convert_parsers(data_subparsers, parents) _add_data_split_parsers(data_subparsers, parents) _add_data_validate_parsers(data_subparsers, parents) - _add_data_clean_parsers(data_subparsers, parents) def _add_data_convert_parsers(data_subparsers, parents: List[argparse.ArgumentParser]): @@ -117,18 +116,6 @@ def _add_data_validate_parsers(data_subparsers, parents: List[argparse.ArgumentP arguments.set_validator_arguments(story_structure_parser) -def _add_data_clean_parsers(data_subparsers, parents: List[argparse.ArgumentParser]): - - clean_parser = data_subparsers.add_parser( - "clean", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - parents=parents, - help="[Experimental] Ensures that story names are unique.", - ) - clean_parser.set_defaults(func=deduplicate_story_names) - arguments.set_validator_arguments(clean_parser) - - def split_nlu_data(args) -> None: from rasa.nlu.training_data.loading import load_data from rasa.nlu.training_data.util import get_file_format @@ -195,70 +182,12 @@ def validate_stories(args): domain_path=args.domain, training_data_paths=args.data ) - # This loads the stories and thus fills `STORY_NAME_TALLY` (see next code block) + # Loads the stories validator = loop.run_until_complete(Validator.from_importer(file_importer)) - # Check for duplicate story names - from rasa.core.training.structures import STORY_NAME_TALLY - - duplicate_story_names = { - name: count for (name, count) in STORY_NAME_TALLY.items() if count > 1 - } - story_names_unique = len(duplicate_story_names) == 0 - if not story_names_unique: - msg = "Found duplicate story names:\n" - for (name, count) in duplicate_story_names.items(): - msg += f" '{name}' appears {count}x\n" - logger.error(msg) - # If names are unique, look for story conflicts - if story_names_unique: - everything_is_alright = validator.verify_story_structure( - not args.fail_on_warnings, max_history=args.max_history - ) - else: - everything_is_alright = False + everything_is_alright = validator.verify_story_structure( + not args.fail_on_warnings, max_history=args.max_history + ) sys.exit(0) if everything_is_alright else sys.exit(1) - - -def deduplicate_story_names(args): - """ - Changes story names so as to make them unique. - - WARNING: Only works for markdown files at the moment - """ - - logger.info("Replacing duplicate story names...") - - import shutil - - story_file_names, _ = data.get_core_nlu_files(args.data) - names = set() # Set of names we have already encountered - for in_file_name in story_file_names: - if not in_file_name.endswith(".md"): - logger.warning( - f"Support for cleaning non-markdown file '{in_file_name}' is not yet implemented" - ) - continue - out_file_name = in_file_name + ".new" - with open(in_file_name, "r") as in_file, open(out_file_name, "w+") as out_file: - for line in in_file: - line = line.strip() - if line.startswith("## "): - name = line[3:] - # Check if we have already encountered a story with this name - if name in names: - # Find a unique name - old_name = name - k = 1 - while name in names: - name = old_name + f" ({k})" - k += 1 - logger.info(f"- replacing {old_name} with {name}") - names.add(name) - out_file.write(f"## {name}\n") - else: - out_file.write(line + "\n") - - shutil.move(in_file_name + ".new", in_file_name) diff --git a/rasa/core/training/dsl.py b/rasa/core/training/dsl.py index c16f33412230..681083f64284 100644 --- a/rasa/core/training/dsl.py +++ b/rasa/core/training/dsl.py @@ -20,7 +20,6 @@ GENERATED_CHECKPOINT_PREFIX, GENERATED_HASH_LENGTH, FORM_PREFIX, - STORY_NAME_TALLY, ) from rasa.nlu.training_data.formats import MarkdownReader from rasa.core.domain import Domain @@ -392,12 +391,6 @@ def new_story_part(self, name): self._add_current_stories_to_result() self.current_step_builder = StoryStepBuilder(name) - # Tally names of stories, so we can identify duplicate names - if name not in STORY_NAME_TALLY: - STORY_NAME_TALLY[name] = 1 - else: - STORY_NAME_TALLY[name] += 1 - def add_checkpoint(self, name: Text, conditions: Optional[Dict[Text, Any]]) -> None: # Ensure story part already has a name diff --git a/rasa/core/training/structures.py b/rasa/core/training/structures.py index 3f3413cbc1d5..a4146a20efc6 100644 --- a/rasa/core/training/structures.py +++ b/rasa/core/training/structures.py @@ -39,9 +39,6 @@ # will get increased with each new instance STEP_COUNT = 1 -# Tally over story names, filled when reading stories -STORY_NAME_TALLY = {} - class StoryStringHelper: """A helper class to mark story steps that are inside a form with `form: ` diff --git a/tests/cli/test_rasa_data.py b/tests/cli/test_rasa_data.py index 84198eb382a2..3021e9ab12e7 100644 --- a/tests/cli/test_rasa_data.py +++ b/tests/cli/test_rasa_data.py @@ -76,47 +76,3 @@ def test_validate_files_exit_early(): assert pytest_e.type == SystemExit assert pytest_e.value.code == 1 - - -def test_clean_init(run_in_default_project: Callable[..., RunResult]): - # Nothing to be cleaned in the init project - output = run_in_default_project("data", "clean") - assert output.outlines == [] - assert output.errlines == [] - - -def test_clean(run: Callable[..., RunResult]): - os.mkdir("data") - with open("data/stories.md", "w+") as file: - file.write( - "## story\n" - "* greet\n" - " - utter_greet\n" - "\n" - "## story\n" - "* bye\n" - " - utter_bye\n" - ) - - with open("domain.yml", "w+") as file: - file.write( - "actions:\n" - "- utter_greet\n" - "- utter_bye\n" - "intents:\n" - "- greet\n" - "- bye\n" - "templates:\n" - " utter_greet:\n" - ' - text: "hi"\n' - " utter_bye:\n" - ' - text: "bye"\n' - ) - - output = run("data", "clean") - # One replacement + headline - assert len(output.errlines) == 2 - - output = run("data", "validate", "stories", "--max-history", "3") - # No errors: - assert "No story structure conflicts found" in output.errlines[-1] From ebd48f7c1ca81849427c87ba0e87a83510a77707 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 13 Dec 2019 10:36:50 +0100 Subject: [PATCH 073/209] Fix test_data_validate_help --- tests/cli/test_rasa_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cli/test_rasa_data.py b/tests/cli/test_rasa_data.py index 3021e9ab12e7..010098363e6c 100644 --- a/tests/cli/test_rasa_data.py +++ b/tests/cli/test_rasa_data.py @@ -60,8 +60,8 @@ def test_data_convert_help(run: Callable[..., RunResult]): def test_data_validate_help(run: Callable[..., RunResult]): output = run("data", "validate", "--help") - help_text = """usage: rasa data validate [-h] [-v] [-vv] [--quiet] [--fail-on-warnings] - [-d DOMAIN] [--data DATA]""" + help_text = """usage: rasa data validate [-h] [-v] [-vv] [--quiet] + [--max-history MAX_HISTORY] [--fail-on-warnings]""" lines = help_text.split("\n") From 1e9705b1b6eb36669904aa1d55e7e05ee805fefd Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 13 Dec 2019 10:44:51 +0100 Subject: [PATCH 074/209] Add more tests --- data/test_stories/stories_conflicting_1.md | 19 ++++++++-------- data/test_stories/stories_conflicting_2.md | 14 ++++++++++++ tests/core/test_storyconflict.py | 25 +++++++++++++++++++++- 3 files changed, 48 insertions(+), 10 deletions(-) create mode 100644 data/test_stories/stories_conflicting_2.md diff --git a/data/test_stories/stories_conflicting_1.md b/data/test_stories/stories_conflicting_1.md index 001b7087c700..d772f46ee33a 100644 --- a/data/test_stories/stories_conflicting_1.md +++ b/data/test_stories/stories_conflicting_1.md @@ -1,14 +1,15 @@ -## greetings +## story 1 +* greet + - utter_greet +* greet + - utter_greet * greet - utter_greet -> check_greet -## happy path -> check_greet +## story 2 * default + - utter_greet +* greet + - utter_greet +* greet - utter_default - -## problem -> check_greet -* default - - utter_goodbye diff --git a/data/test_stories/stories_conflicting_2.md b/data/test_stories/stories_conflicting_2.md new file mode 100644 index 000000000000..001b7087c700 --- /dev/null +++ b/data/test_stories/stories_conflicting_2.md @@ -0,0 +1,14 @@ +## greetings +* greet + - utter_greet +> check_greet + +## happy path +> check_greet +* default + - utter_default + +## problem +> check_greet +* default + - utter_goodbye diff --git a/tests/core/test_storyconflict.py b/tests/core/test_storyconflict.py index 0b625bdc7911..c67c73fd4aab 100644 --- a/tests/core/test_storyconflict.py +++ b/tests/core/test_storyconflict.py @@ -25,7 +25,7 @@ async def test_find_no_conflicts(): assert conflicts == [] -async def test_find_conflicts(): +async def test_find_conflicts_in_short_history(): importer = RasaFileImporter( domain_path="data/test_domains/default.yml", training_data_paths=["data/test_stories/stories_conflicting_1.md"], @@ -39,6 +39,29 @@ async def test_find_conflicts(): augmentation_factor=0, ).generate() + # `max_history = 3` is too small, so a conflict must arise + conflicts = StoryConflict.find_conflicts(trackers, validator.domain, 3) + assert len(conflicts) == 1 + + # With `max_history = 4` the conflict should disappear + conflicts = StoryConflict.find_conflicts(trackers, validator.domain, 4) + assert len(conflicts) == 0 + + +async def test_find_conflicts_checkpoints(): + importer = RasaFileImporter( + domain_path="data/test_domains/default.yml", + training_data_paths=["data/test_stories/stories_conflicting_2.md"], + ) + validator = await Validator.from_importer(importer) + + trackers = TrainingDataGenerator( + validator.story_graph, + domain=validator.domain, + remove_duplicates=False, + augmentation_factor=0, + ).generate() + # Create a list of `StoryConflict` objects conflicts = StoryConflict.find_conflicts(trackers, validator.domain, 5) From 69df7b4efd8527ddf4af41656241819fa263ad79 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 13 Dec 2019 11:13:35 +0100 Subject: [PATCH 075/209] Let data validate check stories even if other tests unsuccessful --- rasa/cli/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 930b6f8a412f..8abd1cf2c1a7 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -155,7 +155,7 @@ def validate_files(args) -> NoReturn: "Will not test for inconsistencies in stories since " "you did not provide --max-history." ) - if everything_is_alright and args.max_history: + if args.max_history: # Only run story structure validation if everything else is fine # since this might take a while everything_is_alright = validator.verify_story_structure( From 63d56999684fd4a08527e240e834793ac1571a92 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 13 Dec 2019 12:12:35 +0100 Subject: [PATCH 076/209] Fix test_verify_bad_story_structure --- tests/core/test_validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_validator.py b/tests/core/test_validator.py index 6f02eebe0ac5..4180d999eba3 100644 --- a/tests/core/test_validator.py +++ b/tests/core/test_validator.py @@ -48,7 +48,7 @@ async def test_verify_story_structure(): async def test_verify_bad_story_structure(): importer = RasaFileImporter( domain_path="data/test_domains/default.yml", - training_data_paths=["data/test_stories/stories_conflicting_1.md"], + training_data_paths=["data/test_stories/stories_conflicting_2.md"], ) validator = await Validator.from_importer(importer) assert not validator.verify_story_structure(ignore_warnings=False) From c1310a934fbfb9c8a93271fd47d8ec5294bfee8b Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 13 Dec 2019 12:13:33 +0100 Subject: [PATCH 077/209] Apply BLACK --- rasa/core/story_conflict.py | 3 +-- rasa/core/validator.py | 6 +----- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index 4e723fcc6fa4..076a6f2d7160 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -10,8 +10,7 @@ class StoryConflict: def __init__( - self, - sliced_states: List[Optional[Dict[Text, float]]], + self, sliced_states: List[Optional[Dict[Text, float]]], ): self.sliced_states = sliced_states self.hash = hash(str(list(sliced_states))) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index b3ebb2089807..df8fbb810ac5 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -230,11 +230,7 @@ def verify_all(self, ignore_warnings: bool = True) -> bool: logger.info("Validating utterances...") stories_are_valid = self.verify_utterances_in_stories(ignore_warnings) - return ( - intents_are_valid - and stories_are_valid - and there_is_no_duplication - ) + return intents_are_valid and stories_are_valid and there_is_no_duplication def verify_domain_validity(self) -> bool: """Checks whether the domain returned by the importer is empty, indicating an invalid domain.""" From 3ffc9456c4e4ccd1564821cf8299605becb8c35b Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 13 Dec 2019 13:29:21 +0100 Subject: [PATCH 078/209] Simplify code --- rasa/core/story_conflict.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index 076a6f2d7160..3ad4353fc047 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -31,6 +31,24 @@ def find_conflicts( :return: List of conflicts """ + # We do this in two steps, to reduce memory consumption: + + # Create a 'state -> list of actions' dict, where the state is + # represented by its hash + rules = StoryConflict._find_conflicting_states(trackers, domain, max_history) + + # Iterate once more over all states and note the (unhashed) state, + # for which a conflict occurs + conflicts = StoryConflict._build_conflicts_from_states( + trackers, domain, max_history, rules + ) + + return conflicts + + @staticmethod + def _find_conflicting_states( + trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int + ) -> Dict[Text, Optional[List[Text]]]: # Create a 'state -> list of actions' dict, where the state is # represented by its hash rules = {} @@ -45,10 +63,17 @@ def find_conflicts( rules[h] = [event.as_story_string()] # Keep only conflicting rules - rules = { + return { state: actions for (state, actions) in rules.items() if len(actions) > 1 } + @staticmethod + def _build_conflicts_from_states( + trackers: List["TrackerWithCachedStates"], + domain: Domain, + max_history: int, + rules: Dict[Text, Optional[List[Text]]], + ): # Iterate once more over all states and note the (unhashed) state, # for which a conflict occurs conflicts = {} @@ -69,7 +94,7 @@ def find_conflicts( @staticmethod def _sliced_states_iterator( trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int - ) -> (TrackerWithCachedStates, Event, List[Optional[Dict[Text, float]]]): + ) -> (TrackerWithCachedStates, Event, List[Dict[Text, float]]): """ Iterate over all given trackers and all sliced states within each tracker, where the slicing is based on `max_history` From 8ab741a4307b7bc6ab940f2fc2d32d145d927fec Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 13 Dec 2019 13:32:53 +0100 Subject: [PATCH 079/209] Fix Pygments lexer type --- docs/user-guide/validate-files.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user-guide/validate-files.rst b/docs/user-guide/validate-files.rst index 2929a7750af7..6bdd41b0d216 100644 --- a/docs/user-guide/validate-files.rst +++ b/docs/user-guide/validate-files.rst @@ -80,7 +80,7 @@ If your stories don’t reflect that, Rasa Core cannot learn the correct behavio Take, for example, the following two stories: -.. code-block:: markdown +.. code-block:: md ## Story 1 * greet From 4f755f597c15eeabb3f25f4b3c297ce98e6e299f Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 13 Dec 2019 13:50:45 +0100 Subject: [PATCH 080/209] Add warning about uniqueness of story names --- docs/user-guide/validate-files.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/user-guide/validate-files.rst b/docs/user-guide/validate-files.rst index 6bdd41b0d216..189a1bd85c91 100644 --- a/docs/user-guide/validate-files.rst +++ b/docs/user-guide/validate-files.rst @@ -113,3 +113,7 @@ Just use `rasa data validate` in the command line, as follows: Here we specify a `max-history` value of 3. This means, that 3 events (user / bot actions) are taken into account for action prediction, but the particular setting does not matter for this example, because regardless of how long of a history you take into account, the conflict always exists. +.. warning:: + The `rasa data validate stories` script assumes that all your **story names are unique**. + If your stories are in the Markdown format, you may find duplicate names with a command like + `grep -h "##" data/*.md | uniq -c | grep "^[^1]"`. From d645cd5f5bde4d62b365c8f2b7fadef201f789cf Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 13 Dec 2019 15:24:25 +0100 Subject: [PATCH 081/209] Clarify code of _get_prev_event --- rasa/core/story_conflict.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index 3ad4353fc047..d3e3f5f238b2 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -128,16 +128,22 @@ def _get_prev_event( :param state: Element of sliced states :return: (type, name) strings of the prior event """ - if not state: - return None, None - result = (None, None) + prev_event_type = None + prev_event_name = None + for k in state: - if k.startswith(PREV_PREFIX): - if k[len(PREV_PREFIX) :] != ACTION_LISTEN_NAME: - result = ("action", k[len(PREV_PREFIX) :]) - elif k.startswith(MESSAGE_INTENT_ATTRIBUTE + "_") and not result[0]: - result = ("intent", k[len(MESSAGE_INTENT_ATTRIBUTE + "_") :]) - return result + if ( + k.startswith(PREV_PREFIX) + and k[len(PREV_PREFIX) :] != ACTION_LISTEN_NAME + ): + prev_event_type = "action" + prev_event_name = k[len(PREV_PREFIX) :] + + if not prev_event_type and k.startswith(MESSAGE_INTENT_ATTRIBUTE + "_"): + prev_event_type = "intent" + prev_event_name = k[len(MESSAGE_INTENT_ATTRIBUTE + "_") :] + + return prev_event_type, prev_event_name def add_conflicting_action(self, action: Text, story_name: Text): """ From b7c0dd2be825a6ca6ae6bce6fe8c96fcfa025b67 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 13 Dec 2019 15:32:56 +0100 Subject: [PATCH 082/209] Clarify code of _build_conflicts_from_states --- rasa/core/story_conflict.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index d3e3f5f238b2..3b8a889b5abb 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -81,9 +81,11 @@ def _build_conflicts_from_states( trackers, domain, max_history ): h = hash(str(list(sliced_states))) + + if h in rules and h not in conflicts: + conflicts[h] = StoryConflict(sliced_states) + if h in rules: - if h not in conflicts: - conflicts[h] = StoryConflict(sliced_states) conflicts[h].add_conflicting_action( action=event.as_story_string(), story_name=tracker.sender_id ) From 6c53b6469e6a4e350dea7548e7bed09876a0401a Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 13 Dec 2019 15:40:10 +0100 Subject: [PATCH 083/209] Clarify code of incorrect_stories --- rasa/core/story_conflict.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index 3b8a889b5abb..c602ed86cf9c 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -180,20 +180,21 @@ def incorrect_stories(self) -> List[Text]: Returns a list of story names that have not yet been corrected. """ - if self.correct_response: - incorrect_stories = [] - for stories in [ - s - for (a, s) in self._conflicting_actions.items() - if a != self.correct_response - ]: - for story in stories: - incorrect_stories.append(story) - return incorrect_stories - else: + if not self.correct_response: # Return all stories return [v[0] for v in self._conflicting_actions.values()] + incorrect_stories = [] + story_lists_with_uncorrected_responses = [ + s + for (a, s) in self._conflicting_actions.items() + if a != self.correct_response + ] + for stories in story_lists_with_uncorrected_responses: + for story in stories: + incorrect_stories.append(story) + return incorrect_stories + @property def has_prior_events(self) -> bool: """ From cb471c658894c5fa85a93138f739c9109ba4002e Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 13 Dec 2019 15:41:53 +0100 Subject: [PATCH 084/209] Clarify code of story_prior_to_conflict --- rasa/core/story_conflict.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index c602ed86cf9c..edc2aef54119 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -210,12 +210,15 @@ def story_prior_to_conflict(self) -> Text: """ result = "" for state in self.sliced_states: - if state: - event_type, event_name = self._get_prev_event(state) - if event_type == "intent": - result += f"* {event_name}\n" - else: - result += f" - {event_name}\n" + if not state: + continue + + event_type, event_name = self._get_prev_event(state) + if event_type == "intent": + result += f"* {event_name}\n" + else: + result += f" - {event_name}\n" + return result def __str__(self): From 9d3a4259f3eb9eca734a985899fb8fe7ed9d0f50 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 13 Dec 2019 17:20:30 +0100 Subject: [PATCH 085/209] Clarify code of __str__ --- rasa/core/story_conflict.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index edc2aef54119..925dfcb931a9 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -232,14 +232,17 @@ def __str__(self): # List which stories are in conflict with one another for action, stories in self._conflicting_actions.items(): # Summarize if necessary - if len(stories) == 1: - stories = f"'{stories[0]}'" - elif len(stories) == 2: - stories = f"'{stories[0]}' and '{stories[1]}'" - elif len(stories) == 3: - stories = f"'{stories[0]}', '{stories[1]}', and '{stories[2]}'" - elif len(stories) >= 4: - stories = f"'{stories[0]}' and {len(stories) - 1} other trackers" - conflict_string += f" {action} predicted in {stories}\n" + story_desc = { + 1: "'{}'", + 2: "'{}' and '{}'", + 3: "'{}', '{}', and '{}'", + }.get(len(stories)) + if story_desc: + story_desc = story_desc.format(*stories) + else: + # Four or more stories are present + story_desc = f"'{stories[0]}' and {len(stories) - 1} other trackers" + + conflict_string += f" {action} predicted in {story_desc}\n" return conflict_string From 81d366c96dd0e7f4b94a9692fe8f2e68ce7b63c7 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 13 Dec 2019 17:24:58 +0100 Subject: [PATCH 086/209] Clarify code of _find_conflicting_states --- rasa/core/story_conflict.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index 925dfcb931a9..293390f1768b 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -56,11 +56,10 @@ def _find_conflicting_states( trackers, domain, max_history ): h = hash(str(list(sliced_states))) - if h in rules: - if event.as_story_string() not in rules[h]: - rules[h] += [event.as_story_string()] - else: + if h not in rules: rules[h] = [event.as_story_string()] + elif h in rules and event.as_story_string() not in rules[h]: + rules[h] += [event.as_story_string()] # Keep only conflicting rules return { From 37f0ba5f37c93c56ee6006ae2284118bae1a62a3 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 13 Dec 2019 18:07:10 +0100 Subject: [PATCH 087/209] Remove ToDo --- rasa/core/story_conflict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index 293390f1768b..cfb69a434641 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -108,7 +108,7 @@ def _sliced_states_iterator( states = tracker.past_states(domain) states = [ dict(state) for state in states - ] # ToDo: Check against rasa/core/featurizers.py:318 + ] # From rasa/core/featurizers.py:318 idx = 0 for event in tracker.events: From 9a36c1596ebab0bca708272f2b17e68feeeac1fa Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Sun, 15 Dec 2019 19:56:48 +0100 Subject: [PATCH 088/209] Fix _get_prev_event --- rasa/core/story_conflict.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index cfb69a434641..863afae58330 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -132,6 +132,9 @@ def _get_prev_event( prev_event_type = None prev_event_name = None + if not state: + return prev_event_type, prev_event_name + for k in state: if ( k.startswith(PREV_PREFIX) From 0cd92d0407546597e6b3f94c16e81e9360d19d2c Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Sun, 15 Dec 2019 19:59:23 +0100 Subject: [PATCH 089/209] Fix _find_conflicting_states --- rasa/core/story_conflict.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index 863afae58330..7e3befc24ee7 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -48,7 +48,7 @@ def find_conflicts( @staticmethod def _find_conflicting_states( trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int - ) -> Dict[Text, Optional[List[Text]]]: + ) -> Dict[int, Optional[List[Text]]]: # Create a 'state -> list of actions' dict, where the state is # represented by its hash rules = {} @@ -63,7 +63,9 @@ def _find_conflicting_states( # Keep only conflicting rules return { - state: actions for (state, actions) in rules.items() if len(actions) > 1 + state_hash: actions + for (state_hash, actions) in rules.items() + if len(actions) > 1 } @staticmethod From 87e3d755e141caeefd69c43d0e5131b12327153b Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Sun, 15 Dec 2019 21:54:44 +0100 Subject: [PATCH 090/209] Fix _build_conflicts_from_states --- rasa/core/story_conflict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index 7e3befc24ee7..3a63a84ab12e 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -73,7 +73,7 @@ def _build_conflicts_from_states( trackers: List["TrackerWithCachedStates"], domain: Domain, max_history: int, - rules: Dict[Text, Optional[List[Text]]], + rules: Dict[int, Optional[List[Text]]], ): # Iterate once more over all states and note the (unhashed) state, # for which a conflict occurs From a9414f785a5ef178c5a4f9d5fd0e37352ffdba0f Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 3 Jan 2020 15:01:55 +0100 Subject: [PATCH 091/209] Update docs/user-guide/validate-files.rst Co-Authored-By: Tom Bocklisch --- docs/user-guide/validate-files.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user-guide/validate-files.rst b/docs/user-guide/validate-files.rst index 189a1bd85c91..e5a7b67100a4 100644 --- a/docs/user-guide/validate-files.rst +++ b/docs/user-guide/validate-files.rst @@ -75,7 +75,7 @@ In particular, you can test if your stories are inconsistent, i.e. if different Here is a more detailed explanation. The purpose of Rasa Core is to predict the correct next bot action, given the dialogue state, that is the history of intents, entities, slots, and actions. -Crucially, Rasa Core assumes that for any given dialogue state, exactly one next action is the correct one. +Crucially, Rasa assumes that for any given dialogue state, exactly one next action is the correct one. If your stories don’t reflect that, Rasa Core cannot learn the correct behaviour. Take, for example, the following two stories: From 9f3c42a1a7fb798b9e3dd2b130a1c29ea8a49424 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 3 Jan 2020 15:02:24 +0100 Subject: [PATCH 092/209] Rasa Core -> Rasa Co-Authored-By: Tom Bocklisch --- docs/user-guide/validate-files.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user-guide/validate-files.rst b/docs/user-guide/validate-files.rst index e5a7b67100a4..6e8c4069bbc1 100644 --- a/docs/user-guide/validate-files.rst +++ b/docs/user-guide/validate-files.rst @@ -76,7 +76,7 @@ Here is a more detailed explanation. The purpose of Rasa Core is to predict the correct next bot action, given the dialogue state, that is the history of intents, entities, slots, and actions. Crucially, Rasa assumes that for any given dialogue state, exactly one next action is the correct one. -If your stories don’t reflect that, Rasa Core cannot learn the correct behaviour. +If your stories don’t reflect that, Rasa cannot learn the correct behaviour. Take, for example, the following two stories: From d7ff1f151258b2dfb3b92e79d40e0eb319f38aa6 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 3 Jan 2020 15:02:46 +0100 Subject: [PATCH 093/209] Replace Rasa Core -> Rasa Co-Authored-By: Tom Bocklisch --- docs/user-guide/validate-files.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/user-guide/validate-files.rst b/docs/user-guide/validate-files.rst index 6e8c4069bbc1..db8b3338ac41 100644 --- a/docs/user-guide/validate-files.rst +++ b/docs/user-guide/validate-files.rst @@ -95,7 +95,9 @@ Take, for example, the following two stories: * inform_happy - utter_goodbye -These two stories are inconsistent, because Rasa Core cannot know if it should predict `utter_happy` or `utter_goodbye` after `inform_happy`, as there is nothing that would distinguish the dialogue states at `inform_happy` in the two stories and the subsequent actions are different in Story 1 and Story 2. +These two stories are inconsistent, because Rasa cannot know if it should predict ``utter_happy`` or ``utter_goodbye`` +after ``inform_happy``, as there is nothing that would distinguish the dialogue states at ``inform_happy`` in the two +stories and the subsequent actions are different in Story 1 and Story 2. This conflict can now be automatically identified with our new story structure tool. Just use `rasa data validate` in the command line, as follows: From fd9f1a4b86d033081c404e053ff6473b63bb4d83 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 3 Jan 2020 15:03:03 +0100 Subject: [PATCH 094/209] Add correct quotes Co-Authored-By: Tom Bocklisch --- docs/user-guide/validate-files.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user-guide/validate-files.rst b/docs/user-guide/validate-files.rst index db8b3338ac41..1592947f643d 100644 --- a/docs/user-guide/validate-files.rst +++ b/docs/user-guide/validate-files.rst @@ -100,7 +100,7 @@ after ``inform_happy``, as there is nothing that would distinguish the dialogue stories and the subsequent actions are different in Story 1 and Story 2. This conflict can now be automatically identified with our new story structure tool. -Just use `rasa data validate` in the command line, as follows: +Just use ``rasa data validate`` in the command line, as follows: .. code-block:: bash From dc9913644674785ebd771df144b4d10494a4b1c5 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 3 Jan 2020 15:04:33 +0100 Subject: [PATCH 095/209] Avoid sys.exit(0) if everything ok Co-Authored-By: Tom Bocklisch --- rasa/cli/data.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 8abd1cf2c1a7..21a15e3960c7 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -190,4 +190,6 @@ def validate_stories(args): not args.fail_on_warnings, max_history=args.max_history ) - sys.exit(0) if everything_is_alright else sys.exit(1) + if not everything_is_alright: + print_error("Story validation completed with errors.") + sys.exit(1) From a03123598eab5eb5013818eb3a76e3d06ef5cfb5 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 3 Jan 2020 17:14:09 +0100 Subject: [PATCH 096/209] Define StoryConflict.__hash__ --- rasa/core/story_conflict.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index 3a63a84ab12e..529b046ccb41 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -13,10 +13,12 @@ def __init__( self, sliced_states: List[Optional[Dict[Text, float]]], ): self.sliced_states = sliced_states - self.hash = hash(str(list(sliced_states))) self._conflicting_actions = {} # {"action": ["story_1", ...], ...} self.correct_response = None + def __hash__(self): + return hash(str(list(self.sliced_states))) + @staticmethod def find_conflicts( trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int From d375c13670c532bd464381e9b16df84b844fd39d Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 3 Jan 2020 17:26:48 +0100 Subject: [PATCH 097/209] Move static functions outside of StoryConflict object --- rasa/core/story_conflict.py | 274 ++++++++++++++++++------------------ rasa/core/validator.py | 4 +- 2 files changed, 139 insertions(+), 139 deletions(-) diff --git a/rasa/core/story_conflict.py b/rasa/core/story_conflict.py index 529b046ccb41..2c847bd5e79d 100644 --- a/rasa/core/story_conflict.py +++ b/rasa/core/story_conflict.py @@ -19,140 +19,6 @@ def __init__( def __hash__(self): return hash(str(list(self.sliced_states))) - @staticmethod - def find_conflicts( - trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int - ) -> List: - """ - Generate a list of StoryConflict objects, describing - conflicts in the given trackers. - :param trackers: Trackers in which to search for conflicts - :param domain: The domain - :param max_history: The maximum history length to be - taken into account - :return: List of conflicts - """ - - # We do this in two steps, to reduce memory consumption: - - # Create a 'state -> list of actions' dict, where the state is - # represented by its hash - rules = StoryConflict._find_conflicting_states(trackers, domain, max_history) - - # Iterate once more over all states and note the (unhashed) state, - # for which a conflict occurs - conflicts = StoryConflict._build_conflicts_from_states( - trackers, domain, max_history, rules - ) - - return conflicts - - @staticmethod - def _find_conflicting_states( - trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int - ) -> Dict[int, Optional[List[Text]]]: - # Create a 'state -> list of actions' dict, where the state is - # represented by its hash - rules = {} - for tracker, event, sliced_states in StoryConflict._sliced_states_iterator( - trackers, domain, max_history - ): - h = hash(str(list(sliced_states))) - if h not in rules: - rules[h] = [event.as_story_string()] - elif h in rules and event.as_story_string() not in rules[h]: - rules[h] += [event.as_story_string()] - - # Keep only conflicting rules - return { - state_hash: actions - for (state_hash, actions) in rules.items() - if len(actions) > 1 - } - - @staticmethod - def _build_conflicts_from_states( - trackers: List["TrackerWithCachedStates"], - domain: Domain, - max_history: int, - rules: Dict[int, Optional[List[Text]]], - ): - # Iterate once more over all states and note the (unhashed) state, - # for which a conflict occurs - conflicts = {} - for tracker, event, sliced_states in StoryConflict._sliced_states_iterator( - trackers, domain, max_history - ): - h = hash(str(list(sliced_states))) - - if h in rules and h not in conflicts: - conflicts[h] = StoryConflict(sliced_states) - - if h in rules: - conflicts[h].add_conflicting_action( - action=event.as_story_string(), story_name=tracker.sender_id - ) - - # Remove conflicts that arise from unpredictable actions - return [c for (h, c) in conflicts.items() if c.has_prior_events] - - @staticmethod - def _sliced_states_iterator( - trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int - ) -> (TrackerWithCachedStates, Event, List[Dict[Text, float]]): - """ - Iterate over all given trackers and all sliced states within - each tracker, where the slicing is based on `max_history` - :param trackers: List of trackers - :param domain: Domain (used for tracker.past_states) - :param max_history: Assumed `max_history` value for slicing - :return: Yields (tracker, event, sliced_states) triplet - """ - for tracker in trackers: - states = tracker.past_states(domain) - states = [ - dict(state) for state in states - ] # From rasa/core/featurizers.py:318 - - idx = 0 - for event in tracker.events: - if isinstance(event, ActionExecuted): - sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( - states[: idx + 1], max_history - ) - yield tracker, event, sliced_states - idx += 1 - - @staticmethod - def _get_prev_event( - state: Optional[Dict[Text, float]] - ) -> [Optional[Text], Optional[Text]]: - """ - Returns the type and name of the event (action or intent) previous to the - given state - :param state: Element of sliced states - :return: (type, name) strings of the prior event - """ - prev_event_type = None - prev_event_name = None - - if not state: - return prev_event_type, prev_event_name - - for k in state: - if ( - k.startswith(PREV_PREFIX) - and k[len(PREV_PREFIX) :] != ACTION_LISTEN_NAME - ): - prev_event_type = "action" - prev_event_name = k[len(PREV_PREFIX) :] - - if not prev_event_type and k.startswith(MESSAGE_INTENT_ATTRIBUTE + "_"): - prev_event_type = "intent" - prev_event_name = k[len(MESSAGE_INTENT_ATTRIBUTE + "_") :] - - return prev_event_type, prev_event_name - def add_conflicting_action(self, action: Text, story_name: Text): """ Add another action that follows from the same state @@ -207,7 +73,7 @@ def has_prior_events(self) -> bool: Returns True iff anything has happened before this conflict. """ - return self._get_prev_event(self.sliced_states[-1])[0] is not None + return _get_prev_event(self.sliced_states[-1])[0] is not None def story_prior_to_conflict(self) -> Text: """ @@ -219,7 +85,7 @@ def story_prior_to_conflict(self) -> Text: if not state: continue - event_type, event_name = self._get_prev_event(state) + event_type, event_name = _get_prev_event(state) if event_type == "intent": result += f"* {event_name}\n" else: @@ -229,7 +95,7 @@ def story_prior_to_conflict(self) -> Text: def __str__(self): # Describe where the conflict occurs in the stories - last_event_type, last_event_name = self._get_prev_event(self.sliced_states[-1]) + last_event_type, last_event_name = _get_prev_event(self.sliced_states[-1]) if last_event_type: conflict_string = f"CONFLICT after {last_event_type} '{last_event_name}':\n" else: @@ -252,3 +118,137 @@ def __str__(self): conflict_string += f" {action} predicted in {story_desc}\n" return conflict_string + + +def find_story_conflicts( + trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int +) -> List[StoryConflict]: + """ + Generate a list of StoryConflict objects, describing + conflicts in the given trackers. + :param trackers: Trackers in which to search for conflicts + :param domain: The domain + :param max_history: The maximum history length to be + taken into account + :return: List of conflicts + """ + + # We do this in two steps, to reduce memory consumption: + + # Create a 'state -> list of actions' dict, where the state is + # represented by its hash + rules = _find_conflicting_states(trackers, domain, max_history) + + # Iterate once more over all states and note the (unhashed) state, + # for which a conflict occurs + conflicts = _build_conflicts_from_states( + trackers, domain, max_history, rules + ) + + return conflicts + + +def _find_conflicting_states( + trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int +) -> Dict[int, Optional[List[Text]]]: + # Create a 'state -> list of actions' dict, where the state is + # represented by its hash + rules = {} + for tracker, event, sliced_states in _sliced_states_iterator( + trackers, domain, max_history + ): + h = hash(str(list(sliced_states))) + if h not in rules: + rules[h] = [event.as_story_string()] + elif h in rules and event.as_story_string() not in rules[h]: + rules[h] += [event.as_story_string()] + + # Keep only conflicting rules + return { + state_hash: actions + for (state_hash, actions) in rules.items() + if len(actions) > 1 + } + + +def _build_conflicts_from_states( + trackers: List["TrackerWithCachedStates"], + domain: Domain, + max_history: int, + rules: Dict[int, Optional[List[Text]]], +): + # Iterate once more over all states and note the (unhashed) state, + # for which a conflict occurs + conflicts = {} + for tracker, event, sliced_states in _sliced_states_iterator( + trackers, domain, max_history + ): + h = hash(str(list(sliced_states))) + + if h in rules and h not in conflicts: + conflicts[h] = StoryConflict(sliced_states) + + if h in rules: + conflicts[h].add_conflicting_action( + action=event.as_story_string(), story_name=tracker.sender_id + ) + + # Remove conflicts that arise from unpredictable actions + return [c for (h, c) in conflicts.items() if c.has_prior_events] + + +def _sliced_states_iterator( + trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int +) -> (TrackerWithCachedStates, Event, List[Dict[Text, float]]): + """ + Iterate over all given trackers and all sliced states within + each tracker, where the slicing is based on `max_history` + :param trackers: List of trackers + :param domain: Domain (used for tracker.past_states) + :param max_history: Assumed `max_history` value for slicing + :return: Yields (tracker, event, sliced_states) triplet + """ + for tracker in trackers: + states = tracker.past_states(domain) + states = [ + dict(state) for state in states + ] # From rasa/core/featurizers.py:318 + + idx = 0 + for event in tracker.events: + if isinstance(event, ActionExecuted): + sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( + states[: idx + 1], max_history + ) + yield tracker, event, sliced_states + idx += 1 + + +def _get_prev_event( + state: Optional[Dict[Text, float]] +) -> [Optional[Text], Optional[Text]]: + """ + Returns the type and name of the event (action or intent) previous to the + given state + :param state: Element of sliced states + :return: (type, name) strings of the prior event + """ + prev_event_type = None + prev_event_name = None + + if not state: + return prev_event_type, prev_event_name + + for k in state: + if ( + k.startswith(PREV_PREFIX) + and k[len(PREV_PREFIX) :] != ACTION_LISTEN_NAME + ): + prev_event_type = "action" + prev_event_name = k[len(PREV_PREFIX) :] + + if not prev_event_type and k.startswith(MESSAGE_INTENT_ATTRIBUTE + "_"): + prev_event_type = "intent" + prev_event_name = k[len(MESSAGE_INTENT_ATTRIBUTE + "_") :] + + return prev_event_type, prev_event_name diff --git a/rasa/core/validator.py b/rasa/core/validator.py index df8fbb810ac5..91e0d650785d 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -10,7 +10,7 @@ from rasa.core.training.dsl import UserUttered from rasa.core.training.dsl import ActionExecuted from rasa.core.constants import UTTER_PREFIX -from rasa.core.story_conflict import StoryConflict +from rasa.core.story_conflict import find_story_conflicts logger = logging.getLogger(__name__) @@ -204,7 +204,7 @@ def verify_story_structure( ).generate() # Create a list of `StoryConflict` objects - conflicts = StoryConflict.find_conflicts(trackers, self.domain, max_history) + conflicts = find_story_conflicts(trackers, self.domain, max_history) if len(conflicts) == 0: logger.info("No story structure conflicts found") From 67365c2a68b094da3556448137b67c328bc8148d Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 3 Jan 2020 17:29:57 +0100 Subject: [PATCH 098/209] Use raise argparse.ArgumentError in validate_stories --- rasa/cli/data.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 8abd1cf2c1a7..f26abf6b2643 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -173,8 +173,7 @@ def validate_stories(args): # Check if a valid setting for `max_history` was given if not isinstance(args.max_history, int) or args.max_history < 1: - logger.error("You have to provide a positive integer for --max-history.") - sys.exit(1) + raise argparse.ArgumentError("You have to provide a positive integer for --max-history.") # Prepare story and domain file import loop = asyncio.get_event_loop() From 8a40e004364325be2a03e6490f657c96604f8c5d Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 3 Jan 2020 17:31:24 +0100 Subject: [PATCH 099/209] Import missing print_error --- rasa/cli/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 499f38877504..ed2bece2b840 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -6,7 +6,7 @@ from rasa import data from rasa.cli.arguments import data as arguments -from rasa.cli.utils import get_validated_path +from rasa.cli.utils import get_validated_path, print_error from rasa.constants import DEFAULT_DATA_PATH from typing import NoReturn From 35cca6ffc2e8a09ce4d1b51bd3d8d0c9b55fe185 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 3 Jan 2020 17:32:55 +0100 Subject: [PATCH 100/209] Move story_conflict.py to core.training --- rasa/core/{ => training}/story_conflict.py | 0 rasa/core/validator.py | 2 +- tests/core/test_storyconflict.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename rasa/core/{ => training}/story_conflict.py (100%) diff --git a/rasa/core/story_conflict.py b/rasa/core/training/story_conflict.py similarity index 100% rename from rasa/core/story_conflict.py rename to rasa/core/training/story_conflict.py diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 91e0d650785d..ceaeffdf3243 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -10,7 +10,7 @@ from rasa.core.training.dsl import UserUttered from rasa.core.training.dsl import ActionExecuted from rasa.core.constants import UTTER_PREFIX -from rasa.core.story_conflict import find_story_conflicts +from rasa.core.training.story_conflict import find_story_conflicts logger = logging.getLogger(__name__) diff --git a/tests/core/test_storyconflict.py b/tests/core/test_storyconflict.py index c67c73fd4aab..f7be6fc1b0e6 100644 --- a/tests/core/test_storyconflict.py +++ b/tests/core/test_storyconflict.py @@ -1,4 +1,4 @@ -from rasa.core.story_conflict import StoryConflict +from rasa.core.training.story_conflict import StoryConflict from rasa.core.training.generator import TrainingDataGenerator from rasa.core.validator import Validator from rasa.importers.rasa import RasaFileImporter From 3ee76e241c2dcdc978568ec1a2b63bb628e85d5c Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 3 Jan 2020 17:35:06 +0100 Subject: [PATCH 101/209] Fix tests to use find_story_conflicts --- tests/core/test_storyconflict.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/core/test_storyconflict.py b/tests/core/test_storyconflict.py index f7be6fc1b0e6..090e8016e332 100644 --- a/tests/core/test_storyconflict.py +++ b/tests/core/test_storyconflict.py @@ -1,4 +1,4 @@ -from rasa.core.training.story_conflict import StoryConflict +from rasa.core.training.story_conflict import StoryConflict, find_story_conflicts from rasa.core.training.generator import TrainingDataGenerator from rasa.core.validator import Validator from rasa.importers.rasa import RasaFileImporter @@ -20,7 +20,7 @@ async def test_find_no_conflicts(): ).generate() # Create a list of `StoryConflict` objects - conflicts = StoryConflict.find_conflicts(trackers, validator.domain, 5) + conflicts = find_story_conflicts(trackers, validator.domain, 5) assert conflicts == [] @@ -40,11 +40,11 @@ async def test_find_conflicts_in_short_history(): ).generate() # `max_history = 3` is too small, so a conflict must arise - conflicts = StoryConflict.find_conflicts(trackers, validator.domain, 3) + conflicts = find_story_conflicts(trackers, validator.domain, 3) assert len(conflicts) == 1 # With `max_history = 4` the conflict should disappear - conflicts = StoryConflict.find_conflicts(trackers, validator.domain, 4) + conflicts = find_story_conflicts(trackers, validator.domain, 4) assert len(conflicts) == 0 @@ -63,7 +63,7 @@ async def test_find_conflicts_checkpoints(): ).generate() # Create a list of `StoryConflict` objects - conflicts = StoryConflict.find_conflicts(trackers, validator.domain, 5) + conflicts = find_story_conflicts(trackers, validator.domain, 5) assert len(conflicts) == 1 assert conflicts[0].conflicting_actions == ["utter_goodbye", "utter_default"] From bc181fbc633a7d9919c8d1304c527190b4e92eb6 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 3 Jan 2020 17:39:58 +0100 Subject: [PATCH 102/209] Fix arguments of argparse.ArgumentError --- rasa/cli/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index ed2bece2b840..5c244b38eaa0 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -173,7 +173,7 @@ def validate_stories(args): # Check if a valid setting for `max_history` was given if not isinstance(args.max_history, int) or args.max_history < 1: - raise argparse.ArgumentError("You have to provide a positive integer for --max-history.") + raise argparse.ArgumentError(args.max_history, "You have to provide a positive integer for --max-history.") # Prepare story and domain file import loop = asyncio.get_event_loop() From 7e77e0fc2e34b00329027f203a0427874d034d45 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 10 Jan 2020 16:16:08 +0100 Subject: [PATCH 103/209] Remove empty line Co-Authored-By: Tobias Wochinger --- docs/user-guide/validate-files.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/user-guide/validate-files.rst b/docs/user-guide/validate-files.rst index 1592947f643d..c05f936b0a5e 100644 --- a/docs/user-guide/validate-files.rst +++ b/docs/user-guide/validate-files.rst @@ -66,7 +66,6 @@ To use these functions it is necessary to create a `Validator` object and initia validator.verify_all() - Test Story Files for Conflicts ------------------------------ From 6b70c51eb16d57748fd8774293a0e44316253f21 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 10 Jan 2020 16:17:17 +0100 Subject: [PATCH 104/209] Specify return type Co-Authored-By: Tobias Wochinger --- rasa/cli/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 5c244b38eaa0..4dce739fca20 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -33,7 +33,7 @@ def add_subparser( _add_data_validate_parsers(data_subparsers, parents) -def _add_data_convert_parsers(data_subparsers, parents: List[argparse.ArgumentParser]): +def _add_data_convert_parsers(data_subparsers, parents: List[argparse.ArgumentParser]) -> None: import rasa.nlu.convert as convert convert_parser = data_subparsers.add_parser( From 3d0400ded1d17b804c4da9cb797a409dbbbc9e15 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 10 Jan 2020 16:17:43 +0100 Subject: [PATCH 105/209] Specify return type Co-Authored-By: Tobias Wochinger --- rasa/cli/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 4dce739fca20..ffbaaf5e15b5 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -56,7 +56,7 @@ def _add_data_convert_parsers(data_subparsers, parents: List[argparse.ArgumentPa arguments.set_convert_arguments(convert_nlu_parser) -def _add_data_split_parsers(data_subparsers, parents: List[argparse.ArgumentParser]): +def _add_data_split_parsers(data_subparsers, parents: List[argparse.ArgumentParser]) -> None: split_parser = data_subparsers.add_parser( "split", formatter_class=argparse.ArgumentDefaultsHelpFormatter, From 8a801844f0527acd4d88ac46cfc9037adc9fb984 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 10 Jan 2020 16:18:01 +0100 Subject: [PATCH 106/209] Specify return type Co-Authored-By: Tobias Wochinger --- rasa/cli/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index ffbaaf5e15b5..dd51a6ad9a6e 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -78,7 +78,7 @@ def _add_data_split_parsers(data_subparsers, parents: List[argparse.ArgumentPars arguments.set_split_arguments(nlu_split_parser) -def _add_data_validate_parsers(data_subparsers, parents: List[argparse.ArgumentParser]): +def _add_data_validate_parsers(data_subparsers, parents: List[argparse.ArgumentParser]) -> None: validate_parser = data_subparsers.add_parser( "validate", formatter_class=argparse.ArgumentDefaultsHelpFormatter, From 7965ee6d94be7b3520df2ee51a5857d58d418fd5 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 10 Jan 2020 16:18:33 +0100 Subject: [PATCH 107/209] Clarify error message Co-Authored-By: Tobias Wochinger --- rasa/cli/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index dd51a6ad9a6e..e03404ecce50 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -153,7 +153,7 @@ def validate_files(args) -> NoReturn: if not args.max_history: logger.info( "Will not test for inconsistencies in stories since " - "you did not provide --max-history." + "you did not provide a value for `--max-history`." ) if args.max_history: # Only run story structure validation if everything else is fine From cc230aaee8a9443252c2d8fe76f92e7c1e02242e Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 10:34:06 +0100 Subject: [PATCH 108/209] Clarify text --- docs/user-guide/validate-files.rst | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/docs/user-guide/validate-files.rst b/docs/user-guide/validate-files.rst index c05f936b0a5e..cbf8bc4b7a33 100644 --- a/docs/user-guide/validate-files.rst +++ b/docs/user-guide/validate-files.rst @@ -70,12 +70,8 @@ Test Story Files for Conflicts ------------------------------ In addition to the default tests described above, you can also do a more in-depth structural test of your stories. -In particular, you can test if your stories are inconsistent, i.e. if different bot actions follow after the same dialogue history. -Here is a more detailed explanation. - -The purpose of Rasa Core is to predict the correct next bot action, given the dialogue state, that is the history of intents, entities, slots, and actions. -Crucially, Rasa assumes that for any given dialogue state, exactly one next action is the correct one. -If your stories don’t reflect that, Rasa cannot learn the correct behaviour. +In particular, you can test if your stories are inconsistent, i.e. if different bot actions follow from the same dialogue history. +If this is not the case, then Rasa cannot learn the correct behaviour. Take, for example, the following two stories: From d2e37725ef5d8668d374f9f72b4f926016336441 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 10:54:48 +0100 Subject: [PATCH 109/209] Fix typo Co-Authored-By: Tobias Wochinger --- docs/user-guide/validate-files.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user-guide/validate-files.rst b/docs/user-guide/validate-files.rst index c05f936b0a5e..78360086e4b8 100644 --- a/docs/user-guide/validate-files.rst +++ b/docs/user-guide/validate-files.rst @@ -112,7 +112,7 @@ Just use ``rasa data validate`` in the command line, as follows: > utter_happy predicted in 'Story 1' Here we specify a `max-history` value of 3. -This means, that 3 events (user / bot actions) are taken into account for action prediction, but the particular setting does not matter for this example, because regardless of how long of a history you take into account, the conflict always exists. +This means, that 3 events (user / bot actions) are taken into account for action predictions, but the particular setting does not matter for this example, because regardless of how long of a history you take into account, the conflict always exists. .. warning:: The `rasa data validate stories` script assumes that all your **story names are unique**. From 25238dbe5bc24f510dcc1f275536253c9e063ea7 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 10:57:57 +0100 Subject: [PATCH 110/209] Clarify code in conflicting_actions_with_counts --- rasa/core/training/story_conflict.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 2c847bd5e79d..be7629c6e6fb 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -44,7 +44,10 @@ def conflicting_actions_with_counts(self) -> List[Text]: Returns a list of strings, describing what action occurs how often """ - return [f"{a} [{len(s)}x]" for (a, s) in self._conflicting_actions.items()] + return [ + f"{action} [{len(stories)}x]" + for (action, stories) in self._conflicting_actions.items() + ] @property def incorrect_stories(self) -> List[Text]: @@ -141,9 +144,7 @@ def find_story_conflicts( # Iterate once more over all states and note the (unhashed) state, # for which a conflict occurs - conflicts = _build_conflicts_from_states( - trackers, domain, max_history, rules - ) + conflicts = _build_conflicts_from_states(trackers, domain, max_history, rules) return conflicts @@ -210,9 +211,7 @@ def _sliced_states_iterator( """ for tracker in trackers: states = tracker.past_states(domain) - states = [ - dict(state) for state in states - ] # From rasa/core/featurizers.py:318 + states = [dict(state) for state in states] # From rasa/core/featurizers.py:318 idx = 0 for event in tracker.events: @@ -240,10 +239,7 @@ def _get_prev_event( return prev_event_type, prev_event_name for k in state: - if ( - k.startswith(PREV_PREFIX) - and k[len(PREV_PREFIX) :] != ACTION_LISTEN_NAME - ): + if k.startswith(PREV_PREFIX) and k[len(PREV_PREFIX) :] != ACTION_LISTEN_NAME: prev_event_type = "action" prev_event_name = k[len(PREV_PREFIX) :] From 62dccd7f75f6b742ff2a33e2fe6ef58fa9daf4ed Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 11:06:52 +0100 Subject: [PATCH 111/209] Remove unused code --- rasa/core/training/story_conflict.py | 50 ---------------------------- tests/core/test_storyconflict.py | 14 -------- 2 files changed, 64 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index be7629c6e6fb..8aadafb06fea 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -38,38 +38,6 @@ def conflicting_actions(self) -> List[Text]: """ return list(self._conflicting_actions.keys()) - @property - def conflicting_actions_with_counts(self) -> List[Text]: - """ - Returns a list of strings, describing what action - occurs how often - """ - return [ - f"{action} [{len(stories)}x]" - for (action, stories) in self._conflicting_actions.items() - ] - - @property - def incorrect_stories(self) -> List[Text]: - """ - Returns a list of story names that have not yet been - corrected. - """ - if not self.correct_response: - # Return all stories - return [v[0] for v in self._conflicting_actions.values()] - - incorrect_stories = [] - story_lists_with_uncorrected_responses = [ - s - for (a, s) in self._conflicting_actions.items() - if a != self.correct_response - ] - for stories in story_lists_with_uncorrected_responses: - for story in stories: - incorrect_stories.append(story) - return incorrect_stories - @property def has_prior_events(self) -> bool: """ @@ -78,24 +46,6 @@ def has_prior_events(self) -> bool: """ return _get_prev_event(self.sliced_states[-1])[0] is not None - def story_prior_to_conflict(self) -> Text: - """ - Generates a story string, describing the events that - lead up to the conflict. - """ - result = "" - for state in self.sliced_states: - if not state: - continue - - event_type, event_name = _get_prev_event(state) - if event_type == "intent": - result += f"* {event_name}\n" - else: - result += f" - {event_name}\n" - - return result - def __str__(self): # Describe where the conflict occurs in the stories last_event_type, last_event_name = _get_prev_event(self.sliced_states[-1]) diff --git a/tests/core/test_storyconflict.py b/tests/core/test_storyconflict.py index 090e8016e332..be236fdaf6b1 100644 --- a/tests/core/test_storyconflict.py +++ b/tests/core/test_storyconflict.py @@ -82,7 +82,6 @@ async def test_add_conflicting_action(): conflict.add_conflicting_action("utter_greet", "xyz") conflict.add_conflicting_action("utter_default", "uvw") assert conflict.conflicting_actions == ["utter_greet", "utter_default"] - assert conflict.incorrect_stories == ["xyz", "uvw"] async def test_has_prior_events(): @@ -102,16 +101,3 @@ async def test_has_no_prior_events(): sliced_states = [None] conflict = StoryConflict(sliced_states) assert not conflict.has_prior_events - - -async def test_story_prior_to_conflict(): - - story = "* greet\n - utter_greet\n" - sliced_states = [ - None, - {}, - {"intent_greet": 1.0, "prev_action_listen": 1.0}, - {"prev_utter_greet": 1.0, "intent_greet": 1.0}, - ] - conflict = StoryConflict(sliced_states) - assert conflict.story_prior_to_conflict() == story From 29f572d50f52c8c6107289bc7f2480a0608247a3 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 11:11:16 +0100 Subject: [PATCH 112/209] Declare output types --- rasa/core/training/story_conflict.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 8aadafb06fea..3a26ebb92dc1 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -11,15 +11,15 @@ class StoryConflict: def __init__( self, sliced_states: List[Optional[Dict[Text, float]]], - ): + ) -> None: self.sliced_states = sliced_states self._conflicting_actions = {} # {"action": ["story_1", ...], ...} self.correct_response = None - def __hash__(self): + def __hash__(self) -> int: return hash(str(list(self.sliced_states))) - def add_conflicting_action(self, action: Text, story_name: Text): + def add_conflicting_action(self, action: Text, story_name: Text) -> None: """ Add another action that follows from the same state :param action: Name of the action @@ -46,7 +46,7 @@ def has_prior_events(self) -> bool: """ return _get_prev_event(self.sliced_states[-1])[0] is not None - def __str__(self): + def __str__(self) -> Text: # Describe where the conflict occurs in the stories last_event_type, last_event_name = _get_prev_event(self.sliced_states[-1]) if last_event_type: @@ -127,7 +127,7 @@ def _build_conflicts_from_states( domain: Domain, max_history: int, rules: Dict[int, Optional[List[Text]]], -): +) -> List["StoryConflict"]: # Iterate once more over all states and note the (unhashed) state, # for which a conflict occurs conflicts = {} From 6e26d084b4125f1fde62807e4aecb16e9e73fb56 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 13:09:39 +0100 Subject: [PATCH 113/209] Reformat doc strings --- rasa/core/training/story_conflict.py | 72 +++++++++++++++++----------- 1 file changed, 43 insertions(+), 29 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 3a26ebb92dc1..17ead743ba8d 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -20,11 +20,11 @@ def __hash__(self) -> int: return hash(str(list(self.sliced_states))) def add_conflicting_action(self, action: Text, story_name: Text) -> None: - """ - Add another action that follows from the same state - :param action: Name of the action - :param story_name: Name of the story where this action - is chosen + """Adds another action that follows from the same state. + + Args: + action: Name of the action. + story_name: Name of the story where this action is chosen. """ if action not in self._conflicting_actions: self._conflicting_actions[action] = [story_name] @@ -33,16 +33,20 @@ def add_conflicting_action(self, action: Text, story_name: Text) -> None: @property def conflicting_actions(self) -> List[Text]: - """ - Returns the list of conflicting actions + """List of conflicting actions. + + Returns: + List of conflicting actions. + """ return list(self._conflicting_actions.keys()) @property def has_prior_events(self) -> bool: - """ - Returns True iff anything has happened before this - conflict. + """Checks if prior events exist. + + Returns: + True if anything has happened before this conflict, otherwise False. """ return _get_prev_event(self.sliced_states[-1])[0] is not None @@ -76,14 +80,14 @@ def __str__(self) -> Text: def find_story_conflicts( trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int ) -> List[StoryConflict]: - """ - Generate a list of StoryConflict objects, describing - conflicts in the given trackers. - :param trackers: Trackers in which to search for conflicts - :param domain: The domain - :param max_history: The maximum history length to be - taken into account - :return: List of conflicts + """Generates a list of StoryConflict objects, describing conflicts in the given trackers. + + Args: + trackers: Trackers in which to search for conflicts + domain: The domain + max_history: The maximum history length to be taken into account + Returns: + List of conflicts """ # We do this in two steps, to reduce memory consumption: @@ -151,13 +155,18 @@ def _build_conflicts_from_states( def _sliced_states_iterator( trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int ) -> (TrackerWithCachedStates, Event, List[Dict[Text, float]]): - """ - Iterate over all given trackers and all sliced states within - each tracker, where the slicing is based on `max_history` - :param trackers: List of trackers - :param domain: Domain (used for tracker.past_states) - :param max_history: Assumed `max_history` value for slicing - :return: Yields (tracker, event, sliced_states) triplet + """Creates an iterator over sliced states. + + Iterate over all given trackers and all sliced states within each tracker, + where the slicing is based on `max_history`. + + Args: + trackers: List of trackers + domain: Domain (used for tracker.past_states) + max_history: Assumed `max_history` value for slicing + + Yields: + A (tracker, event, sliced_states) triplet """ for tracker in trackers: states = tracker.past_states(domain) @@ -176,11 +185,16 @@ def _sliced_states_iterator( def _get_prev_event( state: Optional[Dict[Text, float]] ) -> [Optional[Text], Optional[Text]]: - """ + """Returns previous event type and name. + Returns the type and name of the event (action or intent) previous to the - given state - :param state: Element of sliced states - :return: (type, name) strings of the prior event + given state. + + Args: + state: Element of sliced states. + + Returns: + Tuple of (type, name) strings of the prior event. """ prev_event_type = None prev_event_name = None From e917ea62d28abdde41319f39c891566dbc66a318 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 14:59:14 +0100 Subject: [PATCH 114/209] Rename variables for clarity --- rasa/core/training/story_conflict.py | 34 ++++++++++++++-------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 17ead743ba8d..064166ae13bd 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -94,11 +94,11 @@ def find_story_conflicts( # Create a 'state -> list of actions' dict, where the state is # represented by its hash - rules = _find_conflicting_states(trackers, domain, max_history) + state_action_dict = _find_conflicting_states(trackers, domain, max_history) # Iterate once more over all states and note the (unhashed) state, # for which a conflict occurs - conflicts = _build_conflicts_from_states(trackers, domain, max_history, rules) + conflicts = _build_conflicts_from_states(trackers, domain, max_history, state_action_dict) return conflicts @@ -108,20 +108,20 @@ def _find_conflicting_states( ) -> Dict[int, Optional[List[Text]]]: # Create a 'state -> list of actions' dict, where the state is # represented by its hash - rules = {} + state_action_dict = {} for tracker, event, sliced_states in _sliced_states_iterator( trackers, domain, max_history ): - h = hash(str(list(sliced_states))) - if h not in rules: - rules[h] = [event.as_story_string()] - elif h in rules and event.as_story_string() not in rules[h]: - rules[h] += [event.as_story_string()] + hashed_state = hash(str(list(sliced_states))) + if hashed_state not in state_action_dict: + state_action_dict[hashed_state] = [event.as_story_string()] + elif hashed_state in state_action_dict and event.as_story_string() not in state_action_dict[hashed_state]: + state_action_dict[hashed_state] += [event.as_story_string()] - # Keep only conflicting rules + # Keep only conflicting `state_action_dict`s return { state_hash: actions - for (state_hash, actions) in rules.items() + for (state_hash, actions) in state_action_dict.items() if len(actions) > 1 } @@ -138,18 +138,18 @@ def _build_conflicts_from_states( for tracker, event, sliced_states in _sliced_states_iterator( trackers, domain, max_history ): - h = hash(str(list(sliced_states))) + hashed_state = hash(str(list(sliced_states))) - if h in rules and h not in conflicts: - conflicts[h] = StoryConflict(sliced_states) + if hashed_state in rules and hashed_state not in conflicts: + conflicts[hashed_state] = StoryConflict(sliced_states) - if h in rules: - conflicts[h].add_conflicting_action( + if hashed_state in rules: + conflicts[hashed_state].add_conflicting_action( action=event.as_story_string(), story_name=tracker.sender_id ) # Remove conflicts that arise from unpredictable actions - return [c for (h, c) in conflicts.items() if c.has_prior_events] + return [conflict for (hashed_state, conflict) in conflicts.items() if conflict.has_prior_events] def _sliced_states_iterator( @@ -170,7 +170,7 @@ def _sliced_states_iterator( """ for tracker in trackers: states = tracker.past_states(domain) - states = [dict(state) for state in states] # From rasa/core/featurizers.py:318 + states = [dict(state) for state in states] idx = 0 for event in tracker.events: From 8141bcc59782e9e45720a372001b9c6e11798af2 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 15:10:07 +0100 Subject: [PATCH 115/209] Use `defaultdict(list)` to simplify code --- rasa/core/training/story_conflict.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 064166ae13bd..c06ab78ee652 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -1,3 +1,4 @@ +from collections import defaultdict from typing import List, Optional, Dict, Text from rasa.core.actions.action import ACTION_LISTEN_NAME @@ -13,7 +14,7 @@ def __init__( self, sliced_states: List[Optional[Dict[Text, float]]], ) -> None: self.sliced_states = sliced_states - self._conflicting_actions = {} # {"action": ["story_1", ...], ...} + self._conflicting_actions = defaultdict(list) # {"action": ["story_1", ...], ...} self.correct_response = None def __hash__(self) -> int: @@ -26,10 +27,7 @@ def add_conflicting_action(self, action: Text, story_name: Text) -> None: action: Name of the action. story_name: Name of the story where this action is chosen. """ - if action not in self._conflicting_actions: - self._conflicting_actions[action] = [story_name] - else: - self._conflicting_actions[action] += [story_name] + self._conflicting_actions[action] += [story_name] @property def conflicting_actions(self) -> List[Text]: @@ -108,14 +106,12 @@ def _find_conflicting_states( ) -> Dict[int, Optional[List[Text]]]: # Create a 'state -> list of actions' dict, where the state is # represented by its hash - state_action_dict = {} + state_action_dict = defaultdict(list) for tracker, event, sliced_states in _sliced_states_iterator( trackers, domain, max_history ): hashed_state = hash(str(list(sliced_states))) - if hashed_state not in state_action_dict: - state_action_dict[hashed_state] = [event.as_story_string()] - elif hashed_state in state_action_dict and event.as_story_string() not in state_action_dict[hashed_state]: + if event.as_story_string() not in state_action_dict[hashed_state]: state_action_dict[hashed_state] += [event.as_story_string()] # Keep only conflicting `state_action_dict`s From 53c2b2ac2a7c41c84f17f666ff3a8b05caba33c2 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 15:17:31 +0100 Subject: [PATCH 116/209] Declare types in `validate_*` functions --- rasa/cli/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index e03404ecce50..5161634b1478 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -132,7 +132,7 @@ def split_nlu_data(args) -> None: test.persist(args.out, filename=f"test_data.{fformat}") -def validate_files(args) -> NoReturn: +def validate_files(args: "argparse.Namespace") -> NoReturn: """Validate all files needed for training a model. Fails with a non-zero exit code if there are any errors in the data.""" @@ -164,7 +164,7 @@ def validate_files(args) -> NoReturn: sys.exit(0) if everything_is_alright else sys.exit(1) -def validate_stories(args): +def validate_stories(args: "argparse.Namespace") -> NoReturn: """Validate all files needed for training a model. Fails with a non-zero exit code if there are any errors in the data.""" From c55a276071480669bb1e3e8a1dc013603aae493e Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 15:49:01 +0100 Subject: [PATCH 117/209] Define and use TrackerEventStateTuple --- rasa/core/training/story_conflict.py | 29 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index c06ab78ee652..2510eae34ddd 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -1,4 +1,4 @@ -from collections import defaultdict +from collections import defaultdict, namedtuple from typing import List, Optional, Dict, Text from rasa.core.actions.action import ACTION_LISTEN_NAME @@ -75,6 +75,9 @@ def __str__(self) -> Text: return conflict_string +TrackerEventStateTuple = namedtuple("TrackerEventStateTuple", "tracker event sliced_states") + + def find_story_conflicts( trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int ) -> List[StoryConflict]: @@ -107,12 +110,10 @@ def _find_conflicting_states( # Create a 'state -> list of actions' dict, where the state is # represented by its hash state_action_dict = defaultdict(list) - for tracker, event, sliced_states in _sliced_states_iterator( - trackers, domain, max_history - ): - hashed_state = hash(str(list(sliced_states))) - if event.as_story_string() not in state_action_dict[hashed_state]: - state_action_dict[hashed_state] += [event.as_story_string()] + for element in _sliced_states_iterator(trackers, domain, max_history): + hashed_state = hash(str(list(element.sliced_states))) + if element.event.as_story_string() not in state_action_dict[hashed_state]: + state_action_dict[hashed_state] += [element.event.as_story_string()] # Keep only conflicting `state_action_dict`s return { @@ -131,17 +132,15 @@ def _build_conflicts_from_states( # Iterate once more over all states and note the (unhashed) state, # for which a conflict occurs conflicts = {} - for tracker, event, sliced_states in _sliced_states_iterator( - trackers, domain, max_history - ): - hashed_state = hash(str(list(sliced_states))) + for element in _sliced_states_iterator(trackers, domain, max_history): + hashed_state = hash(str(list(element.sliced_states))) if hashed_state in rules and hashed_state not in conflicts: - conflicts[hashed_state] = StoryConflict(sliced_states) + conflicts[hashed_state] = StoryConflict(element.sliced_states) if hashed_state in rules: conflicts[hashed_state].add_conflicting_action( - action=event.as_story_string(), story_name=tracker.sender_id + action=element.event.as_story_string(), story_name=element.tracker.sender_id ) # Remove conflicts that arise from unpredictable actions @@ -150,7 +149,7 @@ def _build_conflicts_from_states( def _sliced_states_iterator( trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int -) -> (TrackerWithCachedStates, Event, List[Dict[Text, float]]): +) -> TrackerEventStateTuple: """Creates an iterator over sliced states. Iterate over all given trackers and all sliced states within each tracker, @@ -174,7 +173,7 @@ def _sliced_states_iterator( sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history( states[: idx + 1], max_history ) - yield tracker, event, sliced_states + yield TrackerEventStateTuple(tracker, event, sliced_states) idx += 1 From 700e028f26a8fcc321ad1a8268d26c3bee1d056d Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 15:50:14 +0100 Subject: [PATCH 118/209] Rename function to _get_previous_event --- rasa/core/training/story_conflict.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 2510eae34ddd..4be063e4284e 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -46,11 +46,11 @@ def has_prior_events(self) -> bool: Returns: True if anything has happened before this conflict, otherwise False. """ - return _get_prev_event(self.sliced_states[-1])[0] is not None + return _get_previous_event(self.sliced_states[-1])[0] is not None def __str__(self) -> Text: # Describe where the conflict occurs in the stories - last_event_type, last_event_name = _get_prev_event(self.sliced_states[-1]) + last_event_type, last_event_name = _get_previous_event(self.sliced_states[-1]) if last_event_type: conflict_string = f"CONFLICT after {last_event_type} '{last_event_name}':\n" else: @@ -177,7 +177,7 @@ def _sliced_states_iterator( idx += 1 -def _get_prev_event( +def _get_previous_event( state: Optional[Dict[Text, float]] ) -> [Optional[Text], Optional[Text]]: """Returns previous event type and name. From c4c3f8ccc0e4e3bf7346bd75c8e57998e3eedf08 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 15:51:15 +0100 Subject: [PATCH 119/209] Add fullstops to doc strings --- rasa/core/training/story_conflict.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 4be063e4284e..ae48fe861dda 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -84,11 +84,11 @@ def find_story_conflicts( """Generates a list of StoryConflict objects, describing conflicts in the given trackers. Args: - trackers: Trackers in which to search for conflicts - domain: The domain - max_history: The maximum history length to be taken into account + trackers: Trackers in which to search for conflicts. + domain: The domain. + max_history: The maximum history length to be taken into account. Returns: - List of conflicts + List of conflicts. """ # We do this in two steps, to reduce memory consumption: @@ -156,12 +156,12 @@ def _sliced_states_iterator( where the slicing is based on `max_history`. Args: - trackers: List of trackers - domain: Domain (used for tracker.past_states) - max_history: Assumed `max_history` value for slicing + trackers: List of trackers. + domain: Domain (used for tracker.past_states). + max_history: Assumed `max_history` value for slicing. Yields: - A (tracker, event, sliced_states) triplet + A (tracker, event, sliced_states) triplet. """ for tracker in trackers: states = tracker.past_states(domain) From 8f107f79041b0b55dad1b17d9e7c14956b6d6013 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 16:04:43 +0100 Subject: [PATCH 120/209] Simplify StoryConflict.__str__ --- rasa/core/training/story_conflict.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index ae48fe861dda..0a0610b05f1b 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -52,27 +52,26 @@ def __str__(self) -> Text: # Describe where the conflict occurs in the stories last_event_type, last_event_name = _get_previous_event(self.sliced_states[-1]) if last_event_type: - conflict_string = f"CONFLICT after {last_event_type} '{last_event_name}':\n" + conflict_message = f"CONFLICT after {last_event_type} '{last_event_name}':\n" else: - conflict_string = f"CONFLICT at the beginning of stories:\n" + conflict_message = f"CONFLICT at the beginning of stories:\n" # List which stories are in conflict with one another for action, stories in self._conflicting_actions.items(): # Summarize if necessary - story_desc = { - 1: "'{}'", - 2: "'{}' and '{}'", - 3: "'{}', '{}', and '{}'", - }.get(len(stories)) - if story_desc: - story_desc = story_desc.format(*stories) - else: + if len(stories) > 3: # Four or more stories are present - story_desc = f"'{stories[0]}' and {len(stories) - 1} other trackers" + conflict_description = f"'{stories[0]}' and {len(stories) - 1} other trackers" + else: + conflict_description = { + 1: "'{}'", + 2: "'{}' and '{}'", + 3: "'{}', '{}', and '{}'", + }.get(len(stories)).format(*stories) - conflict_string += f" {action} predicted in {story_desc}\n" + conflict_message += f" {action} predicted in {conflict_description}\n" - return conflict_string + return conflict_message TrackerEventStateTuple = namedtuple("TrackerEventStateTuple", "tracker event sliced_states") From a9f56215de09772df1b61c256da8562f519684af Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 16:14:35 +0100 Subject: [PATCH 121/209] Declare return type --- rasa/core/validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index ceaeffdf3243..7640db90f906 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -18,7 +18,7 @@ class Validator: """A class used to verify usage of intents and utterances.""" - def __init__(self, domain: Domain, intents: TrainingData, story_graph: StoryGraph): + def __init__(self, domain: Domain, intents: TrainingData, story_graph: StoryGraph) -> None: """Initializes the Validator object. """ self.domain = domain From 6a196fa37e45868de3ce61ead6a99026af66da17 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 16:15:28 +0100 Subject: [PATCH 122/209] Rename test_story_conflict.py --- tests/core/{test_storyconflict.py => test_story_conflict.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/core/{test_storyconflict.py => test_story_conflict.py} (100%) diff --git a/tests/core/test_storyconflict.py b/tests/core/test_story_conflict.py similarity index 100% rename from tests/core/test_storyconflict.py rename to tests/core/test_story_conflict.py From cea9208e8e920861ea384b5e33d5468ca59d6185 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 16:39:25 +0100 Subject: [PATCH 123/209] Remove empty lines --- tests/core/test_story_conflict.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/core/test_story_conflict.py b/tests/core/test_story_conflict.py index be236fdaf6b1..8d876e688988 100644 --- a/tests/core/test_story_conflict.py +++ b/tests/core/test_story_conflict.py @@ -70,7 +70,6 @@ async def test_find_conflicts_checkpoints(): async def test_add_conflicting_action(): - sliced_states = [ None, {}, @@ -85,7 +84,6 @@ async def test_add_conflicting_action(): async def test_has_prior_events(): - sliced_states = [ None, {}, @@ -97,7 +95,6 @@ async def test_has_prior_events(): async def test_has_no_prior_events(): - sliced_states = [None] conflict = StoryConflict(sliced_states) assert not conflict.has_prior_events From 848720795dbd1cf71bc19460840c36e203ced043 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 16:40:11 +0100 Subject: [PATCH 124/209] Apply BLACK formatting --- rasa/cli/data.py | 17 ++++++++--- rasa/core/training/story_conflict.py | 43 ++++++++++++++++++---------- rasa/core/validator.py | 4 ++- 3 files changed, 44 insertions(+), 20 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 5161634b1478..c5787958f54d 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -33,7 +33,9 @@ def add_subparser( _add_data_validate_parsers(data_subparsers, parents) -def _add_data_convert_parsers(data_subparsers, parents: List[argparse.ArgumentParser]) -> None: +def _add_data_convert_parsers( + data_subparsers, parents: List[argparse.ArgumentParser] +) -> None: import rasa.nlu.convert as convert convert_parser = data_subparsers.add_parser( @@ -56,7 +58,9 @@ def _add_data_convert_parsers(data_subparsers, parents: List[argparse.ArgumentPa arguments.set_convert_arguments(convert_nlu_parser) -def _add_data_split_parsers(data_subparsers, parents: List[argparse.ArgumentParser]) -> None: +def _add_data_split_parsers( + data_subparsers, parents: List[argparse.ArgumentParser] +) -> None: split_parser = data_subparsers.add_parser( "split", formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -78,7 +82,9 @@ def _add_data_split_parsers(data_subparsers, parents: List[argparse.ArgumentPars arguments.set_split_arguments(nlu_split_parser) -def _add_data_validate_parsers(data_subparsers, parents: List[argparse.ArgumentParser]) -> None: +def _add_data_validate_parsers( + data_subparsers, parents: List[argparse.ArgumentParser] +) -> None: validate_parser = data_subparsers.add_parser( "validate", formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -173,7 +179,10 @@ def validate_stories(args: "argparse.Namespace") -> NoReturn: # Check if a valid setting for `max_history` was given if not isinstance(args.max_history, int) or args.max_history < 1: - raise argparse.ArgumentError(args.max_history, "You have to provide a positive integer for --max-history.") + raise argparse.ArgumentError( + args.max_history, + "You have to provide a positive integer for --max-history.", + ) # Prepare story and domain file import loop = asyncio.get_event_loop() diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 0a0610b05f1b..49654833eaf0 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -10,11 +10,11 @@ class StoryConflict: - def __init__( - self, sliced_states: List[Optional[Dict[Text, float]]], - ) -> None: + def __init__(self, sliced_states: List[Optional[Dict[Text, float]]],) -> None: self.sliced_states = sliced_states - self._conflicting_actions = defaultdict(list) # {"action": ["story_1", ...], ...} + self._conflicting_actions = defaultdict( + list + ) # {"action": ["story_1", ...], ...} self.correct_response = None def __hash__(self) -> int: @@ -52,7 +52,9 @@ def __str__(self) -> Text: # Describe where the conflict occurs in the stories last_event_type, last_event_name = _get_previous_event(self.sliced_states[-1]) if last_event_type: - conflict_message = f"CONFLICT after {last_event_type} '{last_event_name}':\n" + conflict_message = ( + f"CONFLICT after {last_event_type} '{last_event_name}':\n" + ) else: conflict_message = f"CONFLICT at the beginning of stories:\n" @@ -61,20 +63,24 @@ def __str__(self) -> Text: # Summarize if necessary if len(stories) > 3: # Four or more stories are present - conflict_description = f"'{stories[0]}' and {len(stories) - 1} other trackers" + conflict_description = ( + f"'{stories[0]}' and {len(stories) - 1} other trackers" + ) else: - conflict_description = { - 1: "'{}'", - 2: "'{}' and '{}'", - 3: "'{}', '{}', and '{}'", - }.get(len(stories)).format(*stories) + conflict_description = ( + {1: "'{}'", 2: "'{}' and '{}'", 3: "'{}', '{}', and '{}'",} + .get(len(stories)) + .format(*stories) + ) conflict_message += f" {action} predicted in {conflict_description}\n" return conflict_message -TrackerEventStateTuple = namedtuple("TrackerEventStateTuple", "tracker event sliced_states") +TrackerEventStateTuple = namedtuple( + "TrackerEventStateTuple", "tracker event sliced_states" +) def find_story_conflicts( @@ -98,7 +104,9 @@ def find_story_conflicts( # Iterate once more over all states and note the (unhashed) state, # for which a conflict occurs - conflicts = _build_conflicts_from_states(trackers, domain, max_history, state_action_dict) + conflicts = _build_conflicts_from_states( + trackers, domain, max_history, state_action_dict + ) return conflicts @@ -139,11 +147,16 @@ def _build_conflicts_from_states( if hashed_state in rules: conflicts[hashed_state].add_conflicting_action( - action=element.event.as_story_string(), story_name=element.tracker.sender_id + action=element.event.as_story_string(), + story_name=element.tracker.sender_id, ) # Remove conflicts that arise from unpredictable actions - return [conflict for (hashed_state, conflict) in conflicts.items() if conflict.has_prior_events] + return [ + conflict + for (hashed_state, conflict) in conflicts.items() + if conflict.has_prior_events + ] def _sliced_states_iterator( diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 7640db90f906..154c224be446 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -18,7 +18,9 @@ class Validator: """A class used to verify usage of intents and utterances.""" - def __init__(self, domain: Domain, intents: TrainingData, story_graph: StoryGraph) -> None: + def __init__( + self, domain: Domain, intents: TrainingData, story_graph: StoryGraph + ) -> None: """Initializes the Validator object. """ self.domain = domain From f27812fcd54360c3c9829b862c16c650189cabe5 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 16:53:03 +0100 Subject: [PATCH 125/209] Add tests for story conflicts --- data/test_stories/stories_conflicting_3.md | 14 ++++ data/test_stories/stories_conflicting_4.md | 17 +++++ data/test_stories/stories_conflicting_5.md | 16 +++++ data/test_stories/stories_conflicting_6.md | 22 ++++++ tests/core/test_story_conflict.py | 84 ++++++++++++++++++++++ 5 files changed, 153 insertions(+) create mode 100644 data/test_stories/stories_conflicting_3.md create mode 100644 data/test_stories/stories_conflicting_4.md create mode 100644 data/test_stories/stories_conflicting_5.md create mode 100644 data/test_stories/stories_conflicting_6.md diff --git a/data/test_stories/stories_conflicting_3.md b/data/test_stories/stories_conflicting_3.md new file mode 100644 index 000000000000..2218f6cea164 --- /dev/null +++ b/data/test_stories/stories_conflicting_3.md @@ -0,0 +1,14 @@ +## greetings +* greet + - utter_greet +> check_greet + +## happy path +> check_greet +* default OR greet + - utter_default + +## problem +> check_greet +* greet + - utter_goodbye diff --git a/data/test_stories/stories_conflicting_4.md b/data/test_stories/stories_conflicting_4.md new file mode 100644 index 000000000000..372c38ff6d15 --- /dev/null +++ b/data/test_stories/stories_conflicting_4.md @@ -0,0 +1,17 @@ +## story 1 +* greet + - utter_greet +* greet + - slot{"cuisine": "German"} + - utter_greet +* greet + - utter_greet + +## story 2 +* greet + - utter_greet +* greet + - slot{"cuisine": "German"} + - utter_greet +* greet + - utter_default diff --git a/data/test_stories/stories_conflicting_5.md b/data/test_stories/stories_conflicting_5.md new file mode 100644 index 000000000000..6865c9db9b4f --- /dev/null +++ b/data/test_stories/stories_conflicting_5.md @@ -0,0 +1,16 @@ +## story 1 +* greet + - utter_greet +* greet + - utter_greet + - slot{"cuisine": "German"} +* greet + - utter_greet + +## story 2 +* greet + - utter_greet +* greet + - utter_greet +* greet + - utter_default diff --git a/data/test_stories/stories_conflicting_6.md b/data/test_stories/stories_conflicting_6.md new file mode 100644 index 000000000000..f58dc258078e --- /dev/null +++ b/data/test_stories/stories_conflicting_6.md @@ -0,0 +1,22 @@ +## story 1 +* greet + - utter_greet + +## story 2 +* greet + - utter_default + +## story 3 +* greet + - utter_default +* greet + +## story 4 +* greet + - utter_default +* default + +## story 5 +* greet + - utter_default +* goodbye diff --git a/tests/core/test_story_conflict.py b/tests/core/test_story_conflict.py index 8d876e688988..172784f46086 100644 --- a/tests/core/test_story_conflict.py +++ b/tests/core/test_story_conflict.py @@ -69,6 +69,90 @@ async def test_find_conflicts_checkpoints(): assert conflicts[0].conflicting_actions == ["utter_goodbye", "utter_default"] +async def test_find_conflicts_or(): + importer = RasaFileImporter( + domain_path="data/test_domains/default.yml", + training_data_paths=["data/test_stories/stories_conflicting_3.md"], + ) + validator = await Validator.from_importer(importer) + + trackers = TrainingDataGenerator( + validator.story_graph, + domain=validator.domain, + remove_duplicates=False, + augmentation_factor=0, + ).generate() + + # Create a list of `StoryConflict` objects + conflicts = find_story_conflicts(trackers, validator.domain, 5) + + assert len(conflicts) == 1 + assert conflicts[0].conflicting_actions == ["utter_default", "utter_goodbye"] + + +async def test_find_conflicts_slots(): + importer = RasaFileImporter( + domain_path="data/test_domains/default.yml", + training_data_paths=["data/test_stories/stories_conflicting_4.md"], + ) + validator = await Validator.from_importer(importer) + + trackers = TrainingDataGenerator( + validator.story_graph, + domain=validator.domain, + remove_duplicates=False, + augmentation_factor=0, + ).generate() + + # Create a list of `StoryConflict` objects + conflicts = find_story_conflicts(trackers, validator.domain, 5) + + assert len(conflicts) == 1 + assert conflicts[0].conflicting_actions == ["utter_default", "utter_greet"] + + +async def test_find_conflicts_slots_2(): + importer = RasaFileImporter( + domain_path="data/test_domains/default.yml", + training_data_paths=["data/test_stories/stories_conflicting_5.md"], + ) + validator = await Validator.from_importer(importer) + + trackers = TrainingDataGenerator( + validator.story_graph, + domain=validator.domain, + remove_duplicates=False, + augmentation_factor=0, + ).generate() + + # Create a list of `StoryConflict` objects + conflicts = find_story_conflicts(trackers, validator.domain, 5) + + assert len(conflicts) == 0 + + +async def test_find_conflicts_multiple_stories(): + importer = RasaFileImporter( + domain_path="data/test_domains/default.yml", + training_data_paths=["data/test_stories/stories_conflicting_6.md"], + ) + validator = await Validator.from_importer(importer) + + trackers = TrainingDataGenerator( + validator.story_graph, + domain=validator.domain, + remove_duplicates=False, + augmentation_factor=0, + ).generate() + + # Create a list of `StoryConflict` objects + conflicts = find_story_conflicts(trackers, validator.domain, 5) + + assert len(conflicts) == 1 + print(conflicts[0]) + assert "and 3 other trackers" in str(conflicts[0]) + + async def test_add_conflicting_action(): sliced_states = [ None, From 86aefe5fe08ca8d6703cdd3ff59cc93e749d181d Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 17:30:51 +0100 Subject: [PATCH 126/209] Define _setup_trackers_for_testing --- tests/core/test_story_conflict.py | 108 +++++++++--------------------- 1 file changed, 30 insertions(+), 78 deletions(-) diff --git a/tests/core/test_story_conflict.py b/tests/core/test_story_conflict.py index 172784f46086..5d6461f86496 100644 --- a/tests/core/test_story_conflict.py +++ b/tests/core/test_story_conflict.py @@ -5,10 +5,9 @@ from tests.core.conftest import DEFAULT_STORIES_FILE, DEFAULT_DOMAIN_PATH_WITH_SLOTS -async def test_find_no_conflicts(): +async def _setup_trackers_for_testing(domain_path, training_data_file): importer = RasaFileImporter( - domain_path=DEFAULT_DOMAIN_PATH_WITH_SLOTS, - training_data_paths=[DEFAULT_STORIES_FILE], + domain_path=domain_path, training_data_paths=[training_data_file], ) validator = await Validator.from_importer(importer) @@ -19,137 +18,90 @@ async def test_find_no_conflicts(): augmentation_factor=0, ).generate() + return trackers, validator.domain + + +async def test_find_no_conflicts(): + trackers, domain = await _setup_trackers_for_testing( + DEFAULT_DOMAIN_PATH_WITH_SLOTS, DEFAULT_STORIES_FILE + ) + # Create a list of `StoryConflict` objects - conflicts = find_story_conflicts(trackers, validator.domain, 5) + conflicts = find_story_conflicts(trackers, domain, 5) assert conflicts == [] async def test_find_conflicts_in_short_history(): - importer = RasaFileImporter( - domain_path="data/test_domains/default.yml", - training_data_paths=["data/test_stories/stories_conflicting_1.md"], + trackers, domain = await _setup_trackers_for_testing( + "data/test_domains/default.yml", "data/test_stories/stories_conflicting_1.md" ) - validator = await Validator.from_importer(importer) - - trackers = TrainingDataGenerator( - validator.story_graph, - domain=validator.domain, - remove_duplicates=False, - augmentation_factor=0, - ).generate() # `max_history = 3` is too small, so a conflict must arise - conflicts = find_story_conflicts(trackers, validator.domain, 3) + conflicts = find_story_conflicts(trackers, domain, 3) assert len(conflicts) == 1 # With `max_history = 4` the conflict should disappear - conflicts = find_story_conflicts(trackers, validator.domain, 4) + conflicts = find_story_conflicts(trackers, domain, 4) assert len(conflicts) == 0 async def test_find_conflicts_checkpoints(): - importer = RasaFileImporter( - domain_path="data/test_domains/default.yml", - training_data_paths=["data/test_stories/stories_conflicting_2.md"], + trackers, domain = await _setup_trackers_for_testing( + "data/test_domains/default.yml", "data/test_stories/stories_conflicting_2.md" ) - validator = await Validator.from_importer(importer) - - trackers = TrainingDataGenerator( - validator.story_graph, - domain=validator.domain, - remove_duplicates=False, - augmentation_factor=0, - ).generate() # Create a list of `StoryConflict` objects - conflicts = find_story_conflicts(trackers, validator.domain, 5) + conflicts = find_story_conflicts(trackers, domain, 5) assert len(conflicts) == 1 assert conflicts[0].conflicting_actions == ["utter_goodbye", "utter_default"] async def test_find_conflicts_or(): - importer = RasaFileImporter( - domain_path="data/test_domains/default.yml", - training_data_paths=["data/test_stories/stories_conflicting_3.md"], + trackers, domain = await _setup_trackers_for_testing( + "data/test_domains/default.yml", "data/test_stories/stories_conflicting_3.md" ) - validator = await Validator.from_importer(importer) - - trackers = TrainingDataGenerator( - validator.story_graph, - domain=validator.domain, - remove_duplicates=False, - augmentation_factor=0, - ).generate() # Create a list of `StoryConflict` objects - conflicts = find_story_conflicts(trackers, validator.domain, 5) + conflicts = find_story_conflicts(trackers, domain, 5) assert len(conflicts) == 1 assert conflicts[0].conflicting_actions == ["utter_default", "utter_goodbye"] async def test_find_conflicts_slots(): - importer = RasaFileImporter( - domain_path="data/test_domains/default.yml", - training_data_paths=["data/test_stories/stories_conflicting_4.md"], + trackers, domain = await _setup_trackers_for_testing( + "data/test_domains/default.yml", "data/test_stories/stories_conflicting_4.md" ) - validator = await Validator.from_importer(importer) - - trackers = TrainingDataGenerator( - validator.story_graph, - domain=validator.domain, - remove_duplicates=False, - augmentation_factor=0, - ).generate() # Create a list of `StoryConflict` objects - conflicts = find_story_conflicts(trackers, validator.domain, 5) + conflicts = find_story_conflicts(trackers, domain, 5) assert len(conflicts) == 1 assert conflicts[0].conflicting_actions == ["utter_default", "utter_greet"] async def test_find_conflicts_slots_2(): - importer = RasaFileImporter( - domain_path="data/test_domains/default.yml", - training_data_paths=["data/test_stories/stories_conflicting_5.md"], + trackers, domain = await _setup_trackers_for_testing( + "data/test_domains/default.yml", "data/test_stories/stories_conflicting_5.md" ) - validator = await Validator.from_importer(importer) - - trackers = TrainingDataGenerator( - validator.story_graph, - domain=validator.domain, - remove_duplicates=False, - augmentation_factor=0, - ).generate() # Create a list of `StoryConflict` objects - conflicts = find_story_conflicts(trackers, validator.domain, 5) + conflicts = find_story_conflicts(trackers, domain, 5) assert len(conflicts) == 0 async def test_find_conflicts_multiple_stories(): - importer = RasaFileImporter( - domain_path="data/test_domains/default.yml", - training_data_paths=["data/test_stories/stories_conflicting_6.md"], + trackers, domain = await _setup_trackers_for_testing( + "data/test_domains/default.yml", "data/test_stories/stories_conflicting_6.md" ) - validator = await Validator.from_importer(importer) - - trackers = TrainingDataGenerator( - validator.story_graph, - domain=validator.domain, - remove_duplicates=False, - augmentation_factor=0, - ).generate() # Create a list of `StoryConflict` objects - conflicts = find_story_conflicts(trackers, validator.domain, 5) + conflicts = find_story_conflicts(trackers, domain, 5) assert len(conflicts) == 1 - print(conflicts[0]) assert "and 3 other trackers" in str(conflicts[0]) From 1123db6ef6f5da01119f447cd2bf036577234a9f Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 17:40:30 +0100 Subject: [PATCH 127/209] Clarify output message --- rasa/core/validator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 154c224be446..32fb45759efb 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -196,7 +196,9 @@ def verify_story_structure( """Verifies that bot behaviour in stories is deterministic.""" logger.info("Story structure validation...") - logger.info(f"Assuming max_history = {max_history}") + logger.info( + f" Considering the preceding {max_history} turns for conflict analysis." + ) trackers = TrainingDataGenerator( self.story_graph, From 01c60f7df641ad519d7650c65d47d7346820fc5b Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 17:41:11 +0100 Subject: [PATCH 128/209] Add missing tick-marks --- docs/user-guide/validate-files.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/user-guide/validate-files.rst b/docs/user-guide/validate-files.rst index a8df730f6eaa..033104976676 100644 --- a/docs/user-guide/validate-files.rst +++ b/docs/user-guide/validate-files.rst @@ -107,10 +107,11 @@ Just use ``rasa data validate`` in the command line, as follows: > utter_goodbye predicted in 'Story 2' > utter_happy predicted in 'Story 1' -Here we specify a `max-history` value of 3. +Here we specify a ``max-history`` value of 3. This means, that 3 events (user / bot actions) are taken into account for action predictions, but the particular setting does not matter for this example, because regardless of how long of a history you take into account, the conflict always exists. .. warning:: - The `rasa data validate stories` script assumes that all your **story names are unique**. + + The ``rasa data validate stories`` script assumes that all your **story names are unique**. If your stories are in the Markdown format, you may find duplicate names with a command like - `grep -h "##" data/*.md | uniq -c | grep "^[^1]"`. + ``grep -h "##" data/*.md | uniq -c | grep "^[^1]"``. From 863a46298e3a83885c3716604204f87f57e407b0 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 17:49:22 +0100 Subject: [PATCH 129/209] Clarify help strings --- rasa/cli/data.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index c5787958f54d..b911250043d3 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -95,7 +95,7 @@ def _add_data_validate_parsers( "--max-history", type=int, default=None, - help="Assume this max_history setting for story structure validation.", + help="Number of turns taken into account for story structure validation.", ) validate_parser.set_defaults(func=validate_files) arguments.set_validator_arguments(validate_parser) @@ -110,13 +110,7 @@ def _add_data_validate_parsers( story_structure_parser.add_argument( "--max-history", type=int, - help="Assume this max_history setting for validation.", - ) - story_structure_parser.add_argument( - "--prompt", - action="store_true", - default=False, - help="Ask how conflicts should be fixed", + help="Number of turns taken into account for story structure validation.", ) story_structure_parser.set_defaults(func=validate_stories) arguments.set_validator_arguments(story_structure_parser) From 0ea731546fb4c8381c488d6ab186810d41873b83 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 17:50:20 +0100 Subject: [PATCH 130/209] Use else --- rasa/cli/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index b911250043d3..d537f61c4436 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -155,7 +155,7 @@ def validate_files(args: "argparse.Namespace") -> NoReturn: "Will not test for inconsistencies in stories since " "you did not provide a value for `--max-history`." ) - if args.max_history: + else: # Only run story structure validation if everything else is fine # since this might take a while everything_is_alright = validator.verify_story_structure( From adc45c5c7a4fb1b97bf233860113a02cf701aaa6 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 13 Jan 2020 17:54:46 +0100 Subject: [PATCH 131/209] Avoid exit on successful validate --- rasa/cli/data.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index d537f61c4436..fe87c3cdf4b2 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -132,7 +132,7 @@ def split_nlu_data(args) -> None: test.persist(args.out, filename=f"test_data.{fformat}") -def validate_files(args: "argparse.Namespace") -> NoReturn: +def validate_files(args: "argparse.Namespace") -> None: """Validate all files needed for training a model. Fails with a non-zero exit code if there are any errors in the data.""" @@ -161,10 +161,13 @@ def validate_files(args: "argparse.Namespace") -> NoReturn: everything_is_alright = validator.verify_story_structure( not args.fail_on_warnings, max_history=args.max_history ) - sys.exit(0) if everything_is_alright else sys.exit(1) + + if not everything_is_alright: + print_error("Project validation completed with errors.") + sys.exit(1) -def validate_stories(args: "argparse.Namespace") -> NoReturn: +def validate_stories(args: "argparse.Namespace") -> None: """Validate all files needed for training a model. Fails with a non-zero exit code if there are any errors in the data.""" From 846537d7538aa0cc8424fee9ad260e769b3f8e46 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Tue, 14 Jan 2020 09:36:55 +0100 Subject: [PATCH 132/209] Dirty bugfix for sanic-plugins-framework dependency --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index e052b6c45e78..45821b050062 100644 --- a/requirements.txt +++ b/requirements.txt @@ -41,6 +41,7 @@ terminaltables==3.1.0 sanic==19.9.0 sanic-cors==0.9.9.post1 sanic-jwt==1.3.2 +sanic-plugins-framework==0.8.2 # needed because of https://github.com/huge-success/sanic/issues/1729 multidict==4.6.1 aiohttp==3.5.4 From ceb0abc3bb3c63512951975d53edea4d53cdf8dd Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Thu, 16 Jan 2020 11:59:11 +0100 Subject: [PATCH 133/209] Clarify documentation of `rasa data validate` --- docs/user-guide/validate-files.rst | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/user-guide/validate-files.rst b/docs/user-guide/validate-files.rst index 033104976676..88b27e87baed 100644 --- a/docs/user-guide/validate-files.rst +++ b/docs/user-guide/validate-files.rst @@ -18,7 +18,8 @@ You can run it with the following command: rasa data validate -The script above runs most of the validations on your files. Here is the list of options to +The script above runs all the validations on your files, except for story structure validation, +which is omitted unless you provide the `--max-history` argument. Here is the list of options to the script: .. program-output:: rasa data validate --help @@ -94,8 +95,8 @@ These two stories are inconsistent, because Rasa cannot know if it should predic after ``inform_happy``, as there is nothing that would distinguish the dialogue states at ``inform_happy`` in the two stories and the subsequent actions are different in Story 1 and Story 2. -This conflict can now be automatically identified with our new story structure tool. -Just use ``rasa data validate`` in the command line, as follows: +This conflict can be automatically identified with our story structure validation tool. +To do this, use ``rasa data validate`` in the command line, as follows: .. code-block:: bash @@ -108,7 +109,7 @@ Just use ``rasa data validate`` in the command line, as follows: > utter_happy predicted in 'Story 1' Here we specify a ``max-history`` value of 3. -This means, that 3 events (user / bot actions) are taken into account for action predictions, but the particular setting does not matter for this example, because regardless of how long of a history you take into account, the conflict always exists. +This means, that 3 events (user messages / bot actions) are taken into account for action predictions, but the particular setting does not matter for this example, because regardless of how long of a history you take into account, the conflict always exists. .. warning:: From 7c23e8050d09ed4c4cb187ad75d79ab4f5314ae5 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Thu, 16 Jan 2020 13:02:14 +0100 Subject: [PATCH 134/209] Make logger messages consistent --- rasa/core/validator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 32fb45759efb..32d5e3b5d6e7 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -197,7 +197,7 @@ def verify_story_structure( logger.info("Story structure validation...") logger.info( - f" Considering the preceding {max_history} turns for conflict analysis." + f"Considering the preceding {max_history} turns for conflict analysis." ) trackers = TrainingDataGenerator( @@ -211,7 +211,7 @@ def verify_story_structure( conflicts = find_story_conflicts(trackers, self.domain, max_history) if len(conflicts) == 0: - logger.info("No story structure conflicts found") + logger.info("No story structure conflicts found.") else: for conflict in conflicts: logger.warning(conflict) From a45a9b07f749e38ea791f34fc7bc3b0a6f602775 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Thu, 16 Jan 2020 13:06:41 +0100 Subject: [PATCH 135/209] Remove outdated ToDo string --- rasa/core/validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 32d5e3b5d6e7..b450c9a1af34 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -203,7 +203,7 @@ def verify_story_structure( trackers = TrainingDataGenerator( self.story_graph, domain=self.domain, - remove_duplicates=False, # ToDo: Q&A: Why not remove_duplicates=True? + remove_duplicates=False, augmentation_factor=0, ).generate() From 0e4bfeed0f3e06e304e7d31c50237787a1991620 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Thu, 16 Jan 2020 13:14:41 +0100 Subject: [PATCH 136/209] Remove irrelevant comment --- rasa/core/validator.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index b450c9a1af34..502c1236384f 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -216,9 +216,6 @@ def verify_story_structure( for conflict in conflicts: logger.warning(conflict) - # For code stub to fix the conflict in the command line, - # see commit 3fdc08a030dbd85c15b4f5d7e8b5ad6a254eefb4 - return ignore_warnings or len(conflicts) == 0 def verify_all(self, ignore_warnings: bool = True) -> bool: From e76e90bb8b443e2f313ce5fa4b798912a20981d1 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Thu, 16 Jan 2020 13:40:50 +0100 Subject: [PATCH 137/209] Add test_verify_bad_story_structure_ignore_warnings --- tests/core/test_validator.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/core/test_validator.py b/tests/core/test_validator.py index 4180d999eba3..72d2cd15d24a 100644 --- a/tests/core/test_validator.py +++ b/tests/core/test_validator.py @@ -54,6 +54,15 @@ async def test_verify_bad_story_structure(): assert not validator.verify_story_structure(ignore_warnings=False) +async def test_verify_bad_story_structure_ignore_warnings(): + importer = RasaFileImporter( + domain_path="data/test_domains/default.yml", + training_data_paths=["data/test_stories/stories_conflicting_2.md"], + ) + validator = await Validator.from_importer(importer) + assert validator.verify_story_structure(ignore_warnings=True) + + async def test_fail_on_invalid_utterances(tmpdir): # domain and stories are from different domain and should produce warnings invalid_domain = str(tmpdir / "invalid_domain.yml") From 485537ae6e96fc30050900a57bf31c71ecf55c78 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 10:35:34 +0100 Subject: [PATCH 138/209] Fix consequences of renaming INTENT_ATTRIBUTE --- rasa/core/training/story_conflict.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 49654833eaf0..308d1504936e 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -5,7 +5,7 @@ from rasa.core.domain import PREV_PREFIX, Domain from rasa.core.events import ActionExecuted, Event from rasa.core.featurizers import MaxHistoryTrackerFeaturizer -from rasa.nlu.constants import MESSAGE_INTENT_ATTRIBUTE +from rasa.nlu.constants import INTENT_ATTRIBUTE from rasa.core.training.generator import TrackerWithCachedStates @@ -214,8 +214,8 @@ def _get_previous_event( prev_event_type = "action" prev_event_name = k[len(PREV_PREFIX) :] - if not prev_event_type and k.startswith(MESSAGE_INTENT_ATTRIBUTE + "_"): + if not prev_event_type and k.startswith(INTENT_ATTRIBUTE + "_"): prev_event_type = "intent" - prev_event_name = k[len(MESSAGE_INTENT_ATTRIBUTE + "_") :] + prev_event_name = k[len(INTENT_ATTRIBUTE + "_") :] return prev_event_type, prev_event_name From 6b617c0682474583d4080075e600826c16c6a01c Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 10:47:00 +0100 Subject: [PATCH 139/209] Fix return type annotation --- rasa/core/training/story_conflict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 308d1504936e..3b965f080707 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -1,5 +1,5 @@ from collections import defaultdict, namedtuple -from typing import List, Optional, Dict, Text +from typing import List, Optional, Dict, Text, Tuple from rasa.core.actions.action import ACTION_LISTEN_NAME from rasa.core.domain import PREV_PREFIX, Domain @@ -191,7 +191,7 @@ def _sliced_states_iterator( def _get_previous_event( state: Optional[Dict[Text, float]] -) -> [Optional[Text], Optional[Text]]: +) -> Tuple[Optional[Text], Optional[Text]]: """Returns previous event type and name. Returns the type and name of the event (action or intent) previous to the From 07b3d8ab2f683d3a7fc066d63ecaa01aa94f6586 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 11:02:21 +0100 Subject: [PATCH 140/209] Add test for _get_previous_event --- rasa/core/training/story_conflict.py | 6 +++--- tests/core/test_story_conflict.py | 7 ++++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 3b965f080707..78c49ca0dc6e 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -210,12 +210,12 @@ def _get_previous_event( return prev_event_type, prev_event_name for k in state: - if k.startswith(PREV_PREFIX) and k[len(PREV_PREFIX) :] != ACTION_LISTEN_NAME: + if k.startswith(PREV_PREFIX) and k.replace(PREV_PREFIX, "") != ACTION_LISTEN_NAME: prev_event_type = "action" - prev_event_name = k[len(PREV_PREFIX) :] + prev_event_name = k.replace(PREV_PREFIX, "") if not prev_event_type and k.startswith(INTENT_ATTRIBUTE + "_"): prev_event_type = "intent" - prev_event_name = k[len(INTENT_ATTRIBUTE + "_") :] + prev_event_name = k.replace(INTENT_ATTRIBUTE + "_", "") return prev_event_type, prev_event_name diff --git a/tests/core/test_story_conflict.py b/tests/core/test_story_conflict.py index 5d6461f86496..3cdf08a51aed 100644 --- a/tests/core/test_story_conflict.py +++ b/tests/core/test_story_conflict.py @@ -1,4 +1,4 @@ -from rasa.core.training.story_conflict import StoryConflict, find_story_conflicts +from rasa.core.training.story_conflict import StoryConflict, find_story_conflicts, _get_previous_event from rasa.core.training.generator import TrainingDataGenerator from rasa.core.validator import Validator from rasa.importers.rasa import RasaFileImporter @@ -130,6 +130,11 @@ async def test_has_prior_events(): assert conflict.has_prior_events +async def test_get_previous_event(): + assert _get_previous_event({"prev_utter_greet": 1.0, "intent_greet": 1.0}) == ("action", "utter_greet") + assert _get_previous_event({"intent_greet": 1.0, "prev_action_listen": 1.0}) == ("intent", "greet") + + async def test_has_no_prior_events(): sliced_states = [None] conflict = StoryConflict(sliced_states) From c8e2e7e3bd7aea06cffde30709d609e580f75d4a Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 11:13:59 +0100 Subject: [PATCH 141/209] Simplify StoryConflict.__str__ --- rasa/core/training/story_conflict.py | 29 +++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 78c49ca0dc6e..6bda6282a7ae 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -60,22 +60,25 @@ def __str__(self) -> Text: # List which stories are in conflict with one another for action, stories in self._conflicting_actions.items(): - # Summarize if necessary - if len(stories) > 3: - # Four or more stories are present - conflict_description = ( - f"'{stories[0]}' and {len(stories) - 1} other trackers" - ) - else: - conflict_description = ( - {1: "'{}'", 2: "'{}' and '{}'", 3: "'{}', '{}', and '{}'",} + conflict_message += " " + self._summarize_conflict(action, stories) + + return conflict_message + + @staticmethod + def _summarize_conflict(action, stories): + if len(stories) > 3: + # Four or more stories are present + conflict_description = ( + f"'{stories[0]}' and {len(stories) - 1} other trackers" + ) + else: + conflict_description = ( + {1: "'{}'", 2: "'{}' and '{}'", 3: "'{}', '{}', and '{}'", } .get(len(stories)) .format(*stories) - ) - - conflict_message += f" {action} predicted in {conflict_description}\n" + ) - return conflict_message + return f"{action} predicted in {conflict_description}\n" TrackerEventStateTuple = namedtuple( From 070b597ed46e658ef2f931c511f85258fb8a7b42 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 11:35:59 +0100 Subject: [PATCH 142/209] Simplify _build_conflicts_from_states --- rasa/core/training/story_conflict.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 6bda6282a7ae..53a1af15c80d 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -137,7 +137,7 @@ def _build_conflicts_from_states( trackers: List["TrackerWithCachedStates"], domain: Domain, max_history: int, - rules: Dict[int, Optional[List[Text]]], + state_action_dict: Dict[int, Optional[List[Text]]], ) -> List["StoryConflict"]: # Iterate once more over all states and note the (unhashed) state, # for which a conflict occurs @@ -145,10 +145,10 @@ def _build_conflicts_from_states( for element in _sliced_states_iterator(trackers, domain, max_history): hashed_state = hash(str(list(element.sliced_states))) - if hashed_state in rules and hashed_state not in conflicts: - conflicts[hashed_state] = StoryConflict(element.sliced_states) + if hashed_state in state_action_dict: + if hashed_state not in conflicts: + conflicts[hashed_state] = StoryConflict(element.sliced_states) - if hashed_state in rules: conflicts[hashed_state].add_conflicting_action( action=element.event.as_story_string(), story_name=element.tracker.sender_id, From 1b95d6c667599767e135ffde05565968e8019ae4 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 11:45:39 +0100 Subject: [PATCH 143/209] Simplify _get_previous_event --- rasa/core/training/story_conflict.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 53a1af15c80d..f222f5a1b1bc 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -216,9 +216,11 @@ def _get_previous_event( if k.startswith(PREV_PREFIX) and k.replace(PREV_PREFIX, "") != ACTION_LISTEN_NAME: prev_event_type = "action" prev_event_name = k.replace(PREV_PREFIX, "") + break - if not prev_event_type and k.startswith(INTENT_ATTRIBUTE + "_"): + if k.startswith(INTENT_ATTRIBUTE + "_"): prev_event_type = "intent" prev_event_name = k.replace(INTENT_ATTRIBUTE + "_", "") + break return prev_event_type, prev_event_name From f3d2937ac75aa2b6d592c272b14a33f2e4bf8720 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 11:58:07 +0100 Subject: [PATCH 144/209] Simplify _get_previous_event --- rasa/core/training/story_conflict.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index f222f5a1b1bc..565bb4f3f423 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -217,8 +217,7 @@ def _get_previous_event( prev_event_type = "action" prev_event_name = k.replace(PREV_PREFIX, "") break - - if k.startswith(INTENT_ATTRIBUTE + "_"): + elif k.startswith(INTENT_ATTRIBUTE + "_"): prev_event_type = "intent" prev_event_name = k.replace(INTENT_ATTRIBUTE + "_", "") break From a4b74cc04942d8fd7d536d374fe74d15c27f970b Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 14:55:21 +0100 Subject: [PATCH 145/209] Fix return type of StoryConflict._sliced_states_iterator --- rasa/core/training/story_conflict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 565bb4f3f423..8e07cfeee5e1 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -1,5 +1,5 @@ from collections import defaultdict, namedtuple -from typing import List, Optional, Dict, Text, Tuple +from typing import List, Optional, Dict, Text, Tuple, Generator from rasa.core.actions.action import ACTION_LISTEN_NAME from rasa.core.domain import PREV_PREFIX, Domain @@ -164,7 +164,7 @@ def _build_conflicts_from_states( def _sliced_states_iterator( trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int -) -> TrackerEventStateTuple: +) -> Generator[TrackerEventStateTuple, None, None]: """Creates an iterator over sliced states. Iterate over all given trackers and all sliced states within each tracker, From 03952f5783121b99424631e2a927e4b008952de5 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 16:53:20 +0100 Subject: [PATCH 146/209] Avoid `from` --- rasa/cli/data.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index fe87c3cdf4b2..36b65033da8e 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -6,7 +6,7 @@ from rasa import data from rasa.cli.arguments import data as arguments -from rasa.cli.utils import get_validated_path, print_error +import rasa.cli.utils from rasa.constants import DEFAULT_DATA_PATH from typing import NoReturn @@ -120,7 +120,7 @@ def split_nlu_data(args) -> None: from rasa.nlu.training_data.loading import load_data from rasa.nlu.training_data.util import get_file_format - data_path = get_validated_path(args.nlu, "nlu", DEFAULT_DATA_PATH) + data_path = rasa.cli.utils.get_validated_path(args.nlu, "nlu", DEFAULT_DATA_PATH) data_path = data.get_nlu_directory(data_path) nlu_data = load_data(data_path) @@ -163,7 +163,7 @@ def validate_files(args: "argparse.Namespace") -> None: ) if not everything_is_alright: - print_error("Project validation completed with errors.") + rasa.cli.utils.print_error("Project validation completed with errors.") sys.exit(1) @@ -196,5 +196,5 @@ def validate_stories(args: "argparse.Namespace") -> None: ) if not everything_is_alright: - print_error("Story validation completed with errors.") + rasa.cli.utils.print_error("Story validation completed with errors.") sys.exit(1) From 54936b0f3611c1fd0e390614928892475f73254e Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 16:54:34 +0100 Subject: [PATCH 147/209] Update rasa/cli/data.py Co-Authored-By: Tobias Wochinger --- rasa/cli/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index fe87c3cdf4b2..29da5a4d944a 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -36,7 +36,7 @@ def add_subparser( def _add_data_convert_parsers( data_subparsers, parents: List[argparse.ArgumentParser] ) -> None: - import rasa.nlu.convert as convert + from rasa.nlu import convert convert_parser = data_subparsers.add_parser( "convert", From c33eef5282655e0505d1ded2a7a599b8ae39d622 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 16:57:38 +0100 Subject: [PATCH 148/209] Optimize imports Co-Authored-By: Tobias Wochinger --- rasa/cli/data.py | 1 - 1 file changed, 1 deletion(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index b29cd88d0087..073a17c03ea3 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -8,7 +8,6 @@ from rasa.cli.arguments import data as arguments import rasa.cli.utils from rasa.constants import DEFAULT_DATA_PATH -from typing import NoReturn logger = logging.getLogger(__name__) From f787eb3cb96c31e31855b4daa2056418412cfbd7 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 16:58:22 +0100 Subject: [PATCH 149/209] Remove quotes from type Co-Authored-By: Tobias Wochinger --- rasa/cli/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 073a17c03ea3..4e90c2e9b2b5 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -131,7 +131,7 @@ def split_nlu_data(args) -> None: test.persist(args.out, filename=f"test_data.{fformat}") -def validate_files(args: "argparse.Namespace") -> None: +def validate_files(args: argparse.Namespace) -> None: """Validate all files needed for training a model. Fails with a non-zero exit code if there are any errors in the data.""" From 90b3452228a439bc246a9c16f3b04e0e6dae087d Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 17:02:04 +0100 Subject: [PATCH 150/209] Clarify doc string --- rasa/cli/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 4e90c2e9b2b5..1df335c53101 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -167,7 +167,7 @@ def validate_files(args: argparse.Namespace) -> None: def validate_stories(args: "argparse.Namespace") -> None: - """Validate all files needed for training a model. + """Validate only the story structure. Fails with a non-zero exit code if there are any errors in the data.""" from rasa.core.validator import Validator From 8ab78691cf1e8112bba250502260960b83691b3d Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 17:03:26 +0100 Subject: [PATCH 151/209] Rename variable for clarity --- rasa/cli/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 1df335c53101..a832b38a96af 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -190,10 +190,10 @@ def validate_stories(args: "argparse.Namespace") -> None: validator = loop.run_until_complete(Validator.from_importer(file_importer)) # If names are unique, look for story conflicts - everything_is_alright = validator.verify_story_structure( + stories_are_consistent = validator.verify_story_structure( not args.fail_on_warnings, max_history=args.max_history ) - if not everything_is_alright: + if not stories_are_consistent: rasa.cli.utils.print_error("Story validation completed with errors.") sys.exit(1) From 19cd71ff9b482d6d4f58b2df660506a6f93d16c5 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 17:06:47 +0100 Subject: [PATCH 152/209] Apply BLACK formatting --- rasa/cli/data.py | 2 +- rasa/core/training/story_conflict.py | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index a832b38a96af..77da050ac47a 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -169,7 +169,7 @@ def validate_files(args: argparse.Namespace) -> None: def validate_stories(args: "argparse.Namespace") -> None: """Validate only the story structure. - Fails with a non-zero exit code if there are any errors in the data.""" + Fails with a non-zero exit code if there are any errors in the data.""" from rasa.core.validator import Validator from rasa.importers.rasa import RasaFileImporter diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 8e07cfeee5e1..3c58a565d881 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -73,9 +73,9 @@ def _summarize_conflict(action, stories): ) else: conflict_description = ( - {1: "'{}'", 2: "'{}' and '{}'", 3: "'{}', '{}', and '{}'", } - .get(len(stories)) - .format(*stories) + {1: "'{}'", 2: "'{}' and '{}'", 3: "'{}', '{}', and '{}'",} + .get(len(stories)) + .format(*stories) ) return f"{action} predicted in {conflict_description}\n" @@ -213,7 +213,10 @@ def _get_previous_event( return prev_event_type, prev_event_name for k in state: - if k.startswith(PREV_PREFIX) and k.replace(PREV_PREFIX, "") != ACTION_LISTEN_NAME: + if ( + k.startswith(PREV_PREFIX) + and k.replace(PREV_PREFIX, "") != ACTION_LISTEN_NAME + ): prev_event_type = "action" prev_event_name = k.replace(PREV_PREFIX, "") break From 8c17f8ccdc768df5dfa55a47f73ee8d01cd46b49 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 17:07:48 +0100 Subject: [PATCH 153/209] Delete unused property --- rasa/core/training/story_conflict.py | 1 - 1 file changed, 1 deletion(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 3c58a565d881..1f09f7864e21 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -15,7 +15,6 @@ def __init__(self, sliced_states: List[Optional[Dict[Text, float]]],) -> None: self._conflicting_actions = defaultdict( list ) # {"action": ["story_1", ...], ...} - self.correct_response = None def __hash__(self) -> int: return hash(str(list(self.sliced_states))) From a14c9b73dc40b9bde63530fe8ea1e4af6fc38b50 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 17:10:29 +0100 Subject: [PATCH 154/209] Rename `state_action_mapping` --- rasa/core/training/story_conflict.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 1f09f7864e21..9c331e95c71b 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -102,12 +102,12 @@ def find_story_conflicts( # Create a 'state -> list of actions' dict, where the state is # represented by its hash - state_action_dict = _find_conflicting_states(trackers, domain, max_history) + state_action_mapping = _find_conflicting_states(trackers, domain, max_history) # Iterate once more over all states and note the (unhashed) state, # for which a conflict occurs conflicts = _build_conflicts_from_states( - trackers, domain, max_history, state_action_dict + trackers, domain, max_history, state_action_mapping ) return conflicts @@ -118,16 +118,16 @@ def _find_conflicting_states( ) -> Dict[int, Optional[List[Text]]]: # Create a 'state -> list of actions' dict, where the state is # represented by its hash - state_action_dict = defaultdict(list) + state_action_mapping = defaultdict(list) for element in _sliced_states_iterator(trackers, domain, max_history): hashed_state = hash(str(list(element.sliced_states))) - if element.event.as_story_string() not in state_action_dict[hashed_state]: - state_action_dict[hashed_state] += [element.event.as_story_string()] + if element.event.as_story_string() not in state_action_mapping[hashed_state]: + state_action_mapping[hashed_state] += [element.event.as_story_string()] - # Keep only conflicting `state_action_dict`s + # Keep only conflicting `state_action_mapping`s return { state_hash: actions - for (state_hash, actions) in state_action_dict.items() + for (state_hash, actions) in state_action_mapping.items() if len(actions) > 1 } From 2dbc7c2a2631f205b1bfe239cffd3131436d1f8f Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 17:13:02 +0100 Subject: [PATCH 155/209] Remove quotes from type declaration --- rasa/core/training/story_conflict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 9c331e95c71b..d3040eb0a7ca 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -133,7 +133,7 @@ def _find_conflicting_states( def _build_conflicts_from_states( - trackers: List["TrackerWithCachedStates"], + trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int, state_action_dict: Dict[int, Optional[List[Text]]], From a88b6f9ec446989144586afd94742ac7306ebf06 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 17:14:37 +0100 Subject: [PATCH 156/209] Add comment for clarification --- rasa/core/training/story_conflict.py | 1 + 1 file changed, 1 insertion(+) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index d3040eb0a7ca..f9e2f692b620 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -154,6 +154,7 @@ def _build_conflicts_from_states( ) # Remove conflicts that arise from unpredictable actions + # (actions that start the conversation) return [ conflict for (hashed_state, conflict) in conflicts.items() From 71074d8654c271943bad428b05b35343a95b0baa Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 17:15:34 +0100 Subject: [PATCH 157/209] Spell out variable names --- rasa/core/training/story_conflict.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index f9e2f692b620..4a8a958533a7 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -206,23 +206,23 @@ def _get_previous_event( Returns: Tuple of (type, name) strings of the prior event. """ - prev_event_type = None - prev_event_name = None + previous_event_type = None + previous_event_name = None if not state: - return prev_event_type, prev_event_name + return previous_event_type, previous_event_name for k in state: if ( k.startswith(PREV_PREFIX) and k.replace(PREV_PREFIX, "") != ACTION_LISTEN_NAME ): - prev_event_type = "action" - prev_event_name = k.replace(PREV_PREFIX, "") + previous_event_type = "action" + previous_event_name = k.replace(PREV_PREFIX, "") break elif k.startswith(INTENT_ATTRIBUTE + "_"): - prev_event_type = "intent" - prev_event_name = k.replace(INTENT_ATTRIBUTE + "_", "") + previous_event_type = "intent" + previous_event_name = k.replace(INTENT_ATTRIBUTE + "_", "") break - return prev_event_type, prev_event_name + return previous_event_type, previous_event_name From 797df7d3eeae893eb0145e0da6e1231abf0cc4f7 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 17:17:16 +0100 Subject: [PATCH 158/209] Rename `turn_label` --- rasa/core/training/story_conflict.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 4a8a958533a7..3163a408d0c6 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -212,17 +212,17 @@ def _get_previous_event( if not state: return previous_event_type, previous_event_name - for k in state: + for turn_label in state: if ( - k.startswith(PREV_PREFIX) - and k.replace(PREV_PREFIX, "") != ACTION_LISTEN_NAME + turn_label.startswith(PREV_PREFIX) + and turn_label.replace(PREV_PREFIX, "") != ACTION_LISTEN_NAME ): previous_event_type = "action" - previous_event_name = k.replace(PREV_PREFIX, "") + previous_event_name = turn_label.replace(PREV_PREFIX, "") break - elif k.startswith(INTENT_ATTRIBUTE + "_"): + elif turn_label.startswith(INTENT_ATTRIBUTE + "_"): previous_event_type = "intent" - previous_event_name = k.replace(INTENT_ATTRIBUTE + "_", "") + previous_event_name = turn_label.replace(INTENT_ATTRIBUTE + "_", "") break return previous_event_type, previous_event_name From 3a7ccdf32a756011edb7a2f452ef5dd1764f2632 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 17:26:29 +0100 Subject: [PATCH 159/209] Use subclassing to define `TrackerEventStateTuple` --- rasa/core/training/story_conflict.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 3163a408d0c6..5b29086db637 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -1,5 +1,5 @@ from collections import defaultdict, namedtuple -from typing import List, Optional, Dict, Text, Tuple, Generator +from typing import List, Optional, Dict, Text, Tuple, Generator, NamedTuple from rasa.core.actions.action import ACTION_LISTEN_NAME from rasa.core.domain import PREV_PREFIX, Domain @@ -80,9 +80,12 @@ def _summarize_conflict(action, stories): return f"{action} predicted in {conflict_description}\n" -TrackerEventStateTuple = namedtuple( - "TrackerEventStateTuple", "tracker event sliced_states" -) +class TrackerEventStateTuple(NamedTuple): + """Holds a tracker, an event, and sliced states associated with those.""" + + tracker: TrackerWithCachedStates + event: Event + sliced_states: List[Dict[Text, float]] def find_story_conflicts( From d902a72c498f2bc86fb39d5fd061c03f7745eeef Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 20 Jan 2020 17:29:07 +0100 Subject: [PATCH 160/209] Define `TrackerEventStateTuple.sliced_states_hash` --- rasa/core/training/story_conflict.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 5b29086db637..7f1d35b04478 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -87,6 +87,10 @@ class TrackerEventStateTuple(NamedTuple): event: Event sliced_states: List[Dict[Text, float]] + @property + def sliced_states_hash(self): + return hash(str(list(self.sliced_states))) + def find_story_conflicts( trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int @@ -123,7 +127,7 @@ def _find_conflicting_states( # represented by its hash state_action_mapping = defaultdict(list) for element in _sliced_states_iterator(trackers, domain, max_history): - hashed_state = hash(str(list(element.sliced_states))) + hashed_state = element.sliced_states_hash if element.event.as_story_string() not in state_action_mapping[hashed_state]: state_action_mapping[hashed_state] += [element.event.as_story_string()] @@ -145,7 +149,7 @@ def _build_conflicts_from_states( # for which a conflict occurs conflicts = {} for element in _sliced_states_iterator(trackers, domain, max_history): - hashed_state = hash(str(list(element.sliced_states))) + hashed_state = element.sliced_states_hash if hashed_state in state_action_dict: if hashed_state not in conflicts: From 9225a19c3e7e3964a07ebbc0552cef5f09716570 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 27 Jan 2020 17:47:38 +0100 Subject: [PATCH 161/209] Use double quotemarks in rst Co-Authored-By: Tobias Wochinger --- docs/user-guide/validate-files.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user-guide/validate-files.rst b/docs/user-guide/validate-files.rst index d4f526703809..0e08cb6b8878 100644 --- a/docs/user-guide/validate-files.rst +++ b/docs/user-guide/validate-files.rst @@ -19,7 +19,7 @@ You can run it with the following command: rasa data validate The script above runs all the validations on your files, except for story structure validation, -which is omitted unless you provide the `--max-history` argument. Here is the list of options to +which is omitted unless you provide the ``--max-history`` argument. Here is the list of options to the script: .. program-output:: rasa data validate --help From 303b2a7704f57c02fbde13b9f77750b9674a7ef4 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 27 Jan 2020 17:49:08 +0100 Subject: [PATCH 162/209] Update rasa/core/training/story_conflict.py Co-Authored-By: Tobias Wochinger --- rasa/core/training/story_conflict.py | 1 - 1 file changed, 1 deletion(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 7f1d35b04478..92f5e93aef6e 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -104,7 +104,6 @@ def find_story_conflicts( Returns: List of conflicts. """ - # We do this in two steps, to reduce memory consumption: # Create a 'state -> list of actions' dict, where the state is From 8f7074a23c8d8841ccc37d3fda5409672e1d06fd Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 27 Jan 2020 17:56:23 +0100 Subject: [PATCH 163/209] Simplify code with `_append_story_structure_arguments` --- rasa/cli/data.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 77da050ac47a..7df3b0c46432 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -90,12 +90,7 @@ def _add_data_validate_parsers( parents=parents, help="Validates domain and data files to check for possible mistakes.", ) - validate_parser.add_argument( - "--max-history", - type=int, - default=None, - help="Number of turns taken into account for story structure validation.", - ) + _append_story_structure_arguments(validate_parser) validate_parser.set_defaults(func=validate_files) arguments.set_validator_arguments(validate_parser) @@ -106,13 +101,18 @@ def _add_data_validate_parsers( parents=parents, help="Checks for inconsistencies in the story files.", ) - story_structure_parser.add_argument( + _append_story_structure_arguments(story_structure_parser) + story_structure_parser.set_defaults(func=validate_stories) + arguments.set_validator_arguments(story_structure_parser) + + +def _append_story_structure_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument( "--max-history", type=int, + default=None, help="Number of turns taken into account for story structure validation.", ) - story_structure_parser.set_defaults(func=validate_stories) - arguments.set_validator_arguments(story_structure_parser) def split_nlu_data(args) -> None: From bd2f56df5d5c7e589800ff72fb861f41524be5a0 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 27 Jan 2020 18:53:23 +0100 Subject: [PATCH 164/209] Absorb `validate_stories` into `validate_files` --- rasa/cli/data.py | 57 +++++++++++++++--------------------------- rasa/core/validator.py | 2 +- 2 files changed, 21 insertions(+), 38 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 7df3b0c46432..b4e3fa9e4c23 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -8,6 +8,8 @@ from rasa.cli.arguments import data as arguments import rasa.cli.utils from rasa.constants import DEFAULT_DATA_PATH +from rasa.core.validator import Validator +from rasa.importers.rasa import RasaFileImporter logger = logging.getLogger(__name__) @@ -102,7 +104,7 @@ def _add_data_validate_parsers( help="Checks for inconsistencies in the story files.", ) _append_story_structure_arguments(story_structure_parser) - story_structure_parser.set_defaults(func=validate_stories) + story_structure_parser.set_defaults(func=validate_files, stories_only=True) arguments.set_validator_arguments(story_structure_parser) @@ -132,47 +134,42 @@ def split_nlu_data(args) -> None: def validate_files(args: argparse.Namespace) -> None: - """Validate all files needed for training a model. - - Fails with a non-zero exit code if there are any errors in the data.""" - from rasa.core.validator import Validator - from rasa.importers.rasa import RasaFileImporter - loop = asyncio.get_event_loop() file_importer = RasaFileImporter( domain_path=args.domain, training_data_paths=args.data ) validator = loop.run_until_complete(Validator.from_importer(file_importer)) - domain_is_valid = validator.verify_domain_validity() - if not domain_is_valid: - sys.exit(1) - everything_is_alright = validator.verify_all(not args.fail_on_warnings) - if not args.max_history: + if "stories_only" in args: + all_good = _validate_story_structure(validator, args) + elif not args.max_history: logger.info( "Will not test for inconsistencies in stories since " "you did not provide a value for `--max-history`." ) + all_good = _validate_domain(validator) and _validate_nlu(validator, args) else: - # Only run story structure validation if everything else is fine - # since this might take a while - everything_is_alright = validator.verify_story_structure( - not args.fail_on_warnings, max_history=args.max_history + all_good = ( + _validate_domain(validator) + and _validate_nlu(validator, args) + and _validate_story_structure(validator, args) ) - if not everything_is_alright: + if not all_good: rasa.cli.utils.print_error("Project validation completed with errors.") sys.exit(1) -def validate_stories(args: "argparse.Namespace") -> None: - """Validate only the story structure. +def _validate_domain(validator: Validator) -> bool: + return validator.verify_domain_validity() + - Fails with a non-zero exit code if there are any errors in the data.""" - from rasa.core.validator import Validator - from rasa.importers.rasa import RasaFileImporter +def _validate_nlu(validator: Validator, args: argparse.Namespace) -> bool: + return validator.verify_nlu(not args.fail_on_warnings) + +def _validate_story_structure(validator: Validator, args: argparse.Namespace) -> bool: # Check if a valid setting for `max_history` was given if not isinstance(args.max_history, int) or args.max_history < 1: raise argparse.ArgumentError( @@ -180,20 +177,6 @@ def validate_stories(args: "argparse.Namespace") -> None: "You have to provide a positive integer for --max-history.", ) - # Prepare story and domain file import - loop = asyncio.get_event_loop() - file_importer = RasaFileImporter( - domain_path=args.domain, training_data_paths=args.data - ) - - # Loads the stories - validator = loop.run_until_complete(Validator.from_importer(file_importer)) - - # If names are unique, look for story conflicts - stories_are_consistent = validator.verify_story_structure( + return validator.verify_story_structure( not args.fail_on_warnings, max_history=args.max_history ) - - if not stories_are_consistent: - rasa.cli.utils.print_error("Story validation completed with errors.") - sys.exit(1) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 8a3b8faff167..6a17e464ba1a 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -218,7 +218,7 @@ def verify_story_structure( return ignore_warnings or len(conflicts) == 0 - def verify_all(self, ignore_warnings: bool = True) -> bool: + def verify_nlu(self, ignore_warnings: bool = True) -> bool: """Runs all the validations on intents and utterances.""" logger.info("Validating intents...") From 7510b7b77a230f10f5fa336f5689a184dd1d5326 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 27 Jan 2020 18:54:44 +0100 Subject: [PATCH 165/209] Make `StoryConflict._sliced_states` private --- rasa/core/training/story_conflict.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 92f5e93aef6e..476562cefe44 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -11,13 +11,13 @@ class StoryConflict: def __init__(self, sliced_states: List[Optional[Dict[Text, float]]],) -> None: - self.sliced_states = sliced_states + self._sliced_states = sliced_states self._conflicting_actions = defaultdict( list ) # {"action": ["story_1", ...], ...} def __hash__(self) -> int: - return hash(str(list(self.sliced_states))) + return hash(str(list(self._sliced_states))) def add_conflicting_action(self, action: Text, story_name: Text) -> None: """Adds another action that follows from the same state. @@ -45,11 +45,11 @@ def has_prior_events(self) -> bool: Returns: True if anything has happened before this conflict, otherwise False. """ - return _get_previous_event(self.sliced_states[-1])[0] is not None + return _get_previous_event(self._sliced_states[-1])[0] is not None def __str__(self) -> Text: # Describe where the conflict occurs in the stories - last_event_type, last_event_name = _get_previous_event(self.sliced_states[-1]) + last_event_type, last_event_name = _get_previous_event(self._sliced_states[-1]) if last_event_type: conflict_message = ( f"CONFLICT after {last_event_type} '{last_event_name}':\n" From 06a6622e2cbadf7d42ce6859d0119699ac2a9ec3 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 27 Jan 2020 18:57:08 +0100 Subject: [PATCH 166/209] Rename `conflict_has_prior_events` --- rasa/core/training/story_conflict.py | 4 ++-- tests/core/test_story_conflict.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 476562cefe44..942fd2c00221 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -39,7 +39,7 @@ def conflicting_actions(self) -> List[Text]: return list(self._conflicting_actions.keys()) @property - def has_prior_events(self) -> bool: + def conflict_has_prior_events(self) -> bool: """Checks if prior events exist. Returns: @@ -164,7 +164,7 @@ def _build_conflicts_from_states( return [ conflict for (hashed_state, conflict) in conflicts.items() - if conflict.has_prior_events + if conflict.conflict_has_prior_events ] diff --git a/tests/core/test_story_conflict.py b/tests/core/test_story_conflict.py index 3cdf08a51aed..3dc569f54000 100644 --- a/tests/core/test_story_conflict.py +++ b/tests/core/test_story_conflict.py @@ -127,7 +127,7 @@ async def test_has_prior_events(): {"prev_utter_greet": 1.0, "intent_greet": 1.0}, ] conflict = StoryConflict(sliced_states) - assert conflict.has_prior_events + assert conflict.conflict_has_prior_events async def test_get_previous_event(): @@ -138,4 +138,4 @@ async def test_get_previous_event(): async def test_has_no_prior_events(): sliced_states = [None] conflict = StoryConflict(sliced_states) - assert not conflict.has_prior_events + assert not conflict.conflict_has_prior_events From fcd61b20c4e423fc9181d21e2d2cbc9efa525474 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 27 Jan 2020 18:58:18 +0100 Subject: [PATCH 167/209] Add quote ticks Co-Authored-By: Tobias Wochinger --- rasa/core/training/story_conflict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 942fd2c00221..d6477c663e09 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -43,7 +43,7 @@ def conflict_has_prior_events(self) -> bool: """Checks if prior events exist. Returns: - True if anything has happened before this conflict, otherwise False. + `True` if anything has happened before this conflict, otherwise `False`. """ return _get_previous_event(self._sliced_states[-1])[0] is not None From bd2a5ed8c9c59156dbf6208a04085242b16921b0 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 27 Jan 2020 19:01:16 +0100 Subject: [PATCH 168/209] Declare types for `StoryConflict._summarize_conflict` --- rasa/core/training/story_conflict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 942fd2c00221..4b5b20843e42 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -64,7 +64,7 @@ def __str__(self) -> Text: return conflict_message @staticmethod - def _summarize_conflict(action, stories): + def _summarize_conflict(action: Text, stories: List[Text]) -> Text: if len(stories) > 3: # Four or more stories are present conflict_description = ( From 047c73dbffe00da8dce3560f5d07f76f8d8be006 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 27 Jan 2020 19:02:50 +0100 Subject: [PATCH 169/209] Let conflict summary always show at least two names --- rasa/core/training/story_conflict.py | 2 +- tests/core/test_story_conflict.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 4b5b20843e42..ce4f7ea88cc3 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -68,7 +68,7 @@ def _summarize_conflict(action: Text, stories: List[Text]) -> Text: if len(stories) > 3: # Four or more stories are present conflict_description = ( - f"'{stories[0]}' and {len(stories) - 1} other trackers" + f"'{stories[0]}', '{stories[1]}', and {len(stories) - 2} other trackers" ) else: conflict_description = ( diff --git a/tests/core/test_story_conflict.py b/tests/core/test_story_conflict.py index 3dc569f54000..3882cd38a77a 100644 --- a/tests/core/test_story_conflict.py +++ b/tests/core/test_story_conflict.py @@ -102,7 +102,7 @@ async def test_find_conflicts_multiple_stories(): conflicts = find_story_conflicts(trackers, domain, 5) assert len(conflicts) == 1 - assert "and 3 other trackers" in str(conflicts[0]) + assert "and 2 other trackers" in str(conflicts[0]) async def test_add_conflicting_action(): From 5f5df0dfd343726301c922a07672408f8fecda86 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 27 Jan 2020 19:03:56 +0100 Subject: [PATCH 170/209] Declare return type of `TrackerEventStateTuple.sliced_states_hash` --- rasa/core/training/story_conflict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index ce4f7ea88cc3..75fdc8357f3f 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -88,7 +88,7 @@ class TrackerEventStateTuple(NamedTuple): sliced_states: List[Dict[Text, float]] @property - def sliced_states_hash(self): + def sliced_states_hash(self) -> int: return hash(str(list(self.sliced_states))) From 7f72d7c97ec5a5e43986457b40d43f03e0006523 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 27 Jan 2020 19:04:09 +0100 Subject: [PATCH 171/209] Update rasa/core/training/story_conflict.py Co-Authored-By: Tobias Wochinger --- rasa/core/training/story_conflict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index d6477c663e09..fbd210ca73e9 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -95,7 +95,7 @@ def sliced_states_hash(self): def find_story_conflicts( trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int ) -> List[StoryConflict]: - """Generates a list of StoryConflict objects, describing conflicts in the given trackers. + """Generates a list of `StoryConflict` objects, describing conflicts in the given trackers. Args: trackers: Trackers in which to search for conflicts. From c4c936f0ff253edceffa528d4bc7d77289569bc3 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 27 Jan 2020 19:06:23 +0100 Subject: [PATCH 172/209] Rename local variable `conflicting_state_action_mapping` --- rasa/core/training/story_conflict.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 75fdc8357f3f..8c19fba8c2a6 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -108,12 +108,12 @@ def find_story_conflicts( # Create a 'state -> list of actions' dict, where the state is # represented by its hash - state_action_mapping = _find_conflicting_states(trackers, domain, max_history) + conflicting_state_action_mapping = _find_conflicting_states(trackers, domain, max_history) # Iterate once more over all states and note the (unhashed) state, # for which a conflict occurs conflicts = _build_conflicts_from_states( - trackers, domain, max_history, state_action_mapping + trackers, domain, max_history, conflicting_state_action_mapping ) return conflicts @@ -142,7 +142,7 @@ def _build_conflicts_from_states( trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int, - state_action_dict: Dict[int, Optional[List[Text]]], + conflicting_state_action_mapping: Dict[int, Optional[List[Text]]], ) -> List["StoryConflict"]: # Iterate once more over all states and note the (unhashed) state, # for which a conflict occurs @@ -150,7 +150,7 @@ def _build_conflicts_from_states( for element in _sliced_states_iterator(trackers, domain, max_history): hashed_state = element.sliced_states_hash - if hashed_state in state_action_dict: + if hashed_state in conflicting_state_action_mapping: if hashed_state not in conflicts: conflicts[hashed_state] = StoryConflict(element.sliced_states) From aac90e0dc99934eb42253fb7b8433fdf31d194da Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 27 Jan 2020 19:07:40 +0100 Subject: [PATCH 173/209] Clarify comment --- rasa/core/training/story_conflict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 8c19fba8c2a6..343de48d8153 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -159,7 +159,7 @@ def _build_conflicts_from_states( story_name=element.tracker.sender_id, ) - # Remove conflicts that arise from unpredictable actions + # Return list of conflicts that arise from unpredictable actions # (actions that start the conversation) return [ conflict From 4e2b767de79cd216548b2b33c2bab73c6b336d2f Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 27 Jan 2020 19:10:57 +0100 Subject: [PATCH 174/209] Use `return` instead of `break` --- rasa/core/training/story_conflict.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 343de48d8153..83261c7e0ad9 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -108,7 +108,9 @@ def find_story_conflicts( # Create a 'state -> list of actions' dict, where the state is # represented by its hash - conflicting_state_action_mapping = _find_conflicting_states(trackers, domain, max_history) + conflicting_state_action_mapping = _find_conflicting_states( + trackers, domain, max_history + ) # Iterate once more over all states and note the (unhashed) state, # for which a conflict occurs @@ -212,23 +214,17 @@ def _get_previous_event( Returns: Tuple of (type, name) strings of the prior event. """ - previous_event_type = None - previous_event_name = None if not state: - return previous_event_type, previous_event_name + return None, None for turn_label in state: if ( turn_label.startswith(PREV_PREFIX) and turn_label.replace(PREV_PREFIX, "") != ACTION_LISTEN_NAME ): - previous_event_type = "action" - previous_event_name = turn_label.replace(PREV_PREFIX, "") - break + return "action", turn_label.replace(PREV_PREFIX, "") elif turn_label.startswith(INTENT_ATTRIBUTE + "_"): - previous_event_type = "intent" - previous_event_name = turn_label.replace(INTENT_ATTRIBUTE + "_", "") - break + return "intent", turn_label.replace(INTENT_ATTRIBUTE + "_", "") - return previous_event_type, previous_event_name + return None, None From af9d87704c0f90e6b53bbfb868c541b64e4c075a Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 27 Jan 2020 19:20:51 +0100 Subject: [PATCH 175/209] Fix _get_previous_event --- rasa/core/training/story_conflict.py | 16 +++++++++++++--- tests/core/test_story_conflict.py | 1 + 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 83261c7e0ad9..a14e32ec165f 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -215,16 +215,26 @@ def _get_previous_event( Tuple of (type, name) strings of the prior event. """ + previous_event_type = None + previous_event_name = None + if not state: - return None, None + return previous_event_type, previous_event_name + # A typical state is, for example, + # `{'prev_action_listen': 1.0, 'intent_greet': 1.0, 'slot_cuisine_0': 1.0}`. + # We need to look out for `prev_` and `intent_` prefixes in the labels. for turn_label in state: if ( turn_label.startswith(PREV_PREFIX) and turn_label.replace(PREV_PREFIX, "") != ACTION_LISTEN_NAME ): + # The `prev_...` was an action that was NOT `action_listen` return "action", turn_label.replace(PREV_PREFIX, "") elif turn_label.startswith(INTENT_ATTRIBUTE + "_"): - return "intent", turn_label.replace(INTENT_ATTRIBUTE + "_", "") + # We found an intent, but it is only the previous event if + # the `prev_...` was `prev_action_listen`, so we don't return. + previous_event_type = "intent" + previous_event_name = turn_label.replace(INTENT_ATTRIBUTE + "_", "") - return None, None + return previous_event_type, previous_event_name diff --git a/tests/core/test_story_conflict.py b/tests/core/test_story_conflict.py index 3882cd38a77a..f58066d1d485 100644 --- a/tests/core/test_story_conflict.py +++ b/tests/core/test_story_conflict.py @@ -132,6 +132,7 @@ async def test_has_prior_events(): async def test_get_previous_event(): assert _get_previous_event({"prev_utter_greet": 1.0, "intent_greet": 1.0}) == ("action", "utter_greet") + assert _get_previous_event({"intent_greet": 1.0, "prev_utter_greet": 1.0}) == ("action", "utter_greet") assert _get_previous_event({"intent_greet": 1.0, "prev_action_listen": 1.0}) == ("intent", "greet") From 339299b6456091541bd6fa51b3b013b0d44d2b9f Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 27 Jan 2020 19:27:16 +0100 Subject: [PATCH 176/209] Expand doc string --- rasa/core/validator.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 6a17e464ba1a..e7c43edc97d0 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -193,7 +193,16 @@ def verify_utterances_in_stories(self, ignore_warnings: bool = True) -> bool: def verify_story_structure( self, ignore_warnings: bool = True, max_history: int = 5 ) -> bool: - """Verifies that bot behaviour in stories is deterministic.""" + """Verifies that bot behaviour in stories is deterministic. + + Args: + ignore_warnings: When `True`, return `True` even if conflicts were found. + max_history: Maximal number of events to take into account for conflict identification. + + Returns: + `False` is a conflict was found and `ignore_warnings` is `False`. + `True` otherwise. + """ logger.info("Story structure validation...") logger.info( From 6c9764a6d53e9c8057b08050d197206a57ee3f95 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 27 Jan 2020 19:28:02 +0100 Subject: [PATCH 177/209] Use `not` instead of `len(...) == 0` --- rasa/core/validator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index e7c43edc97d0..5f235dc80480 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -219,13 +219,13 @@ def verify_story_structure( # Create a list of `StoryConflict` objects conflicts = find_story_conflicts(trackers, self.domain, max_history) - if len(conflicts) == 0: + if not conflicts: logger.info("No story structure conflicts found.") else: for conflict in conflicts: logger.warning(conflict) - return ignore_warnings or len(conflicts) == 0 + return ignore_warnings or not conflicts def verify_nlu(self, ignore_warnings: bool = True) -> bool: """Runs all the validations on intents and utterances.""" From 1b44bef232d2533ebb3b0e87169f8c477bdc9b46 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 27 Jan 2020 19:49:13 +0100 Subject: [PATCH 178/209] Add tests for `data validate ...` warnings --- rasa/cli/data.py | 2 +- tests/cli/test_rasa_data.py | 25 ++++++++++++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index b4e3fa9e4c23..3dabf598eb67 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -174,7 +174,7 @@ def _validate_story_structure(validator: Validator, args: argparse.Namespace) -> if not isinstance(args.max_history, int) or args.max_history < 1: raise argparse.ArgumentError( args.max_history, - "You have to provide a positive integer for --max-history.", + "You have to provide a positive integer for `--max-history`.", ) return validator.verify_story_structure( diff --git a/tests/cli/test_rasa_data.py b/tests/cli/test_rasa_data.py index 010098363e6c..25ef91406350 100644 --- a/tests/cli/test_rasa_data.py +++ b/tests/cli/test_rasa_data.py @@ -1,7 +1,7 @@ import os import pytest from collections import namedtuple -from typing import Callable +from typing import Callable, Text from _pytest.pytester import RunResult from rasa.cli import data @@ -69,6 +69,29 @@ def test_data_validate_help(run: Callable[..., RunResult]): assert output.outlines[i] == line +def _text_is_part_of_output_error(text: Text, output: RunResult) -> bool: + found_info_string = False + for line in output.errlines: + if text in line: + found_info_string = True + return found_info_string + + +def test_data_validate_without_max_history(run: Callable[..., RunResult]): + output = run("data", "validate") + assert _text_is_part_of_output_error("did not provide a value for `--max-history`", output) + + +def test_data_validate_stories_without_max_history(run: Callable[..., RunResult]): + output = run("data", "validate", "stories") + assert _text_is_part_of_output_error("have to provide a positive integer for `--max-history`", output) + + +def test_data_validate_stories_with_max_history_zero(run: Callable[..., RunResult]): + output = run("data", "validate", "stories", "--max-history", "0") + assert _text_is_part_of_output_error("have to provide a positive integer for `--max-history`", output) + + def test_validate_files_exit_early(): with pytest.raises(SystemExit) as pytest_e: args = {"domain": "data/test_domains/duplicate_intents.yml", "data": None} From 87e99b5df8d4b25330e900bc42913a03546be3d8 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 27 Jan 2020 19:56:53 +0100 Subject: [PATCH 179/209] Rephrase Co-Authored-By: Brian Hopkins --- docs/user-guide/validate-files.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user-guide/validate-files.rst b/docs/user-guide/validate-files.rst index 0e08cb6b8878..2421eea49e05 100644 --- a/docs/user-guide/validate-files.rst +++ b/docs/user-guide/validate-files.rst @@ -91,7 +91,7 @@ Take, for example, the following two stories: * inform_happy - utter_goodbye -These two stories are inconsistent, because Rasa cannot know if it should predict ``utter_happy`` or ``utter_goodbye`` +These two stories are inconsistent, because Rasa doesn't know if it should predict ``utter_happy`` or ``utter_goodbye`` after ``inform_happy``, as there is nothing that would distinguish the dialogue states at ``inform_happy`` in the two stories and the subsequent actions are different in Story 1 and Story 2. From f0c8eb54db738f8638c510cd1abdc12d68163526 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Tue, 28 Jan 2020 07:43:33 +0100 Subject: [PATCH 180/209] Add `run_in_default_project_with_info` --- rasa/core/training/story_conflict.py | 2 +- tests/cli/conftest.py | 12 ++++++++++++ tests/cli/test_rasa_data.py | 12 ++++++------ 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 314bb567d979..66d7222086ab 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -72,7 +72,7 @@ def _summarize_conflict(action: Text, stories: List[Text]) -> Text: ) else: conflict_description = ( - {1: "'{}'", 2: "'{}' and '{}'", 3: "'{}', '{}', and '{}'",} + {1: "'{}'", 2: "'{}' and '{}'", 3: "'{}', '{}', and '{}'"} .get(len(stories)) .format(*stories) ) diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py index b7294eaeadff..75ccf71b1f18 100644 --- a/tests/cli/conftest.py +++ b/tests/cli/conftest.py @@ -32,3 +32,15 @@ def do_run(*args): return testdir.run(*args) return do_run + + +@pytest.fixture +def run_in_default_project_with_info(testdir: Testdir) -> Callable[..., RunResult]: + os.environ["LOG_LEVEL"] = "INFO" + testdir.run("rasa", "init", "--no-prompt") + + def do_run(*args): + args = ["rasa"] + list(args) + return testdir.run(*args) + + return do_run diff --git a/tests/cli/test_rasa_data.py b/tests/cli/test_rasa_data.py index 25ef91406350..7a0e3910085f 100644 --- a/tests/cli/test_rasa_data.py +++ b/tests/cli/test_rasa_data.py @@ -77,18 +77,18 @@ def _text_is_part_of_output_error(text: Text, output: RunResult) -> bool: return found_info_string -def test_data_validate_without_max_history(run: Callable[..., RunResult]): - output = run("data", "validate") +def test_data_validate_without_max_history(run_in_default_project_with_info: Callable[..., RunResult]): + output = run_in_default_project_with_info("data", "validate") assert _text_is_part_of_output_error("did not provide a value for `--max-history`", output) -def test_data_validate_stories_without_max_history(run: Callable[..., RunResult]): - output = run("data", "validate", "stories") +def test_data_validate_stories_without_max_history(run_in_default_project_with_info: Callable[..., RunResult]): + output = run_in_default_project_with_info("data", "validate", "stories") assert _text_is_part_of_output_error("have to provide a positive integer for `--max-history`", output) -def test_data_validate_stories_with_max_history_zero(run: Callable[..., RunResult]): - output = run("data", "validate", "stories", "--max-history", "0") +def test_data_validate_stories_with_max_history_zero(run_in_default_project_with_info: Callable[..., RunResult]): + output = run_in_default_project_with_info("data", "validate", "stories", "--max-history", "0") assert _text_is_part_of_output_error("have to provide a positive integer for `--max-history`", output) From eea7a95e99c4d5a284dab60fc20602b788eac4a5 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Tue, 28 Jan 2020 07:44:27 +0100 Subject: [PATCH 181/209] Apply BLACK formatting --- tests/cli/test_rasa_data.py | 28 +++++++++++++++++++++------- tests/core/test_story_conflict.py | 21 +++++++++++++++++---- 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/tests/cli/test_rasa_data.py b/tests/cli/test_rasa_data.py index 7a0e3910085f..111aefc082d2 100644 --- a/tests/cli/test_rasa_data.py +++ b/tests/cli/test_rasa_data.py @@ -77,19 +77,33 @@ def _text_is_part_of_output_error(text: Text, output: RunResult) -> bool: return found_info_string -def test_data_validate_without_max_history(run_in_default_project_with_info: Callable[..., RunResult]): +def test_data_validate_without_max_history( + run_in_default_project_with_info: Callable[..., RunResult] +): output = run_in_default_project_with_info("data", "validate") - assert _text_is_part_of_output_error("did not provide a value for `--max-history`", output) + assert _text_is_part_of_output_error( + "did not provide a value for `--max-history`", output + ) -def test_data_validate_stories_without_max_history(run_in_default_project_with_info: Callable[..., RunResult]): +def test_data_validate_stories_without_max_history( + run_in_default_project_with_info: Callable[..., RunResult] +): output = run_in_default_project_with_info("data", "validate", "stories") - assert _text_is_part_of_output_error("have to provide a positive integer for `--max-history`", output) + assert _text_is_part_of_output_error( + "have to provide a positive integer for `--max-history`", output + ) -def test_data_validate_stories_with_max_history_zero(run_in_default_project_with_info: Callable[..., RunResult]): - output = run_in_default_project_with_info("data", "validate", "stories", "--max-history", "0") - assert _text_is_part_of_output_error("have to provide a positive integer for `--max-history`", output) +def test_data_validate_stories_with_max_history_zero( + run_in_default_project_with_info: Callable[..., RunResult] +): + output = run_in_default_project_with_info( + "data", "validate", "stories", "--max-history", "0" + ) + assert _text_is_part_of_output_error( + "have to provide a positive integer for `--max-history`", output + ) def test_validate_files_exit_early(): diff --git a/tests/core/test_story_conflict.py b/tests/core/test_story_conflict.py index f58066d1d485..28f8746d8d1b 100644 --- a/tests/core/test_story_conflict.py +++ b/tests/core/test_story_conflict.py @@ -1,4 +1,8 @@ -from rasa.core.training.story_conflict import StoryConflict, find_story_conflicts, _get_previous_event +from rasa.core.training.story_conflict import ( + StoryConflict, + find_story_conflicts, + _get_previous_event, +) from rasa.core.training.generator import TrainingDataGenerator from rasa.core.validator import Validator from rasa.importers.rasa import RasaFileImporter @@ -131,9 +135,18 @@ async def test_has_prior_events(): async def test_get_previous_event(): - assert _get_previous_event({"prev_utter_greet": 1.0, "intent_greet": 1.0}) == ("action", "utter_greet") - assert _get_previous_event({"intent_greet": 1.0, "prev_utter_greet": 1.0}) == ("action", "utter_greet") - assert _get_previous_event({"intent_greet": 1.0, "prev_action_listen": 1.0}) == ("intent", "greet") + assert _get_previous_event({"prev_utter_greet": 1.0, "intent_greet": 1.0}) == ( + "action", + "utter_greet", + ) + assert _get_previous_event({"intent_greet": 1.0, "prev_utter_greet": 1.0}) == ( + "action", + "utter_greet", + ) + assert _get_previous_event({"intent_greet": 1.0, "prev_action_listen": 1.0}) == ( + "intent", + "greet", + ) async def test_has_no_prior_events(): From 68d43be96e4ee005a7253155a52e0a8506910c28 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Tue, 28 Jan 2020 09:30:36 +0100 Subject: [PATCH 182/209] Fix `args.max_history` --- rasa/cli/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 3dabf598eb67..ca958f80ffb4 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -143,7 +143,7 @@ def validate_files(args: argparse.Namespace) -> None: if "stories_only" in args: all_good = _validate_story_structure(validator, args) - elif not args.max_history: + elif "max_history" not in args or args.max_history is None: logger.info( "Will not test for inconsistencies in stories since " "you did not provide a value for `--max-history`." From bb2535497c78d9796cf71863714a02b7f6e4edda Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Tue, 28 Jan 2020 15:05:54 +0100 Subject: [PATCH 183/209] Simplify `StoryConflict._summarize_conflict` --- rasa/core/training/story_conflict.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 66d7222086ab..7b15af0bd254 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -70,12 +70,20 @@ def _summarize_conflict(action: Text, stories: List[Text]) -> Text: conflict_description = ( f"'{stories[0]}', '{stories[1]}', and {len(stories) - 2} other trackers" ) - else: + elif len(stories) == 3: + conflict_description = ( + f"'{stories[0]}', '{stories[1]}', and '{stories[2]}'" + ) + elif len(stories) == 2: conflict_description = ( - {1: "'{}'", 2: "'{}' and '{}'", 3: "'{}', '{}', and '{}'"} - .get(len(stories)) - .format(*stories) + f"'{stories[0]}' and '{stories[1]}'" ) + elif len(stories) == 1: + conflict_description = ( + f"'{stories[0]}'" + ) + else: + raise ValueError("Trying to summarize conflict without stories.") return f"{action} predicted in {conflict_description}\n" From d73f89b828901e1f7c6ddf8c3cdbde7f2ab666e5 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Tue, 28 Jan 2020 16:43:59 +0100 Subject: [PATCH 184/209] Apply BLACK formatting --- rasa/core/training/story_conflict.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 7b15af0bd254..7502b95cbdda 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -71,17 +71,11 @@ def _summarize_conflict(action: Text, stories: List[Text]) -> Text: f"'{stories[0]}', '{stories[1]}', and {len(stories) - 2} other trackers" ) elif len(stories) == 3: - conflict_description = ( - f"'{stories[0]}', '{stories[1]}', and '{stories[2]}'" - ) + conflict_description = f"'{stories[0]}', '{stories[1]}', and '{stories[2]}'" elif len(stories) == 2: - conflict_description = ( - f"'{stories[0]}' and '{stories[1]}'" - ) + conflict_description = f"'{stories[0]}' and '{stories[1]}'" elif len(stories) == 1: - conflict_description = ( - f"'{stories[0]}'" - ) + conflict_description = f"'{stories[0]}'" else: raise ValueError("Trying to summarize conflict without stories.") From bde9ad49d27c743c53a17253685ab2124aa16991 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Tue, 28 Jan 2020 16:48:02 +0100 Subject: [PATCH 185/209] Add changelog --- changelog/4088.feature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog/4088.feature.rst diff --git a/changelog/4088.feature.rst b/changelog/4088.feature.rst new file mode 100644 index 000000000000..654e3b72c23e --- /dev/null +++ b/changelog/4088.feature.rst @@ -0,0 +1 @@ +Add story structure validation functionality (e.g. `rasa data validate stories --max-history 5`). From 90c4a30b64f9cff9af814a0d2235e4bf6cf4a9de Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 29 Jan 2020 18:36:22 +0100 Subject: [PATCH 186/209] Fix `args.max_history` and `stories_only` --- rasa/cli/data.py | 12 ++++++++---- tests/cli/test_rasa_data.py | 12 ++++++++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index ca958f80ffb4..1b593dfc272e 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -104,7 +104,7 @@ def _add_data_validate_parsers( help="Checks for inconsistencies in the story files.", ) _append_story_structure_arguments(story_structure_parser) - story_structure_parser.set_defaults(func=validate_files, stories_only=True) + story_structure_parser.set_defaults(func=validate_stories) arguments.set_validator_arguments(story_structure_parser) @@ -133,7 +133,7 @@ def split_nlu_data(args) -> None: test.persist(args.out, filename=f"test_data.{fformat}") -def validate_files(args: argparse.Namespace) -> None: +def validate_files(args: argparse.Namespace, stories_only: bool = False) -> None: loop = asyncio.get_event_loop() file_importer = RasaFileImporter( domain_path=args.domain, training_data_paths=args.data @@ -141,9 +141,9 @@ def validate_files(args: argparse.Namespace) -> None: validator = loop.run_until_complete(Validator.from_importer(file_importer)) - if "stories_only" in args: + if stories_only: all_good = _validate_story_structure(validator, args) - elif "max_history" not in args or args.max_history is None: + elif not args.max_history: logger.info( "Will not test for inconsistencies in stories since " "you did not provide a value for `--max-history`." @@ -161,6 +161,10 @@ def validate_files(args: argparse.Namespace) -> None: sys.exit(1) +def validate_stories(args: argparse.Namespace) -> None: + validate_files(args, stories_only=True) + + def _validate_domain(validator: Validator) -> bool: return validator.verify_domain_validity() diff --git a/tests/cli/test_rasa_data.py b/tests/cli/test_rasa_data.py index 111aefc082d2..23c8bb3574f6 100644 --- a/tests/cli/test_rasa_data.py +++ b/tests/cli/test_rasa_data.py @@ -78,7 +78,7 @@ def _text_is_part_of_output_error(text: Text, output: RunResult) -> bool: def test_data_validate_without_max_history( - run_in_default_project_with_info: Callable[..., RunResult] + run_in_default_project_with_info: Callable[..., RunResult] ): output = run_in_default_project_with_info("data", "validate") assert _text_is_part_of_output_error( @@ -87,7 +87,7 @@ def test_data_validate_without_max_history( def test_data_validate_stories_without_max_history( - run_in_default_project_with_info: Callable[..., RunResult] + run_in_default_project_with_info: Callable[..., RunResult] ): output = run_in_default_project_with_info("data", "validate", "stories") assert _text_is_part_of_output_error( @@ -96,7 +96,7 @@ def test_data_validate_stories_without_max_history( def test_data_validate_stories_with_max_history_zero( - run_in_default_project_with_info: Callable[..., RunResult] + run_in_default_project_with_info: Callable[..., RunResult] ): output = run_in_default_project_with_info( "data", "validate", "stories", "--max-history", "0" @@ -108,7 +108,11 @@ def test_data_validate_stories_with_max_history_zero( def test_validate_files_exit_early(): with pytest.raises(SystemExit) as pytest_e: - args = {"domain": "data/test_domains/duplicate_intents.yml", "data": None} + args = { + "domain": "data/test_domains/duplicate_intents.yml", + "data": None, + "max_history": None + } data.validate_files(namedtuple("Args", args.keys())(*args.values())) assert pytest_e.type == SystemExit From c2c9a370ab49775786e3a022ced9cc71728632e6 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 29 Jan 2020 18:37:36 +0100 Subject: [PATCH 187/209] Update docstring Co-Authored-By: Tobias Wochinger --- rasa/core/validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index 5f235dc80480..d24475ef6021 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -193,7 +193,7 @@ def verify_utterances_in_stories(self, ignore_warnings: bool = True) -> bool: def verify_story_structure( self, ignore_warnings: bool = True, max_history: int = 5 ) -> bool: - """Verifies that bot behaviour in stories is deterministic. + """Verifies that the bot behaviour in stories is deterministic. Args: ignore_warnings: When `True`, return `True` even if conflicts were found. From 0a0e70326a219f67426fbc168d0eb7521a518345 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 29 Jan 2020 18:41:13 +0100 Subject: [PATCH 188/209] Rename `test_find_conflicts_slots_that_break` and `_dont_break` --- tests/core/test_story_conflict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/test_story_conflict.py b/tests/core/test_story_conflict.py index 28f8746d8d1b..1d6b5fd4931b 100644 --- a/tests/core/test_story_conflict.py +++ b/tests/core/test_story_conflict.py @@ -74,7 +74,7 @@ async def test_find_conflicts_or(): assert conflicts[0].conflicting_actions == ["utter_default", "utter_goodbye"] -async def test_find_conflicts_slots(): +async def test_find_conflicts_slots_that_break(): trackers, domain = await _setup_trackers_for_testing( "data/test_domains/default.yml", "data/test_stories/stories_conflicting_4.md" ) @@ -86,7 +86,7 @@ async def test_find_conflicts_slots(): assert conflicts[0].conflicting_actions == ["utter_default", "utter_greet"] -async def test_find_conflicts_slots_2(): +async def test_find_conflicts_slots_that_dont_break(): trackers, domain = await _setup_trackers_for_testing( "data/test_domains/default.yml", "data/test_stories/stories_conflicting_5.md" ) From 7ccdef05a0adb55facbe5be1fd16abdb65c4bc25 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 29 Jan 2020 19:04:29 +0100 Subject: [PATCH 189/209] Add doc-strings --- rasa/cli/data.py | 23 ++++++--- rasa/core/training/story_conflict.py | 76 +++++++++++++++++++++++----- 2 files changed, 78 insertions(+), 21 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 1b593dfc272e..e565df81bacc 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -16,7 +16,7 @@ # noinspection PyProtectedMember def add_subparser( - subparsers: argparse._SubParsersAction, parents: List[argparse.ArgumentParser] + subparsers: argparse._SubParsersAction, parents: List[argparse.ArgumentParser] ): data_parser = subparsers.add_parser( "data", @@ -35,7 +35,7 @@ def add_subparser( def _add_data_convert_parsers( - data_subparsers, parents: List[argparse.ArgumentParser] + data_subparsers, parents: List[argparse.ArgumentParser] ) -> None: from rasa.nlu import convert @@ -60,7 +60,7 @@ def _add_data_convert_parsers( def _add_data_split_parsers( - data_subparsers, parents: List[argparse.ArgumentParser] + data_subparsers, parents: List[argparse.ArgumentParser] ) -> None: split_parser = data_subparsers.add_parser( "split", @@ -76,7 +76,7 @@ def _add_data_split_parsers( parents=parents, formatter_class=argparse.ArgumentDefaultsHelpFormatter, help="Performs a split of your NLU data into training and test data " - "according to the specified percentages.", + "according to the specified percentages.", ) nlu_split_parser.set_defaults(func=split_nlu_data) @@ -84,7 +84,7 @@ def _add_data_split_parsers( def _add_data_validate_parsers( - data_subparsers, parents: List[argparse.ArgumentParser] + data_subparsers, parents: List[argparse.ArgumentParser] ) -> None: validate_parser = data_subparsers.add_parser( "validate", @@ -134,6 +134,13 @@ def split_nlu_data(args) -> None: def validate_files(args: argparse.Namespace, stories_only: bool = False) -> None: + """ + Validates either the story structure or the entire project. + + Args: + args: Commandline arguments + stories_only: If `True`, only the story structure is validated. + """ loop = asyncio.get_event_loop() file_importer = RasaFileImporter( domain_path=args.domain, training_data_paths=args.data @@ -151,9 +158,9 @@ def validate_files(args: argparse.Namespace, stories_only: bool = False) -> None all_good = _validate_domain(validator) and _validate_nlu(validator, args) else: all_good = ( - _validate_domain(validator) - and _validate_nlu(validator, args) - and _validate_story_structure(validator, args) + _validate_domain(validator) + and _validate_nlu(validator, args) + and _validate_story_structure(validator, args) ) if not all_good: diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 7502b95cbdda..9cf374a34141 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -10,7 +10,25 @@ class StoryConflict: - def __init__(self, sliced_states: List[Optional[Dict[Text, float]]],) -> None: + """ + Represents a conflict between two or more stories. + + Here, a conflict means that different actions are supposed to follow from + the same dialogue state, which most policies cannot learn. + + Attributes: + conflicting_actions: A list of actions that all follow from the same state. + conflict_has_prior_events: If `False`, then the conflict occurs without any + prior events (i.e. at the beginning of a dialogue). + """ + + def __init__(self, sliced_states: List[Optional[Dict[Text, float]]], ) -> None: + """ + Creates a `StoryConflict` from a given state. + + Args: + sliced_states: The (sliced) dialogue state at which the conflict occurs. + """ self._sliced_states = sliced_states self._conflicting_actions = defaultdict( list @@ -59,12 +77,21 @@ def __str__(self) -> Text: # List which stories are in conflict with one another for action, stories in self._conflicting_actions.items(): - conflict_message += " " + self._summarize_conflict(action, stories) + conflict_message += f" {self._summarize_action_occurence(action, stories)}" return conflict_message @staticmethod - def _summarize_conflict(action: Text, stories: List[Text]) -> Text: + def _summarize_action_occurence(action: Text, stories: List[Text]) -> Text: + """Gives a summarized textual description of where one action occurs. + + Args: + action: The name of the action. + stories: The stories in which the action occurs. + + Returns: + A textural summary. + """ if len(stories) > 3: # Four or more stories are present conflict_description = ( @@ -95,7 +122,7 @@ def sliced_states_hash(self) -> int: def find_story_conflicts( - trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int + trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int ) -> List[StoryConflict]: """Generates a list of `StoryConflict` objects, describing conflicts in the given trackers. @@ -124,8 +151,18 @@ def find_story_conflicts( def _find_conflicting_states( - trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int + trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int ) -> Dict[int, Optional[List[Text]]]: + """Identifies all states from which different actions follow. + + Args: + trackers: Trackers that contain the states. + domain: The domain object. + max_history: Number of turns to take into account for the state descriptions. + + Returns: + A dictionary mapping state-hashes to a list of actions that follow from each state. + """ # Create a 'state -> list of actions' dict, where the state is # represented by its hash state_action_mapping = defaultdict(list) @@ -143,11 +180,24 @@ def _find_conflicting_states( def _build_conflicts_from_states( - trackers: List[TrackerWithCachedStates], - domain: Domain, - max_history: int, - conflicting_state_action_mapping: Dict[int, Optional[List[Text]]], + trackers: List[TrackerWithCachedStates], + domain: Domain, + max_history: int, + conflicting_state_action_mapping: Dict[int, Optional[List[Text]]], ) -> List["StoryConflict"]: + """Builds a list of `StoryConflict` objects for each given conflict. + + Args: + trackers: Trackers that contain the states. + domain: The domain object. + max_history: Number of turns to take into account for the state descriptions. + conflicting_state_action_mapping: A dictionary mapping state-hashes to a list of actions + that follow from each state. + + Returns: + A list of `StoryConflict` objects that describe inconsistencies in the story + structure. These objects also contain the history that leads up to the conflict. + """ # Iterate once more over all states and note the (unhashed) state, # for which a conflict occurs conflicts = {} @@ -173,7 +223,7 @@ def _build_conflicts_from_states( def _sliced_states_iterator( - trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int + trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int ) -> Generator[TrackerEventStateTuple, None, None]: """Creates an iterator over sliced states. @@ -203,7 +253,7 @@ def _sliced_states_iterator( def _get_previous_event( - state: Optional[Dict[Text, float]] + state: Optional[Dict[Text, float]] ) -> Tuple[Optional[Text], Optional[Text]]: """Returns previous event type and name. @@ -228,8 +278,8 @@ def _get_previous_event( # We need to look out for `prev_` and `intent_` prefixes in the labels. for turn_label in state: if ( - turn_label.startswith(PREV_PREFIX) - and turn_label.replace(PREV_PREFIX, "") != ACTION_LISTEN_NAME + turn_label.startswith(PREV_PREFIX) + and turn_label.replace(PREV_PREFIX, "") != ACTION_LISTEN_NAME ): # The `prev_...` was an action that was NOT `action_listen` return "action", turn_label.replace(PREV_PREFIX, "") From 300a8ae768b45dac0b00dfb77e7743aa2de9d560 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Wed, 29 Jan 2020 20:15:43 +0100 Subject: [PATCH 190/209] Fix `run_in_default_project` vis. `os.environ["LOG_LEVEL"]` --- tests/cli/conftest.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py index 75ccf71b1f18..4340c3a4649d 100644 --- a/tests/cli/conftest.py +++ b/tests/cli/conftest.py @@ -29,14 +29,15 @@ def run_in_default_project(testdir: Testdir) -> Callable[..., RunResult]: def do_run(*args): args = ["rasa"] + list(args) - return testdir.run(*args) + result = testdir.run(*args) + os.environ["LOG_LEVEL"] = "INFO" + return result return do_run @pytest.fixture def run_in_default_project_with_info(testdir: Testdir) -> Callable[..., RunResult]: - os.environ["LOG_LEVEL"] = "INFO" testdir.run("rasa", "init", "--no-prompt") def do_run(*args): From 34ba1def216a1472da1e617067c100f157f75579 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Thu, 30 Jan 2020 10:49:52 +0100 Subject: [PATCH 191/209] Apply BLACK formatting --- rasa/cli/data.py | 16 ++++++++-------- rasa/core/training/story_conflict.py | 22 +++++++++++----------- tests/cli/test_rasa_data.py | 8 ++++---- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index e565df81bacc..b9ef4705efeb 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -16,7 +16,7 @@ # noinspection PyProtectedMember def add_subparser( - subparsers: argparse._SubParsersAction, parents: List[argparse.ArgumentParser] + subparsers: argparse._SubParsersAction, parents: List[argparse.ArgumentParser] ): data_parser = subparsers.add_parser( "data", @@ -35,7 +35,7 @@ def add_subparser( def _add_data_convert_parsers( - data_subparsers, parents: List[argparse.ArgumentParser] + data_subparsers, parents: List[argparse.ArgumentParser] ) -> None: from rasa.nlu import convert @@ -60,7 +60,7 @@ def _add_data_convert_parsers( def _add_data_split_parsers( - data_subparsers, parents: List[argparse.ArgumentParser] + data_subparsers, parents: List[argparse.ArgumentParser] ) -> None: split_parser = data_subparsers.add_parser( "split", @@ -76,7 +76,7 @@ def _add_data_split_parsers( parents=parents, formatter_class=argparse.ArgumentDefaultsHelpFormatter, help="Performs a split of your NLU data into training and test data " - "according to the specified percentages.", + "according to the specified percentages.", ) nlu_split_parser.set_defaults(func=split_nlu_data) @@ -84,7 +84,7 @@ def _add_data_split_parsers( def _add_data_validate_parsers( - data_subparsers, parents: List[argparse.ArgumentParser] + data_subparsers, parents: List[argparse.ArgumentParser] ) -> None: validate_parser = data_subparsers.add_parser( "validate", @@ -158,9 +158,9 @@ def validate_files(args: argparse.Namespace, stories_only: bool = False) -> None all_good = _validate_domain(validator) and _validate_nlu(validator, args) else: all_good = ( - _validate_domain(validator) - and _validate_nlu(validator, args) - and _validate_story_structure(validator, args) + _validate_domain(validator) + and _validate_nlu(validator, args) + and _validate_story_structure(validator, args) ) if not all_good: diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 9cf374a34141..d043541bfd08 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -22,7 +22,7 @@ class StoryConflict: prior events (i.e. at the beginning of a dialogue). """ - def __init__(self, sliced_states: List[Optional[Dict[Text, float]]], ) -> None: + def __init__(self, sliced_states: List[Optional[Dict[Text, float]]],) -> None: """ Creates a `StoryConflict` from a given state. @@ -122,7 +122,7 @@ def sliced_states_hash(self) -> int: def find_story_conflicts( - trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int + trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int ) -> List[StoryConflict]: """Generates a list of `StoryConflict` objects, describing conflicts in the given trackers. @@ -151,7 +151,7 @@ def find_story_conflicts( def _find_conflicting_states( - trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int + trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int ) -> Dict[int, Optional[List[Text]]]: """Identifies all states from which different actions follow. @@ -180,10 +180,10 @@ def _find_conflicting_states( def _build_conflicts_from_states( - trackers: List[TrackerWithCachedStates], - domain: Domain, - max_history: int, - conflicting_state_action_mapping: Dict[int, Optional[List[Text]]], + trackers: List[TrackerWithCachedStates], + domain: Domain, + max_history: int, + conflicting_state_action_mapping: Dict[int, Optional[List[Text]]], ) -> List["StoryConflict"]: """Builds a list of `StoryConflict` objects for each given conflict. @@ -223,7 +223,7 @@ def _build_conflicts_from_states( def _sliced_states_iterator( - trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int + trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int ) -> Generator[TrackerEventStateTuple, None, None]: """Creates an iterator over sliced states. @@ -253,7 +253,7 @@ def _sliced_states_iterator( def _get_previous_event( - state: Optional[Dict[Text, float]] + state: Optional[Dict[Text, float]] ) -> Tuple[Optional[Text], Optional[Text]]: """Returns previous event type and name. @@ -278,8 +278,8 @@ def _get_previous_event( # We need to look out for `prev_` and `intent_` prefixes in the labels. for turn_label in state: if ( - turn_label.startswith(PREV_PREFIX) - and turn_label.replace(PREV_PREFIX, "") != ACTION_LISTEN_NAME + turn_label.startswith(PREV_PREFIX) + and turn_label.replace(PREV_PREFIX, "") != ACTION_LISTEN_NAME ): # The `prev_...` was an action that was NOT `action_listen` return "action", turn_label.replace(PREV_PREFIX, "") diff --git a/tests/cli/test_rasa_data.py b/tests/cli/test_rasa_data.py index 23c8bb3574f6..6cf331a8feb5 100644 --- a/tests/cli/test_rasa_data.py +++ b/tests/cli/test_rasa_data.py @@ -78,7 +78,7 @@ def _text_is_part_of_output_error(text: Text, output: RunResult) -> bool: def test_data_validate_without_max_history( - run_in_default_project_with_info: Callable[..., RunResult] + run_in_default_project_with_info: Callable[..., RunResult] ): output = run_in_default_project_with_info("data", "validate") assert _text_is_part_of_output_error( @@ -87,7 +87,7 @@ def test_data_validate_without_max_history( def test_data_validate_stories_without_max_history( - run_in_default_project_with_info: Callable[..., RunResult] + run_in_default_project_with_info: Callable[..., RunResult] ): output = run_in_default_project_with_info("data", "validate", "stories") assert _text_is_part_of_output_error( @@ -96,7 +96,7 @@ def test_data_validate_stories_without_max_history( def test_data_validate_stories_with_max_history_zero( - run_in_default_project_with_info: Callable[..., RunResult] + run_in_default_project_with_info: Callable[..., RunResult] ): output = run_in_default_project_with_info( "data", "validate", "stories", "--max-history", "0" @@ -111,7 +111,7 @@ def test_validate_files_exit_early(): args = { "domain": "data/test_domains/duplicate_intents.yml", "data": None, - "max_history": None + "max_history": None, } data.validate_files(namedtuple("Args", args.keys())(*args.values())) From 2f54c1f0b9f140d139f0e03740d5ee42d1211be7 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 10 Feb 2020 17:30:59 +0100 Subject: [PATCH 192/209] Enable story structure validation without `max_history` --- rasa/cli/data.py | 10 ++-------- rasa/core/training/story_conflict.py | 13 ++++++++++++- rasa/core/validator.py | 7 ++----- tests/cli/test_rasa_data.py | 20 +------------------- 4 files changed, 17 insertions(+), 33 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index b9ef4705efeb..399097753ca3 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -150,12 +150,6 @@ def validate_files(args: argparse.Namespace, stories_only: bool = False) -> None if stories_only: all_good = _validate_story_structure(validator, args) - elif not args.max_history: - logger.info( - "Will not test for inconsistencies in stories since " - "you did not provide a value for `--max-history`." - ) - all_good = _validate_domain(validator) and _validate_nlu(validator, args) else: all_good = ( _validate_domain(validator) @@ -182,10 +176,10 @@ def _validate_nlu(validator: Validator, args: argparse.Namespace) -> bool: def _validate_story_structure(validator: Validator, args: argparse.Namespace) -> bool: # Check if a valid setting for `max_history` was given - if not isinstance(args.max_history, int) or args.max_history < 1: + if isinstance(args.max_history, int) and args.max_history < 1: raise argparse.ArgumentError( args.max_history, - "You have to provide a positive integer for `--max-history`.", + f"The value of `--max-history {args.max_history}` is not a positive integer.", ) return validator.verify_story_structure( diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index d043541bfd08..42536a90f55d 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -1,3 +1,4 @@ +import logging from collections import defaultdict, namedtuple from typing import List, Optional, Dict, Text, Tuple, Generator, NamedTuple @@ -8,6 +9,8 @@ from rasa.nlu.constants import INTENT_ATTRIBUTE from rasa.core.training.generator import TrackerWithCachedStates +logger = logging.getLogger(__name__) + class StoryConflict: """ @@ -122,7 +125,9 @@ def sliced_states_hash(self) -> int: def find_story_conflicts( - trackers: List[TrackerWithCachedStates], domain: Domain, max_history: int + trackers: List[TrackerWithCachedStates], + domain: Domain, + max_history: Optional[int] = None, ) -> List[StoryConflict]: """Generates a list of `StoryConflict` objects, describing conflicts in the given trackers. @@ -133,6 +138,12 @@ def find_story_conflicts( Returns: List of conflicts. """ + # Use the length of the longest story for `max_history` if not specified otherwise + if not max_history: + max_history = max([len(tracker.past_states(domain)) for tracker in trackers]) + + logger.info(f"Considering the preceding {max_history} turns for conflict analysis.") + # We do this in two steps, to reduce memory consumption: # Create a 'state -> list of actions' dict, where the state is diff --git a/rasa/core/validator.py b/rasa/core/validator.py index d24475ef6021..d954d53f6509 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -1,7 +1,7 @@ import logging import warnings from collections import defaultdict -from typing import Set, Text +from typing import Set, Text, Optional from rasa.core.domain import Domain from rasa.core.training.generator import TrainingDataGenerator from rasa.importers.importer import TrainingDataImporter @@ -191,7 +191,7 @@ def verify_utterances_in_stories(self, ignore_warnings: bool = True) -> bool: return everything_is_alright def verify_story_structure( - self, ignore_warnings: bool = True, max_history: int = 5 + self, ignore_warnings: bool = True, max_history: Optional[int] = None ) -> bool: """Verifies that the bot behaviour in stories is deterministic. @@ -205,9 +205,6 @@ def verify_story_structure( """ logger.info("Story structure validation...") - logger.info( - f"Considering the preceding {max_history} turns for conflict analysis." - ) trackers = TrainingDataGenerator( self.story_graph, diff --git a/tests/cli/test_rasa_data.py b/tests/cli/test_rasa_data.py index 6cf331a8feb5..86e0020677ef 100644 --- a/tests/cli/test_rasa_data.py +++ b/tests/cli/test_rasa_data.py @@ -77,24 +77,6 @@ def _text_is_part_of_output_error(text: Text, output: RunResult) -> bool: return found_info_string -def test_data_validate_without_max_history( - run_in_default_project_with_info: Callable[..., RunResult] -): - output = run_in_default_project_with_info("data", "validate") - assert _text_is_part_of_output_error( - "did not provide a value for `--max-history`", output - ) - - -def test_data_validate_stories_without_max_history( - run_in_default_project_with_info: Callable[..., RunResult] -): - output = run_in_default_project_with_info("data", "validate", "stories") - assert _text_is_part_of_output_error( - "have to provide a positive integer for `--max-history`", output - ) - - def test_data_validate_stories_with_max_history_zero( run_in_default_project_with_info: Callable[..., RunResult] ): @@ -102,7 +84,7 @@ def test_data_validate_stories_with_max_history_zero( "data", "validate", "stories", "--max-history", "0" ) assert _text_is_part_of_output_error( - "have to provide a positive integer for `--max-history`", output + "is not a positive integer", output ) From aea42dbae1d278464b5d123c10f023f4d38a3cad Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 10 Feb 2020 17:41:45 +0100 Subject: [PATCH 193/209] Use `print_error_and_exit` --- rasa/cli/data.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 399097753ca3..34637eb24380 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -158,8 +158,7 @@ def validate_files(args: argparse.Namespace, stories_only: bool = False) -> None ) if not all_good: - rasa.cli.utils.print_error("Project validation completed with errors.") - sys.exit(1) + rasa.cli.utils.print_error_and_exit("Project validation completed with errors.") def validate_stories(args: argparse.Namespace) -> None: From e4d484abf4d1ea313a01666a17c61aaa3d2121ce Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 10 Feb 2020 17:44:25 +0100 Subject: [PATCH 194/209] Rename `_summarize_conflicting_actions` --- rasa/core/training/story_conflict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 42536a90f55d..6d1d59dbeb5d 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -80,12 +80,12 @@ def __str__(self) -> Text: # List which stories are in conflict with one another for action, stories in self._conflicting_actions.items(): - conflict_message += f" {self._summarize_action_occurence(action, stories)}" + conflict_message += f" {self._summarize_conflicting_actions(action, stories)}" return conflict_message @staticmethod - def _summarize_action_occurence(action: Text, stories: List[Text]) -> Text: + def _summarize_conflicting_actions(action: Text, stories: List[Text]) -> Text: """Gives a summarized textual description of where one action occurs. Args: From b1e4d71d56f6c3da48f0e4503f887bd2869efda8 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 10 Feb 2020 17:46:05 +0100 Subject: [PATCH 195/209] Change error message text --- rasa/core/training/story_conflict.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 6d1d59dbeb5d..53445b18205c 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -80,7 +80,9 @@ def __str__(self) -> Text: # List which stories are in conflict with one another for action, stories in self._conflicting_actions.items(): - conflict_message += f" {self._summarize_conflicting_actions(action, stories)}" + conflict_message += ( + f" {self._summarize_conflicting_actions(action, stories)}" + ) return conflict_message @@ -107,7 +109,10 @@ def _summarize_conflicting_actions(action: Text, stories: List[Text]) -> Text: elif len(stories) == 1: conflict_description = f"'{stories[0]}'" else: - raise ValueError("Trying to summarize conflict without stories.") + raise ValueError( + "An internal error occurred while trying to summarise a conflict without stories. " + "Please file a bug report at https://github.com/RasaHQ/rasa." + ) return f"{action} predicted in {conflict_description}\n" From c7fae3b9d99b28bc7f65e897f22ebdfccd5993c8 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 10 Feb 2020 17:48:44 +0100 Subject: [PATCH 196/209] Fix docstring formatting --- rasa/core/training/story_conflict.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 53445b18205c..3238dbf34d4c 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -134,14 +134,15 @@ def find_story_conflicts( domain: Domain, max_history: Optional[int] = None, ) -> List[StoryConflict]: - """Generates a list of `StoryConflict` objects, describing conflicts in the given trackers. + """Generates `StoryConflict` objects, describing conflicts in the given trackers. Args: trackers: Trackers in which to search for conflicts. domain: The domain. max_history: The maximum history length to be taken into account. + Returns: - List of conflicts. + StoryConflict objects. """ # Use the length of the longest story for `max_history` if not specified otherwise if not max_history: From cebf957232f2cd47f4d4dc40e839c166f9f82756 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 10 Feb 2020 17:50:12 +0100 Subject: [PATCH 197/209] Avoid importing `find_story_conflicts` directly --- rasa/core/validator.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/rasa/core/validator.py b/rasa/core/validator.py index d954d53f6509..f4c6d8a524bf 100644 --- a/rasa/core/validator.py +++ b/rasa/core/validator.py @@ -10,7 +10,7 @@ from rasa.core.training.dsl import UserUttered from rasa.core.training.dsl import ActionExecuted from rasa.core.constants import UTTER_PREFIX -from rasa.core.training.story_conflict import find_story_conflicts +import rasa.core.training.story_conflict logger = logging.getLogger(__name__) @@ -214,7 +214,9 @@ def verify_story_structure( ).generate() # Create a list of `StoryConflict` objects - conflicts = find_story_conflicts(trackers, self.domain, max_history) + conflicts = rasa.core.training.story_conflict.find_story_conflicts( + trackers, self.domain, max_history + ) if not conflicts: logger.info("No story structure conflicts found.") From 4ab25371bb5e234a9468ada40cc1da3852e84875 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Tue, 11 Feb 2020 10:58:39 +0100 Subject: [PATCH 198/209] Apply BLACK formatting --- tests/cli/test_rasa_data.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/cli/test_rasa_data.py b/tests/cli/test_rasa_data.py index 86e0020677ef..bf182cc3eb50 100644 --- a/tests/cli/test_rasa_data.py +++ b/tests/cli/test_rasa_data.py @@ -83,9 +83,7 @@ def test_data_validate_stories_with_max_history_zero( output = run_in_default_project_with_info( "data", "validate", "stories", "--max-history", "0" ) - assert _text_is_part_of_output_error( - "is not a positive integer", output - ) + assert _text_is_part_of_output_error("is not a positive integer", output) def test_validate_files_exit_early(): From eb2a5f8104c95c071135c83d8970e014af23bbab Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 17 Feb 2020 14:11:17 +0100 Subject: [PATCH 199/209] Avoid all-caps output --- rasa/core/training/story_conflict.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index 3238dbf34d4c..a147a73390ae 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -72,11 +72,11 @@ def __str__(self) -> Text: # Describe where the conflict occurs in the stories last_event_type, last_event_name = _get_previous_event(self._sliced_states[-1]) if last_event_type: + conflict_message = f"Story structure conflict after {last_event_type} '{last_event_name}':\n" + else: conflict_message = ( - f"CONFLICT after {last_event_type} '{last_event_name}':\n" + f"Story structure conflict at the beginning of stories:\n" ) - else: - conflict_message = f"CONFLICT at the beginning of stories:\n" # List which stories are in conflict with one another for action, stories in self._conflicting_actions.items(): From d5344d9185d90d09d43c268822190cade9600373 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Tue, 18 Feb 2020 11:14:53 +0100 Subject: [PATCH 200/209] Add _get_length_of_longest_story --- rasa/core/training/story_conflict.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/rasa/core/training/story_conflict.py b/rasa/core/training/story_conflict.py index a147a73390ae..2510608a68b8 100644 --- a/rasa/core/training/story_conflict.py +++ b/rasa/core/training/story_conflict.py @@ -129,6 +129,21 @@ def sliced_states_hash(self) -> int: return hash(str(list(self.sliced_states))) +def _get_length_of_longest_story( + trackers: List[TrackerWithCachedStates], domain: Domain +) -> int: + """Returns the longest story in the given trackers. + + Args: + trackers: Trackers to get stories from. + domain: The domain. + + Returns: + The maximal length of any story + """ + return max([len(tracker.past_states(domain)) for tracker in trackers]) + + def find_story_conflicts( trackers: List[TrackerWithCachedStates], domain: Domain, @@ -144,9 +159,8 @@ def find_story_conflicts( Returns: StoryConflict objects. """ - # Use the length of the longest story for `max_history` if not specified otherwise if not max_history: - max_history = max([len(tracker.past_states(domain)) for tracker in trackers]) + max_history = _get_length_of_longest_story(trackers, domain) logger.info(f"Considering the preceding {max_history} turns for conflict analysis.") From 5ec7e549ea1ca73c7790a538146f9c63be1adc50 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Tue, 18 Feb 2020 11:20:46 +0100 Subject: [PATCH 201/209] Add types for _setup_trackers_for_testing --- tests/core/test_story_conflict.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/core/test_story_conflict.py b/tests/core/test_story_conflict.py index 1d6b5fd4931b..2e530f31a22a 100644 --- a/tests/core/test_story_conflict.py +++ b/tests/core/test_story_conflict.py @@ -1,15 +1,20 @@ +from typing import Text, List, Tuple + +from rasa.core.domain import Domain from rasa.core.training.story_conflict import ( StoryConflict, find_story_conflicts, _get_previous_event, ) -from rasa.core.training.generator import TrainingDataGenerator +from rasa.core.training.generator import TrainingDataGenerator, TrackerWithCachedStates from rasa.core.validator import Validator from rasa.importers.rasa import RasaFileImporter from tests.core.conftest import DEFAULT_STORIES_FILE, DEFAULT_DOMAIN_PATH_WITH_SLOTS -async def _setup_trackers_for_testing(domain_path, training_data_file): +async def _setup_trackers_for_testing( + domain_path: Text, training_data_file: Text +) -> Tuple[List[TrackerWithCachedStates], Domain]: importer = RasaFileImporter( domain_path=domain_path, training_data_paths=[training_data_file], ) From 31e02b8e45d6ead9f4bf0bab79ff46a3118e04cb Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Tue, 18 Feb 2020 11:45:46 +0100 Subject: [PATCH 202/209] Add type hint for `split_nlu_data` --- rasa/cli/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 34637eb24380..d858d35041b2 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -117,7 +117,7 @@ def _append_story_structure_arguments(parser: argparse.ArgumentParser) -> None: ) -def split_nlu_data(args) -> None: +def split_nlu_data(args: argparse.Namespace) -> None: from rasa.nlu.training_data.loading import load_data from rasa.nlu.training_data.util import get_file_format From bf16cc4e4e65d4cce28caee4e4f21e90aa20dd31 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 21 Feb 2020 14:50:24 +0100 Subject: [PATCH 203/209] Use Monkeypatch for `test_data_validate_stories_with_max_history_zero` --- rasa/cli/data.py | 3 +-- tests/cli/conftest.py | 11 ----------- tests/cli/test_rasa_data.py | 35 +++++++++++++++++++++++++++++------ 3 files changed, 30 insertions(+), 19 deletions(-) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index d858d35041b2..102c032a34cb 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -176,8 +176,7 @@ def _validate_nlu(validator: Validator, args: argparse.Namespace) -> bool: def _validate_story_structure(validator: Validator, args: argparse.Namespace) -> bool: # Check if a valid setting for `max_history` was given if isinstance(args.max_history, int) and args.max_history < 1: - raise argparse.ArgumentError( - args.max_history, + raise argparse.ArgumentTypeError( f"The value of `--max-history {args.max_history}` is not a positive integer.", ) diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py index d912c8d31781..2ac6e97f47e3 100644 --- a/tests/cli/conftest.py +++ b/tests/cli/conftest.py @@ -57,14 +57,3 @@ def do_run(*args): return result return do_run - - -@pytest.fixture -def run_in_default_project_with_info(testdir: Testdir) -> Callable[..., RunResult]: - testdir.run("rasa", "init", "--no-prompt") - - def do_run(*args): - args = ["rasa"] + list(args) - return testdir.run(*args) - - return do_run diff --git a/tests/cli/test_rasa_data.py b/tests/cli/test_rasa_data.py index bf182cc3eb50..1a84b3b48bf8 100644 --- a/tests/cli/test_rasa_data.py +++ b/tests/cli/test_rasa_data.py @@ -1,7 +1,11 @@ +import argparse import os +from unittest.mock import Mock import pytest from collections import namedtuple from typing import Callable, Text + +from _pytest.monkeypatch import MonkeyPatch from _pytest.pytester import RunResult from rasa.cli import data @@ -77,13 +81,32 @@ def _text_is_part_of_output_error(text: Text, output: RunResult) -> bool: return found_info_string -def test_data_validate_stories_with_max_history_zero( - run_in_default_project_with_info: Callable[..., RunResult] -): - output = run_in_default_project_with_info( - "data", "validate", "stories", "--max-history", "0" +def test_data_validate_stories_with_max_history_zero(monkeypatch: MonkeyPatch): + import rasa.cli.data as data + + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(help="Rasa commands") + data.add_subparser(subparsers, parents=[]) + + args = parser.parse_args( + [ + "data", + "validate", + "stories", + "--max-history", + 0 + ] ) - assert _text_is_part_of_output_error("is not a positive integer", output) + + import rasa.cli.data as data + + async def from_importer(_) -> "Validator": + return Mock() + + monkeypatch.setattr("rasa.core.validator.Validator.from_importer", from_importer) + + with pytest.raises(argparse.ArgumentTypeError): + data.validate_files(args) def test_validate_files_exit_early(): From c1e70c9eca86dac59e631dcd6df15c2d58396d4f Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 21 Feb 2020 14:56:15 +0100 Subject: [PATCH 204/209] Apply BLACK formatting --- tests/cli/test_rasa_data.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tests/cli/test_rasa_data.py b/tests/cli/test_rasa_data.py index 1a84b3b48bf8..209f123219ac 100644 --- a/tests/cli/test_rasa_data.py +++ b/tests/cli/test_rasa_data.py @@ -88,15 +88,7 @@ def test_data_validate_stories_with_max_history_zero(monkeypatch: MonkeyPatch): subparsers = parser.add_subparsers(help="Rasa commands") data.add_subparser(subparsers, parents=[]) - args = parser.parse_args( - [ - "data", - "validate", - "stories", - "--max-history", - 0 - ] - ) + args = parser.parse_args(["data", "validate", "stories", "--max-history", 0]) import rasa.cli.data as data From 5b18b280da530192bb14fed16d4a0005f5298b14 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Fri, 21 Feb 2020 15:31:31 +0100 Subject: [PATCH 205/209] Clean up test_data_validate_stories_with_max_history_zero --- tests/cli/test_rasa_data.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/cli/test_rasa_data.py b/tests/cli/test_rasa_data.py index 209f123219ac..38dbf417fb9c 100644 --- a/tests/cli/test_rasa_data.py +++ b/tests/cli/test_rasa_data.py @@ -82,20 +82,16 @@ def _text_is_part_of_output_error(text: Text, output: RunResult) -> bool: def test_data_validate_stories_with_max_history_zero(monkeypatch: MonkeyPatch): - import rasa.cli.data as data - parser = argparse.ArgumentParser() subparsers = parser.add_subparsers(help="Rasa commands") data.add_subparser(subparsers, parents=[]) args = parser.parse_args(["data", "validate", "stories", "--max-history", 0]) - import rasa.cli.data as data - - async def from_importer(_) -> "Validator": + async def mock_from_importer(importer: "TrainingDataImporter") -> "Validator": return Mock() - monkeypatch.setattr("rasa.core.validator.Validator.from_importer", from_importer) + monkeypatch.setattr("rasa.core.validator.Validator.from_importer", mock_from_importer) with pytest.raises(argparse.ArgumentTypeError): data.validate_files(args) From 040980a6548d43a475ae110b3e5a237663a910a2 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 24 Feb 2020 10:32:24 +0100 Subject: [PATCH 206/209] Move `Validator` up to `rasa.validator` --- rasa/cli/data.py | 3 +-- rasa/{core => }/validator.py | 0 tests/cli/test_rasa_data.py | 4 +++- tests/core/test_story_conflict.py | 2 +- tests/core/test_validator.py | 2 +- 5 files changed, 6 insertions(+), 5 deletions(-) rename rasa/{core => }/validator.py (100%) diff --git a/rasa/cli/data.py b/rasa/cli/data.py index 102c032a34cb..f5ed5d0f1da4 100644 --- a/rasa/cli/data.py +++ b/rasa/cli/data.py @@ -1,14 +1,13 @@ import logging import argparse import asyncio -import sys from typing import List from rasa import data from rasa.cli.arguments import data as arguments import rasa.cli.utils from rasa.constants import DEFAULT_DATA_PATH -from rasa.core.validator import Validator +from rasa.validator import Validator from rasa.importers.rasa import RasaFileImporter logger = logging.getLogger(__name__) diff --git a/rasa/core/validator.py b/rasa/validator.py similarity index 100% rename from rasa/core/validator.py rename to rasa/validator.py diff --git a/tests/cli/test_rasa_data.py b/tests/cli/test_rasa_data.py index 38dbf417fb9c..f9051b0b9468 100644 --- a/tests/cli/test_rasa_data.py +++ b/tests/cli/test_rasa_data.py @@ -8,6 +8,8 @@ from _pytest.monkeypatch import MonkeyPatch from _pytest.pytester import RunResult from rasa.cli import data +from rasa.importers.importer import TrainingDataImporter +from rasa.validator import Validator def test_data_split_nlu(run_in_default_project: Callable[..., RunResult]): @@ -88,7 +90,7 @@ def test_data_validate_stories_with_max_history_zero(monkeypatch: MonkeyPatch): args = parser.parse_args(["data", "validate", "stories", "--max-history", 0]) - async def mock_from_importer(importer: "TrainingDataImporter") -> "Validator": + async def mock_from_importer(importer: TrainingDataImporter) -> Validator: return Mock() monkeypatch.setattr("rasa.core.validator.Validator.from_importer", mock_from_importer) diff --git a/tests/core/test_story_conflict.py b/tests/core/test_story_conflict.py index 2e530f31a22a..1a426850a6b9 100644 --- a/tests/core/test_story_conflict.py +++ b/tests/core/test_story_conflict.py @@ -7,7 +7,7 @@ _get_previous_event, ) from rasa.core.training.generator import TrainingDataGenerator, TrackerWithCachedStates -from rasa.core.validator import Validator +from rasa.validator import Validator from rasa.importers.rasa import RasaFileImporter from tests.core.conftest import DEFAULT_STORIES_FILE, DEFAULT_DOMAIN_PATH_WITH_SLOTS diff --git a/tests/core/test_validator.py b/tests/core/test_validator.py index db8afc0d6038..c9a9665f18a8 100644 --- a/tests/core/test_validator.py +++ b/tests/core/test_validator.py @@ -1,5 +1,5 @@ import pytest -from rasa.core.validator import Validator +from rasa.validator import Validator from rasa.importers.rasa import RasaFileImporter from tests.core.conftest import ( DEFAULT_STORIES_FILE, From c044a30c09121b6a4560331e04eb91cda2a68b7f Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 24 Feb 2020 11:05:36 +0100 Subject: [PATCH 207/209] Apply BLACK formatting --- tests/cli/test_rasa_data.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/cli/test_rasa_data.py b/tests/cli/test_rasa_data.py index f9051b0b9468..2fed2b40c8fb 100644 --- a/tests/cli/test_rasa_data.py +++ b/tests/cli/test_rasa_data.py @@ -93,7 +93,9 @@ def test_data_validate_stories_with_max_history_zero(monkeypatch: MonkeyPatch): async def mock_from_importer(importer: TrainingDataImporter) -> Validator: return Mock() - monkeypatch.setattr("rasa.core.validator.Validator.from_importer", mock_from_importer) + monkeypatch.setattr( + "rasa.core.validator.Validator.from_importer", mock_from_importer + ) with pytest.raises(argparse.ArgumentTypeError): data.validate_files(args) From 99937db44bcc6ee1e01308b13f0af729f5a11894 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 24 Feb 2020 11:07:33 +0100 Subject: [PATCH 208/209] Fix `test_data_validate_stories_with_max_history_zero` --- tests/cli/test_rasa_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cli/test_rasa_data.py b/tests/cli/test_rasa_data.py index 2fed2b40c8fb..3f92f0cd31c6 100644 --- a/tests/cli/test_rasa_data.py +++ b/tests/cli/test_rasa_data.py @@ -94,7 +94,7 @@ async def mock_from_importer(importer: TrainingDataImporter) -> Validator: return Mock() monkeypatch.setattr( - "rasa.core.validator.Validator.from_importer", mock_from_importer + "rasa.validator.Validator.from_importer", mock_from_importer ) with pytest.raises(argparse.ArgumentTypeError): From f4ddd78463de716eb7ab59eac379e80723fcccf2 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 24 Feb 2020 11:08:43 +0100 Subject: [PATCH 209/209] Apply BLACK formatting again --- tests/cli/test_rasa_data.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/cli/test_rasa_data.py b/tests/cli/test_rasa_data.py index 3f92f0cd31c6..07e54e8c47ed 100644 --- a/tests/cli/test_rasa_data.py +++ b/tests/cli/test_rasa_data.py @@ -93,9 +93,7 @@ def test_data_validate_stories_with_max_history_zero(monkeypatch: MonkeyPatch): async def mock_from_importer(importer: TrainingDataImporter) -> Validator: return Mock() - monkeypatch.setattr( - "rasa.validator.Validator.from_importer", mock_from_importer - ) + monkeypatch.setattr("rasa.validator.Validator.from_importer", mock_from_importer) with pytest.raises(argparse.ArgumentTypeError): data.validate_files(args)