diff --git a/RELEASE.md b/RELEASE.md index 4573aa8f..1ba9c32f 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -16,6 +16,7 @@ limitations under the License. # Current version (0.8.0.dev) * Under development. * TODO: Add official Keras Model support, including Keras layers, Sequential, and Model subclasses for defining subnetworks. + * Support `tf.keras.metrics.Metrics` during evaluation. * Stop individual subnetwork training on `OutOfRangeError` raised during bagging. * Gracefully handle NaN losses from ensembles during training. When an ensemble or subnetwork has a NaN loss during training, its training is marked as terminated. As long as one ensemble (and therefore underlying subnetworks) does not have a NaN loss, training will continue. * Train forever if `max_steps` and `steps` are both `None`. diff --git a/adanet/core/BUILD b/adanet/core/BUILD index 9d74d7c8..efb91a0b 100644 --- a/adanet/core/BUILD +++ b/adanet/core/BUILD @@ -316,6 +316,7 @@ py_test( ":report_materializer", ":testing_utils", "//adanet/subnetwork", + "//adanet/tf_compat", "@absl_py//absl/testing:parameterized", ], ) diff --git a/adanet/core/estimator.py b/adanet/core/estimator.py index 362a023d..b37b25b2 100644 --- a/adanet/core/estimator.py +++ b/adanet/core/estimator.py @@ -121,7 +121,10 @@ def begin(self): metric_fn, tensors = self._eval_metrics tensors = [tf_compat.v1.placeholder(t.dtype, t.shape) for t in tensors] eval_metric_ops = metric_fn(*tensors) - self._eval_metric_tensors = {k: v[0] for k, v in eval_metric_ops.items()} + self._eval_metric_tensors = {} + for key in sorted(eval_metric_ops): + value = tf_compat.metric_op(eval_metric_ops[key]) + self._eval_metric_tensors[key] = value[0] def _dict_to_str(self, dictionary): """Get a `str` representation of a `dict`. diff --git a/adanet/core/estimator_test.py b/adanet/core/estimator_test.py index 08c5ceaf..fef6813b 100644 --- a/adanet/core/estimator_test.py +++ b/adanet/core/estimator_test.py @@ -720,6 +720,11 @@ def test_lifecycle(self, run_config = tf.estimator.RunConfig(tf_random_seed=42) + def _metric_fn(predictions): + mean = tf.keras.metrics.Mean() + mean.update_state(predictions["predictions"]) + return {"keras_mean": mean} + default_ensembler_kwargs = { "mixture_weight_type": mixture_weight_type, "mixture_weight_initializer": tf_compat.v1.zeros_initializer(), @@ -737,6 +742,7 @@ def test_lifecycle(self, ensemble_strategies=ensemble_strategies, report_materializer=report_materializer, replicate_ensemble_in_training=replicate_ensemble_in_training, + metric_fn=_metric_fn, model_dir=self.test_subdirectory, config=run_config, **default_ensembler_kwargs) @@ -1366,6 +1372,14 @@ def create_estimator_spec(self, train_op=train_op_fn(1)) +def _mean_keras_metric(value): + """Returns the mean of given value as a Keras metric.""" + + mean = tf.keras.metrics.Mean() + mean.update_state(value) + return mean + + class EstimatorSummaryWriterTest(tu.AdanetTestCase): """Test that Tensorboard summaries get written correctly.""" @@ -1497,35 +1511,43 @@ def test_disable_summaries(self): "mixture_weight_norms/adanet/" "adanet_weighted_ensemble/subnetwork_0", ensemble_subdir) + # pylint: disable=g-long-lambda @parameterized.named_parameters( { "testcase_name": "none_metrics", "head": _EvalMetricsHead(None), "want_summaries": [], "want_loss": -1.791, - }, - { + }, { "testcase_name": "metrics_fn", "head": _EvalMetricsHead(None), - # pylint: disable=g-long-lambda "metric_fn": lambda predictions: { "avg": tf_compat.v1.metrics.mean(predictions) }, - # pylint: enable=g-long-lambda "want_summaries": ["avg"], "want_loss": -1.791, - }, - { + }, { + "testcase_name": + "keras_metrics_fn", + "head": + _EvalMetricsHead(None), + "metric_fn": + lambda predictions: { + "avg": _mean_keras_metric(predictions) + }, + "want_summaries": ["avg"], + "want_loss": + -1.791, + }, { "testcase_name": "empty_metrics", "head": _EvalMetricsHead({}), "want_summaries": [], "want_loss": -1.791, - }, - { + }, { "testcase_name": "evaluation_name", "head": @@ -1541,8 +1563,7 @@ def test_disable_summaries(self): "subnetwork/t0_dnn/eval_continuous", "ensemble_subdir": "ensemble/t0_dnn_grow_complexity_regularized/eval_continuous", - }, - { + }, { "testcase_name": "regression_head", "head": @@ -1551,8 +1572,7 @@ def test_disable_summaries(self): "want_summaries": ["average_loss"], "want_loss": .256, - }, - { + }, { "testcase_name": "binary_classification_head", "head": @@ -1563,8 +1583,7 @@ def test_disable_summaries(self): "want_summaries": ["average_loss", "accuracy", "recall"], "want_loss": 0.122, - }, - { + }, { "testcase_name": "all_metrics", "head": @@ -1588,6 +1607,7 @@ def test_disable_summaries(self): "want_loss": -1.791, }) + # pylint: enable=g-long-lambda def test_eval_metrics( self, head, diff --git a/adanet/core/eval_metrics.py b/adanet/core/eval_metrics.py index 64149d21..85442e28 100644 --- a/adanet/core/eval_metrics.py +++ b/adanet/core/eval_metrics.py @@ -417,6 +417,6 @@ def _group_metric_ops(self, metric_fns, metric_fn_args): for metric_fn, args in zip(metric_fns, metric_fn_args): eval_metric_ops = call_eval_metrics((metric_fn, args)) for metric_name in sorted(eval_metric_ops): - metric_op = eval_metric_ops[metric_name] + metric_op = tf_compat.metric_op(eval_metric_ops[metric_name]) grouped_metrics[metric_name].append(metric_op) return grouped_metrics diff --git a/adanet/core/report_materializer.py b/adanet/core/report_materializer.py index f38a153a..8f50d0c5 100644 --- a/adanet/core/report_materializer.py +++ b/adanet/core/report_materializer.py @@ -23,6 +23,7 @@ from absl import logging from adanet import subnetwork +from adanet import tf_compat import tensorflow as tf @@ -96,13 +97,13 @@ def materialize_subnetwork_reports(self, sess, iteration_number, metric_update_ops = [] for subnetwork_report in subnetwork_reports.values(): for metric_tuple in subnetwork_report.metrics.values(): - metric_update_ops.append(metric_tuple[1]) + metric_update_ops.append(tf_compat.metric_op(metric_tuple)[1]) # Extract the Tensors to be materialized. tensors_to_materialize = {} for name, subnetwork_report in subnetwork_reports.items(): metrics = { - metric_key: metric_tuple[0] + metric_key: tf_compat.metric_op(metric_tuple)[0] for metric_key, metric_tuple in subnetwork_report.metrics.items() } tensors_to_materialize[name] = { diff --git a/adanet/subnetwork/BUILD b/adanet/subnetwork/BUILD index e9f9ab9b..f5f5d465 100644 --- a/adanet/subnetwork/BUILD +++ b/adanet/subnetwork/BUILD @@ -36,6 +36,7 @@ py_library( name = "report", srcs = ["report.py"], deps = [ + "//adanet/tf_compat", "@six_archive//:six", ], ) diff --git a/adanet/subnetwork/report.py b/adanet/subnetwork/report.py index 2c6a835d..84ee299b 100644 --- a/adanet/subnetwork/report.py +++ b/adanet/subnetwork/report.py @@ -21,6 +21,7 @@ import collections +from adanet import tf_compat import six import tensorflow as tf @@ -90,6 +91,7 @@ def _is_accepted_dtype(tensor): # Validate metrics metrics_copy = {} for key, value in metrics.items(): + value = tf_compat.metric_op(value) if not isinstance(value, tuple): raise ValueError( "metric '{}' has invalid type {}. Must be a tuple.".format( diff --git a/adanet/tf_compat/__init__.py b/adanet/tf_compat/__init__.py index a0c0c859..211fb5c6 100644 --- a/adanet/tf_compat/__init__.py +++ b/adanet/tf_compat/__init__.py @@ -143,3 +143,38 @@ def random_normal(*args, **kwargs): return tf.random.normal(*args, **kwargs) except AttributeError: return tf.random_normal(*args, **kwargs) + + +def metric_op(metric): + """Converts Keras metrics into a metric op tuple. + + NOTE: If this method is called in for loop, the runtime is O(n^2). However + the number of eval metrics at any given time should be small enough that + this does not affect performance. Any impact is only during graph construction + time, and therefore has no effect on steps/s. + + Args: + metric: Either a `tf.keras.metric.Metric` instance or a tuple of Tensor + value and update op. + + Returns: + A tuple of metric Tensor value and update op. + """ + + if not isinstance(metric, tf.keras.metrics.Metric): + return metric + vars_to_add = set() + vars_to_add.update(metric.variables) + metric = (metric.result(), metric.updates[0]) + _update_variable_collection(tf.GraphKeys.LOCAL_VARIABLES, vars_to_add) + _update_variable_collection(tf.GraphKeys.METRIC_VARIABLES, vars_to_add) + return metric + + +def _update_variable_collection(collection_name, vars_to_add): + """Add variables to collection.""" + collection = set(tf.get_collection(collection_name)) + # Skip variables that are in the collection already: O(n) runtime. + vars_to_add = vars_to_add - collection + for v in vars_to_add: + tf.add_to_collection(collection_name, v)