diff --git a/im2txt/im2txt/configuration.py b/im2txt/im2txt/configuration.py index 294be81bbdb9dc..3b664eb9f0cd96 100644 --- a/im2txt/im2txt/configuration.py +++ b/im2txt/im2txt/configuration.py @@ -77,10 +77,6 @@ def __init__(self): # If < 1.0, the dropout keep probability applied to LSTM variables. self.lstm_dropout_keep_prob = 0.7 - # How many model checkpoints to keep. - self.max_checkpoints_to_keep = 5 - self.keep_checkpoint_every_n_hours = 10000 - class TrainingConfig(object): """Wrapper class for training hyperparameters.""" @@ -103,3 +99,6 @@ def __init__(self): # If not None, clip gradients to this value. self.clip_gradients = 5.0 + + # How many model checkpoints to keep. + self.max_checkpoints_to_keep = 5 diff --git a/im2txt/im2txt/evaluate.py b/im2txt/im2txt/evaluate.py index 3a95558c512f49..3ff6e5932dd722 100644 --- a/im2txt/im2txt/evaluate.py +++ b/im2txt/im2txt/evaluate.py @@ -104,11 +104,12 @@ def evaluate_model(sess, model, global_step, summary_writer, summary_op): global_step) -def run_once(model, summary_writer, summary_op): +def run_once(model, saver, summary_writer, summary_op): """Evaluates the latest model checkpoint. Args: model: Instance of ShowAndTellModel; the model to evaluate. + saver: Instance of tf.train.Saver for restoring model Variables. summary_writer: Instance of SummaryWriter. summary_op: Op for generating model summaries. """ @@ -121,7 +122,7 @@ def run_once(model, summary_writer, summary_op): with tf.Session() as sess: # Load model from checkpoint. tf.logging.info("Loading model from checkpoint: %s", model_path) - model.saver.restore(sess, model_path) + saver.restore(sess, model_path) global_step = tf.train.global_step(sess, model.global_step.name) tf.logging.info("Successfully loaded %s at global step = %d.", os.path.basename(model_path), global_step) @@ -166,6 +167,9 @@ def run(): model = show_and_tell_model.ShowAndTellModel(model_config, mode="eval") model.build() + # Create the Saver to restore model Variables. + saver = tf.train.Saver() + # Create the summary operation and the summary writer. summary_op = tf.merge_all_summaries() summary_writer = tf.train.SummaryWriter(eval_dir) @@ -177,7 +181,7 @@ def run(): start = time.time() tf.logging.info("Starting evaluation at " + time.strftime( "%Y-%m-%d-%H:%M:%S", time.localtime())) - run_once(model, summary_writer, summary_op) + run_once(model, saver, summary_writer, summary_op) time_to_next_eval = start + FLAGS.eval_interval_secs - time.time() if time_to_next_eval > 0: time.sleep(time_to_next_eval) diff --git a/im2txt/im2txt/inference_utils/inference_wrapper_base.py b/im2txt/im2txt/inference_utils/inference_wrapper_base.py index d305101710e604..e94cd6af474488 100644 --- a/im2txt/im2txt/inference_utils/inference_wrapper_base.py +++ b/im2txt/im2txt/inference_utils/inference_wrapper_base.py @@ -112,10 +112,8 @@ def build_graph_from_config(self, model_config, checkpoint_path): from the checkpoint file. """ tf.logging.info("Building model.") - model = self.build_model(model_config) - saver = model.saver - if not saver: - saver = tf.Saver() + self.build_model(model_config) + saver = tf.train.Saver() return self._create_restore_fn(checkpoint_path, saver) diff --git a/im2txt/im2txt/show_and_tell_model.py b/im2txt/im2txt/show_and_tell_model.py index 5faad9452f9baf..5bddade3760d47 100644 --- a/im2txt/im2txt/show_and_tell_model.py +++ b/im2txt/im2txt/show_and_tell_model.py @@ -347,12 +347,6 @@ def setup_global_step(self): self.global_step = global_step - def setup_saver(self): - """Sets up the Saver for loading and saving model checkpoints.""" - self.saver = tf.train.Saver( - max_to_keep=self.config.max_checkpoints_to_keep, - keep_checkpoint_every_n_hours=self.config.keep_checkpoint_every_n_hours) - def build(self): """Creates all ops for training and evaluation.""" self.build_inputs() @@ -361,4 +355,3 @@ def build(self): self.build_model() self.setup_inception_initializer() self.setup_global_step() - self.setup_saver() diff --git a/im2txt/im2txt/train.py b/im2txt/im2txt/train.py index 2c2df210cc2444..db602735ba11e7 100644 --- a/im2txt/im2txt/train.py +++ b/im2txt/im2txt/train.py @@ -95,6 +95,9 @@ def _learning_rate_decay_fn(learning_rate, global_step): clip_gradients=training_config.clip_gradients, learning_rate_decay_fn=learning_rate_decay_fn) + # Set up the Saver for saving and restoring model checkpoints. + saver = tf.train.Saver(max_to_keep=training_config.max_checkpoints_to_keep) + # Run training. tf.contrib.slim.learning.train( train_op, @@ -104,7 +107,7 @@ def _learning_rate_decay_fn(learning_rate, global_step): global_step=model.global_step, number_of_steps=FLAGS.number_of_steps, init_fn=model.init_fn, - saver=model.saver) + saver=saver) if __name__ == "__main__":