-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainUNSUP.py
223 lines (174 loc) · 8.17 KB
/
trainUNSUP.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import argparse
import core.myTheta as myTheta
import core.trainingRoutines as training
from collections import defaultdict, Counter
import sys, os, pickle
import numpy as np
def getVocabulary(vocabularies):
print 'Loading vocabulary.'
vocabulary = set()
for m in vocabularies:
with open(m, 'rb') as f:
vocabulary.update(pickle.load(f))
vocabulary = list(vocabulary)
print '\tThere are',len(vocabulary),'words in the combined treebanks.'
try: vocabulary.pop(vocabulary.index('UNKNOWN'))
except: True
vocabulary.insert(0,'UNKNOWN')
return vocabulary
def getGrammar(option, grammars):
rules = defaultdict(Counter)
print 'Grammar-based parameter selection:', option
if len(grammars)<1:
print 'No RULES.pik file found. Exit program'
sys.exit()
for m in grammars:
with open(m, 'rb') as f:
newRules =pickle.load(f)
for lhs, rhss in newRules.iteritems():
for rhs, count in rhss.iteritems():
rules[lhs][rhs]+= count
return rules
def initializeTheta(kind,args,vocabulary, grammar,maxArity):
print 'Initializing theta.'
if args['pars']:
with open(args['pars'], 'rb') as f:
theta = pickle.load(f)
print 'Retrieved Theta from disk.'
else:
dims = dict((k, args[k]) for k in ['inside','outside'])
dims['maxArity']=maxArity
if args['emb']:
with open(args['emb'], 'rb') as f:
V,voc = pickle.load(f)
if 'UNKNOWN' not in voc: voc.insert('UNKNOWN',0)
vocabulary = [w for w in vocabulary if w in voc]
V = np.vstack(tuple([V[i] for i in [voc.index(w) for w in vocabulary]]))
dims['word'] = len(V[0])
else:
V = None
dims['word'] = args['word']
if dims['word'] is None:
print 'Either embeddings or dword must be specified. Stop program.'
sys.exit()
if not dims['inside']: dims['inside'] = dims['word']
if not dims['outside']: dims['outside'] = dims['word']
dims['nwords']=len(vocabulary)
theta = myTheta.Theta(kind, dims, grammar, V, vocabulary)
theta.printDims()
return theta
def main(args):
print 'Start (part of) experiment '+ args['experiment']
kind = args['kind']
if kind not in [ 'IORNN', 'RAE']:
raise KeyError('not a valid kind (IORNN/RAE):'+kind)
source = args['sourceTrain']
if os.path.isdir(source):
files = [f for f in [os.path.join(source,f) for f in os.listdir(source)] if os.path.isfile(f)]
if kind == 'IORNN': treebanksTrain = [f for f in files if 'IORNNS' in f]
elif kind == 'RAE': treebanksTrain = [f for f in files if 'RAES' in f]
vocabularies = [f for f in files if 'VOC.pik' in f]
grammars = [f for f in files if 'RULES.pik' in f]
else:
print 'no valid source directory:',source
sys.exit()
source = args['sourceValid']
if os.path.isdir(source):
files = [f for f in [os.path.join(source,f) for f in os.listdir(source)] if os.path.isfile(f)]
if kind == 'IORNN': treebanksValid = [f for f in files if 'IORNNS' in f]
elif kind == 'RAE': treebanksValid = [f for f in files if 'RAES' in f]
vocabularies.extend([f for f in files if 'VOC.pik' in f])
grammars.extend([f for f in files if 'RULES.pik' in f])
else:
print 'no valid source directory:',source
sys.exit()
if len(vocabularies)<2:
print 'no two vocabulary files.'
sys.exit()
if len(grammars)<2:
print 'no two grammar files.'
sys.exit()
if len(treebanksTrain)<1 or len(treebanksValid)<1:
print 'no training or validation data obtained. Abort execution.'
sys.exit()
vocabulary = getVocabulary(vocabularies)
style =args['grammar'][0]
grammar = getGrammar(style, grammars)
maxArity=6
theta=initializeTheta(kind, args,vocabulary, grammar,maxArity)
hyperParams = dict((k, args[k]) for k in ['nEpochs','bSize','lambda','alpha'])
cores = max(1,args['cores']-1) # keep one core free for optimal efficiency
hyperParams['nRules']=args['grammar'][1]
hyperParams['startAt']=args['grammar'][2]
hyperParams['ada'] = True
hyperParams['fixEmb'] = args['fix']
print 'Hyper parameters:'
for param, value in hyperParams.iteritems():
print '\t',param, '-' ,value
print '\tnumber of cores -', cores
outDir = args['out']
if not os.path.isdir(outDir):
print 'Not a valid output directory:', outDir
sys.exit()
tTreebank = training.Treebank(treebanksTrain,maxArity)
vTreebank = training.Treebank(treebanksValid[:1],maxArity)
training.storeTheta(theta, os.path.join(outDir,'initialTheta.pik'))
# training...
if style == 'beginSmall': training.beginSmall(tTreebank, vTreebank, hyperParams, theta, outDir, cores)
elif style == 'None': training.plainTrain(tTreebank, vTreebank, hyperParams, theta, outDir, cores)
elif style == 'LHS':
theta.specializeHeads()
training.plainTrain(tTreebank, vTreebank, hyperParams, theta, outDir, cores)
elif style == 'Rules':
theta.specializeRules(hyperParams['nRules'])
training.plainTrain(tTreebank, vTreebank, hyperParams, theta, outDir, cores)
elif style == 'LHS+Rules':
theta.specializeHeads()
theta.specializeRules(hyperParams['nRules'])
training.plainTrain(tTreebank, vTreebank, hyperParams, theta, outDir, cores)
class ValidateGrammar(argparse.Action):
def __call__(self, parser, args, values, option_string=None):
valid_subjects = ['None','LHS','Rules','beginSmall','LHS+Rules']
kind = values[0]
if kind not in valid_subjects:
raise ValueError('invalid grammar-option {s!r}'.format(s=kind))
if len(values)>1:
n = int(values[1])
else: n = 0
if len(values)>2:
startAt=int(values[2])
else: startAt = 10
if len(values)>3:
print '-g grammar options',values[2:], 'are ignored'
# kind, n = values
setattr(args, self.dest, (kind,n,startAt))
def mybool(string):
if string in ['F', 'f', 'false', 'False']: return False
if string in ['T', 't', 'true', 'True']: return True
raise Exception('Not a valid choice for arg: '+string)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train RAE/ IORNN unsupervised on a treebank')
parser.add_argument('-k','--kind', type=str, help='kind (IO/RNN)', required=True)
# data:
parser.add_argument('-exp','--experiment', type=str, help='Identifier of the experiment', required=True)
parser.add_argument('-st','--sourceTrain', type=str, help='Directory with pickled treebank(s), grammar(s) and vocabulary(s) for training', required=True)
parser.add_argument('-sv','--sourceValid', type=str, help='Directory with pickled treebank(s), grammar(s) and vocabulary(s) for validation', required=True)
parser.add_argument('-e','--emb', type=str, help='File with pickled embeddings', required=False)
parser.add_argument('-g','--grammar', action=ValidateGrammar, help='Kind of parameter specialization', nargs='+',required=True)
# parser.add_argument('-g','--grammar', type=str, help='File with pickled grammar', required=False)
parser.add_argument('-o','--out', type=str, help='Output file to store pickled theta', required=True)
parser.add_argument('-p','--pars', type=str, help='File with pickled theta to initialize with', required=False)
# network hyperparameters:
parser.add_argument('-din','--inside', type=int, help='Dimensionality of inside representations', required=False)
parser.add_argument('-dwrd','--word', type=int, help='Dimensionality of leaves (word nodes)', required=False)
parser.add_argument('-dout','--outside', type=int, help='Dimensionality of outside representations', required=False)
# training hyperparameters:
parser.add_argument('-n','--nEpochs', type=int, help='Maximal number of epochs to train per phase', required=True)
parser.add_argument('-b','--bSize', type=int, default = 50, help='Batch size for minibatch training', required=False)
parser.add_argument('-l','--lambda', type=float, help='Regularization parameter lambdaL2', required=True)
parser.add_argument('-a','--alpha', type=float, help='Learning rate parameter alpha', required=True)
parser.add_argument('-f','--fix', type=mybool, default = 'F', help='Whether the word embeddings should be fixed', required=False)
# computation:
parser.add_argument('-c','--cores', type=int, default=1, help='Number of cores to use for parallel processing', required=False)
args = vars(parser.parse_args())
main(args)