Skip to content

Commit

Permalink
Fix a bug in the im2txt code where the Saver is created before the
Browse files Browse the repository at this point in the history
optimizer.
  • Loading branch information
cshallue committed Sep 23, 2016
1 parent 71f239f commit cd5e9b7
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 19 deletions.
7 changes: 3 additions & 4 deletions im2txt/im2txt/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,6 @@ def __init__(self):
# If < 1.0, the dropout keep probability applied to LSTM variables.
self.lstm_dropout_keep_prob = 0.7

# How many model checkpoints to keep.
self.max_checkpoints_to_keep = 5
self.keep_checkpoint_every_n_hours = 10000


class TrainingConfig(object):
"""Wrapper class for training hyperparameters."""
Expand All @@ -103,3 +99,6 @@ def __init__(self):

# If not None, clip gradients to this value.
self.clip_gradients = 5.0

# How many model checkpoints to keep.
self.max_checkpoints_to_keep = 5
10 changes: 7 additions & 3 deletions im2txt/im2txt/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,12 @@ def evaluate_model(sess, model, global_step, summary_writer, summary_op):
global_step)


def run_once(model, summary_writer, summary_op):
def run_once(model, saver, summary_writer, summary_op):
"""Evaluates the latest model checkpoint.
Args:
model: Instance of ShowAndTellModel; the model to evaluate.
saver: Instance of tf.train.Saver for restoring model Variables.
summary_writer: Instance of SummaryWriter.
summary_op: Op for generating model summaries.
"""
Expand All @@ -121,7 +122,7 @@ def run_once(model, summary_writer, summary_op):
with tf.Session() as sess:
# Load model from checkpoint.
tf.logging.info("Loading model from checkpoint: %s", model_path)
model.saver.restore(sess, model_path)
saver.restore(sess, model_path)
global_step = tf.train.global_step(sess, model.global_step.name)
tf.logging.info("Successfully loaded %s at global step = %d.",
os.path.basename(model_path), global_step)
Expand Down Expand Up @@ -166,6 +167,9 @@ def run():
model = show_and_tell_model.ShowAndTellModel(model_config, mode="eval")
model.build()

# Create the Saver to restore model Variables.
saver = tf.train.Saver()

# Create the summary operation and the summary writer.
summary_op = tf.merge_all_summaries()
summary_writer = tf.train.SummaryWriter(eval_dir)
Expand All @@ -177,7 +181,7 @@ def run():
start = time.time()
tf.logging.info("Starting evaluation at " + time.strftime(
"%Y-%m-%d-%H:%M:%S", time.localtime()))
run_once(model, summary_writer, summary_op)
run_once(model, saver, summary_writer, summary_op)
time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
if time_to_next_eval > 0:
time.sleep(time_to_next_eval)
Expand Down
6 changes: 2 additions & 4 deletions im2txt/im2txt/inference_utils/inference_wrapper_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,8 @@ def build_graph_from_config(self, model_config, checkpoint_path):
from the checkpoint file.
"""
tf.logging.info("Building model.")
model = self.build_model(model_config)
saver = model.saver
if not saver:
saver = tf.Saver()
self.build_model(model_config)
saver = tf.train.Saver()

return self._create_restore_fn(checkpoint_path, saver)

Expand Down
7 changes: 0 additions & 7 deletions im2txt/im2txt/show_and_tell_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,12 +347,6 @@ def setup_global_step(self):

self.global_step = global_step

def setup_saver(self):
"""Sets up the Saver for loading and saving model checkpoints."""
self.saver = tf.train.Saver(
max_to_keep=self.config.max_checkpoints_to_keep,
keep_checkpoint_every_n_hours=self.config.keep_checkpoint_every_n_hours)

def build(self):
"""Creates all ops for training and evaluation."""
self.build_inputs()
Expand All @@ -361,4 +355,3 @@ def build(self):
self.build_model()
self.setup_inception_initializer()
self.setup_global_step()
self.setup_saver()
5 changes: 4 additions & 1 deletion im2txt/im2txt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def _learning_rate_decay_fn(learning_rate, global_step):
clip_gradients=training_config.clip_gradients,
learning_rate_decay_fn=learning_rate_decay_fn)

# Set up the Saver for saving and restoring model checkpoints.
saver = tf.train.Saver(max_to_keep=training_config.max_checkpoints_to_keep)

# Run training.
tf.contrib.slim.learning.train(
train_op,
Expand All @@ -104,7 +107,7 @@ def _learning_rate_decay_fn(learning_rate, global_step):
global_step=model.global_step,
number_of_steps=FLAGS.number_of_steps,
init_fn=model.init_fn,
saver=model.saver)
saver=saver)


if __name__ == "__main__":
Expand Down

0 comments on commit cd5e9b7

Please sign in to comment.