-
Notifications
You must be signed in to change notification settings - Fork 4
/
config.py
129 lines (105 loc) · 4.6 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import os
import time
tf.flags.DEFINE_string("dataset", "refcoco",
"Dataset for training/test, option:refcoco/refcoco+/refcocog.")
tf.flags.DEFINE_string("vocab_file", "./data/word_embedding/vocabulary_72700.txt",
"Vocabulary file path.")
tf.flags.DEFINE_string("wordembed_params", "./data/word_embedding/embed_matrix.npy",
"word embedding initialization file path.")
tf.flags.DEFINE_string("checkpoint", "",
"checkpoint for evaluation (only used during evaluation stage).")
tf.flags.DEFINE_integer("log_interval", 500,
"Interval for saving log.")
tf.flags.DEFINE_integer("snapshot_start", 120000,
"Start step for saving snapshot.")
tf.flags.DEFINE_integer("snapshot_interval", 1000,
"Interval for saving snapshot.")
tf.flags.DEFINE_integer("lr_decay_step", 120000,
"Learning rate will decay after this step.")
tf.flags.DEFINE_integer("max_iter", 180000,
"Maximum iterations for training.")
tf.flags.DEFINE_boolean('supervised', True,
"If false, the model will be trained under unsupervised learning")
FLAGS = tf.app.flags.FLAGS
class Model_Config(object):
"""Wrapper class for model hyperparameters."""
def __init__(self):
"""Sets the default model and training hyperparameters."""
# Supervised/unsupervised learning
self.is_supervised = FLAGS.supervised
# LSTM input and output dimensionality, respectively.
self.embed_dim = 300
self.lstm_dim = 1000
# Sequence maximum length and vocabulary length.
self.L = 20
self.num_vocab = 72704
# Visual and spatial feature dimensionality
self.vis_dim = 4096
self.spa_dim = 5
# Encoder, decoder and regularizer embedding dimensionality, respectively.
self.enc_dim = 512
self.dec_dim = 512
self.reg_dim = 512
# Training hyperparameters.
# If True, the dropout applied to LSTM variables.
self.lstm_dropout = False
# Hyperparameters for learning rate and Momentum optimizer
self.start_lr = 0.01
self.lr_decay_step = FLAGS.lr_decay_step
self.lr_decay_rate = 0.1
self.momentum = 0.95
self.max_iter = FLAGS.max_iter
# If not None, clip gradients to this value.
self.clip_gradients = 10.0
# Weight decay for regularization.
self.weight_decay = 0.0005
# Decay for averaging loss and accuracy .
self.avg_decay = 0.99
class File_Config(object):
"""Data path for reader and main function."""
def __init__(self, model='vc'):
"""Sets the data path."""
# Dataset type.
self.dataset = FLAGS.dataset # refcoco/refcoco+/refcocog
# If True, print loading information.
self.info_print = True
# Model type
self.model = model
# Set checkpoint (only useful in evaluation)
self.checkpoint = FLAGS.checkpoint
# Set split type for different datasets.
self.setup()
def set_split(self):
assert self.dataset in ['refcoco', 'refcoco+', 'refcocog'], "Dataset should be refcoco/refcoco+/refcocog"
if self.dataset in ['refcoco', 'refcoco+']:
self.split = 'unc'
else:
self.split = 'google'
def set_log_options(self):
"""Set tensorflow log and snapshot options."""
# Set snapshot and log options
if FLAGS.supervised:
self.log_dir = './tflog/%s/' % self.dataset
self.snapshot_dir = './tfmodel/%s/' % self.dataset
else:
self.log_dir = './tflog/%s_un/' % self.dataset
self.snapshot_dir = './tfmodel/%s_un/' % self.dataset
self.log_interval = FLAGS.log_interval
self.snapshot_file = os.path.join(self.snapshot_dir, 'iter_%d.tfmodel')
self.snapshot_start = FLAGS.snapshot_start
self.snapshot_interval = FLAGS.snapshot_interval
def set_init_params(self):
"""Set initialization parameters."""
self.num_vocab = 72704
self.embed_dim = 300
self.vocab_file = FLAGS.vocab_file
self.wordembed_params = FLAGS.wordembed_params
def setup(self):
"""Set tensorflow log directory and so on."""
self.set_split()
self.set_log_options()
self.set_init_params()