forked from manuvn/lpRNN-awd-lstm-lm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpointer.py
144 lines (132 loc) · 6.26 KB
/
pointer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import argparse
import os
import time
import math
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import data
import model
from utils import batchify, get_batch, repackage_hidden
def one_hot(idx, size, cuda=True):
a = np.zeros((1, size), np.float32)
a[0][idx] = 1
v = Variable(torch.from_numpy(a))
if cuda: v = v.cuda()
return v
def evaluate(data_source, batch_size=10, window=3785):
# Turn on evaluation mode which disables dropout.
if args.model == 'QRNN': model.reset()
model.eval()
total_loss = 0
# ntokens = len(corpus.dictionary)
hidden = model.init_hidden(batch_size)
next_word_history = None
pointer_history = None
for i in range(0, data_source.size(0) - 1, args.bptt):
if i > 0: print('Iteration properties',i, len(data_source), math.exp(total_loss / i))
data, targets = get_batch(data_source, i, evaluation=True, args=args)
output, hidden, rnn_outs, _ = model(data, hidden, return_h=True)
rnn_out = rnn_outs[-1].squeeze()
output = model.decoder(rnn_out.view(rnn_out.size(0)*rnn_out.size(1), rnn_out.size(2)))
rnn_out = rnn_out.view(rnn_out.size(0)*rnn_out.size(1), rnn_out.size(2))
output_flat = output.view(-1, ntokens)
###
# Fill pointer history
start_idx = len(next_word_history) if next_word_history is not None else 0
next_word_history = torch.cat([one_hot(t.data, ntokens) for t in targets]) if next_word_history is None else torch.cat([next_word_history, torch.cat([one_hot(t.data, ntokens) for t in targets])])
#print(next_word_history)
pointer_history = Variable(rnn_out.data) if pointer_history is None else torch.cat([pointer_history, Variable(rnn_out.data)], dim=0)
#print(pointer_history)
###
# Built-in cross entropy
# total_loss += len(data) * criterion(output_flat, targets).data[0]
###
# Manual cross entropy
# softmax_output_flat = torch.nn.functional.softmax(output_flat)
# soft = torch.gather(softmax_output_flat, dim=1, index=targets.view(-1, 1))
# entropy = -torch.log(soft)
# total_loss += len(data) * entropy.mean().data[0]
###
# Pointer manual cross entropy
loss = 0
softmax_output_flat = torch.nn.functional.softmax(output_flat)
for idx, vocab_loss in enumerate(softmax_output_flat):
p = vocab_loss
if start_idx + idx > window:
valid_next_word = next_word_history[start_idx + idx - window:start_idx + idx]
valid_pointer_history = pointer_history[start_idx + idx - window:start_idx + idx]
logits = torch.mv(valid_pointer_history, rnn_out[idx])
theta = args.theta
ptr_attn = torch.nn.functional.softmax(theta * logits).view(-1, 1)
ptr_dist = (ptr_attn.expand_as(valid_next_word) * valid_next_word).sum(0).squeeze()
lambdah = args.lambdasm
p = lambdah * ptr_dist + (1 - lambdah) * vocab_loss
###
target_loss = p[targets[idx].data]
loss += (-torch.log(target_loss)).data
total_loss += loss / batch_size
###
hidden = repackage_hidden(hidden)
next_word_history = next_word_history[-window:]
pointer_history = pointer_history[-window:]
return total_loss / len(data_source)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model')
parser.add_argument('--data', type=str, default='data/penn',
help='location of the data corpus')
parser.add_argument('--model', type=str, default='LSTM',
help='type of recurrent net (LSTM, QRNN)')
parser.add_argument('--save', type=str,default='best.pt',
help='model to use the pointer over')
parser.add_argument('--cuda', action='store_false',
help='use CUDA')
parser.add_argument('--bptt', type=int, default=5000,
help='sequence length')
parser.add_argument('--window', type=int, default=3785,
help='pointer window length')
parser.add_argument('--theta', type=float, default=0.6625523432485668,
help='mix between uniform distribution and pointer softmax distribution over previous words')
parser.add_argument('--lambdasm', type=float, default=0.12785920428335693,
help='linear mix between only pointer (1) and only vocab (0) distribution')
parser.add_argument('--savepath', type=str, default='.',
help='loation to dump results')
args = parser.parse_args()
###############################################################################
# Load data
###############################################################################
corpus = data.Corpus(args.data)
eval_batch_size = 2
test_batch_size = 2
#train_data = batchify(corpus.train, args.batch_size)
val_data = batchify(corpus.valid, test_batch_size, args)
test_data = batchify(corpus.test, test_batch_size, args)
###############################################################################
# Build the model
###############################################################################
ntokens = len(corpus.dictionary)
criterion = nn.CrossEntropyLoss()
args.save = os.path.join(args.savepath, args.save)
print(args.save)
# Load the best saved model.
with open(args.save, 'rb') as f:
if not args.cuda:
model = torch.load(f, map_location=lambda storage, loc: storage)
else:
model = torch.load(f)
if type(model) == list:
model = model[0] # Why does load return a list?
print(model)
# Run on val data.
val_loss = evaluate(val_data, test_batch_size, args.window)
print('=' * 89)
print('| End of pointer | val loss {:5.2f} | val ppl {:8.2f}'.format(
val_loss, math.exp(val_loss)))
print('=' * 89)
# Run on test data.
test_loss = evaluate(test_data, test_batch_size, args.window)
print('=' * 89)
print('| End of pointer | test loss {:5.2f} | test ppl {:8.2f}'.format(
test_loss, math.exp(test_loss)))
print('=' * 89)