-
Notifications
You must be signed in to change notification settings - Fork 84
/
inference.py
63 lines (51 loc) · 2.86 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
class ViterbiDecoder():
"""
Viterbi Decoder.
"""
def __init__(self, tag_map):
"""
:param tag_map: tag map
"""
self.tagset_size = len(tag_map)
self.start_tag = tag_map['<start>']
self.end_tag = tag_map['<end>']
def decode(self, scores, lengths):
"""
:param scores: CRF scores
:param lengths: word sequence lengths
:return: decoded sequences
"""
batch_size = scores.size(0)
word_pad_len = scores.size(1)
# Create a tensor to hold accumulated sequence scores at each current tag
scores_upto_t = torch.zeros(batch_size, self.tagset_size)
# Create a tensor to hold back-pointers
# i.e., indices of the previous_tag that corresponds to maximum accumulated score at current tag
# Let pads be the <end> tag index, since that was the last tag in the decoded sequence
backpointers = torch.ones((batch_size, max(lengths), self.tagset_size), dtype=torch.long) * self.end_tag
for t in range(max(lengths)):
batch_size_t = sum([l > t for l in lengths]) # effective batch size (sans pads) at this timestep
if t == 0:
scores_upto_t[:batch_size_t] = scores[:batch_size_t, t, self.start_tag, :] # (batch_size, tagset_size)
backpointers[:batch_size_t, t, :] = torch.ones((batch_size_t, self.tagset_size),
dtype=torch.long) * self.start_tag
else:
# We add scores at current timestep to scores accumulated up to previous timestep, and
# choose the previous timestep that corresponds to the max. accumulated score for each current timestep
scores_upto_t[:batch_size_t], backpointers[:batch_size_t, t, :] = torch.max(
scores[:batch_size_t, t, :, :] + scores_upto_t[:batch_size_t].unsqueeze(2),
dim=1) # (batch_size, tagset_size)
# Decode/trace best path backwards
decoded = torch.zeros((batch_size, backpointers.size(1)), dtype=torch.long)
pointer = torch.ones((batch_size, 1),
dtype=torch.long) * self.end_tag # the pointers at the ends are all <end> tags
for t in list(reversed(range(backpointers.size(1)))):
decoded[:, t] = torch.gather(backpointers[:, t, :], 1, pointer).squeeze(1)
pointer = decoded[:, t].unsqueeze(1) # (batch_size, 1)
# Sanity check
assert torch.equal(decoded[:, 0], torch.ones((batch_size), dtype=torch.long) * self.start_tag)
# Remove the <starts> at the beginning, and append with <ends> (to compare to targets, if any)
decoded = torch.cat([decoded[:, 1:], torch.ones((batch_size, 1), dtype=torch.long) * self.start_tag],
dim=1)
return decoded