-
Notifications
You must be signed in to change notification settings - Fork 40
/
configurations.py
115 lines (99 loc) · 4.47 KB
/
configurations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
from Transparency.common_code.common import *
def generate_basic_config(dataset, exp_name) :
config = {
'model' :{
'encoder' : {
'vocab_size' : dataset.vec.vocab_size,
'embed_size' : dataset.vec.word_dim
},
'decoder' : {
'attention' : {
'type' : 'tanh'
},
'output_size' : dataset.output_size
}
},
'training' : {
'bsize' : dataset.bsize if hasattr(dataset, 'bsize') else 32,
'weight_decay' : 1e-5,
'pos_weight' : dataset.pos_weight if hasattr(dataset, 'pos_weight') else None,
'basepath' : dataset.basepath if hasattr(dataset, 'basepath') else 'outputs',
'exp_dirname' : os.path.join(dataset.name, exp_name)
}
}
return config
def generate_lstm_config(dataset) :
config = generate_basic_config(dataset, exp_name='lstm+tanh')
hidden_size = dataset.hidden_size if hasattr(dataset, 'hidden_size') else 128
config['model']['encoder'].update({'type': 'rnn', 'hidden_size' : hidden_size})
return config
def generate_average_config(dataset) :
config = generate_basic_config(dataset, exp_name='average+tanh')
hidden_size = dataset.hidden_size if hasattr(dataset, 'hidden_size') else 128
config['model']['encoder'].update({'projection' : True, 'hidden_size' : hidden_size, 'activation' : 'tanh', 'type' : 'average'})
return config
def generate_cnn_config(dataset, filters=(1, 3, 5, 7)) :
config = generate_basic_config(dataset, exp_name='cnn' + str(filters).replace(' ', '') + '+tanh')
hidden_size = dataset.hidden_size if hasattr(dataset, 'hidden_size') else 128
config['model']['encoder'].update({'kernel_sizes': filters, 'hidden_size' : hidden_size // len(filters), 'activation' : 'relu', 'type': 'cnn'})
return config
def generate_vanilla_lstm_config(dataset) :
config = generate_lstm_config(dataset)
config['model']['decoder']['use_attention'] = False
config['training']['exp_dirname'] = os.path.join(dataset.name, 'lstm')
return config
def generate_logodds_config(dataset) :
config = generate_lstm_config(dataset)
model = get_latest_model(os.path.join('outputs', dataset.name, 'LR+TFIDF'))
config['model']['decoder']['attention'] = {
'type' : 'logodds',
'logodds_file' : os.path.join(model, 'logodds.p')
}
config['training']['exp_dirname'] = os.path.join(dataset.name, 'lstm+logodds')
return config
def generate_logodds_regularised_config(dataset) :
config = generate_lstm_config(dataset)
model = get_latest_model(os.path.join('outputs', dataset.name, 'LR+TFIDF'))
config['model']['decoder']['regularizer_attention'] = {
'type' : 'logodds',
'logodds_file' : os.path.join(model, 'logodds.p')
}
config['training']['exp_dirname'] = os.path.join(dataset.name, 'lstm+logodds(Reg)')
return config
def generate_vanilla_cnn_config(dataset) :
config = generate_cnn_config(dataset)
config['model']['decoder']['use_attention'] = False
config['training']['exp_dirname'] = os.path.join(dataset.name, 'cnn')
return config
def generate_single_cnn_config(dataset) :
config = generate_cnn_config(dataset, filters=(3,))
config['training']['exp_dirname'] = os.path.join(dataset.name, 'cnn(3)+tanh')
return config
def make_attention_dot(func) :
def new_func(dataset) :
config = func(dataset)
config['model']['decoder']['attention']['type'] = 'dot'
config['training']['exp_dirname'] = config['training']['exp_dirname'].replace('tanh', 'dot')
return config
return new_func
configurations = {
'vanilla_lstm' : generate_vanilla_lstm_config,
'vanilla_cnn' : generate_vanilla_cnn_config,
'lstm' : generate_lstm_config,
'average' : generate_average_config,
'cnn' : generate_cnn_config,
'logodds_lstm' : generate_logodds_config,
'logodds_lstm_reg' : generate_logodds_regularised_config,
'single_cnn' : generate_single_cnn_config,
'lstm_dot' : make_attention_dot(generate_lstm_config),
'cnn_dot' : make_attention_dot(generate_cnn_config),
'average_dot' : make_attention_dot(generate_average_config)
}
def wrap_config_for_qa(func) :
def new_func(dataset) :
config = func(dataset)
config['model']['decoder']['attention']['type'] += '_qa'
return config
return new_func
configurations_qa = { k:wrap_config_for_qa(v) for k, v in configurations.items() }