diff --git a/adanet/core/estimator.py b/adanet/core/estimator.py index e66bf442..45f50d8c 100644 --- a/adanet/core/estimator.py +++ b/adanet/core/estimator.py @@ -25,6 +25,8 @@ import os import time +from absl import logging +from adanet import tf_compat from adanet.core.architecture import _Architecture from adanet.core.candidate import _CandidateBuilder from adanet.core.ensemble_builder import _EnsembleBuilder @@ -44,7 +46,7 @@ import tensorflow as tf -class _StopAfterTrainingHook(tf.train.SessionRunHook): +class _StopAfterTrainingHook(tf.estimator.SessionRunHook): """Hook that requests stop once iteration is over.""" def __init__(self, iteration, after_fn): @@ -65,7 +67,7 @@ def before_run(self, run_context): """See `SessionRunHook`.""" del run_context # Unused - return tf.train.SessionRunArgs(self._iteration.is_over_fn()) + return tf.estimator.SessionRunArgs(self._iteration.is_over_fn()) def after_run(self, run_context, run_values): """See `SessionRunHook`.""" @@ -77,7 +79,7 @@ def after_run(self, run_context, run_values): self._after_fn() -class _EvalMetricSaverHook(tf.train.SessionRunHook): +class _EvalMetricSaverHook(tf.estimator.SessionRunHook): """A hook for writing candidate evaluation metrics as summaries to disk.""" def __init__(self, name, kind, eval_metrics, output_dir): @@ -110,8 +112,8 @@ def begin(self): # the metric variables. The metrics themselves are computed as a result of # being returned in the EstimatorSpec by _adanet_model_fn. metric_fn, tensors = self._eval_metrics - tensors = [tf.placeholder(t.dtype, t.shape) for t in tensors] - eval_metric_ops = metric_fn(*tensors) + tensors = {k: tf.placeholder(v.dtype, v.shape) for k, v in tensors.items()} + eval_metric_ops = metric_fn(**tensors) self._eval_metric_tensors = {k: v[0] for k, v in eval_metric_ops.items()} def _dict_to_str(self, dictionary): @@ -135,7 +137,7 @@ def end(self, session): eval_dict, current_global_step = session.run( (self._eval_metric_tensors, current_global_step)) - tf.logging.info("Saving %s '%s' dict for global step %d: %s", + logging.info("Saving %s '%s' dict for global step %d: %s", self._kind, self._name, current_global_step, self._dict_to_str(eval_dict)) summary_writer = tf.summary.FileWriterCache.get(self._output_dir) @@ -150,14 +152,14 @@ def end(self, session): summ.value[i].tag = "{}/{}".format(key, i) summary_proto.value.extend(summ.value) else: - tf.logging.warn( + logging.warn( "Skipping summary for %s, must be a float, np.float32, " "or a serialized string of Summary.", key) summary_writer.add_summary(summary_proto, current_global_step) summary_writer.flush() -class _OverwriteCheckpointHook(tf.train.SessionRunHook): +class _OverwriteCheckpointHook(tf.estimator.SessionRunHook): """Hook to overwrite the latest checkpoint with next iteration variables.""" def __init__(self, current_iteration, iteration_number_tensor, @@ -227,14 +229,14 @@ def before_run(self, run_context): self._checkpoint_overwritten = True -class _HookContextDecorator(tf.train.SessionRunHook): +class _HookContextDecorator(tf.estimator.SessionRunHook): """Decorates a SessionRunHook's public methods to run within a context.""" def __init__(self, hook, context, is_growing_phase): """Initializes a _HookContextDecorator instance. Args: - hook: The tf.train.SessionRunHook to decorate. + hook: The tf.estimator.SessionRunHook to decorate. context: The context to enter before calling the hook's public methods. is_growing_phase: Whether we are in the AdaNet graph growing phase. If so, only hook.begin() and hook.end() will be called. @@ -498,7 +500,7 @@ def __init__(self, k: v for k, v in kwargs.items() if k in default_ensembler_args } if default_ensembler_kwargs: - tf.logging.warn( + logging.warn( "The following arguments have been moved to " "`adanet.ensemble.ComplexityRegularizedEnsembler` which can be " "specified in the `ensemblers` argument: {}".format( @@ -510,7 +512,7 @@ def __init__(self, placement_strategy_arg = "experimental_placement_strategy" placement_strategy = kwargs.pop(placement_strategy_arg, None) if placement_strategy: - tf.logging.warning( + logging.warning( "%s is an experimental feature. Its behavior is not guaranteed " "to be backwards compatible.", placement_strategy_arg) @@ -636,7 +638,7 @@ def train(self, ), self._train_loop_context(): while True: current_iteration = self._latest_checkpoint_iteration_number() - tf.logging.info("Beginning training AdaNet iteration %s", + logging.info("Beginning training AdaNet iteration %s", current_iteration) self._iteration_ended = False result = super(Estimator, self).train( @@ -645,7 +647,7 @@ def train(self, max_steps=max_steps, saving_listeners=saving_listeners) - tf.logging.info("Finished training Adanet iteration %s", + logging.info("Finished training Adanet iteration %s", current_iteration) # If training ended because the maximum number of training steps @@ -658,7 +660,7 @@ def train(self, if not self._iteration_ended: return result - tf.logging.info("Beginning bookkeeping phase for iteration %s", + logging.info("Beginning bookkeeping phase for iteration %s", current_iteration) # The chief prepares the next AdaNet iteration, and increments the @@ -695,13 +697,13 @@ def train(self, # Check timeout when waiting for potentially downed chief. if timer.secs_remaining() == 0: - tf.logging.error( + logging.error( "Chief job did not prepare next iteration after %s secs. It " "may have been preempted, been turned down, or crashed. This " "worker is now exiting training.", self._worker_wait_timeout_secs) return result - tf.logging.info("Waiting for chief to finish") + logging.info("Waiting for chief to finish") time.sleep(self._worker_wait_secs) # Stagger starting workers to prevent training instability. @@ -712,11 +714,11 @@ def train(self, delay_secs = min(self._max_worker_delay_secs, (task_id + 1.) * self._delay_secs_per_worker) if delay_secs > 0.: - tf.logging.info("Waiting %d secs before continuing training.", + logging.info("Waiting %d secs before continuing training.", delay_secs) time.sleep(delay_secs) - tf.logging.info("Finished bookkeeping phase for iteration %s", + logging.info("Finished bookkeeping phase for iteration %s", current_iteration) def evaluate(self, @@ -786,31 +788,31 @@ def _prepare_next_iteration(self, train_input_fn): Args: train_input_fn: The input_fn used during training. """ - tf.logging.info("Preparing next iteration:") + logging.info("Preparing next iteration:") # First, evaluate and choose the best ensemble for this iteration. - tf.logging.info("Evaluating candidates...") + logging.info("Evaluating candidates...") self._prepare_next_iteration_state = self._Keys.EVALUATE_ENSEMBLES if self._evaluator: evaluator_input_fn = self._evaluator.input_fn else: evaluator_input_fn = train_input_fn self._call_adanet_model_fn(evaluator_input_fn, tf.estimator.ModeKeys.EVAL) - tf.logging.info("Done evaluating candidates.") + logging.info("Done evaluating candidates.") # Then materialize and store the subnetwork reports. if self._report_materializer: - tf.logging.info("Materializing reports...") + logging.info("Materializing reports...") self._prepare_next_iteration_state = self._Keys.MATERIALIZE_REPORT self._call_adanet_model_fn(self._report_materializer.input_fn, tf.estimator.ModeKeys.EVAL) - tf.logging.info("Done materializing reports.") + logging.info("Done materializing reports.") self._best_ensemble_index = None # Finally, create the graph for the next iteration and overwrite the model # directory checkpoint with the expanded graph. - tf.logging.info("Adapting graph and incrementing iteration number...") + logging.info("Adapting graph and incrementing iteration number...") self._prepare_next_iteration_state = self._Keys.INCREMENT_ITERATION temp_model_dir = os.path.join(self.model_dir, "temp_model_dir") if tf.gfile.Exists(temp_model_dir): @@ -826,9 +828,9 @@ def _prepare_next_iteration(self, train_input_fn): saving_listeners=None) tf.gfile.DeleteRecursively(temp_model_dir) self._prepare_next_iteration_state = None - tf.logging.info("Done adapting graph and incrementing iteration number.") + logging.info("Done adapting graph and incrementing iteration number.") - tf.logging.info("Finished preparing next iteration.") + logging.info("Finished preparing next iteration.") def _architecture_filename(self, iteration_number): """Returns the filename of the given iteration's frozen graph.""" @@ -865,7 +867,7 @@ def _get_best_ensemble_index(self, current_iteration): # Skip the evaluation phase when there is only one candidate subnetwork. if len(current_iteration.candidates) == 1: - tf.logging.info( + logging.info( "As the only candidate, '%s' is moving onto the next iteration.", current_iteration.candidates[0].ensemble_spec.name) return 0 @@ -874,14 +876,14 @@ def _get_best_ensemble_index(self, current_iteration): # previous_ensemble. if current_iteration.number > 0 and self._force_grow and (len( current_iteration.candidates) == 2): - tf.logging.info( + logging.info( "As the only candidate with `force_grow` enabled, '%s' is moving" "onto the next iteration.", current_iteration.candidates[1].ensemble_spec.name) return 1 latest_checkpoint = tf.train.latest_checkpoint(self.model_dir) - tf.logging.info("Starting ensemble evaluation for iteration %s", + logging.info("Starting ensemble evaluation for iteration %s", current_iteration.number) with tf.Session() as sess: init = tf.group(tf.global_variables_initializer(), @@ -906,9 +908,9 @@ def _get_best_ensemble_index(self, current_iteration): ensemble_name = current_iteration.candidates[i].ensemble_spec.name values.append("{}/{} = {:.6f}".format(metric_name, ensemble_name, adanet_losses[i])) - tf.logging.info("Computed ensemble metrics: %s", ", ".join(values)) + logging.info("Computed ensemble metrics: %s", ", ".join(values)) if self._force_grow and current_iteration.number > 0: - tf.logging.info( + logging.info( "The `force_grow` override is enabled, so the " "the performance of the previous ensemble will be ignored.") # NOTE: The zero-th index candidate at iteration t>0 is always the @@ -917,9 +919,9 @@ def _get_best_ensemble_index(self, current_iteration): index = np.argmin(adanet_losses) + 1 else: index = np.argmin(adanet_losses) - tf.logging.info("Finished ensemble evaluation for iteration %s", + logging.info("Finished ensemble evaluation for iteration %s", current_iteration.number) - tf.logging.info("'%s' at index %s is moving onto the next iteration", + logging.info("'%s' at index %s is moving onto the next iteration", current_iteration.candidates[index].ensemble_spec.name, index) return index @@ -936,7 +938,7 @@ def _materialize_report(self, current_iteration): """ latest_checkpoint = tf.train.latest_checkpoint(self.model_dir) - tf.logging.info("Starting metric logging for iteration %s", + logging.info("Starting metric logging for iteration %s", current_iteration.number) assert self._best_ensemble_index is not None @@ -961,7 +963,7 @@ def _materialize_report(self, current_iteration): self._report_accessor.write_iteration_report(current_iteration.number, materialized_reports) - tf.logging.info("Finished saving subnetwork reports for iteration %s", + logging.info("Finished saving subnetwork reports for iteration %s", current_iteration.number) def _decorate_hooks(self, hooks): @@ -989,7 +991,7 @@ def _training_chief_hooks(self, current_iteration, training): training: Whether in training mode. Returns: - A list of `tf.train.SessionRunHook` instances. + A list of `tf.estimator.SessionRunHook` instances. """ if not training: @@ -1000,7 +1002,7 @@ def _training_chief_hooks(self, current_iteration, training): output_dir = self.model_dir if summary.scope: output_dir = os.path.join(output_dir, summary.namespace, summary.scope) - summary_saver_hook = tf.train.SummarySaverHook( + summary_saver_hook = tf.estimator.SummarySaverHook( save_steps=self.config.save_summary_steps, output_dir=output_dir, summary_op=summary.merge_all()) @@ -1022,7 +1024,7 @@ def _training_hooks(self, current_iteration, training, _OverwriteCheckpointHook will be created. Returns: - A list of `tf.train.SessionRunHook` instances. + A list of `tf.estimator.SessionRunHook` instances. """ if not training: @@ -1051,7 +1053,7 @@ def _evaluation_hooks(self, current_iteration, training): training: Whether in training mode. Returns: - A list of `tf.train.SessionRunHook` instances. + A list of `tf.estimator.SessionRunHook` instances. """ if training: @@ -1376,7 +1378,7 @@ def _adanet_model_fn(self, features, labels, mode, params): if not tf.gfile.Exists(architecture_filename): continue architecture = self._read_architecture(architecture_filename) - tf.logging.info( + logging.info( "Importing architecture from %s: [%s].", architecture_filename, ", ".join( sorted([ @@ -1456,7 +1458,7 @@ def _adanet_model_fn(self, features, labels, mode, params): assert mode == tf.estimator.ModeKeys.TRAIN assert self.config.is_chief latest_checkpoint = tf.train.latest_checkpoint(self.model_dir) - tf.logging.info( + logging.info( "Overwriting checkpoint with new graph for iteration %s to %s", iteration_number, latest_checkpoint) return self._create_estimator_spec(current_iteration, mode, diff --git a/adanet/core/eval_metrics.py b/adanet/core/eval_metrics.py index 32c4a485..2a8ce8e6 100644 --- a/adanet/core/eval_metrics.py +++ b/adanet/core/eval_metrics.py @@ -104,7 +104,14 @@ def create_eval_metrics(self, features, labels, estimator_spec, metric_fn): self._eval_metrics_store.add_eval_metrics( self._templatize_metric_fn(spec_fn), spec_args) - loss_fn = lambda loss: {"loss": tf.compat.v1.metrics.mean(loss)} + if tf.executing_eagerly(): + loss_metric = tf.keras.metrics.Mean() + + def loss_fn(loss): + loss_metric(loss) + return {"loss": loss_metric } + else: + loss_fn = lambda loss: {"loss": tf.compat.v1.metrics.mean(loss) } loss_fn_args = [tf.reshape(estimator_spec.loss, [1])] self._eval_metrics_store.add_eval_metrics( self._templatize_metric_fn(loss_fn), loss_fn_args) @@ -137,6 +144,9 @@ def _templatize_metric_fn(self, metric_fn): The original metric_fn wrapped with a template function. """ + if tf.executing_eagerly(): + return metric_fn + def _metric_fn(*args, **kwargs): """The wrapping function to be returned.""" @@ -179,6 +189,17 @@ def _metric_fn(*args): return _metric_fn, self._eval_metrics_store.flatten_args() +class _StringMetric(tf.keras.metrics.Metric): + + def __init__(self, name='string_metric', **kwargs): + super(_StringMetric, self).__init__(name=name, **kwargs) + self._value = self.add_weight(name='value', initializer=tf.keras.initializers.Constant(''), dtype=tf.string) + + def update_state(self, value): + self._value.assign(value) + + def result(self): + return self._value class _EnsembleMetrics(_SubnetworkMetrics): """A object which creates evaluation metrics for Ensembles.""" @@ -219,9 +240,13 @@ def _architecture_metric_fn(): tensor=tf.compat.v1.make_tensor_proto(architecture_, dtype=tf.string)) architecture_summary = tf.convert_to_tensor( value=summary_proto.SerializeToString(), name="architecture") - return { - "architecture/adanet/ensembles": (architecture_summary, tf.no_op()) - } + + if tf.executing_eagerly: + metric = _StringMetric() + metric(architecture_summary) + else: + metric = (architecture_summary, tf.no_op()) + return {"architecture/adanet/ensembles": metric} return _architecture_metric_fn @@ -285,7 +310,11 @@ def _best_eval_metrics_fn(*args): with tf.compat.v1.variable_scope("best_eval_metrics"): args = list(args) - idx, idx_update_op = tf.compat.v1.metrics.mean(args.pop()) + if tf.executing_eagerly(): + mean = tf.keras.metrics.Mean() + idx, idx_update_op = mean.result(), mean(args.pop()) + else: + idx, idx_update_op = tf.compat.v1.metrics.mean(args.pop()) metric_fns = self._candidates_eval_metrics_store.metric_fns metric_fn_args = self._candidates_eval_metrics_store.pack_args( diff --git a/adanet/core/eval_metrics_test.py b/adanet/core/eval_metrics_test.py index 12935492..fb6e3bc4 100644 --- a/adanet/core/eval_metrics_test.py +++ b/adanet/core/eval_metrics_test.py @@ -35,6 +35,8 @@ def _run_metrics(sess, metrics): metric_ops = metrics if isinstance(metric_ops, tuple): metric_ops = call_eval_metrics(metric_ops) + if tf.executing_eagerly(): + return {k: metric_ops[k].result().numpy() for k in metric_ops} sess.run((tf.compat.v1.global_variables_initializer(), tf.compat.v1.local_variables_initializer())) sess.run(metric_ops) @@ -65,8 +67,9 @@ def setUp(self): })) def _assert_tensors_equal(self, actual, expected): - with self.test_session() as sess: - actual, expected = sess.run((actual, expected)) + if not tf.executing_eagerly(): + with self.test_session() as sess: + actual, expected = sess.run((actual, expected)) self.assertEqual(actual, expected) def _spec_metric_fn(self, features, labels, predictions, loss): @@ -76,12 +79,20 @@ def _spec_metric_fn(self, features, labels, predictions, loss): self._estimator_spec.loss ] self._assert_tensors_equal(actual, expected) + if tf.executing_eagerly(): + metric = tf.metrics.Mean() + metric(tf.constant(1.)) + return {"metric_1": metric} return {"metric_1": tf.compat.v1.metrics.mean(tf.constant(1.))} def _metric_fn(self, features, predictions): actual = [features, predictions] expected = [self._features, self._estimator_spec.predictions] self._assert_tensors_equal(actual, expected) + if tf.executing_eagerly(): + metric = tf.metrics.Mean() + metric(tf.constant(2.)) + return {"metric_2": metric} return {"metric_2": tf.compat.v1.metrics.mean(tf.constant(2.))} @parameterized.named_parameters({ @@ -110,7 +121,12 @@ def test_subnetwork_metrics_user_metric_fn_overrides_metrics(self): overridden_value = 100. def _overriding_metric_fn(): - return {"metric_1": tf.compat.v1.metrics.mean(tf.constant(overridden_value))} + value = tf.constant(overridden_value) + if tf.executing_eagerly(): + metric = tf.metrics.Mean() + metric.update_state(value) + return {"metric_1": metric} + return {"metric_1": tf.compat.v1.metrics.mean(value)} metrics = _SubnetworkMetrics() metrics.create_eval_metrics(self._features, self._labels, @@ -164,6 +180,10 @@ def test_iteration_metrics(self, use_tpu, mode): for i in range(10): def metric_fn(val=i): + if tf.executing_eagerly(): + metric = tf.metrics.Mean() + metric(tf.constant(val)) + return {"ensemble_metric": metric} return {"ensemble_metric": tf.compat.v1.metrics.mean(tf.constant(val))} spec = _EnsembleSpec( diff --git a/adanet/core/summary.py b/adanet/core/summary.py index 789bf057..2df23a1a 100644 --- a/adanet/core/summary.py +++ b/adanet/core/summary.py @@ -24,7 +24,7 @@ import os import tensorflow as tf -from tensorflow.contrib.tpu.python.tpu import tpu_function +from tensorflow.python.tpu import tpu_function # pylint: disable=g-direct-tensorflow-import from tensorflow.python.ops import summary_op_util from tensorflow.python.ops import summary_ops_v2 as summary_v2_lib @@ -201,8 +201,8 @@ def __init__(self, scope=None, skip_summary=False, namespace=None): """ if tpu_function.get_tpu_context().number_of_shards: - tf.logging.log_first_n( - tf.logging.WARN, + tf.compat.v1.logging.log_first_n( + tf.compat.v1.logging.WARN, "Scoped summaries will be skipped since they do not support TPU", 1) skip_summary = True @@ -211,10 +211,10 @@ def __init__(self, scope=None, skip_summary=False, namespace=None): self._additional_scope = None self._skip_summary = skip_summary self._summary_ops = [] - self._actual_summary_scalar_fn = tf.summary.scalar - self._actual_summary_image_fn = tf.summary.image - self._actual_summary_histogram_fn = tf.summary.histogram - self._actual_summary_audio_fn = tf.summary.audio + self._actual_summary_scalar_fn = tf.compat.v1.summary.scalar + self._actual_summary_image_fn = tf.compat.v1.summary.image + self._actual_summary_histogram_fn = tf.compat.v1.summary.histogram + self._actual_summary_audio_fn = tf.compat.v1.summary.audio @property def scope(self): @@ -232,7 +232,7 @@ def namespace(self): def current_scope(self): """Registers the current context's scope to strip it from summary tags.""" - self._additional_scope = tf.get_default_graph().get_name_scope() + self._additional_scope = tf.compat.v1.get_default_graph().get_name_scope() yield self._additional_scope = None @@ -334,7 +334,7 @@ def merge_all(self): only used in the internal implementation, so this should be OK. """ - current_graph = tf.get_default_graph() + current_graph = tf.compat.v1.get_default_graph() return [op for op in self._summary_ops if op.graph == current_graph] @@ -380,10 +380,10 @@ def __init__(self, logdir, namespace=None, scope=None, skip_summary=False): self._scope = scope self._additional_scope = None self._skip_summary = skip_summary - self._actual_summary_scalar_fn = tf.contrib.summary.scalar - self._actual_summary_image_fn = tf.contrib.summary.image - self._actual_summary_histogram_fn = tf.contrib.summary.histogram - self._actual_summary_audio_fn = tf.contrib.summary.audio + self._actual_summary_scalar_fn = tf.compat.v2.summary.scalar + self._actual_summary_image_fn = tf.compat.v2.summary.image + self._actual_summary_histogram_fn = tf.compat.v2.summary.histogram + self._actual_summary_audio_fn = tf.compat.v2.summary.audio self._summary_tuples = [] @property @@ -408,7 +408,7 @@ def logdir(self): def current_scope(self): """Registers the current context's scope to strip it from summary tags.""" - self._additional_scope = tf.get_default_graph().get_name_scope() + self._additional_scope = tf.compat.v1.get_default_graph().get_name_scope() yield self._additional_scope = None @@ -458,7 +458,7 @@ def _create_summary(self, summary_fn, name, tensor): additional_scope = self._additional_scope # name_scope is from whichever scope the summary actually gets called in. # e.g. "foo/bar/baz" - name_scope = tf.get_default_graph().get_name_scope() + name_scope = tf.compat.v1.get_default_graph().get_name_scope() # Reuse name_scope if it exists by appending "/" to it. name_scope = name_scope + "/" if name_scope else name_scope @@ -471,8 +471,8 @@ def _summary_fn(tensor, step): # e.g. "foo/bar/baz/scalar" will become "baz/scalar" when # additional_scope is "foo/bar". # TODO: Figure out a cleaner way to handle this. - assert not tf.get_default_graph().get_name_scope() - with tf.name_scope(name_scope): + assert not tf.compat.v1.get_default_graph().get_name_scope() + with tf.compat.v1.name_scope(name_scope): with self._strip_tag_scope(additional_scope): # TODO: Do summaries need to be reduced before writing? # Presumably each tensor core creates its own summary so we may be @@ -488,7 +488,7 @@ def _summary_fn(name, tensor, step): name=name, tensor=tensor, family=family, step=step) self._create_summary(_summary_fn, name, - tf.reshape(tf.convert_to_tensor(tensor), [1])) + tf.reshape(tf.convert_to_tensor(value=tensor), [1])) def image(self, name, tensor, max_outputs=3, family=None): @@ -508,7 +508,7 @@ def _summary_fn(name, tensor, step): return self._actual_summary_histogram_fn( name=name, tensor=tensor, family=family, step=step) - self._create_summary(_summary_fn, name, tf.convert_to_tensor(values)) + self._create_summary(_summary_fn, name, tf.convert_to_tensor(value=values)) def audio(self, name, tensor, sample_rate, max_outputs=3, family=None): @@ -549,7 +549,7 @@ def scalar(self, name, tensor, collections=None, family=None): """See `tf.summary.scalar`.""" if collections is not None: - tf.logging.warning( + tf.compat.v1.logging.warning( "The `collections` argument will be " "ignored for scalar summary: %s, %s", name, tensor) return self._summary.scalar(name=name, tensor=tensor, family=family) @@ -558,7 +558,7 @@ def image(self, name, tensor, max_outputs=3, collections=None, family=None): """See `tf.summary.image`.""" if collections is not None: - tf.logging.warning( + tf.compat.v1.logging.warning( "The `collections` argument will be " "ignored for image summary: %s, %s", name, tensor) return self._summary.image( @@ -568,7 +568,7 @@ def histogram(self, name, values, collections=None, family=None): """See `tf.summary.histogram`.""" if collections is not None: - tf.logging.warning( + tf.compat.v1.logging.warning( "The `collections` argument will be " "ignored for histogram summary: %s, %s", name, values) return self._summary.histogram(name=name, values=values, family=family) @@ -583,7 +583,7 @@ def audio(self, """See `tf.summary.audio`.""" if collections is not None: - tf.logging.warning( + tf.compat.v1.logging.warning( "The `collections` argument will be " "ignored for audio summary: %s, %s", name, tensor) return self._summary.audio( @@ -597,7 +597,7 @@ def scalar_v2(self, name, tensor, family=None, step=None): """See `tf.contrib.summary.scalar`.""" if step is not None: - tf.logging.warning( + tf.compat.v1.logging.warning( "The `step` argument will be ignored to use the global step for " "scalar summary: %s, %s", name, tensor) return self._summary.scalar(name=name, tensor=tensor, family=family) @@ -612,12 +612,12 @@ def image_v2(self, """See `tf.contrib.summary.image`.""" if step is not None: - tf.logging.warning( + tf.compat.v1.logging.warning( "The `step` argument will be ignored to use the global step for " "image summary: %s, %s", name, tensor) # TODO: Add support for `bad_color` arg. if bad_color is not None: - tf.logging.warning( + tf.compat.v1.logging.warning( "The `bad_color` arg is not supported for image summary: %s, %s", name, tensor) return self._summary.image( @@ -627,7 +627,7 @@ def histogram_v2(self, name, tensor, family=None, step=None): """See `tf.contrib.summary.histogram`.""" if step is not None: - tf.logging.warning( + tf.compat.v1.logging.warning( "The `step` argument will be ignored to use the global step for " "histogram summary: %s, %s", name, tensor) return self._summary.histogram(name=name, values=tensor, family=family) @@ -642,7 +642,7 @@ def audio_v2(self, """See `tf.contrib.summary.audio`.""" if step is not None: - tf.logging.warning( + tf.compat.v1.logging.warning( "The `step` argument will be ignored to use the global step for " "audio summary: %s, %s", name, tensor) return self._summary.audio( @@ -679,18 +679,18 @@ def monkey_patched_summaries(summary): # Monkey-patch global attributes. wrapped_summary = _SummaryWrapper(summary) - tf.summary.scalar = wrapped_summary.scalar - tf.summary.image = wrapped_summary.image - tf.summary.histogram = wrapped_summary.histogram - tf.summary.audio = wrapped_summary.audio + tf.compat.v1.summary.scalar = wrapped_summary.scalar + tf.compat.v1.summary.image = wrapped_summary.image + tf.compat.v1.summary.histogram = wrapped_summary.histogram + tf.compat.v1.summary.audio = wrapped_summary.audio summary_lib.scalar = wrapped_summary.scalar summary_lib.image = wrapped_summary.image summary_lib.histogram = wrapped_summary.histogram summary_lib.audio = wrapped_summary.audio - tf.contrib.summary.scalar = wrapped_summary.scalar_v2 - tf.contrib.summary.image = wrapped_summary.image_v2 - tf.contrib.summary.histogram = wrapped_summary.histogram_v2 - tf.contrib.summary.audio = wrapped_summary.audio_v2 + tf.compat.v2.summary.scalar = wrapped_summary.scalar_v2 + tf.compat.v2.summary.image = wrapped_summary.image_v2 + tf.compat.v2.summary.histogram = wrapped_summary.histogram_v2 + tf.compat.v2.summary.audio = wrapped_summary.audio_v2 summary_v2_lib.scalar = wrapped_summary.scalar_v2 summary_v2_lib.image = wrapped_summary.image_v2 summary_v2_lib.histogram = wrapped_summary.histogram_v2 @@ -704,15 +704,15 @@ def monkey_patched_summaries(summary): summary_v2_lib.histogram = old_summary_v2_histogram summary_v2_lib.image = old_summary_v2_image summary_v2_lib.scalar = old_summary_v2_scalar - tf.contrib.summary.audio = old_summary_v2_audio - tf.contrib.summary.histogram = old_summary_v2_histogram - tf.contrib.summary.image = old_summary_v2_image - tf.contrib.summary.scalar = old_summary_v2_scalar + tf.compat.v2.summary.audio = old_summary_v2_audio + tf.compat.v2.summary.histogram = old_summary_v2_histogram + tf.compat.v2.summary.image = old_summary_v2_image + tf.compat.v2.summary.scalar = old_summary_v2_scalar summary_lib.audio = old_summary_audio summary_lib.histogram = old_summary_histogram summary_lib.image = old_summary_image summary_lib.scalar = old_summary_scalar - tf.summary.audio = old_summary_audio - tf.summary.histogram = old_summary_histogram - tf.summary.image = old_summary_image - tf.summary.scalar = old_summary_scalar + tf.compat.v1.summary.audio = old_summary_audio + tf.compat.v1.summary.histogram = old_summary_histogram + tf.compat.v1.summary.image = old_summary_image + tf.compat.v1.summary.scalar = old_summary_scalar diff --git a/adanet/core/testing_utils.py b/adanet/core/testing_utils.py index f6ba165a..61bc4902 100644 --- a/adanet/core/testing_utils.py +++ b/adanet/core/testing_utils.py @@ -21,7 +21,9 @@ import os import shutil +import sys +from absl import flags from absl.testing import parameterized from adanet.core.architecture import _Architecture from adanet.core.ensemble_builder import _EnsembleSpec @@ -35,7 +37,7 @@ def dummy_tensor(shape=(), random_seed=42): """Returns a randomly initialized tensor.""" return tf.Variable( - tf.random_normal(shape=shape, seed=random_seed), + tf.random.normal(shape=shape, seed=random_seed), trainable=False).read_value() @@ -88,7 +90,7 @@ def dummy_ensemble_spec(name, if adanet_loss is None: adanet_loss = dummy_tensor([], random_seed * 2) else: - adanet_loss = tf.convert_to_tensor(adanet_loss) + adanet_loss = tf.convert_to_tensor(value=adanet_loss) logits = dummy_tensor([], random_seed * 3) if dict_predictions: @@ -181,7 +183,7 @@ def dummy_estimator_spec(loss=None, mode=tf.estimator.ModeKeys.TRAIN, predictions=predictions, loss=loss, - train_op=tf.no_op(), + train_op=loss, eval_metric_ops=eval_metric_ops) @@ -206,11 +208,11 @@ def _input_fn(params=None): del params # Unused. - input_features = tf.data.Dataset.from_tensors( - [features]).make_one_shot_iterator().get_next() + input_features = tf.compat.v1.data.make_one_shot_iterator(tf.data.Dataset.from_tensors( + [features])).get_next() if labels is not None: - input_labels = tf.data.Dataset.from_tensors( - [labels]).make_one_shot_iterator().get_next() + input_labels = tf.compat.v1.data.make_one_shot_iterator(tf.data.Dataset.from_tensors( + [labels])).get_next() else: input_labels = None return {"x": input_features}, input_labels @@ -220,10 +222,10 @@ def _input_fn(params=None): def head(): return tf.contrib.estimator.regression_head( - loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE) + loss_reduction=tf.compat.v1.losses.Reduction.SUM_OVER_BATCH_SIZE) -class ModifierSessionRunHook(tf.train.SessionRunHook): +class ModifierSessionRunHook(tf.estimator.SessionRunHook): """Modifies the graph by adding a variable.""" def __init__(self, var_name="hook_created_variable"): @@ -240,7 +242,7 @@ def begin(self): if self._begun: raise ValueError("begin called twice without end.") self._begun = True - _ = tf.get_variable(name=self._var_name, initializer="") + _ = tf.compat.v1.get_variable(name=self._var_name, initializer="") def end(self, session): """Adds a variable to the graph. @@ -264,7 +266,8 @@ class AdanetTestCase(parameterized.TestCase, tf.test.TestCase): def setUp(self): super(AdanetTestCase, self).setUp() # Setup and cleanup test directory. - self.test_subdirectory = os.path.join(tf.flags.FLAGS.test_tmpdir, self.id()) + flags.FLAGS(sys.argv) + self.test_subdirectory = os.path.join(flags.FLAGS.test_tmpdir, self.id()) shutil.rmtree(self.test_subdirectory, ignore_errors=True) os.makedirs(self.test_subdirectory) diff --git a/adanet/core/tpu_estimator.py b/adanet/core/tpu_estimator.py index b75b4ba1..7c63f01e 100644 --- a/adanet/core/tpu_estimator.py +++ b/adanet/core/tpu_estimator.py @@ -25,12 +25,16 @@ from adanet.core.estimator import Estimator from distutils.version import LooseVersion import tensorflow as tf -from tensorflow.contrib.tpu.python.tpu import tpu_function +from tensorflow.python.tpu import tpu_function from tensorflow.python.framework import ops # pylint: disable=g-direct-tensorflow-import +try: + _TPU_ESTIMATOR_CLASS = tf.contrib.estimator.tpu.TPUEstimator +except AttributeError: + _TPU_ESTIMATOR_CLASS = object # TODO: Move hooks to their own module. -class _StepCounterHook(tf.train.SessionRunHook): +class _StepCounterHook(tf.estimator.SessionRunHook): """Hook that counts steps per second. TODO: Remove once Estimator uses summaries v2 by default. @@ -106,7 +110,7 @@ def end(self, session): self._summary_writer.flush() -class TPUEstimator(Estimator, tf.contrib.tpu.TPUEstimator): +class TPUEstimator(Estimator, _TPU_ESTIMATOR_CLASS): """An :class:`adanet.Estimator` capable of training and evaluating on TPU. Note: Unless :code:`use_tpu=False`, training will run on TPU. However, @@ -158,6 +162,9 @@ def __init__(self, debug=False, **kwargs): + if LooseVersion(tf.VERSION) >= LooseVersion("2.0.0"): + raise ValueError("TPUEstimator is not yet supported with TensorFlow 2.0.") + self._use_tpu = use_tpu if not self._use_tpu: tf.logging.warning(