-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunlikely.py
129 lines (109 loc) · 5.04 KB
/
unlikely.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
##################################################################
#Language Model
#This file identifies words that are not well-predicted by a model
##################################################################
import argparse
import csv
import os
import pandas as pd
import torch
from torch.autograd import Variable
import textData
from utils import norm_weights
parser = argparse.ArgumentParser(description='PyTorch Language Model')
#Parameters
parser.add_argument('--data', type=str, default='baum_wiz_clean',
help='location of the data corpus')
parser.add_argument('--checkpoint', type=str, default='model.pt',
help='model checkpoint to use')
parser.add_argument('--outf', type=str, default='unlikely_words.csv',
help='output file for unlikely words report')
parser.add_argument('--diff', type=float, default=0.05,
help='threshold for determining unlikeliness (default 0.05)')
parser.add_argument('--ignore', type=str, default='EOS',
help='generated words to be ignored (default "EOS")')
parser.add_argument('--text', type=str, default='test',
help='text used to assess model (train, valid, test)')
parser.add_argument('--seed', type=int, default=1111,
help='random seed')
parser.add_argument('--cuda', action='store_true',
help='use CUDA')
parser.add_argument('--temperature', type=float, default=1.0,
help='temperature (diversity increases with arg value)')
parser.add_argument('--log-interval', type=int, default=5000,
help='reporting interval')
args = parser.parse_args()
#Set random seed for reproducibility
torch.manual_seed(args.seed)
if torch.cuda.is_available():
if not args.cuda:
print('WARNING: You have a CUDA device, so you should probably run with --cuda')
device = torch.device('cuda' if args.cuda else 'cpu')
if args.temperature < 1e-3:
parser.error('--temperature has to be greater than or equal to 1e-3')
if args.text not in ['train', 'valid', 'test']:
raise ValueError( """An invalid option for `--text` was supplied.
options are ['train', 'valid', 'test']""")
with open(args.checkpoint, 'rb') as f:
model = torch.load(f).to(device)
model.eval()
corpus = textData.Corpus(args.data)
ignored = args.ignore.split()
hidden = model.init_hidden(1)
unlikely_dict = dict()
if args.text == 'train':
corpus_eval = corpus.train
if args.text == 'valid':
corpus_eval = corpus.valid
if args.text == 'test':
corpus_eval = corpus.test
with torch.no_grad(): #Do not track history
hits = 0
for i, word in enumerate(corpus_eval[:-1]):
# Compare generated word's probability to true word's probability
input = word.view(1,1)
output, hidden = model(input, hidden)
word_weights = output.squeeze().div(args.temperature).exp().cpu()
word_probs = norm_weights(word_weights)
word_idx = torch.multinomial(word_weights, 1)[0] #Index of generated word
input.fill_(word_idx)
true_word = corpus.dictionary.idx2word[corpus_eval[i+1]]
gen_word = corpus.dictionary.idx2word[word_idx] #Generated word
true_prob = word_probs[corpus_eval[i+1]]
gen_prob = word_probs[corpus_eval[i]]
#Report those words which the model predicts to have a probability
#of `diff` less than the generated word
if gen_prob - true_prob > args.diff and gen_word not in ignored:
hits += 1
if (true_word, gen_word) not in unlikely_dict:
unlikely_dict[(true_word, gen_word)] = 1
else:
unlikely_dict[(true_word, gen_word)] += 1
#Reporting interval
if i % args.log_interval == 0:
print('| Share of unlikely words: {}/{}'.format(hits, i))
print('Detected {} discrepancies with given parameters'.format(hits))
pathout = os.path.join(args.data, args.outf)
period = args.outf.find('.')
suffix = 0
while os.path.exists(pathout):
new_outf = ''.join([args.outf[:period], str(suffix), args.outf[period:]])
pathout = os.path.join(args.data, new_outf)
suffix += 1
change_fout = True
with open(pathout, 'w') as fout:
wrt = csv.writer(fout, delimiter=',')
wrt.writerow(['diff: {}'.format(args.diff), 'ignore: {}'.format(args.ignore), 'text: {}'.format(args.text)])
wrt.writerow(['actual', 'generated', 'freq'])
for key in unlikely_dict:
wrt.writerow([key[0], key[1], unlikely_dict[key]])
print('Wrote discrepancy data to {}'.format(pathout))
#Output summary data
df = pd.read_csv(pathout, skiprows=1)
print('--diff {} --ignore {} --text {}'.format(args.diff, args.ignore, args.text))
print('Found {} unique combinations of generated and actual words.'.format(df.shape[0]))
print('Group by generated word:')
print(df.groupby('generated').sum()[['freq']].sort_values('freq', ascending=False)[:20])
print('Group by actual word:')
print(df.groupby('actual').sum()[['freq']].sort_values('freq', ascending=False)[:20])
print('Output limited to the twenty most frequent words.')