From 800ad3fc937ee0b3dd2f6e64c4d4b5541a9ae30b Mon Sep 17 00:00:00 2001 From: Charles Weill Date: Wed, 18 Sep 2019 11:51:45 -0400 Subject: [PATCH] Utilize TF 2.0 summaries in core when V2 behavior is enabled. #93 Fallback to old TF v1 versions of summaries when TF v2 behavior is not enabled. Also add additional summary, estimator, and autoensemble tests that only execute when TF v2 is enabled. PiperOrigin-RevId: 269815407 --- adanet/autoensemble/BUILD | 11 + adanet/autoensemble/estimator_test.py | 2 +- adanet/autoensemble/estimator_v2_test.py | 148 +++++++++ adanet/core/BUILD | 27 ++ adanet/core/estimator.py | 97 +++++- adanet/core/estimator_test.py | 10 +- adanet/core/estimator_v2_test.py | 169 ++++++++++ adanet/core/eval_metrics.py | 3 +- adanet/core/summary.py | 384 +++++++++++++++++------ adanet/core/summary_test.py | 30 +- adanet/core/summary_v2_test.py | 299 ++++++++++++++++++ adanet/core/testing_utils.py | 44 +++ adanet/tf_compat/__init__.py | 47 +++ 13 files changed, 1167 insertions(+), 104 deletions(-) create mode 100644 adanet/autoensemble/estimator_v2_test.py create mode 100644 adanet/core/estimator_v2_test.py create mode 100644 adanet/core/summary_v2_test.py diff --git a/adanet/autoensemble/BUILD b/adanet/autoensemble/BUILD index aab9e527..ba4c9a56 100644 --- a/adanet/autoensemble/BUILD +++ b/adanet/autoensemble/BUILD @@ -34,3 +34,14 @@ py_test( "@absl_py//absl/testing:parameterized", ], ) + +py_test( + name = "estimator_v2_test", + size = "large", + srcs = ["estimator_v2_test.py"], + shard_count = 5, + deps = [ + ":estimator", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/adanet/autoensemble/estimator_test.py b/adanet/autoensemble/estimator_test.py index 57a825e2..1553ed7b 100644 --- a/adanet/autoensemble/estimator_test.py +++ b/adanet/autoensemble/estimator_test.py @@ -1,4 +1,4 @@ -"""Tests for AdaNet AutoEnsembleEstimator. +"""Tests for AdaNet AutoEnsembleEstimator in TF 1. Copyright 2018 The AdaNet Authors. All Rights Reserved. diff --git a/adanet/autoensemble/estimator_v2_test.py b/adanet/autoensemble/estimator_v2_test.py new file mode 100644 index 00000000..61ed6ae3 --- /dev/null +++ b/adanet/autoensemble/estimator_v2_test.py @@ -0,0 +1,148 @@ +"""Tests for AdaNet AutoEnsembleEstimator in TF 2. + +Copyright 2019 The AdaNet Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import sys + +from absl import flags +from absl.testing import parameterized +from adanet.autoensemble.estimator import AutoEnsembleEstimator +import tensorflow as tf + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.python.estimator.export import export +from tensorflow_estimator.python.estimator.head import regression_head +# pylint: enable=g-direct-tensorflow-import + + +class AutoEnsembleEstimatorV2Test(parameterized.TestCase, tf.test.TestCase): + + def setUp(self): + super(AutoEnsembleEstimatorV2Test, self).setUp() + # Setup and cleanup test directory. + # Flags are not automatically parsed at this point. + flags.FLAGS(sys.argv) + self.test_subdirectory = os.path.join(flags.FLAGS.test_tmpdir, self.id()) + shutil.rmtree(self.test_subdirectory, ignore_errors=True) + os.makedirs(self.test_subdirectory) + + def tearDown(self): + super(AutoEnsembleEstimatorV2Test, self).tearDown() + shutil.rmtree(self.test_subdirectory, ignore_errors=True) + + # pylint: disable=g-long-lambda + @parameterized.named_parameters( + { + "testcase_name": + "candidate_pool_lambda", + "candidate_pool": + lambda head, feature_columns, optimizer: lambda config: { + "dnn": + tf.estimator.DNNEstimator( + head=head, + feature_columns=feature_columns, + optimizer=optimizer, + hidden_units=[3], + config=config), + "linear": + tf.estimator.LinearEstimator( + head=head, + feature_columns=feature_columns, + optimizer=optimizer, + config=config), + }, + "want_loss": + .209, + },) + # pylint: enable=g-long-lambda + def test_auto_ensemble_estimator_lifecycle(self, + candidate_pool, + want_loss, + max_train_steps=30): + features = {"input_1": [[1., 0.]]} + labels = [[1.]] + + run_config = tf.estimator.RunConfig(tf_random_seed=42) + head = regression_head.RegressionHead() + + # Always create optimizers in a lambda to prevent error like: + # `RuntimeError: Cannot set `iterations` to a new Variable after the + # Optimizer weights have been created` + optimizer = lambda: tf.keras.optimizers.SGD(lr=.01) + feature_columns = [tf.feature_column.numeric_column("input_1", shape=[2])] + + def train_input_fn(): + input_features = {} + for key, feature in features.items(): + input_features[key] = tf.constant(feature, name=key) + input_labels = tf.constant(labels, name="labels") + return input_features, input_labels + + def test_input_fn(): + dataset = tf.data.Dataset.from_tensors([tf.constant(features["input_1"])]) + input_features = tf.compat.v1.data.make_one_shot_iterator( + dataset).get_next() + return {"input_1": input_features}, None + + estimator = AutoEnsembleEstimator( + head=head, + candidate_pool=candidate_pool(head, feature_columns, optimizer), + max_iteration_steps=10, + force_grow=True, + model_dir=self.test_subdirectory, + config=run_config) + + # Train for three iterations. + estimator.train(input_fn=train_input_fn, max_steps=max_train_steps) + + # Evaluate. + eval_results = estimator.evaluate(input_fn=train_input_fn, steps=1) + + self.assertAllClose(max_train_steps, eval_results["global_step"]) + self.assertAllClose(want_loss, eval_results["loss"], atol=.3) + + # Predict. + predictions = estimator.predict(input_fn=test_input_fn) + for prediction in predictions: + self.assertIsNotNone(prediction["predictions"]) + + # Export SavedModel. + def serving_input_fn(): + """Input fn for serving export, starting from serialized example.""" + serialized_example = tf.compat.v1.placeholder( + dtype=tf.string, shape=(None), name="serialized_example") + for key, value in features.items(): + features[key] = tf.constant(value) + return export.SupervisedInputReceiver( + features=features, + labels=tf.constant(labels), + receiver_tensors=serialized_example) + + export_dir_base = os.path.join(self.test_subdirectory, "export") + estimator.export_saved_model( + export_dir_base=export_dir_base, + serving_input_receiver_fn=serving_input_fn) + + +if __name__ == "__main__": + tf.enable_v2_behavior() + tf.test.main() diff --git a/adanet/core/BUILD b/adanet/core/BUILD index fab8f8ed..c7797e7b 100644 --- a/adanet/core/BUILD +++ b/adanet/core/BUILD @@ -58,6 +58,22 @@ py_test( ], ) +py_test( + name = "estimator_v2_test", + size = "large", + srcs = ["estimator_v2_test.py"], + shard_count = 2, + deps = [ + ":ensemble_builder", + ":estimator", + ":evaluator", + ":report_materializer", + ":testing_utils", + "//adanet/subnetwork", + "@absl_py//absl/testing:parameterized", + ], +) + py_test( name = "estimator_distributed_test", size = "large", @@ -245,6 +261,17 @@ py_test( ], ) +py_test( + name = "summary_v2_test", + srcs = ["summary_v2_test.py"], + deps = [ + ":summary", + ":testing_utils", + "@absl_py//absl/testing:parameterized", + "@six_archive//:six", + ], +) + py_library( name = "timer", srcs = ["timer.py"], diff --git a/adanet/core/estimator.py b/adanet/core/estimator.py index c8721008..b5125dc1 100644 --- a/adanet/core/estimator.py +++ b/adanet/core/estimator.py @@ -35,6 +35,7 @@ from adanet.core.iteration import _IterationBuilder from adanet.core.report_accessor import _ReportAccessor from adanet.core.summary import _ScopedSummary +from adanet.core.summary import _ScopedSummaryV2 from adanet.core.summary import _TPUScopedSummary from adanet.core.timer import _CountDownTimer from adanet.distributed import ReplicationStrategy @@ -88,6 +89,68 @@ def _stop_if_is_over(self, run_context): self._after_fn() +class _SummaryV2SaverHook(tf_compat.SessionRunHook): + """A hook that writes summaries to the appropriate log directory on disk.""" + + def __init__(self, summaries, save_steps=None, save_secs=None): + """Initializes a `SummaryV2SaverHook` for writing TF 2 summaries. + + Args: + summaries: List of `_ScopedSummaryV2` instances. + save_steps: `int`, save summaries every N steps. Exactly one of + `save_secs` and `save_steps` should be set. + save_secs: `int`, save summaries every N seconds. + """ + + self._summaries = summaries + self._summary_ops = [] + self._writer_init_ops = [] + self._timer = tf_compat.v1.train.SecondOrStepTimer( + every_secs=save_secs, every_steps=save_steps) + + def begin(self): + self._next_step = None + self._global_step_tensor = tf_compat.v1.train.get_global_step() + + for summary in self._summaries: + assert isinstance(summary, _ScopedSummaryV2) + writer = tf_compat.v2.summary.create_file_writer(summary.logdir) + with writer.as_default(): + for summary_fn, tensor in summary.summary_tuples(): + self._summary_ops.append( + summary_fn(tensor, step=tf.compat.v1.train.get_global_step())) + self._writer_init_ops.append(writer.init()) + + def after_create_session(self, session, coord): + session.run(self._writer_init_ops) + + def before_run(self, run_context): + requests = {"global_step": self._global_step_tensor} + self._request_summary = ( + self._next_step is None or + self._timer.should_trigger_for_step(self._next_step)) + if self._request_summary: + requests["summary"] = self._summary_ops + + return tf_compat.SessionRunArgs(requests) + + def after_run(self, run_context, run_values): + stale_global_step = run_values.results["global_step"] + global_step = stale_global_step + 1 + if self._next_step is None or self._request_summary: + global_step = run_context.session.run(self._global_step_tensor) + + if self._request_summary: + self._timer.update_last_triggered_step(global_step) + + self._next_step = global_step + 1 + + def end(self, session): + # TODO: Run writer.flush() at Session end. + # Currently disabled because the flush op crashes between iterations. + return + + class _EvalMetricSaverHook(tf_compat.SessionRunHook): """A hook for writing candidate evaluation metrics as summaries to disk.""" @@ -588,6 +651,13 @@ def __init__(self, def _summary_maker(self, scope=None, skip_summary=False, namespace=None): """Constructs a `_ScopedSummary`.""" + if tf_compat.is_v2_behavior_enabled(): + # Here we assume TF 2 behavior is enabled. + return _ScopedSummaryV2( + logdir=self._model_dir, + scope=scope, + skip_summary=skip_summary, + namespace=namespace) if self._use_tpu: return _TPUScopedSummary( logdir=self._model_dir, @@ -1346,15 +1416,24 @@ def _training_chief_hooks(self, current_iteration, training): return [] training_hooks = [] - for summary in current_iteration.summaries: - output_dir = self.model_dir - if summary.scope: - output_dir = os.path.join(output_dir, summary.namespace, summary.scope) - summary_saver_hook = tf_compat.SummarySaverHook( - save_steps=self.config.save_summary_steps, - output_dir=output_dir, - summary_op=summary.merge_all()) - training_hooks.append(summary_saver_hook) + if tf_compat.is_v2_behavior_enabled(): + # Use V2 summaries and hook when user is using TF 2 behavior. + training_hooks.append( + _SummaryV2SaverHook( + current_iteration.summaries, + save_steps=self.config.save_summary_steps)) + else: + # Fallback to V1 summaries. + for summary in current_iteration.summaries: + output_dir = self.model_dir + if summary.scope: + output_dir = os.path.join(output_dir, summary.namespace, + summary.scope) + summary_saver_hook = tf_compat.SummarySaverHook( + save_steps=self.config.save_summary_steps, + output_dir=output_dir, + summary_op=summary.merge_all()) + training_hooks.append(summary_saver_hook) training_hooks += list( current_iteration.estimator_spec.training_chief_hooks) return training_hooks diff --git a/adanet/core/estimator_test.py b/adanet/core/estimator_test.py index b9b3e9dd..9a8c3ff5 100644 --- a/adanet/core/estimator_test.py +++ b/adanet/core/estimator_test.py @@ -2348,9 +2348,13 @@ def serving_input_fn(): export_dir_base=self.test_subdirectory, serving_input_receiver_fn=serving_input_fn, experimental_mode=tf.estimator.ModeKeys.PREDICT) - estimator.export_savedmodel( - export_dir_base=self.test_subdirectory, - serving_input_receiver_fn=serving_input_fn) + try: + estimator.export_savedmodel( + export_dir_base=self.test_subdirectory, + serving_input_receiver_fn=serving_input_fn) + except AttributeError as error: + # Log deprecation errors. + logging.warning("Testing estimator#export_savedmodel: %s", error) estimator.experimental_export_all_saved_models( export_dir_base=self.test_subdirectory, input_receiver_fn_map={ diff --git a/adanet/core/estimator_v2_test.py b/adanet/core/estimator_v2_test.py new file mode 100644 index 00000000..29171801 --- /dev/null +++ b/adanet/core/estimator_v2_test.py @@ -0,0 +1,169 @@ +"""Test AdaNet estimator single graph implementation for TF 2. + +Copyright 2019 The AdaNet Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from absl import logging +from adanet import tf_compat +from adanet.core import testing_utils as tu +from adanet.core.estimator import Estimator +from adanet.core.report_materializer import ReportMaterializer +from adanet.subnetwork import Builder +from adanet.subnetwork import SimpleGenerator +from adanet.subnetwork import Subnetwork +import tensorflow as tf + +from tensorflow_estimator.python.estimator.head import regression_head + +logging.set_verbosity(logging.INFO) + +XOR_FEATURES = [[1., 0.], [0., 0], [0., 1.], [1., 1.]] +XOR_LABELS = [[1.], [0.], [1.], [0.]] + + +class _SimpleBuilder(Builder): + """A simple subnetwork builder that takes feature_columns.""" + + def __init__(self, name, seed=42): + self._name = name + self._seed = seed + + @property + def name(self): + return self._name + + def build_subnetwork(self, + features, + logits_dimension, + training, + iteration_step, + summary, + previous_ensemble=None): + seed = self._seed + if previous_ensemble: + # Increment seed so different iterations don't learn the exact same thing. + seed += 1 + + with tf_compat.v1.variable_scope("simple"): + input_layer = tf_compat.v1.feature_column.input_layer( + features=features, + feature_columns=tf.feature_column.numeric_column("x", 2)) + last_layer = input_layer + + with tf_compat.v1.variable_scope("logits"): + logits = tf_compat.v1.layers.dense( + last_layer, + logits_dimension, + kernel_initializer=tf_compat.v1.glorot_uniform_initializer(seed=seed)) + + summary.scalar("scalar", 3) + batch_size = features["x"].get_shape().as_list()[0] + summary.image("image", tf.ones([batch_size, 3, 3, 1])) + with tf_compat.v1.variable_scope("nested"): + summary.scalar("scalar", 5) + + return Subnetwork( + last_layer=last_layer, + logits=logits, + complexity=1, + persisted_tensors={}, + ) + + def build_subnetwork_train_op(self, subnetwork, loss, var_list, labels, + iteration_step, summary, previous_ensemble): + optimizer = tf_compat.v1.train.GradientDescentOptimizer(learning_rate=.001) + return optimizer.minimize(loss, var_list=var_list) + + +class EstimatorSummaryWriterTest(tu.AdanetTestCase): + """Test that Tensorboard summaries get written correctly.""" + + def test_summaries(self): + """Tests that summaries are written to candidate directory.""" + + run_config = tf.estimator.RunConfig( + tf_random_seed=42, + log_step_count_steps=2, + save_summary_steps=2, + model_dir=self.test_subdirectory) + subnetwork_generator = SimpleGenerator([_SimpleBuilder("dnn")]) + report_materializer = ReportMaterializer( + input_fn=tu.dummy_input_fn([[1., 1.]], [[0.]]), steps=1) + estimator = Estimator( + head=regression_head.RegressionHead( + loss_reduction=tf_compat.SUM_OVER_BATCH_SIZE), + subnetwork_generator=subnetwork_generator, + report_materializer=report_materializer, + max_iteration_steps=10, + config=run_config) + train_input_fn = tu.dummy_input_fn([[1., 0.]], [[1.]]) + estimator.train(input_fn=train_input_fn, max_steps=3) + + ensemble_loss = 1.52950 + self.assertAlmostEqual( + ensemble_loss, + tu.check_eventfile_for_keyword("loss", self.test_subdirectory), + places=3) + self.assertIsNotNone( + tu.check_eventfile_for_keyword("global_step/sec", + self.test_subdirectory)) + self.assertEqual( + 0., + tu.check_eventfile_for_keyword("iteration/adanet/iteration", + self.test_subdirectory)) + + subnetwork_subdir = os.path.join(self.test_subdirectory, + "subnetwork/t0_dnn") + self.assertAlmostEqual( + 3., + tu.check_eventfile_for_keyword("scalar", subnetwork_subdir), + places=3) + self.assertEqual((3, 3, 1), + tu.check_eventfile_for_keyword("image", subnetwork_subdir)) + self.assertAlmostEqual( + 5., + tu.check_eventfile_for_keyword("nested/scalar", subnetwork_subdir), + places=3) + + ensemble_subdir = os.path.join( + self.test_subdirectory, "ensemble/t0_dnn_grow_complexity_regularized") + self.assertAlmostEqual( + ensemble_loss, + tu.check_eventfile_for_keyword( + "adanet_loss/adanet/adanet_weighted_ensemble", ensemble_subdir), + places=1) + self.assertAlmostEqual( + 0., + tu.check_eventfile_for_keyword( + "complexity_regularization/adanet/adanet_weighted_ensemble", + ensemble_subdir), + places=3) + self.assertAlmostEqual( + 1., + tu.check_eventfile_for_keyword( + "mixture_weight_norms/adanet/" + "adanet_weighted_ensemble/subnetwork_0", ensemble_subdir), + places=3) + + +if __name__ == "__main__": + tf.enable_v2_behavior() + tf.test.main() diff --git a/adanet/core/eval_metrics.py b/adanet/core/eval_metrics.py index d9426b8e..e6d155e1 100644 --- a/adanet/core/eval_metrics.py +++ b/adanet/core/eval_metrics.py @@ -352,8 +352,7 @@ def _replay_eval_metrics(best_candidate_idx, eval_metric_ops): """Saves replay indices as eval metrics.""" # _replay_indices_for_all is a dict: {candidate: [list of replay_indices]} # We are finding the max length replay list. - pad_value = max( - [len(v) for _, v in self._replay_indices_for_all.items()]) + pad_value = max([len(v) for _, v in self._replay_indices_for_all.items()]) # Creating a matrix of (#candidate) times (max length replay indices). # Entry i,j is the jth replay index of the ith candidate (ensemble). diff --git a/adanet/core/summary.py b/adanet/core/summary.py index b6649233..62c0b9b8 100644 --- a/adanet/core/summary.py +++ b/adanet/core/summary.py @@ -25,10 +25,13 @@ from absl import logging from adanet import tf_compat +import tensorflow as tf_v1 import tensorflow as tf # pylint: disable=g-direct-tensorflow-import +from tensorboard.compat import tf2 from tensorflow.python.ops import summary_op_util from tensorflow.python.ops import summary_ops_v2 as summary_v2_lib +from tensorflow.python.ops.summary_ops_v2 import _INVALID_SCOPE_CHARACTERS from tensorflow.python.summary import summary as summary_lib # pylint: enable=g-direct-tensorflow-import @@ -41,17 +44,20 @@ class Summary(object): __metaclass__ = abc.ABCMeta @abc.abstractmethod - def scalar(self, name, tensor, family=None): + def scalar(self, name, tensor, family=None, description=None): """Outputs a `tf.Summary` protocol buffer containing a single scalar value. The generated tf.Summary has a Tensor.proto containing the input Tensor. Args: - name: A name for the generated node. Will also serve as the series name in - TensorBoard. - tensor: A real numeric Tensor containing a single value. + name: A name for this summary. The summary tag used for TensorBoard will + be this name prefixed by any active name scopes. + tensor: A real numeric scalar value, convertible to a float32 Tensor. family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. + which controls the tab name used for display on Tensorboard. DEPRECATED + in TF 2. + description: Optional long-form description for this summary, as a + constant str. Markdown is supported. Defaults to empty. Returns: A scalar `Tensor` of type `string`. Which contains a `tf.Summary` @@ -62,7 +68,7 @@ def scalar(self, name, tensor, family=None): """ @abc.abstractmethod - def image(self, name, tensor, max_outputs=3, family=None): + def image(self, name, tensor, max_outputs=3, family=None, description=None): """Outputs a `tf.Summary` protocol buffer with images. The summary has up to `max_outputs` summary values containing images. The @@ -93,13 +99,23 @@ def image(self, name, tensor, max_outputs=3, family=None): generated sequentially as '*name*/image/0', '*name*/image/1', etc. Args: - name: A name for the generated node. Will also serve as a series name in - TensorBoard. - tensor: A 4-D `uint8` or `float32` `Tensor` of shape `[batch_size, height, - width, channels]` where `channels` is 1, 3, or 4. - max_outputs: Max number of batch elements to generate images for. + name: A name for this summary. The summary tag used for TensorBoard will + be this name prefixed by any active name scopes. + tensor: A Tensor representing pixel data with shape [k, h, w, c], where k + is the number of images, h and w are the height and width of the images, + and c is the number of channels, which should be 1, 2, 3, or 4 + (grayscale, grayscale with alpha, RGB, RGBA). Any of the dimensions may + be statically unknown (i.e., None). Floating point data will be clipped + to the range [0,1). + max_outputs: Optional int or rank-0 integer Tensor. At most this many + images will be emitted at each step. When more than max_outputs many + images are provided, the first max_outputs many images will be used and + the rest silently discarded. family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. + which controls the tab name used for display on Tensorboard. DEPRECATED + in TF 2. + description: Optional long-form description for this summary, as a + constant str. Markdown is supported. Defaults to empty. Returns: A scalar `Tensor` of type `string`. The serialized `tf.Summary` protocol @@ -107,7 +123,12 @@ def image(self, name, tensor, max_outputs=3, family=None): """ @abc.abstractmethod - def histogram(self, name, values, family=None): + def histogram(self, + name, + values, + family=None, + buckets=None, + description=None): """Outputs a `tf.Summary` protocol buffer with a histogram. Adding a histogram summary makes it possible to visualize your data's @@ -122,12 +143,18 @@ def histogram(self, name, values, family=None): This op reports an `InvalidArgument` error if any value is not finite. Args: - name: A name for the generated node. Will also serve as a series name in - TensorBoard. - values: A real numeric `Tensor`. Any shape. Values to use to build the - histogram. + name: A name for this summary. The summary tag used for TensorBoard will + be this name prefixed by any active name scopes. + values: A Tensor of any shape. Must be castable to float64. family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. + which controls the tab name used for display on Tensorboard. DEPRECATED + in TF 2. + buckets: Optional positive int. The output will have this many buckets, + except in two edge cases. If there is no data, then there are no + buckets. If there is data but all points have the same value, then there + is one bucket whose left and right endpoints are the same. + description: Optional long-form description for this summary, as a + constant str. Markdown is supported. Defaults to empty. Returns: A scalar `Tensor` of type `string`. The serialized `tf.Summary` protocol @@ -135,33 +162,37 @@ def histogram(self, name, values, family=None): """ @abc.abstractmethod - def audio(self, name, tensor, sample_rate, max_outputs=3, family=None): - """Outputs a `tf.Summary` protocol buffer with audio. - - The summary has up to `max_outputs` summary values containing audio. The - audio is built from `tensor` which must be 3-D with shape `[batch_size, - frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are - assumed to be in the range of `[-1.0, 1.0]` with a sample rate of - `sample_rate`. - - The `tag` in the outputted tf.Summary.Value protobufs is generated based on - the - name, with a suffix depending on the max_outputs setting: - - * If `max_outputs` is 1, the summary value tag is '*name*/audio'. - * If `max_outputs` is greater than 1, the summary value tags are - generated sequentially as '*name*/audio/0', '*name*/audio/1', etc + def audio(self, + name, + tensor, + sample_rate, + max_outputs=3, + family=None, + encoding=None, + description=None): + """Writes an audio summary. Args: - name: A name for the generated node. Will also serve as a series name in - TensorBoard. - tensor: A 3-D `float32` `Tensor` of shape `[batch_size, frames, channels]` - or a 2-D `float32` `Tensor` of shape `[batch_size, frames]`. - sample_rate: A Scalar `float32` `Tensor` indicating the sample rate of the - signal in hertz. - max_outputs: Max number of batch elements to generate audio for. + name: A name for this summary. The summary tag used for TensorBoard will + be this name prefixed by any active name scopes. + tensor: A Tensor representing audio data with shape [k, t, c], where k is + the number of audio clips, t is the number of frames, and c is the + number of channels. Elements should be floating-point values in [-1.0, + 1.0]. Any of the dimensions may be statically unknown (i.e., None). + sample_rate: An int or rank-0 int32 Tensor that represents the sample + rate, in Hz. Must be positive. + max_outputs: Optional int or rank-0 integer Tensor. At most this many + audio clips will be emitted at each step. When more than max_outputs + many clips are provided, the first max_outputs many clips will be used + and the rest silently discarded. family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. + which controls the tab name used for display on Tensorboard. DEPRECATED + in TF 2. + encoding: Optional constant str for the desired encoding. Only "wav" is + currently supported, but this is not guaranteed to remain the default, + so if you want "wav" in particular, set this explicitly. + description: Optional long-form description for this summary, as a + constant str. Markdown is supported. Defaults to empty. Returns: A scalar `Tensor` of type `string`. The serialized `tf.Summary` protocol @@ -339,10 +370,10 @@ def merge_all(self): return [op for op in self._summary_ops if op.graph == current_graph] -# TODO: _ScopedSummary and _TPUScopedSummary share a lot of the same +# TODO: _ScopedSummary and _ScopedSummaryV2 share a lot of the same # methods. Extract a base class for the two, or move shared methods into # Summary. -class _TPUScopedSummary(Summary): +class _ScopedSummaryV2(Summary): """Records summaries in a given scope. Only for TPUEstimator. @@ -381,10 +412,10 @@ def __init__(self, logdir, namespace=None, scope=None, skip_summary=False): self._scope = scope self._additional_scope = None self._skip_summary = skip_summary - self._actual_summary_scalar_fn = summary_v2_lib.scalar - self._actual_summary_image_fn = summary_v2_lib.image - self._actual_summary_histogram_fn = summary_v2_lib.histogram - self._actual_summary_audio_fn = summary_v2_lib.audio + self._actual_summary_scalar_fn = tf_compat.v2.summary.scalar + self._actual_summary_image_fn = tf_compat.v2.summary.image + self._actual_summary_histogram_fn = tf_compat.v2.summary.histogram + self._actual_summary_audio_fn = tf_compat.v2.summary.audio self._summary_tuples = [] @property @@ -405,20 +436,30 @@ def logdir(self): return self._logdir + @property + def writer(self): + """Returns the file writer.""" + + return self._writer + @contextlib.contextmanager def current_scope(self): """Registers the current context's scope to strip it from summary tags.""" self._additional_scope = tf_compat.v1.get_default_graph().get_name_scope() - yield - self._additional_scope = None + try: + yield + finally: + self._additional_scope = None @contextlib.contextmanager def _strip_tag_scope(self, additional_scope): """Monkey patches `summary_op_util.summary_scope` to strip tag scopes.""" original_summary_scope = summary_op_util.summary_scope + original_summary_scope_v2 = getattr(summary_v2_lib, "summary_scope") + # TF 1. @contextlib.contextmanager def strip_tag_scope_fn(name, family=None, default_name=None, values=None): tag, scope = (None, None) @@ -427,16 +468,43 @@ def strip_tag_scope_fn(name, family=None, default_name=None, values=None): scope = s yield tag, scope - summary_op_util.summary_scope = strip_tag_scope_fn - yield - summary_op_util.summary_scope = original_summary_scope + # TF 2. + @contextlib.contextmanager + def monkey_patched_summary_scope_fn(name, + default_name="summary", + values=None): + """Rescopes the summary tag with the ScopedSummary's scope.""" + + name = name or default_name + current_scope = tf_compat.v1.get_default_graph().get_name_scope() + tag = current_scope + "/" + name if current_scope else name + # Strip illegal characters from the scope name, and if that leaves + # nothing, use None instead so we pick up the default name. + name = _INVALID_SCOPE_CHARACTERS.sub("", name) or None + with tf.compat.v1.name_scope(name, default_name, values) as scope: + tag = _strip_scope(tag, self.scope, additional_scope) + yield tag, scope + + setattr(summary_op_util, "summary_scope", strip_tag_scope_fn) + setattr(summary_v2_lib, "summary_scope", monkey_patched_summary_scope_fn) + setattr(tf2.summary.experimental, "summary_scope", + monkey_patched_summary_scope_fn) + setattr(tf2.summary, "summary_scope", monkey_patched_summary_scope_fn) + try: + yield + finally: + setattr(summary_op_util, "summary_scope", original_summary_scope) + setattr(summary_v2_lib, "summary_scope", original_summary_scope_v2) + setattr(tf2.summary.experimental, "summary_scope", + original_summary_scope_v2) + setattr(tf2.summary, "summary_scope", original_summary_scope_v2) def _prefix_scope(self, name): scope = self._scope - if not scope: - scope = _DEFAULT_SCOPE if name[0] == "/": name = name[1:] + if not scope: + scope = _DEFAULT_SCOPE return "{scope}/{name}".format(scope=scope, name=name) def _create_summary(self, summary_fn, name, tensor): @@ -482,44 +550,60 @@ def _summary_fn(tensor, step): self._summary_tuples.append((_summary_fn, tensor)) - def scalar(self, name, tensor, family=None): + def scalar(self, name, tensor, family=None, description=None): def _summary_fn(name, tensor, step): return self._actual_summary_scalar_fn( - name=name, tensor=tensor, family=family, step=step) + name=name, data=tensor, description=description, step=step) self._create_summary(_summary_fn, name, - tf.reshape(tf.convert_to_tensor(value=tensor), [1])) + tf.reshape(tf.convert_to_tensor(value=tensor), [])) - def image(self, name, tensor, max_outputs=3, family=None): + def image(self, name, tensor, max_outputs=3, family=None, description=None): def _summary_fn(name, tensor, step): return self._actual_summary_image_fn( name=name, - tensor=tensor, - max_images=max_outputs, - family=family, + data=tensor, + max_outputs=max_outputs, + description=description, step=step) self._create_summary(_summary_fn, name, tf.cast(tensor, tf.float32)) - def histogram(self, name, values, family=None): + def histogram(self, + name, + values, + family=None, + buckets=None, + description=None): def _summary_fn(name, tensor, step): return self._actual_summary_histogram_fn( - name=name, tensor=tensor, family=family, step=step) + name=name, + data=tensor, + buckets=buckets, + description=description, + step=step) self._create_summary(_summary_fn, name, tf.convert_to_tensor(value=values)) - def audio(self, name, tensor, sample_rate, max_outputs=3, family=None): + def audio(self, + name, + tensor, + sample_rate, + max_outputs=3, + family=None, + encoding=None, + description=None): def _summary_fn(name, tensor, step): return self._actual_summary_audio_fn( name=name, - tensor=tensor, + data=tensor, sample_rate=sample_rate, - max_outputs=max_outputs, - family=family, + encoding=encoding, + description=description, step=step) self._create_summary(_summary_fn, name, tf.cast(tensor, tf.float32)) @@ -540,6 +624,68 @@ def clear_summary_tuples(self): self._summary_tuples = [] +class _TPUScopedSummary(_ScopedSummaryV2): + """Records summaries in a given scope. + + Only for TPUEstimator. + + Each scope gets assigned a different collection where summary ops gets added. + + This allows Tensorboard to display summaries with different scopes but the + same name in the same charts. + """ + + def __init__(self, logdir, namespace=None, scope=None, skip_summary=False): + super(_TPUScopedSummary, self).__init__(logdir, namespace, scope, + skip_summary) + self._actual_summary_scalar_fn = summary_v2_lib.scalar + self._actual_summary_image_fn = summary_v2_lib.image + self._actual_summary_histogram_fn = summary_v2_lib.histogram + self._actual_summary_audio_fn = summary_v2_lib.audio + + def scalar(self, name, tensor, family=None): + + def _summary_fn(name, tensor, step): + return self._actual_summary_scalar_fn( + name=name, tensor=tensor, family=family, step=step) + + self._create_summary(_summary_fn, name, + tf.reshape(tf.convert_to_tensor(value=tensor), [1])) + + def image(self, name, tensor, max_outputs=3, family=None): + + def _summary_fn(name, tensor, step): + return self._actual_summary_image_fn( + name=name, + tensor=tensor, + max_images=max_outputs, + family=family, + step=step) + + self._create_summary(_summary_fn, name, tf.cast(tensor, tf.float32)) + + def histogram(self, name, values, family=None): + + def _summary_fn(name, tensor, step): + return self._actual_summary_histogram_fn( + name=name, tensor=tensor, family=family, step=step) + + self._create_summary(_summary_fn, name, tf.convert_to_tensor(value=values)) + + def audio(self, name, tensor, sample_rate, max_outputs=3, family=None): + + def _summary_fn(name, tensor, step): + return self._actual_summary_audio_fn( + name=name, + tensor=tensor, + sample_rate=sample_rate, + max_outputs=max_outputs, + family=family, + step=step) + + self._create_summary(_summary_fn, name, tf.cast(tensor, tf.float32)) + + class _SummaryWrapper(object): """Wraps an `adanet.Summary` to provide summary-like APIs.""" @@ -653,6 +799,60 @@ def audio_v2(self, max_outputs=max_outputs, family=family) + def scalar_v3(self, name, data, step=None, description=None): + """See `tf.compat.v2.summary.scalar`.""" + + if step is not None: + logging.warning( + "The `step` argument will be ignored to use the iteration step for " + "scalar summary: %s", name) + return self._summary.scalar(name=name, tensor=data, description=description) + + def image_v3(self, name, data, step=None, max_outputs=3, description=None): + """See `tf.compat.v2.summary.image`.""" + + if step is not None: + logging.warning( + "The `step` argument will be ignored to use the iteration step for " + "image summary: %s", name) + return self._summary.image( + name=name, + tensor=data, + max_outputs=max_outputs, + description=description) + + def histogram_v3(self, name, data, step=None, buckets=None, description=None): + """See `tf.compat.v2.summary.histogram`.""" + + if step is not None: + logging.warning( + "The `step` argument will be ignored to use the global step for " + "histogram summary: %s", name) + return self._summary.histogram( + name=name, tensor=data, buckets=buckets, description=description) + + def audio_v3(self, + name, + data, + sample_rate, + step=None, + max_outputs=3, + encoding=None, + description=None): + """See `tf.compat.v2.summary.audio`.""" + + if step is not None: + logging.warning( + "The `step` argument will be ignored to use the global step for " + "audio summary: %s", name) + return self._summary.audio( + name=name, + tensor=data, + sample_rate=sample_rate, + max_outputs=max_outputs, + encoding=encoding, + description=description) + @contextlib.contextmanager def monkey_patched_summaries(summary): @@ -684,10 +884,10 @@ def monkey_patched_summaries(summary): # Monkey-patch global attributes. wrapped_summary = _SummaryWrapper(summary) - setattr(tf.summary, "scalar", wrapped_summary.scalar) - setattr(tf.summary, "image", wrapped_summary.image) - setattr(tf.summary, "histogram", wrapped_summary.histogram) - setattr(tf.summary, "audio", wrapped_summary.audio) + setattr(tf_v1.summary, "scalar", wrapped_summary.scalar) + setattr(tf_v1.summary, "image", wrapped_summary.image) + setattr(tf_v1.summary, "histogram", wrapped_summary.histogram) + setattr(tf_v1.summary, "audio", wrapped_summary.audio) setattr(tf_compat.v1.summary, "scalar", wrapped_summary.scalar) setattr(tf_compat.v1.summary, "image", wrapped_summary.image) setattr(tf_compat.v1.summary, "histogram", wrapped_summary.histogram) @@ -696,20 +896,24 @@ def monkey_patched_summaries(summary): setattr(summary_lib, "image", wrapped_summary.image) setattr(summary_lib, "histogram", wrapped_summary.histogram) setattr(summary_lib, "audio", wrapped_summary.audio) - setattr(tf_compat.v2.summary, "scalar", wrapped_summary.scalar_v2) - setattr(tf_compat.v2.summary, "image", wrapped_summary.image_v2) - setattr(tf_compat.v2.summary, "histogram", wrapped_summary.histogram_v2) - setattr(tf_compat.v2.summary, "audio", wrapped_summary.audio_v2) + setattr(tf_compat.v2.summary, "scalar", wrapped_summary.scalar_v3) + setattr(tf_compat.v2.summary, "image", wrapped_summary.image_v3) + setattr(tf_compat.v2.summary, "histogram", wrapped_summary.histogram_v3) + setattr(tf_compat.v2.summary, "audio", wrapped_summary.audio_v3) + setattr(tf.summary, "scalar", wrapped_summary.scalar_v3) + setattr(tf.summary, "image", wrapped_summary.image_v3) + setattr(tf.summary, "histogram", wrapped_summary.histogram_v3) + setattr(tf.summary, "audio", wrapped_summary.audio_v3) setattr(summary_v2_lib, "scalar", wrapped_summary.scalar_v2) setattr(summary_v2_lib, "image", wrapped_summary.image_v2) setattr(summary_v2_lib, "histogram", wrapped_summary.histogram_v2) setattr(summary_v2_lib, "audio", wrapped_summary.audio_v2) try: # TF 2.0 eliminates tf.contrib. - setattr(tf.contrib.summary, "scalar", wrapped_summary.scalar_v2) - setattr(tf.contrib.summary, "image", wrapped_summary.image_v2) - setattr(tf.contrib.summary, "histogram", wrapped_summary.histogram_v2) - setattr(tf.contrib.summary, "audio", wrapped_summary.audio_v2) + setattr(tf_v1.contrib.summary, "scalar", wrapped_summary.scalar_v2) + setattr(tf_v1.contrib.summary, "image", wrapped_summary.image_v2) + setattr(tf_v1.contrib.summary, "histogram", wrapped_summary.histogram_v2) + setattr(tf_v1.contrib.summary, "audio", wrapped_summary.audio_v2) except AttributeError: # TF 2.0 eliminates tf.contrib. pass @@ -719,10 +923,10 @@ def monkey_patched_summaries(summary): finally: # Revert monkey-patches. try: - setattr(tf.contrib.summary, "audio", old_summary_v2_audio) - setattr(tf.contrib.summary, "histogram", old_summary_v2_histogram) - setattr(tf.contrib.summary, "image", old_summary_v2_image) - setattr(tf.contrib.summary, "scalar", old_summary_v2_scalar) + setattr(tf_v1.contrib.summary, "audio", old_summary_v2_audio) + setattr(tf_v1.contrib.summary, "histogram", old_summary_v2_histogram) + setattr(tf_v1.contrib.summary, "image", old_summary_v2_image) + setattr(tf_v1.contrib.summary, "scalar", old_summary_v2_scalar) except AttributeError: # TF 2.0 eliminates tf.contrib. pass @@ -730,6 +934,10 @@ def monkey_patched_summaries(summary): setattr(summary_v2_lib, "histogram", old_summary_v2_histogram) setattr(summary_v2_lib, "image", old_summary_v2_image) setattr(summary_v2_lib, "scalar", old_summary_v2_scalar) + setattr(tf.summary, "audio", old_summary_compat_v2_audio) + setattr(tf.summary, "histogram", old_summary_compat_v2_histogram) + setattr(tf.summary, "image", old_summary_compat_v2_image) + setattr(tf.summary, "scalar", old_summary_compat_v2_scalar) setattr(tf_compat.v2.summary, "audio", old_summary_compat_v2_audio) setattr(tf_compat.v2.summary, "histogram", old_summary_compat_v2_histogram) setattr(tf_compat.v2.summary, "image", old_summary_compat_v2_image) @@ -742,7 +950,7 @@ def monkey_patched_summaries(summary): setattr(tf_compat.v1.summary, "histogram", old_summary_histogram) setattr(tf_compat.v1.summary, "image", old_summary_image) setattr(tf_compat.v1.summary, "scalar", old_summary_scalar) - setattr(tf.summary, "audio", old_summary_audio) - setattr(tf.summary, "histogram", old_summary_histogram) - setattr(tf.summary, "image", old_summary_image) - setattr(tf.summary, "scalar", old_summary_scalar) + setattr(tf_v1.summary, "audio", old_summary_audio) + setattr(tf_v1.summary, "histogram", old_summary_histogram) + setattr(tf_v1.summary, "image", old_summary_image) + setattr(tf_v1.summary, "scalar", old_summary_scalar) diff --git a/adanet/core/summary_test.py b/adanet/core/summary_test.py index bd9ce7cd..ddaa1dcd 100644 --- a/adanet/core/summary_test.py +++ b/adanet/core/summary_test.py @@ -1,4 +1,4 @@ -"""Test AdaNet summary single graph implementation. +"""Test AdaNet summary single graph implementation for TF 1. Copyright 2018 The AdaNet Authors. All Rights Reserved. @@ -48,6 +48,7 @@ class ScopedSummaryTest(parameterized.TestCase, tf.test.TestCase): "testcase_name": "with_scope", "scope": "with_scope", }) + @tf_compat.skip_for_tf2 def test_scope(self, scope): scoped_summary = _ScopedSummary(scope) self.assertEqual(scope, scoped_summary.scope) @@ -64,6 +65,7 @@ def test_scope(self, scope): "scope": None, "skip_summary": True, }) + @tf_compat.skip_for_tf2 def test_scalar_summary(self, scope, skip_summary=False): scoped_summary = _ScopedSummary(scope, skip_summary) with self.test_session() as s: @@ -89,6 +91,7 @@ def test_scalar_summary(self, scope, skip_summary=False): "testcase_name": "with_scope", "scope": "with_scope", }) + @tf_compat.skip_for_tf2 def test_scalar_summary_with_family(self, scope): scoped_summary = _ScopedSummary(scope) with self.test_session() as s: @@ -119,6 +122,7 @@ def test_scalar_summary_with_family(self, scope): "testcase_name": "with_scope", "scope": "with_scope", }) + @tf_compat.skip_for_tf2 def test_summarizing_variable(self, scope): scoped_summary = _ScopedSummary(scope) with self.test_session() as s: @@ -147,6 +151,7 @@ def test_summarizing_variable(self, scope): "scope": None, "skip_summary": True, }) + @tf_compat.skip_for_tf2 def test_image_summary(self, scope, skip_summary=False): scoped_summary = _ScopedSummary(scope, skip_summary) with self.test_session() as s: @@ -173,6 +178,7 @@ def test_image_summary(self, scope, skip_summary=False): "testcase_name": "with_scope", "scope": "with_scope", }) + @tf_compat.skip_for_tf2 def test_image_summary_with_family(self, scope): scoped_summary = _ScopedSummary(scope) with self.test_session() as s: @@ -201,6 +207,7 @@ def test_image_summary_with_family(self, scope): "scope": None, "skip_summary": True, }) + @tf_compat.skip_for_tf2 def test_histogram_summary(self, scope, skip_summary=False): scoped_summary = _ScopedSummary(scope, skip_summary) with self.test_session() as s: @@ -224,6 +231,7 @@ def test_histogram_summary(self, scope, skip_summary=False): "testcase_name": "with_scope", "scope": "with_scope", }) + @tf_compat.skip_for_tf2 def test_histogram_summary_with_family(self, scope): scoped_summary = _ScopedSummary(scope) with self.test_session() as s: @@ -248,6 +256,7 @@ def test_histogram_summary_with_family(self, scope): "scope": None, "skip_summary": True, }) + @tf_compat.skip_for_tf2 def test_audio_summary(self, scope, skip_summary=False): scoped_summary = _ScopedSummary(scope, skip_summary) with self.test_session() as s: @@ -274,6 +283,7 @@ def test_audio_summary(self, scope, skip_summary=False): "testcase_name": "with_scope", "scope": "with_scope", }) + @tf_compat.skip_for_tf2 def test_audio_summary_with_family(self, scope): scoped_summary = _ScopedSummary(scope) with self.test_session() as s: @@ -299,6 +309,7 @@ def test_audio_summary_with_family(self, scope): "testcase_name": "with_scope", "scope": "with_scope", }) + @tf_compat.skip_for_tf2 def test_summary_name_conversion(self, scope): scoped_summary = _ScopedSummary(scope) c = tf.constant(3) @@ -326,6 +337,7 @@ def test_summary_name_conversion(self, scope): "testcase_name": "nested_graph", "nest_graph": True, }) + @tf_compat.skip_for_tf2 def test_merge_all(self, nest_graph): c0 = tf.constant(0) c1 = tf.constant(1) @@ -414,6 +426,7 @@ def write_summaries(self, summary): "testcase_name": "with_scope", "scope": "with_scope", }) + @tf_compat.skip_for_tf2 def test_scope(self, scope): scoped_summary = _TPUScopedSummary(self.test_subdirectory, scope=scope) self.assertEqual(scope, scoped_summary.scope) @@ -430,6 +443,7 @@ def test_scope(self, scope): "scope": None, "skip_summary": True, }) + @tf_compat.skip_for_tf2 def test_scalar_summary(self, scope, skip_summary=False): scoped_summary = _TPUScopedSummary( self.test_subdirectory, scope=scope, skip_summary=skip_summary) @@ -453,6 +467,7 @@ def test_scalar_summary(self, scope, skip_summary=False): "testcase_name": "with_scope", "scope": "with_scope", }) + @tf_compat.skip_for_tf2 def test_scalar_summary_with_family(self, scope): scoped_summary = _TPUScopedSummary(self.test_subdirectory, scope=scope) i = tf.constant(7) @@ -481,6 +496,7 @@ def test_scalar_summary_with_family(self, scope): "testcase_name": "with_scope", "scope": "with_scope", }) + @tf_compat.skip_for_tf2 def test_summarizing_variable(self, scope): scoped_summary = _TPUScopedSummary(self.test_subdirectory, scope=scope) c = tf.constant(42.0) @@ -506,6 +522,7 @@ def test_summarizing_variable(self, scope): "scope": None, "skip_summary": True, }) + @tf_compat.skip_for_tf2 def test_image_summary(self, scope, skip_summary=False): scoped_summary = _TPUScopedSummary( self.test_subdirectory, scope=scope, skip_summary=skip_summary) @@ -530,6 +547,7 @@ def test_image_summary(self, scope, skip_summary=False): "testcase_name": "with_scope", "scope": "with_scope", }) + @tf_compat.skip_for_tf2 def test_image_summary_with_family(self, scope): scoped_summary = _TPUScopedSummary(self.test_subdirectory, scope=scope) i = tf.ones((5, 2, 3, 1)) @@ -556,6 +574,7 @@ def test_image_summary_with_family(self, scope): "scope": None, "skip_summary": True, }) + @tf_compat.skip_for_tf2 def test_histogram_summary(self, scope, skip_summary=False): scoped_summary = _TPUScopedSummary( self.test_subdirectory, scope=scope, skip_summary=skip_summary) @@ -578,6 +597,7 @@ def test_histogram_summary(self, scope, skip_summary=False): "testcase_name": "with_scope", "scope": "with_scope", }) + @tf_compat.skip_for_tf2 def test_histogram_summary_with_family(self, scope): scoped_summary = _TPUScopedSummary(self.test_subdirectory, scope=scope) i = tf.ones((5, 4, 4, 3)) @@ -601,6 +621,7 @@ def test_histogram_summary_with_family(self, scope): "scope": None, "skip_summary": True, }) + @tf_compat.skip_for_tf2 def test_audio_summary(self, scope, skip_summary=False): scoped_summary = _TPUScopedSummary( self.test_subdirectory, scope=scope, skip_summary=skip_summary) @@ -625,6 +646,7 @@ def test_audio_summary(self, scope, skip_summary=False): "testcase_name": "with_scope", "scope": "with_scope", }) + @tf_compat.skip_for_tf2 def test_audio_summary_with_family(self, scope): scoped_summary = _TPUScopedSummary(self.test_subdirectory, scope=scope) i = tf.ones((5, 3, 4)) @@ -647,6 +669,7 @@ def test_audio_summary_with_family(self, scope): "testcase_name": "with_scope", "scope": "with_scope", }) + @tf_compat.skip_for_tf2 def test_summary_name_conversion(self, scope): scoped_summary = _TPUScopedSummary(self.test_subdirectory, scope=scope) c = tf.constant(3) @@ -669,6 +692,7 @@ def test_summary_name_conversion(self, scope): "testcase_name": "with_scope", "scope": "with_scope", }) + @tf_compat.skip_for_tf2 def test_current_scope(self, scope): scoped_summary = _TPUScopedSummary(self.test_subdirectory, scope=scope) i = tf.constant(3) @@ -684,6 +708,7 @@ def test_current_scope(self, scope): self.assertEqual(values[0].tag, "inner1/inner2/a/b/c") self.assertEqual(values[0].simple_value, 3.0) + @tf_compat.skip_for_tf2 def test_summary_args(self): summary = _TPUScopedSummary(self.test_subdirectory) summary.scalar("scalar", 1, "family") @@ -692,6 +717,7 @@ def test_summary_args(self): summary.audio("audio", 1, 3, 3, "family") self.assertLen(summary.summary_tuples(), 4) + @tf_compat.skip_for_tf2 def test_summary_kwargs(self): summary = _TPUScopedSummary(self.test_subdirectory) summary.scalar(name="scalar", tensor=1, family="family") @@ -733,6 +759,7 @@ def _get_summary_ops(self, summary): "summary_maker": functools.partial(_TPUScopedSummary, logdir="/tmp/fakedir") }) + @tf_compat.skip_for_tf2 def test_monkey_patched_summaries_args(self, summary_maker): summary = summary_maker() before = _summaries() @@ -761,6 +788,7 @@ def test_monkey_patched_summaries_args(self, summary_maker): "summary_maker": functools.partial(_TPUScopedSummary, logdir="/tmp/fakedir"), }) + @tf_compat.skip_for_tf2 def test_monkey_patched_summaries_kwargs(self, summary_maker): summary = summary_maker() before = _summaries() diff --git a/adanet/core/summary_v2_test.py b/adanet/core/summary_v2_test.py new file mode 100644 index 00000000..4a3b7b74 --- /dev/null +++ b/adanet/core/summary_v2_test.py @@ -0,0 +1,299 @@ +"""Test AdaNet summary single graph implementation for TF 2. + +Copyright 2019 The AdaNet Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import struct + +from absl.testing import parameterized +from adanet import tf_compat +from adanet.core import testing_utils as tu +from adanet.core.summary import _ScopedSummaryV2 +import tensorflow as tf + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.python.eager import context +from tensorflow.python.framework import test_util +# pylint: enable=g-direct-tensorflow-import + + +def simple_value(summary_value): + """Returns the scalar parsed from the summary proto tensor_value bytes.""" + + return struct.unpack("