Skip to content

Commit

Permalink
Enable more modules for TF 2.0 | #93.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 245298663
  • Loading branch information
cweill committed Apr 25, 2019
1 parent 6798f6c commit 7c6b06f
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 37 deletions.
1 change: 1 addition & 0 deletions adanet/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ py_library(
name = "evaluator",
srcs = ["evaluator.py"],
deps = [
"//adanet/tf_compat",
],
)

Expand Down
9 changes: 4 additions & 5 deletions adanet/core/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,17 +592,16 @@ def _latest_checkpoint_iteration_number(self):
latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)
if latest_checkpoint is None:
return 0
return tf.contrib.framework.load_variable(latest_checkpoint,
self._Keys.CURRENT_ITERATION)
return tf.train.load_variable(latest_checkpoint,
self._Keys.CURRENT_ITERATION)

def _latest_checkpoint_global_step(self):
"""Returns the global step from the latest checkpoint."""

latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)
if latest_checkpoint is None:
return 0
return tf.contrib.framework.load_variable(latest_checkpoint,
tf.GraphKeys.GLOBAL_STEP)
return tf.train.load_variable(latest_checkpoint, tf.GraphKeys.GLOBAL_STEP)

@contextlib.contextmanager
def _train_loop_context(self):
Expand Down Expand Up @@ -1354,7 +1353,7 @@ def _adanet_model_fn(self, features, labels, mode, params):
# variable values to avoid any race conditions between the first and second
# checkpoint reads.
if mode == tf.estimator.ModeKeys.EVAL and self._evaluation_checkpoint_path:
iteration_number = tf.contrib.framework.load_variable(
iteration_number = tf.train.load_variable(
self._evaluation_checkpoint_path, self._Keys.CURRENT_ITERATION)

if self._prepare_next_iteration_state == self._Keys.INCREMENT_ITERATION:
Expand Down
14 changes: 8 additions & 6 deletions adanet/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import math

from absl import logging
from adanet import tf_compat
import tensorflow as tf


Expand Down Expand Up @@ -77,22 +79,22 @@ def evaluate_adanet_losses(self, sess, adanet_losses):
logging_frequency = math.floor(self.steps / 10.)

adanet_losses = [
tf.metrics.mean(adanet_loss) for adanet_loss in adanet_losses
tf_compat.v1.metrics.mean(adanet_loss) for adanet_loss in adanet_losses
]
sess.run(tf.local_variables_initializer())
sess.run(tf_compat.v1.local_variables_initializer())
while True:
if self.steps is not None and evals_completed == self.steps:
break
try:
evals_completed += 1
if (evals_completed % logging_frequency == 0 or
self.steps == evals_completed):
tf.logging.info("Ensemble evaluation [%d/%s]", evals_completed,
self.steps or "??")
logging.info("Ensemble evaluation [%d/%s]", evals_completed,
self.steps or "??")
sess.run(adanet_losses)
except tf.errors.OutOfRangeError:
tf.logging.info("Encountered end of input after %d evaluations",
evals_completed)
logging.info("Encountered end of input after %d evaluations",
evals_completed)
break

# Losses are metric op tuples. Evaluating the first element is idempotent.
Expand Down
15 changes: 10 additions & 5 deletions adanet/core/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,8 @@ def _summary_fn(tensor, step):
# e.g. "foo/bar/baz/scalar" will become "baz/scalar" when
# additional_scope is "foo/bar".
# TODO: Figure out a cleaner way to handle this.
assert not tf.get_default_graph().get_name_scope()
with tf.name_scope(name_scope):
assert not tf_compat.v1.get_default_graph().get_name_scope()
with tf_compat.v1.name_scope(name_scope):
with self._strip_tag_scope(additional_scope):
# TODO: Do summaries need to be reduced before writing?
# Presumably each tensor core creates its own summary so we may be
Expand Down Expand Up @@ -696,23 +696,28 @@ def monkey_patched_summaries(summary):
setattr(summary_v2_lib, "image", wrapped_summary.image_v2)
setattr(summary_v2_lib, "histogram", wrapped_summary.histogram_v2)
setattr(summary_v2_lib, "audio", wrapped_summary.audio_v2)
if not tf_compat.version_greater_or_equal("2.0.0"):
try:
# TF 2.0 eliminates tf.contrib.
setattr(tf.contrib.summary, "scalar", wrapped_summary.scalar_v2)
setattr(tf.contrib.summary, "image", wrapped_summary.image_v2)
setattr(tf.contrib.summary, "histogram", wrapped_summary.histogram_v2)
setattr(tf.contrib.summary, "audio", wrapped_summary.audio_v2)
except AttributeError:
# TF 2.0 eliminates tf.contrib.
pass

try:
yield
finally:
# Revert monkey-patches.
if not tf_compat.version_greater_or_equal("2.0.0"):
# TF 2.0 eliminates tf.contrib.
try:
setattr(tf.contrib.summary, "audio", old_summary_v2_audio)
setattr(tf.contrib.summary, "histogram", old_summary_v2_histogram)
setattr(tf.contrib.summary, "image", old_summary_v2_image)
setattr(tf.contrib.summary, "scalar", old_summary_v2_scalar)
except AttributeError:
# TF 2.0 eliminates tf.contrib.
pass
setattr(summary_v2_lib, "audio", old_summary_v2_audio)
setattr(summary_v2_lib, "histogram", old_summary_v2_histogram)
setattr(summary_v2_lib, "image", old_summary_v2_image)
Expand Down
29 changes: 13 additions & 16 deletions adanet/ensemble/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,12 @@ def build_ensemble(self, subnetworks, previous_ensemble_subnetworks, features,
if isinstance(weighted_subnetwork.subnetwork.last_layer, dict):
weight_initializer = {
key: self._load_variable_from_model_dir(
tf_compat.tensor_name(weighted_subnetwork.weight[key]))
weighted_subnetwork.weight[key])
for key in sorted(weighted_subnetwork.subnetwork.last_layer)
}
else:
weight_initializer = self._load_variable_from_model_dir(
tf_compat.tensor_name(weighted_subnetwork.weight))
weighted_subnetwork.weight)
with tf_compat.v1.variable_scope(
"weighted_subnetwork_{}".format(subnetwork_index)):
weighted_subnetworks.append(
Expand All @@ -284,11 +284,10 @@ def build_ensemble(self, subnetworks, previous_ensemble_subnetworks, features,
weighted_subnetworks, prior=previous_ensemble.bias)
else:
bias = self._create_bias_term(weighted_subnetworks)
logging.info(
"Builders using a pruned set of the subnetworks "
"from the previous ensemble, so its ensemble's bias "
"term will not be warm started with the previous "
"ensemble's bias.")
logging.info("Builders using a pruned set of the subnetworks "
"from the previous ensemble, so its ensemble's bias "
"term will not be warm started with the previous "
"ensemble's bias.")
else:
bias = self._create_bias_term(weighted_subnetworks)

Expand All @@ -309,8 +308,8 @@ def build_ensemble(self, subnetworks, previous_ensemble_subnetworks, features,
logits=logits,
complexity_regularization=complexity_regularization)

def _load_variable_from_model_dir(self, var_name):
return tf.contrib.framework.load_variable(self._model_dir, var_name)
def _load_variable_from_model_dir(self, var):
return tf.train.load_variable(self._model_dir, tf_compat.tensor_name(var))

def _compute_adanet_gamma(self, complexity):
"""For a subnetwork, computes: lambda * r(h) + beta."""
Expand Down Expand Up @@ -405,11 +404,10 @@ def _build_weighted_subnetwork_helper(self,
# [batch_size x timesteps, emb_dim] for matrix multiplication
# and reshaping back.
if ndims == 3:
logging.info(
"Rank 3 tensors like [batch_size, timesteps, d] are "
"reshaped to rank 2 [ batch_size x timesteps, d] for "
"the weight matrix multiplication, and are reshaped "
"to their original shape afterwards.")
logging.info("Rank 3 tensors like [batch_size, timesteps, d] are "
"reshaped to rank 2 [ batch_size x timesteps, d] for "
"the weight matrix multiplication, and are reshaped "
"to their original shape afterwards.")
last_layer = tf.reshape(last_layer, [-1, last_layer_size])
logits = tf.matmul(last_layer, weight)
if ndims == 3:
Expand Down Expand Up @@ -458,8 +456,7 @@ def _create_bias_term_helper(self,
dims = logits.shape.as_list()
shape = dims[-1] if len(dims) > 1 else 1
else:
prior = self._load_variable_from_model_dir(
tf_compat.tensor_name(_lookup_if_dict(prior, key)))
prior = self._load_variable_from_model_dir(_lookup_if_dict(prior, key))
return tf_compat.v1.get_variable(
name="bias_{}".format(index) if index else "bias",
shape=shape,
Expand Down
9 changes: 5 additions & 4 deletions adanet/ensemble/weighted_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,14 @@ def setUp(self):
mock.patch.object(self._optimizer, 'minimize', autospec=True).start()

mock.patch.object(
tf.contrib.framework, 'load_variable', autospec=True).start()
tf.train, 'load_variable', autospec=True).start()

def load_variable(checkpoint_dir, name):
self.assertEqual(checkpoint_dir, 'fake_checkpoint_dir')
return tf.Variable(initial_value=1., name='fake_loaded_variable_' + name)
return tf_compat.v1.get_variable(name='fake_loaded_variable_' + name,
initializer=1.)

tf.contrib.framework.load_variable.side_effect = load_variable
tf.train.load_variable.side_effect = load_variable

self.summary = _FakeSummary()

Expand All @@ -110,7 +111,7 @@ def _build_easy_ensemble(self, subnetworks):
def _build_subnetwork(self, multi_head=False):

last_layer = tf.Variable(
tf.random.normal(shape=(2, 3)), trainable=False).read_value()
tf_compat.random_normal(shape=(2, 3)), trainable=False).read_value()

def new_logits():
return tf_compat.v1.layers.dense(
Expand Down
1 change: 1 addition & 0 deletions adanet/subnetwork/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":generator",
"//adanet/tf_compat",
"@absl_py//absl/testing:parameterized",
],
)
Expand Down
3 changes: 2 additions & 1 deletion adanet/subnetwork/generator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import collections

from absl.testing import parameterized
from adanet import tf_compat
from adanet.subnetwork.generator import Builder
from adanet.subnetwork.generator import Subnetwork
import tensorflow as tf
Expand All @@ -31,7 +32,7 @@ def dummy_tensor(shape=(), random_seed=42):
"""Returns a randomly initialized tensor."""

return tf.Variable(
tf.random_normal(shape=shape, seed=random_seed),
tf_compat.random_normal(shape=shape, seed=random_seed),
trainable=False).read_value()


Expand Down

0 comments on commit 7c6b06f

Please sign in to comment.