Skip to content

Commit

Permalink
Three failing tests left
Browse files Browse the repository at this point in the history
  • Loading branch information
cweill committed Apr 23, 2019
1 parent 4b76c18 commit 781f8ce
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 108 deletions.
86 changes: 44 additions & 42 deletions adanet/core/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import os
import time

from absl import logging
from adanet import tf_compat
from adanet.core.architecture import _Architecture
from adanet.core.candidate import _CandidateBuilder
from adanet.core.ensemble_builder import _EnsembleBuilder
Expand All @@ -44,7 +46,7 @@
import tensorflow as tf


class _StopAfterTrainingHook(tf.train.SessionRunHook):
class _StopAfterTrainingHook(tf.estimator.SessionRunHook):
"""Hook that requests stop once iteration is over."""

def __init__(self, iteration, after_fn):
Expand All @@ -65,7 +67,7 @@ def before_run(self, run_context):
"""See `SessionRunHook`."""

del run_context # Unused
return tf.train.SessionRunArgs(self._iteration.is_over_fn())
return tf.estimator.SessionRunArgs(self._iteration.is_over_fn())

def after_run(self, run_context, run_values):
"""See `SessionRunHook`."""
Expand All @@ -77,7 +79,7 @@ def after_run(self, run_context, run_values):
self._after_fn()


class _EvalMetricSaverHook(tf.train.SessionRunHook):
class _EvalMetricSaverHook(tf.estimator.SessionRunHook):
"""A hook for writing candidate evaluation metrics as summaries to disk."""

def __init__(self, name, kind, eval_metrics, output_dir):
Expand Down Expand Up @@ -110,8 +112,8 @@ def begin(self):
# the metric variables. The metrics themselves are computed as a result of
# being returned in the EstimatorSpec by _adanet_model_fn.
metric_fn, tensors = self._eval_metrics
tensors = [tf.placeholder(t.dtype, t.shape) for t in tensors]
eval_metric_ops = metric_fn(*tensors)
tensors = {k: tf.placeholder(v.dtype, v.shape) for k, v in tensors.items()}
eval_metric_ops = metric_fn(**tensors)
self._eval_metric_tensors = {k: v[0] for k, v in eval_metric_ops.items()}

def _dict_to_str(self, dictionary):
Expand All @@ -135,7 +137,7 @@ def end(self, session):
eval_dict, current_global_step = session.run(
(self._eval_metric_tensors, current_global_step))

tf.logging.info("Saving %s '%s' dict for global step %d: %s",
logging.info("Saving %s '%s' dict for global step %d: %s",
self._kind, self._name, current_global_step,
self._dict_to_str(eval_dict))
summary_writer = tf.summary.FileWriterCache.get(self._output_dir)
Expand All @@ -150,14 +152,14 @@ def end(self, session):
summ.value[i].tag = "{}/{}".format(key, i)
summary_proto.value.extend(summ.value)
else:
tf.logging.warn(
logging.warn(
"Skipping summary for %s, must be a float, np.float32, "
"or a serialized string of Summary.", key)
summary_writer.add_summary(summary_proto, current_global_step)
summary_writer.flush()


class _OverwriteCheckpointHook(tf.train.SessionRunHook):
class _OverwriteCheckpointHook(tf.estimator.SessionRunHook):
"""Hook to overwrite the latest checkpoint with next iteration variables."""

def __init__(self, current_iteration, iteration_number_tensor,
Expand Down Expand Up @@ -227,14 +229,14 @@ def before_run(self, run_context):
self._checkpoint_overwritten = True


class _HookContextDecorator(tf.train.SessionRunHook):
class _HookContextDecorator(tf.estimator.SessionRunHook):
"""Decorates a SessionRunHook's public methods to run within a context."""

def __init__(self, hook, context, is_growing_phase):
"""Initializes a _HookContextDecorator instance.
Args:
hook: The tf.train.SessionRunHook to decorate.
hook: The tf.estimator.SessionRunHook to decorate.
context: The context to enter before calling the hook's public methods.
is_growing_phase: Whether we are in the AdaNet graph growing phase. If so,
only hook.begin() and hook.end() will be called.
Expand Down Expand Up @@ -498,7 +500,7 @@ def __init__(self,
k: v for k, v in kwargs.items() if k in default_ensembler_args
}
if default_ensembler_kwargs:
tf.logging.warn(
logging.warn(
"The following arguments have been moved to "
"`adanet.ensemble.ComplexityRegularizedEnsembler` which can be "
"specified in the `ensemblers` argument: {}".format(
Expand All @@ -510,7 +512,7 @@ def __init__(self,
placement_strategy_arg = "experimental_placement_strategy"
placement_strategy = kwargs.pop(placement_strategy_arg, None)
if placement_strategy:
tf.logging.warning(
logging.warning(
"%s is an experimental feature. Its behavior is not guaranteed "
"to be backwards compatible.", placement_strategy_arg)

Expand Down Expand Up @@ -636,7 +638,7 @@ def train(self,
), self._train_loop_context():
while True:
current_iteration = self._latest_checkpoint_iteration_number()
tf.logging.info("Beginning training AdaNet iteration %s",
logging.info("Beginning training AdaNet iteration %s",
current_iteration)
self._iteration_ended = False
result = super(Estimator, self).train(
Expand All @@ -645,7 +647,7 @@ def train(self,
max_steps=max_steps,
saving_listeners=saving_listeners)

tf.logging.info("Finished training Adanet iteration %s",
logging.info("Finished training Adanet iteration %s",
current_iteration)

# If training ended because the maximum number of training steps
Expand All @@ -658,7 +660,7 @@ def train(self,
if not self._iteration_ended:
return result

tf.logging.info("Beginning bookkeeping phase for iteration %s",
logging.info("Beginning bookkeeping phase for iteration %s",
current_iteration)

# The chief prepares the next AdaNet iteration, and increments the
Expand Down Expand Up @@ -695,13 +697,13 @@ def train(self,

# Check timeout when waiting for potentially downed chief.
if timer.secs_remaining() == 0:
tf.logging.error(
logging.error(
"Chief job did not prepare next iteration after %s secs. It "
"may have been preempted, been turned down, or crashed. This "
"worker is now exiting training.",
self._worker_wait_timeout_secs)
return result
tf.logging.info("Waiting for chief to finish")
logging.info("Waiting for chief to finish")
time.sleep(self._worker_wait_secs)

# Stagger starting workers to prevent training instability.
Expand All @@ -712,11 +714,11 @@ def train(self,
delay_secs = min(self._max_worker_delay_secs,
(task_id + 1.) * self._delay_secs_per_worker)
if delay_secs > 0.:
tf.logging.info("Waiting %d secs before continuing training.",
logging.info("Waiting %d secs before continuing training.",
delay_secs)
time.sleep(delay_secs)

tf.logging.info("Finished bookkeeping phase for iteration %s",
logging.info("Finished bookkeeping phase for iteration %s",
current_iteration)

def evaluate(self,
Expand Down Expand Up @@ -786,31 +788,31 @@ def _prepare_next_iteration(self, train_input_fn):
Args:
train_input_fn: The input_fn used during training.
"""
tf.logging.info("Preparing next iteration:")
logging.info("Preparing next iteration:")

# First, evaluate and choose the best ensemble for this iteration.
tf.logging.info("Evaluating candidates...")
logging.info("Evaluating candidates...")
self._prepare_next_iteration_state = self._Keys.EVALUATE_ENSEMBLES
if self._evaluator:
evaluator_input_fn = self._evaluator.input_fn
else:
evaluator_input_fn = train_input_fn
self._call_adanet_model_fn(evaluator_input_fn, tf.estimator.ModeKeys.EVAL)
tf.logging.info("Done evaluating candidates.")
logging.info("Done evaluating candidates.")

# Then materialize and store the subnetwork reports.
if self._report_materializer:
tf.logging.info("Materializing reports...")
logging.info("Materializing reports...")
self._prepare_next_iteration_state = self._Keys.MATERIALIZE_REPORT
self._call_adanet_model_fn(self._report_materializer.input_fn,
tf.estimator.ModeKeys.EVAL)
tf.logging.info("Done materializing reports.")
logging.info("Done materializing reports.")

self._best_ensemble_index = None

# Finally, create the graph for the next iteration and overwrite the model
# directory checkpoint with the expanded graph.
tf.logging.info("Adapting graph and incrementing iteration number...")
logging.info("Adapting graph and incrementing iteration number...")
self._prepare_next_iteration_state = self._Keys.INCREMENT_ITERATION
temp_model_dir = os.path.join(self.model_dir, "temp_model_dir")
if tf.gfile.Exists(temp_model_dir):
Expand All @@ -826,9 +828,9 @@ def _prepare_next_iteration(self, train_input_fn):
saving_listeners=None)
tf.gfile.DeleteRecursively(temp_model_dir)
self._prepare_next_iteration_state = None
tf.logging.info("Done adapting graph and incrementing iteration number.")
logging.info("Done adapting graph and incrementing iteration number.")

tf.logging.info("Finished preparing next iteration.")
logging.info("Finished preparing next iteration.")

def _architecture_filename(self, iteration_number):
"""Returns the filename of the given iteration's frozen graph."""
Expand Down Expand Up @@ -865,7 +867,7 @@ def _get_best_ensemble_index(self, current_iteration):

# Skip the evaluation phase when there is only one candidate subnetwork.
if len(current_iteration.candidates) == 1:
tf.logging.info(
logging.info(
"As the only candidate, '%s' is moving onto the next iteration.",
current_iteration.candidates[0].ensemble_spec.name)
return 0
Expand All @@ -874,14 +876,14 @@ def _get_best_ensemble_index(self, current_iteration):
# previous_ensemble.
if current_iteration.number > 0 and self._force_grow and (len(
current_iteration.candidates) == 2):
tf.logging.info(
logging.info(
"As the only candidate with `force_grow` enabled, '%s' is moving"
"onto the next iteration.",
current_iteration.candidates[1].ensemble_spec.name)
return 1

latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)
tf.logging.info("Starting ensemble evaluation for iteration %s",
logging.info("Starting ensemble evaluation for iteration %s",
current_iteration.number)
with tf.Session() as sess:
init = tf.group(tf.global_variables_initializer(),
Expand All @@ -906,9 +908,9 @@ def _get_best_ensemble_index(self, current_iteration):
ensemble_name = current_iteration.candidates[i].ensemble_spec.name
values.append("{}/{} = {:.6f}".format(metric_name, ensemble_name,
adanet_losses[i]))
tf.logging.info("Computed ensemble metrics: %s", ", ".join(values))
logging.info("Computed ensemble metrics: %s", ", ".join(values))
if self._force_grow and current_iteration.number > 0:
tf.logging.info(
logging.info(
"The `force_grow` override is enabled, so the "
"the performance of the previous ensemble will be ignored.")
# NOTE: The zero-th index candidate at iteration t>0 is always the
Expand All @@ -917,9 +919,9 @@ def _get_best_ensemble_index(self, current_iteration):
index = np.argmin(adanet_losses) + 1
else:
index = np.argmin(adanet_losses)
tf.logging.info("Finished ensemble evaluation for iteration %s",
logging.info("Finished ensemble evaluation for iteration %s",
current_iteration.number)
tf.logging.info("'%s' at index %s is moving onto the next iteration",
logging.info("'%s' at index %s is moving onto the next iteration",
current_iteration.candidates[index].ensemble_spec.name,
index)
return index
Expand All @@ -936,7 +938,7 @@ def _materialize_report(self, current_iteration):
"""

latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)
tf.logging.info("Starting metric logging for iteration %s",
logging.info("Starting metric logging for iteration %s",
current_iteration.number)

assert self._best_ensemble_index is not None
Expand All @@ -961,7 +963,7 @@ def _materialize_report(self, current_iteration):
self._report_accessor.write_iteration_report(current_iteration.number,
materialized_reports)

tf.logging.info("Finished saving subnetwork reports for iteration %s",
logging.info("Finished saving subnetwork reports for iteration %s",
current_iteration.number)

def _decorate_hooks(self, hooks):
Expand Down Expand Up @@ -989,7 +991,7 @@ def _training_chief_hooks(self, current_iteration, training):
training: Whether in training mode.
Returns:
A list of `tf.train.SessionRunHook` instances.
A list of `tf.estimator.SessionRunHook` instances.
"""

if not training:
Expand All @@ -1000,7 +1002,7 @@ def _training_chief_hooks(self, current_iteration, training):
output_dir = self.model_dir
if summary.scope:
output_dir = os.path.join(output_dir, summary.namespace, summary.scope)
summary_saver_hook = tf.train.SummarySaverHook(
summary_saver_hook = tf.estimator.SummarySaverHook(
save_steps=self.config.save_summary_steps,
output_dir=output_dir,
summary_op=summary.merge_all())
Expand All @@ -1022,7 +1024,7 @@ def _training_hooks(self, current_iteration, training,
_OverwriteCheckpointHook will be created.
Returns:
A list of `tf.train.SessionRunHook` instances.
A list of `tf.estimator.SessionRunHook` instances.
"""

if not training:
Expand Down Expand Up @@ -1051,7 +1053,7 @@ def _evaluation_hooks(self, current_iteration, training):
training: Whether in training mode.
Returns:
A list of `tf.train.SessionRunHook` instances.
A list of `tf.estimator.SessionRunHook` instances.
"""

if training:
Expand Down Expand Up @@ -1376,7 +1378,7 @@ def _adanet_model_fn(self, features, labels, mode, params):
if not tf.gfile.Exists(architecture_filename):
continue
architecture = self._read_architecture(architecture_filename)
tf.logging.info(
logging.info(
"Importing architecture from %s: [%s].", architecture_filename,
", ".join(
sorted([
Expand Down Expand Up @@ -1456,7 +1458,7 @@ def _adanet_model_fn(self, features, labels, mode, params):
assert mode == tf.estimator.ModeKeys.TRAIN
assert self.config.is_chief
latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)
tf.logging.info(
logging.info(
"Overwriting checkpoint with new graph for iteration %s to %s",
iteration_number, latest_checkpoint)
return self._create_estimator_spec(current_iteration, mode,
Expand Down
Loading

0 comments on commit 781f8ce

Please sign in to comment.