Skip to content

Commit

Permalink
Support tf.keras.metrics.Metrics during evaluation.
Browse files Browse the repository at this point in the history
Provides equivalent support as tf.estimator.Estimator.

Required for TensorFlow 2.0 migration #93.

PiperOrigin-RevId: 261331779
  • Loading branch information
cweill committed Aug 2, 2019
1 parent e2d829b commit fdbf4c4
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 18 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
# Current version (0.8.0.dev)
* Under development.
* TODO: Add official Keras Model support, including Keras layers, Sequential, and Model subclasses for defining subnetworks.
* Support `tf.keras.metrics.Metrics` during evaluation.
* Stop individual subnetwork training on `OutOfRangeError` raised during bagging.
* Gracefully handle NaN losses from ensembles during training. When an ensemble or subnetwork has a NaN loss during training, its training is marked as terminated. As long as one ensemble (and therefore underlying subnetworks) does not have a NaN loss, training will continue.
* Train forever if `max_steps` and `steps` are both `None`.
Expand Down
1 change: 1 addition & 0 deletions adanet/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ py_test(
":report_materializer",
":testing_utils",
"//adanet/subnetwork",
"//adanet/tf_compat",
"@absl_py//absl/testing:parameterized",
],
)
Expand Down
5 changes: 4 additions & 1 deletion adanet/core/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ def begin(self):
metric_fn, tensors = self._eval_metrics
tensors = [tf_compat.v1.placeholder(t.dtype, t.shape) for t in tensors]
eval_metric_ops = metric_fn(*tensors)
self._eval_metric_tensors = {k: v[0] for k, v in eval_metric_ops.items()}
self._eval_metric_tensors = {}
for key in sorted(eval_metric_ops):
value = tf_compat.metric_op(eval_metric_ops[key])
self._eval_metric_tensors[key] = value[0]

def _dict_to_str(self, dictionary):
"""Get a `str` representation of a `dict`.
Expand Down
48 changes: 34 additions & 14 deletions adanet/core/estimator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,11 @@ def test_lifecycle(self,

run_config = tf.estimator.RunConfig(tf_random_seed=42)

def _metric_fn(predictions):
mean = tf.keras.metrics.Mean()
mean.update_state(predictions["predictions"])
return {"keras_mean": mean}

default_ensembler_kwargs = {
"mixture_weight_type": mixture_weight_type,
"mixture_weight_initializer": tf_compat.v1.zeros_initializer(),
Expand All @@ -737,6 +742,7 @@ def test_lifecycle(self,
ensemble_strategies=ensemble_strategies,
report_materializer=report_materializer,
replicate_ensemble_in_training=replicate_ensemble_in_training,
metric_fn=_metric_fn,
model_dir=self.test_subdirectory,
config=run_config,
**default_ensembler_kwargs)
Expand Down Expand Up @@ -1366,6 +1372,14 @@ def create_estimator_spec(self,
train_op=train_op_fn(1))


def _mean_keras_metric(value):
"""Returns the mean of given value as a Keras metric."""

mean = tf.keras.metrics.Mean()
mean.update_state(value)
return mean


class EstimatorSummaryWriterTest(tu.AdanetTestCase):
"""Test that Tensorboard summaries get written correctly."""

Expand Down Expand Up @@ -1497,35 +1511,43 @@ def test_disable_summaries(self):
"mixture_weight_norms/adanet/"
"adanet_weighted_ensemble/subnetwork_0", ensemble_subdir)

# pylint: disable=g-long-lambda
@parameterized.named_parameters(
{
"testcase_name": "none_metrics",
"head": _EvalMetricsHead(None),
"want_summaries": [],
"want_loss": -1.791,
},
{
}, {
"testcase_name":
"metrics_fn",
"head":
_EvalMetricsHead(None),
# pylint: disable=g-long-lambda
"metric_fn":
lambda predictions: {
"avg": tf_compat.v1.metrics.mean(predictions)
},
# pylint: enable=g-long-lambda
"want_summaries": ["avg"],
"want_loss":
-1.791,
},
{
}, {
"testcase_name":
"keras_metrics_fn",
"head":
_EvalMetricsHead(None),
"metric_fn":
lambda predictions: {
"avg": _mean_keras_metric(predictions)
},
"want_summaries": ["avg"],
"want_loss":
-1.791,
}, {
"testcase_name": "empty_metrics",
"head": _EvalMetricsHead({}),
"want_summaries": [],
"want_loss": -1.791,
},
{
}, {
"testcase_name":
"evaluation_name",
"head":
Expand All @@ -1541,8 +1563,7 @@ def test_disable_summaries(self):
"subnetwork/t0_dnn/eval_continuous",
"ensemble_subdir":
"ensemble/t0_dnn_grow_complexity_regularized/eval_continuous",
},
{
}, {
"testcase_name":
"regression_head",
"head":
Expand All @@ -1551,8 +1572,7 @@ def test_disable_summaries(self):
"want_summaries": ["average_loss"],
"want_loss":
.256,
},
{
}, {
"testcase_name":
"binary_classification_head",
"head":
Expand All @@ -1563,8 +1583,7 @@ def test_disable_summaries(self):
"want_summaries": ["average_loss", "accuracy", "recall"],
"want_loss":
0.122,
},
{
}, {
"testcase_name":
"all_metrics",
"head":
Expand All @@ -1588,6 +1607,7 @@ def test_disable_summaries(self):
"want_loss":
-1.791,
})
# pylint: enable=g-long-lambda
def test_eval_metrics(
self,
head,
Expand Down
2 changes: 1 addition & 1 deletion adanet/core/eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,6 @@ def _group_metric_ops(self, metric_fns, metric_fn_args):
for metric_fn, args in zip(metric_fns, metric_fn_args):
eval_metric_ops = call_eval_metrics((metric_fn, args))
for metric_name in sorted(eval_metric_ops):
metric_op = eval_metric_ops[metric_name]
metric_op = tf_compat.metric_op(eval_metric_ops[metric_name])
grouped_metrics[metric_name].append(metric_op)
return grouped_metrics
5 changes: 3 additions & 2 deletions adanet/core/report_materializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from absl import logging
from adanet import subnetwork
from adanet import tf_compat
import tensorflow as tf


Expand Down Expand Up @@ -96,13 +97,13 @@ def materialize_subnetwork_reports(self, sess, iteration_number,
metric_update_ops = []
for subnetwork_report in subnetwork_reports.values():
for metric_tuple in subnetwork_report.metrics.values():
metric_update_ops.append(metric_tuple[1])
metric_update_ops.append(tf_compat.metric_op(metric_tuple)[1])

# Extract the Tensors to be materialized.
tensors_to_materialize = {}
for name, subnetwork_report in subnetwork_reports.items():
metrics = {
metric_key: metric_tuple[0]
metric_key: tf_compat.metric_op(metric_tuple)[0]
for metric_key, metric_tuple in subnetwork_report.metrics.items()
}
tensors_to_materialize[name] = {
Expand Down
1 change: 1 addition & 0 deletions adanet/subnetwork/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ py_library(
name = "report",
srcs = ["report.py"],
deps = [
"//adanet/tf_compat",
"@six_archive//:six",
],
)
Expand Down
2 changes: 2 additions & 0 deletions adanet/subnetwork/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import collections

from adanet import tf_compat
import six
import tensorflow as tf

Expand Down Expand Up @@ -90,6 +91,7 @@ def _is_accepted_dtype(tensor):
# Validate metrics
metrics_copy = {}
for key, value in metrics.items():
value = tf_compat.metric_op(value)
if not isinstance(value, tuple):
raise ValueError(
"metric '{}' has invalid type {}. Must be a tuple.".format(
Expand Down
35 changes: 35 additions & 0 deletions adanet/tf_compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,38 @@ def random_normal(*args, **kwargs):
return tf.random.normal(*args, **kwargs)
except AttributeError:
return tf.random_normal(*args, **kwargs)


def metric_op(metric):
"""Converts Keras metrics into a metric op tuple.
NOTE: If this method is called in for loop, the runtime is O(n^2). However
the number of eval metrics at any given time should be small enough that
this does not affect performance. Any impact is only during graph construction
time, and therefore has no effect on steps/s.
Args:
metric: Either a `tf.keras.metric.Metric` instance or a tuple of Tensor
value and update op.
Returns:
A tuple of metric Tensor value and update op.
"""

if not isinstance(metric, tf.keras.metrics.Metric):
return metric
vars_to_add = set()
vars_to_add.update(metric.variables)
metric = (metric.result(), metric.updates[0])
_update_variable_collection(tf.GraphKeys.LOCAL_VARIABLES, vars_to_add)
_update_variable_collection(tf.GraphKeys.METRIC_VARIABLES, vars_to_add)
return metric


def _update_variable_collection(collection_name, vars_to_add):
"""Add variables to collection."""
collection = set(tf.get_collection(collection_name))
# Skip variables that are in the collection already: O(n) runtime.
vars_to_add = vars_to_add - collection
for v in vars_to_add:
tf.add_to_collection(collection_name, v)

0 comments on commit fdbf4c4

Please sign in to comment.