-
Notifications
You must be signed in to change notification settings - Fork 38
/
Copy pathconfig.py
138 lines (120 loc) · 6.59 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/env python3
# Copyright 2018-present, HKUST-KnowComp.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Model architecture/optimization options for WRMCQA document reader."""
import argparse
import logging
logger = logging.getLogger(__name__)
# Index of arguments concerning the core model architecture
MODEL_ARCHITECTURE = {
'model_type', 'embedding_dim', 'char_embedding_dim', 'hidden_size', 'char_hidden_size',
'doc_layers', 'question_layers', 'rnn_type', 'concat_rnn_layers', 'question_merge',
'use_qemb', 'use_exact_match', 'use_pos', 'use_ner', 'use_lemma', 'use_tf', 'hop'
}
# Index of arguments concerning the model optimizer/training
MODEL_OPTIMIZER = {
'fix_embeddings', 'optimizer', 'learning_rate', 'momentum', 'weight_decay',
'rho', 'eps', 'max_len', 'grad_clipping', 'tune_partial',
'rnn_padding', 'dropout_rnn', 'dropout_rnn_output', 'dropout_emb'
}
def str2bool(v):
return v.lower() in ('yes', 'true', 't', '1', 'y')
def add_model_args(parser):
parser.register('type', 'bool', str2bool)
# Model architecture
model = parser.add_argument_group('Reader Model Architecture')
model.add_argument('--model-type', type=str, default='mnemonic',
help='Model architecture type: rnn, r_net, mnemonic')
model.add_argument('--embedding-dim', type=int, default=300,
help='Embedding size if embedding_file is not given')
model.add_argument('--char-embedding-dim', type=int, default=50,
help='Embedding size if char_embedding_file is not given')
model.add_argument('--hidden-size', type=int, default=100,
help='Hidden size of RNN units')
model.add_argument('--char-hidden-size', type=int, default=50,
help='Hidden size of char RNN units')
model.add_argument('--doc-layers', type=int, default=3,
help='Number of encoding layers for document')
model.add_argument('--question-layers', type=int, default=3,
help='Number of encoding layers for question')
model.add_argument('--rnn-type', type=str, default='lstm',
help='RNN type: LSTM, GRU, or RNN')
# Model specific details
detail = parser.add_argument_group('Reader Model Details')
detail.add_argument('--concat-rnn-layers', type='bool', default=True,
help='Combine hidden states from each encoding layer')
detail.add_argument('--question-merge', type=str, default='self_attn',
help='The way of computing the question representation')
detail.add_argument('--use-qemb', type='bool', default=True,
help='Whether to use weighted question embeddings')
detail.add_argument('--use-exact-match', type='bool', default=True,
help='Whether to use in_question_* features')
detail.add_argument('--use-pos', type='bool', default=True,
help='Whether to use pos features')
detail.add_argument('--use-ner', type='bool', default=True,
help='Whether to use ner features')
detail.add_argument('--use-lemma', type='bool', default=True,
help='Whether to use lemma features')
detail.add_argument('--use-tf', type='bool', default=True,
help='Whether to use term frequency features')
detail.add_argument('--hop', type=int, default=2,
help='The number of hops for both aligner and the answer pointer in m-reader')
# Optimization details
optim = parser.add_argument_group('Reader Optimization')
optim.add_argument('--dropout-emb', type=float, default=0.2,
help='Dropout rate for word embeddings')
optim.add_argument('--dropout-rnn', type=float, default=0.2,
help='Dropout rate for RNN states')
optim.add_argument('--dropout-rnn-output', type='bool', default=True,
help='Whether to dropout the RNN output')
optim.add_argument('--optimizer', type=str, default='adamax',
help='Optimizer: sgd, adamax, adadelta')
optim.add_argument('--learning-rate', type=float, default=1.0,
help='Learning rate for sgd, adadelta')
optim.add_argument('--grad-clipping', type=float, default=10,
help='Gradient clipping')
optim.add_argument('--weight-decay', type=float, default=0,
help='Weight decay factor')
optim.add_argument('--momentum', type=float, default=0,
help='Momentum factor')
optim.add_argument('--rho', type=float, default=0.95,
help='Rho for adadelta')
optim.add_argument('--eps', type=float, default=1e-6,
help='Eps for adadelta')
optim.add_argument('--fix-embeddings', type='bool', default=True,
help='Keep word embeddings fixed (use pretrained)')
optim.add_argument('--tune-partial', type=int, default=0,
help='Backprop through only the top N question words')
optim.add_argument('--rnn-padding', type='bool', default=False,
help='Explicitly account for padding in RNN encoding')
optim.add_argument('--max-len', type=int, default=15,
help='The max span allowed during decoding')
def get_model_args(args):
"""Filter args for model ones.
From a args Namespace, return a new Namespace with *only* the args specific
to the model architecture or optimization. (i.e. the ones defined here.)
"""
global MODEL_ARCHITECTURE, MODEL_OPTIMIZER
required_args = MODEL_ARCHITECTURE | MODEL_OPTIMIZER
arg_values = {k: v for k, v in vars(args).items() if k in required_args}
return argparse.Namespace(**arg_values)
def override_model_args(old_args, new_args):
"""Set args to new parameters.
Decide which model args to keep and which to override when resolving a set
of saved args and new args.
We keep the new optimation, but leave the model architecture alone.
"""
global MODEL_OPTIMIZER
old_args, new_args = vars(old_args), vars(new_args)
for k in old_args.keys():
if k in new_args and old_args[k] != new_args[k]:
if k in MODEL_OPTIMIZER:
logger.info('Overriding saved %s: %s --> %s' %
(k, old_args[k], new_args[k]))
old_args[k] = new_args[k]
else:
logger.info('Keeping saved %s: %s' % (k, old_args[k]))
return argparse.Namespace(**old_args)