Skip to content

Commit

Permalink
Fix compatibility.
Browse files Browse the repository at this point in the history
  • Loading branch information
cweill committed Apr 23, 2019
1 parent 2668785 commit e333310
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 28 deletions.
43 changes: 20 additions & 23 deletions adanet/core/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import tensorflow as tf


class _StopAfterTrainingHook(tf.estimator.SessionRunHook):
class _StopAfterTrainingHook(tf_compat.SessionRunHook):
"""Hook that requests stop once iteration is over."""

def __init__(self, iteration, after_fn):
Expand Down Expand Up @@ -79,7 +79,7 @@ def after_run(self, run_context, run_values):
self._after_fn()


class _EvalMetricSaverHook(tf.estimator.SessionRunHook):
class _EvalMetricSaverHook(tf_compat.SessionRunHook):
"""A hook for writing candidate evaluation metrics as summaries to disk."""

def __init__(self, name, kind, eval_metrics, output_dir):
Expand Down Expand Up @@ -137,9 +137,8 @@ def end(self, session):
eval_dict, current_global_step = session.run(
(self._eval_metric_tensors, current_global_step))

logging.info("Saving %s '%s' dict for global step %d: %s",
self._kind, self._name, current_global_step,
self._dict_to_str(eval_dict))
logging.info("Saving %s '%s' dict for global step %d: %s", self._kind,
self._name, current_global_step, self._dict_to_str(eval_dict))
summary_writer = tf.summary.FileWriterCache.get(self._output_dir)
summary_proto = tf.summary.Summary()
for key in eval_dict:
Expand All @@ -159,7 +158,7 @@ def end(self, session):
summary_writer.flush()


class _OverwriteCheckpointHook(tf.estimator.SessionRunHook):
class _OverwriteCheckpointHook(tf_compat.SessionRunHook):
"""Hook to overwrite the latest checkpoint with next iteration variables."""

def __init__(self, current_iteration, iteration_number_tensor,
Expand Down Expand Up @@ -229,14 +228,14 @@ def before_run(self, run_context):
self._checkpoint_overwritten = True


class _HookContextDecorator(tf.estimator.SessionRunHook):
class _HookContextDecorator(tf_compat.SessionRunHook):
"""Decorates a SessionRunHook's public methods to run within a context."""

def __init__(self, hook, context, is_growing_phase):
"""Initializes a _HookContextDecorator instance.
Args:
hook: The tf.estimator.SessionRunHook to decorate.
hook: The tf_compat.SessionRunHook to decorate.
context: The context to enter before calling the hook's public methods.
is_growing_phase: Whether we are in the AdaNet graph growing phase. If so,
only hook.begin() and hook.end() will be called.
Expand Down Expand Up @@ -639,16 +638,15 @@ def train(self,
while True:
current_iteration = self._latest_checkpoint_iteration_number()
logging.info("Beginning training AdaNet iteration %s",
current_iteration)
current_iteration)
self._iteration_ended = False
result = super(Estimator, self).train(
input_fn=input_fn,
hooks=hooks,
max_steps=max_steps,
saving_listeners=saving_listeners)

logging.info("Finished training Adanet iteration %s",
current_iteration)
logging.info("Finished training Adanet iteration %s", current_iteration)

# If training ended because the maximum number of training steps
# occurred, exit training.
Expand All @@ -661,7 +659,7 @@ def train(self,
return result

logging.info("Beginning bookkeeping phase for iteration %s",
current_iteration)
current_iteration)

# The chief prepares the next AdaNet iteration, and increments the
# iteration number by 1.
Expand Down Expand Up @@ -715,11 +713,11 @@ def train(self,
(task_id + 1.) * self._delay_secs_per_worker)
if delay_secs > 0.:
logging.info("Waiting %d secs before continuing training.",
delay_secs)
delay_secs)
time.sleep(delay_secs)

logging.info("Finished bookkeeping phase for iteration %s",
current_iteration)
current_iteration)

def evaluate(self,
input_fn,
Expand Down Expand Up @@ -884,7 +882,7 @@ def _get_best_ensemble_index(self, current_iteration):

latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)
logging.info("Starting ensemble evaluation for iteration %s",
current_iteration.number)
current_iteration.number)
with tf.Session() as sess:
init = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer(), tf.tables_initializer())
Expand Down Expand Up @@ -920,10 +918,9 @@ def _get_best_ensemble_index(self, current_iteration):
else:
index = np.argmin(adanet_losses)
logging.info("Finished ensemble evaluation for iteration %s",
current_iteration.number)
current_iteration.number)
logging.info("'%s' at index %s is moving onto the next iteration",
current_iteration.candidates[index].ensemble_spec.name,
index)
current_iteration.candidates[index].ensemble_spec.name, index)
return index

def _materialize_report(self, current_iteration):
Expand All @@ -939,7 +936,7 @@ def _materialize_report(self, current_iteration):

latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)
logging.info("Starting metric logging for iteration %s",
current_iteration.number)
current_iteration.number)

assert self._best_ensemble_index is not None
best_candidate = current_iteration.candidates[self._best_ensemble_index]
Expand All @@ -964,7 +961,7 @@ def _materialize_report(self, current_iteration):
materialized_reports)

logging.info("Finished saving subnetwork reports for iteration %s",
current_iteration.number)
current_iteration.number)

def _decorate_hooks(self, hooks):
"""Decorate hooks to reset AdaNet state before calling their methods."""
Expand All @@ -991,7 +988,7 @@ def _training_chief_hooks(self, current_iteration, training):
training: Whether in training mode.
Returns:
A list of `tf.estimator.SessionRunHook` instances.
A list of `SessionRunHook` instances.
"""

if not training:
Expand Down Expand Up @@ -1024,7 +1021,7 @@ def _training_hooks(self, current_iteration, training,
_OverwriteCheckpointHook will be created.
Returns:
A list of `tf.estimator.SessionRunHook` instances.
A list of `SessionRunHook` instances.
"""

if not training:
Expand Down Expand Up @@ -1053,7 +1050,7 @@ def _evaluation_hooks(self, current_iteration, training):
training: Whether in training mode.
Returns:
A list of `tf.estimator.SessionRunHook` instances.
A list of `SessionRunHook` instances.
"""

if training:
Expand Down
11 changes: 6 additions & 5 deletions adanet/core/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from absl import flags
from absl.testing import parameterized
from adanet import tf_compat
from adanet.core.architecture import _Architecture
from adanet.core.ensemble_builder import _EnsembleSpec
from adanet.ensemble import ComplexityRegularized
Expand Down Expand Up @@ -208,10 +209,10 @@ def _input_fn(params=None):

del params # Unused.

input_features = tf.compat.v1.data.make_one_shot_iterator(tf.data.Dataset.from_tensors(
input_features = tf_compat.v1.data.make_one_shot_iterator(tf.data.Dataset.from_tensors(
[features])).get_next()
if labels is not None:
input_labels = tf.compat.v1.data.make_one_shot_iterator(tf.data.Dataset.from_tensors(
input_labels = tf_compat.v1.data.make_one_shot_iterator(tf.data.Dataset.from_tensors(
[labels])).get_next()
else:
input_labels = None
Expand All @@ -222,10 +223,10 @@ def _input_fn(params=None):

def head():
return tf.contrib.estimator.regression_head(
loss_reduction=tf.compat.v1.losses.Reduction.SUM_OVER_BATCH_SIZE)
loss_reduction=tf_compat.v1.losses.Reduction.SUM_OVER_BATCH_SIZE)


class ModifierSessionRunHook(tf.estimator.SessionRunHook):
class ModifierSessionRunHook(tf_compat.SessionRunHook):
"""Modifies the graph by adding a variable."""

def __init__(self, var_name="hook_created_variable"):
Expand All @@ -242,7 +243,7 @@ def begin(self):
if self._begun:
raise ValueError("begin called twice without end.")
self._begun = True
_ = tf.compat.v1.get_variable(name=self._var_name, initializer="")
_ = tf_compat.v1.get_variable(name=self._var_name, initializer="")

def end(self, session):
"""Adds a variable to the graph.
Expand Down
10 changes: 10 additions & 0 deletions adanet/tf_compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@
except AttributeError:
v1 = tf

try:
v2 = tf.compat.v2
except AttributeError:
v2 = tf.contrib

try:
SessionRunHook = tf.estimator.SessionRunHook
except AttributeError:
SessionRunHook = tf.train.SessionRunHook


def tensor_name(tensor):
"""Returns the Tensor's name.
Expand Down

0 comments on commit e333310

Please sign in to comment.