Skip to content

Commit

Permalink
Utilize TF 2.0 summaries in core when V2 behavior is enabled. #93
Browse files Browse the repository at this point in the history
Fallback to old TF v1 versions of summaries when TF v2 behavior is not enabled.

Also add additional summary, estimator, and autoensemble tests that only execute when TF v2 is enabled.

PiperOrigin-RevId: 269815407
  • Loading branch information
cweill committed Sep 18, 2019
1 parent 92b955a commit 800ad3f
Show file tree
Hide file tree
Showing 13 changed files with 1,167 additions and 104 deletions.
11 changes: 11 additions & 0 deletions adanet/autoensemble/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,14 @@ py_test(
"@absl_py//absl/testing:parameterized",
],
)

py_test(
name = "estimator_v2_test",
size = "large",
srcs = ["estimator_v2_test.py"],
shard_count = 5,
deps = [
":estimator",
"@absl_py//absl/testing:parameterized",
],
)
2 changes: 1 addition & 1 deletion adanet/autoensemble/estimator_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Tests for AdaNet AutoEnsembleEstimator.
"""Tests for AdaNet AutoEnsembleEstimator in TF 1.
Copyright 2018 The AdaNet Authors. All Rights Reserved.
Expand Down
148 changes: 148 additions & 0 deletions adanet/autoensemble/estimator_v2_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""Tests for AdaNet AutoEnsembleEstimator in TF 2.
Copyright 2019 The AdaNet Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import shutil
import sys

from absl import flags
from absl.testing import parameterized
from adanet.autoensemble.estimator import AutoEnsembleEstimator
import tensorflow as tf

# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.estimator.export import export
from tensorflow_estimator.python.estimator.head import regression_head
# pylint: enable=g-direct-tensorflow-import


class AutoEnsembleEstimatorV2Test(parameterized.TestCase, tf.test.TestCase):

def setUp(self):
super(AutoEnsembleEstimatorV2Test, self).setUp()
# Setup and cleanup test directory.
# Flags are not automatically parsed at this point.
flags.FLAGS(sys.argv)
self.test_subdirectory = os.path.join(flags.FLAGS.test_tmpdir, self.id())
shutil.rmtree(self.test_subdirectory, ignore_errors=True)
os.makedirs(self.test_subdirectory)

def tearDown(self):
super(AutoEnsembleEstimatorV2Test, self).tearDown()
shutil.rmtree(self.test_subdirectory, ignore_errors=True)

# pylint: disable=g-long-lambda
@parameterized.named_parameters(
{
"testcase_name":
"candidate_pool_lambda",
"candidate_pool":
lambda head, feature_columns, optimizer: lambda config: {
"dnn":
tf.estimator.DNNEstimator(
head=head,
feature_columns=feature_columns,
optimizer=optimizer,
hidden_units=[3],
config=config),
"linear":
tf.estimator.LinearEstimator(
head=head,
feature_columns=feature_columns,
optimizer=optimizer,
config=config),
},
"want_loss":
.209,
},)
# pylint: enable=g-long-lambda
def test_auto_ensemble_estimator_lifecycle(self,
candidate_pool,
want_loss,
max_train_steps=30):
features = {"input_1": [[1., 0.]]}
labels = [[1.]]

run_config = tf.estimator.RunConfig(tf_random_seed=42)
head = regression_head.RegressionHead()

# Always create optimizers in a lambda to prevent error like:
# `RuntimeError: Cannot set `iterations` to a new Variable after the
# Optimizer weights have been created`
optimizer = lambda: tf.keras.optimizers.SGD(lr=.01)
feature_columns = [tf.feature_column.numeric_column("input_1", shape=[2])]

def train_input_fn():
input_features = {}
for key, feature in features.items():
input_features[key] = tf.constant(feature, name=key)
input_labels = tf.constant(labels, name="labels")
return input_features, input_labels

def test_input_fn():
dataset = tf.data.Dataset.from_tensors([tf.constant(features["input_1"])])
input_features = tf.compat.v1.data.make_one_shot_iterator(
dataset).get_next()
return {"input_1": input_features}, None

estimator = AutoEnsembleEstimator(
head=head,
candidate_pool=candidate_pool(head, feature_columns, optimizer),
max_iteration_steps=10,
force_grow=True,
model_dir=self.test_subdirectory,
config=run_config)

# Train for three iterations.
estimator.train(input_fn=train_input_fn, max_steps=max_train_steps)

# Evaluate.
eval_results = estimator.evaluate(input_fn=train_input_fn, steps=1)

self.assertAllClose(max_train_steps, eval_results["global_step"])
self.assertAllClose(want_loss, eval_results["loss"], atol=.3)

# Predict.
predictions = estimator.predict(input_fn=test_input_fn)
for prediction in predictions:
self.assertIsNotNone(prediction["predictions"])

# Export SavedModel.
def serving_input_fn():
"""Input fn for serving export, starting from serialized example."""
serialized_example = tf.compat.v1.placeholder(
dtype=tf.string, shape=(None), name="serialized_example")
for key, value in features.items():
features[key] = tf.constant(value)
return export.SupervisedInputReceiver(
features=features,
labels=tf.constant(labels),
receiver_tensors=serialized_example)

export_dir_base = os.path.join(self.test_subdirectory, "export")
estimator.export_saved_model(
export_dir_base=export_dir_base,
serving_input_receiver_fn=serving_input_fn)


if __name__ == "__main__":
tf.enable_v2_behavior()
tf.test.main()
27 changes: 27 additions & 0 deletions adanet/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ py_test(
],
)

py_test(
name = "estimator_v2_test",
size = "large",
srcs = ["estimator_v2_test.py"],
shard_count = 2,
deps = [
":ensemble_builder",
":estimator",
":evaluator",
":report_materializer",
":testing_utils",
"//adanet/subnetwork",
"@absl_py//absl/testing:parameterized",
],
)

py_test(
name = "estimator_distributed_test",
size = "large",
Expand Down Expand Up @@ -245,6 +261,17 @@ py_test(
],
)

py_test(
name = "summary_v2_test",
srcs = ["summary_v2_test.py"],
deps = [
":summary",
":testing_utils",
"@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
)

py_library(
name = "timer",
srcs = ["timer.py"],
Expand Down
97 changes: 88 additions & 9 deletions adanet/core/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from adanet.core.iteration import _IterationBuilder
from adanet.core.report_accessor import _ReportAccessor
from adanet.core.summary import _ScopedSummary
from adanet.core.summary import _ScopedSummaryV2
from adanet.core.summary import _TPUScopedSummary
from adanet.core.timer import _CountDownTimer
from adanet.distributed import ReplicationStrategy
Expand Down Expand Up @@ -88,6 +89,68 @@ def _stop_if_is_over(self, run_context):
self._after_fn()


class _SummaryV2SaverHook(tf_compat.SessionRunHook):
"""A hook that writes summaries to the appropriate log directory on disk."""

def __init__(self, summaries, save_steps=None, save_secs=None):
"""Initializes a `SummaryV2SaverHook` for writing TF 2 summaries.
Args:
summaries: List of `_ScopedSummaryV2` instances.
save_steps: `int`, save summaries every N steps. Exactly one of
`save_secs` and `save_steps` should be set.
save_secs: `int`, save summaries every N seconds.
"""

self._summaries = summaries
self._summary_ops = []
self._writer_init_ops = []
self._timer = tf_compat.v1.train.SecondOrStepTimer(
every_secs=save_secs, every_steps=save_steps)

def begin(self):
self._next_step = None
self._global_step_tensor = tf_compat.v1.train.get_global_step()

for summary in self._summaries:
assert isinstance(summary, _ScopedSummaryV2)
writer = tf_compat.v2.summary.create_file_writer(summary.logdir)
with writer.as_default():
for summary_fn, tensor in summary.summary_tuples():
self._summary_ops.append(
summary_fn(tensor, step=tf.compat.v1.train.get_global_step()))
self._writer_init_ops.append(writer.init())

def after_create_session(self, session, coord):
session.run(self._writer_init_ops)

def before_run(self, run_context):
requests = {"global_step": self._global_step_tensor}
self._request_summary = (
self._next_step is None or
self._timer.should_trigger_for_step(self._next_step))
if self._request_summary:
requests["summary"] = self._summary_ops

return tf_compat.SessionRunArgs(requests)

def after_run(self, run_context, run_values):
stale_global_step = run_values.results["global_step"]
global_step = stale_global_step + 1
if self._next_step is None or self._request_summary:
global_step = run_context.session.run(self._global_step_tensor)

if self._request_summary:
self._timer.update_last_triggered_step(global_step)

self._next_step = global_step + 1

def end(self, session):
# TODO: Run writer.flush() at Session end.
# Currently disabled because the flush op crashes between iterations.
return


class _EvalMetricSaverHook(tf_compat.SessionRunHook):
"""A hook for writing candidate evaluation metrics as summaries to disk."""

Expand Down Expand Up @@ -588,6 +651,13 @@ def __init__(self,

def _summary_maker(self, scope=None, skip_summary=False, namespace=None):
"""Constructs a `_ScopedSummary`."""
if tf_compat.is_v2_behavior_enabled():
# Here we assume TF 2 behavior is enabled.
return _ScopedSummaryV2(
logdir=self._model_dir,
scope=scope,
skip_summary=skip_summary,
namespace=namespace)
if self._use_tpu:
return _TPUScopedSummary(
logdir=self._model_dir,
Expand Down Expand Up @@ -1346,15 +1416,24 @@ def _training_chief_hooks(self, current_iteration, training):
return []

training_hooks = []
for summary in current_iteration.summaries:
output_dir = self.model_dir
if summary.scope:
output_dir = os.path.join(output_dir, summary.namespace, summary.scope)
summary_saver_hook = tf_compat.SummarySaverHook(
save_steps=self.config.save_summary_steps,
output_dir=output_dir,
summary_op=summary.merge_all())
training_hooks.append(summary_saver_hook)
if tf_compat.is_v2_behavior_enabled():
# Use V2 summaries and hook when user is using TF 2 behavior.
training_hooks.append(
_SummaryV2SaverHook(
current_iteration.summaries,
save_steps=self.config.save_summary_steps))
else:
# Fallback to V1 summaries.
for summary in current_iteration.summaries:
output_dir = self.model_dir
if summary.scope:
output_dir = os.path.join(output_dir, summary.namespace,
summary.scope)
summary_saver_hook = tf_compat.SummarySaverHook(
save_steps=self.config.save_summary_steps,
output_dir=output_dir,
summary_op=summary.merge_all())
training_hooks.append(summary_saver_hook)
training_hooks += list(
current_iteration.estimator_spec.training_chief_hooks)
return training_hooks
Expand Down
10 changes: 7 additions & 3 deletions adanet/core/estimator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2348,9 +2348,13 @@ def serving_input_fn():
export_dir_base=self.test_subdirectory,
serving_input_receiver_fn=serving_input_fn,
experimental_mode=tf.estimator.ModeKeys.PREDICT)
estimator.export_savedmodel(
export_dir_base=self.test_subdirectory,
serving_input_receiver_fn=serving_input_fn)
try:
estimator.export_savedmodel(
export_dir_base=self.test_subdirectory,
serving_input_receiver_fn=serving_input_fn)
except AttributeError as error:
# Log deprecation errors.
logging.warning("Testing estimator#export_savedmodel: %s", error)
estimator.experimental_export_all_saved_models(
export_dir_base=self.test_subdirectory,
input_receiver_fn_map={
Expand Down
Loading

0 comments on commit 800ad3f

Please sign in to comment.