-
Notifications
You must be signed in to change notification settings - Fork 52
/
test_random.py
126 lines (114 loc) Β· 3.9 KB
/
test_random.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import itertools
import pytest
from pyrdf2vec.graphs import KG, Vertex
from pyrdf2vec.walkers import RandomWalker
LOOP = [
["Alice", "knows", "Bob"],
["Alice", "knows", "Dean"],
["Bob", "knows", "Dean"],
["Dean", "loves", "Alice"],
]
LONG_CHAIN = [
["Alice", "knows", "Bob"],
["Alice", "knows", "Dean"],
["Bob", "knows", "Mathilde"],
["Mathilde", "knows", "Alfy"],
["Alfy", "knows", "Stephane"],
["Stephane", "knows", "Alfred"],
["Alfred", "knows", "Emma"],
["Emma", "knows", "Julio"],
]
URL = "http://pyRDF2Vec"
KG_LOOP = KG()
KG_CHAIN = KG()
MAX_DEPTHS = range(15)
KGS = [KG_LOOP, KG_CHAIN]
MAX_WALKS = [0, 1, 2, 3, 4, 5]
ROOTS_WITHOUT_URL = ["Alice", "Bob", "Dean"]
WITH_REVERSE = [False, True]
class TestRandomWalker:
@pytest.fixture(scope="session")
def setup(self):
for i, graph in enumerate([LOOP, LONG_CHAIN]):
for row in graph:
subj = Vertex(f"{URL}#{row[0]}")
obj = Vertex((f"{URL}#{row[2]}"))
pred = Vertex(
(f"{URL}#{row[1]}"), predicate=True, vprev=subj, vnext=obj
)
if i == 0:
KG_LOOP.add_walk(subj, pred, obj)
else:
KG_CHAIN.add_walk(subj, pred, obj)
@pytest.mark.parametrize(
"kg, root, max_depth, is_reverse",
list(
itertools.product(KGS, ROOTS_WITHOUT_URL, MAX_DEPTHS, WITH_REVERSE)
),
)
def test_bfs(self, setup, kg, root, max_depth, is_reverse):
root = f"{URL}#{root}"
walks = RandomWalker(max_depth, None, random_state=42)._bfs(
kg, Vertex(root), is_reverse
)
for walk in walks:
assert len(walk) <= (max_depth * 2) + 1
if is_reverse:
assert walk[-1].name == root
else:
assert walk[0].name == root
@pytest.mark.parametrize(
"kg, root, max_depth, max_walks, is_reverse",
list(
itertools.product(
KGS, ROOTS_WITHOUT_URL, MAX_DEPTHS, MAX_WALKS, WITH_REVERSE
),
),
)
def test_dfs(self, setup, kg, root, max_depth, max_walks, is_reverse):
root = f"{URL}#{root}"
for walk in RandomWalker(max_depth, max_walks, random_state=42)._dfs(
kg, Vertex(root), is_reverse
):
assert len(walk) <= (max_depth * 2) + 1
if is_reverse:
assert walk[-1].name == root
else:
assert walk[0].name == root
@pytest.mark.parametrize(
"kg, root, max_depth, max_walks, with_reverse",
list(
itertools.product(
KGS, ROOTS_WITHOUT_URL, MAX_DEPTHS, MAX_WALKS, WITH_REVERSE
)
),
)
def test_extract(
self, setup, kg, root, max_depth, max_walks, with_reverse
):
root = f"{URL}#{root}"
walker = RandomWalker(
max_depth, max_walks, with_reverse=with_reverse, random_state=42
)
walks = walker.extract(kg, [root])[0]
if max_walks is not None:
if with_reverse:
assert len(walks) <= max_walks * max_walks
else:
assert len(walks) <= max_walks
for walk in walks:
for obj in walk[2::2]:
if obj not in walker._entities:
assert obj.startswith("b'")
if not with_reverse:
assert walk[0] == root
assert len(walk) <= (max_depth * 2) + 1
else:
assert len(walk) <= ((max_depth * 2) + 1) * 2
def test_inverse_extract(self, setup):
walker = RandomWalker(1, None, with_reverse=True, random_state=42)
walks = walker.extract(KG_LOOP, [f"{URL}#Bob", f"{URL}#Alice"])
assert any(
walk[0] == f"{URL}#Alice" and walk[2] == f"{URL}#Bob"
for walk in walks[0] + walks[1]
)