-
Notifications
You must be signed in to change notification settings - Fork 8
/
mst.py
106 lines (87 loc) · 3.76 KB
/
mst.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# https://github.com/mrshu/neural-dependency-parser/blob/master/mst_Tim_Dozat.py
import numpy as np
def find_cycles(edges):
vertices = np.arange(len(edges))
indices = np.zeros_like(vertices) - 1
lowlinks = np.zeros_like(vertices) - 1
stack = []
onstack = np.zeros_like(vertices, dtype=np.bool)
current_index = 0
cycles = []
def strong_connect(vertex, current_index):
indices[vertex] = current_index
lowlinks[vertex] = current_index
stack.append(vertex)
current_index += 1
onstack[vertex] = True
for vertex_ in np.where(edges == vertex)[0]:
if indices[vertex_] == -1:
current_index = strong_connect(vertex_, current_index)
lowlinks[vertex] = min(lowlinks[vertex], lowlinks[vertex_])
elif onstack[vertex_]:
lowlinks[vertex] = min(lowlinks[vertex], indices[vertex_])
if lowlinks[vertex] == indices[vertex]:
cycle = []
vertex_ = -1
while vertex_ != vertex:
vertex_ = stack.pop()
onstack[vertex_] = False
cycle.append(vertex_)
if len(cycle) > 1:
cycles.append(np.array(cycle))
return current_index
for vertex in vertices:
if indices[vertex] == -1:
current_index = strong_connect(vertex, current_index)
return cycles
def find_roots(edges):
return np.where(edges[1:] == 0)[0] + 1
def score_edges(probs, edges):
return np.sum(probs[np.arange(1, len(probs)), edges[1:]])
def chu_liu_edmonds(probs):
vertices = np.arange(len(probs))
edges = np.argmax(probs, axis=1)
cycles = find_cycles(edges)
if cycles:
cycle_vertices = cycles.pop()
non_cycle_vertices = np.delete(vertices, cycle_vertices)
cycle_edges = edges[cycle_vertices]
non_cycle_probs = np.array(probs[non_cycle_vertices,:][:,non_cycle_vertices])
non_cycle_probs = np.pad(non_cycle_probs, [[0,1], [0,1]], 'constant')
backoff_cycle_probs = probs[cycle_vertices][:,non_cycle_vertices] / probs[cycle_vertices,cycle_edges][:,None]
non_cycle_probs[-1,:-1] = np.max(backoff_cycle_probs, axis=0)
non_cycle_probs[:-1,-1] = np.max(probs[non_cycle_vertices][:,cycle_vertices], axis=1)
non_cycle_edges = chu_liu_edmonds(non_cycle_probs)
non_cycle_root, non_cycle_edges = non_cycle_edges[-1], non_cycle_edges[:-1]
source_vertex = non_cycle_vertices[non_cycle_root]
cycle_root = np.argmax(backoff_cycle_probs[:,non_cycle_root])
target_vertex = cycle_vertices[cycle_root]
edges[target_vertex] = source_vertex
mask = np.where(non_cycle_edges < len(non_cycle_probs)-1)
edges[non_cycle_vertices[mask]] = non_cycle_vertices[non_cycle_edges[mask]]
mask = np.where(non_cycle_edges == len(non_cycle_probs)-1)
stuff = np.argmax(probs[non_cycle_vertices][:,cycle_vertices], axis=1)
stuff2 = cycle_vertices[stuff]
stuff3 = non_cycle_vertices[mask]
edges[stuff3] = stuff2[mask]
return edges
def mst(probs):
probs *= 1 - np.eye(len(probs)).astype(np.float32)
probs[0] = 0
probs[0, 0] = 1
probs /= np.sum(probs, axis=1, keepdims=True)
edges = chu_liu_edmonds(probs)
roots = find_roots(edges)
best_edges = edges
best_score = -np.inf
if len(roots) > 1:
for root_idx in roots:
edges_ = edges.copy()
for i in range(len(edges)):
if i != 0 and edges[i] == 0 and i != root_idx:
edges_[i] = root_idx
score = score_edges(probs, edges_)
if score > best_score:
best_edges = edges_
best_score = score
return best_edges