-
Notifications
You must be signed in to change notification settings - Fork 4.7k
/
Copy pathtest_graph.py
89 lines (70 loc) · 2.78 KB
/
test_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from rasa.shared.core.training_data.structures import StoryGraph
import rasa.shared.core.training_data.loading
from rasa.shared.core.domain import Domain
def check_graph_is_sorted(g, sorted_nodes, removed_edges):
incoming_edges = {k: [s for s, vs in g.items() if k in vs] for k in g.keys()}
visited = set()
for n in sorted_nodes:
deps = incoming_edges.get(n, [])
# checks that all incoming edges are from nodes we have already visited
assert all(
[d in visited or (d, n) in removed_edges for d in deps]
), "Found an incoming edge from a node that wasn't visited yet!"
visited.add(n)
def test_node_ordering():
example_graph = {
"a": ["b", "c", "d"],
"b": [],
"c": ["d"],
"d": [],
"e": ["f"],
"f": [],
}
sorted_nodes, removed_edges = StoryGraph.topological_sort(example_graph)
# sorting removed_edges converting set converting it to list
assert removed_edges == list()
check_graph_is_sorted(example_graph, sorted_nodes, removed_edges)
def test_node_ordering_with_cycle():
example_graph = {
"a": ["b", "c", "d"],
"b": [],
"c": ["d"],
"d": ["a"],
"e": ["f"],
"f": ["e"],
}
sorted_nodes, removed_edges = StoryGraph.topological_sort(example_graph)
check_graph_is_sorted(example_graph, sorted_nodes, removed_edges)
def test_is_empty():
assert StoryGraph([]).is_empty()
def test_consistent_fingerprints():
stories_path = "data/test_yaml_stories/stories.yml"
domain_path = "data/test_domains/default_with_slots.yml"
domain = Domain.load(domain_path)
story_steps = rasa.shared.core.training_data.loading.load_data_from_resource(
stories_path, domain
)
story_graph = StoryGraph(story_steps)
# read again
story_steps_2 = rasa.shared.core.training_data.loading.load_data_from_resource(
stories_path, domain
)
story_graph_2 = StoryGraph(story_steps_2)
fingerprint = story_graph.fingerprint()
fingerprint_2 = story_graph_2.fingerprint()
assert fingerprint == fingerprint_2
def test_unique_checkpoint_names():
stories_path = "data/test_yaml_stories/story_with_two_equal_or_statements.yml"
domain_path = "data/test_domains/default_with_slots.yml"
domain = Domain.load(domain_path)
story_steps = rasa.shared.core.training_data.loading.load_data_from_resource(
stories_path, domain
)
start_checkpoint_names = {
chk.name for s in story_steps for chk in s.start_checkpoints
}
# first story:
# START_CHECKPOINT, GENR_OR_XXXXX for first OR, GENR_OR_YYYYY for second OR
# additional in second story:
# GENR_OR_ZZZZZ as entities are different from first OR in first story
assert len(start_checkpoint_names) == 4