Skip to content

Latest commit

 

History

History
638 lines (501 loc) · 21.8 KB

metrics.md

File metadata and controls

638 lines (501 loc) · 21.8 KB

Tensorflow Model Analysis Metrics

Overview

TFMA supports the following metrics:

  • Standard keras metrics (tf.keras.metrics.*)
    • Note that you do not need a keras model to use keras metrics. Metrics are computed outside of the graph in beam using the metrics classes directly.
  • Standard TFMA metrics (tfma.metrics.*)
  • Custom keras metrics (metrics derived from tf.keras.metrics.Metric)
  • Custom TFMA metrics (metrics derived from tfma.metrics.Metric using custom beam combiners or metrics derived from other metrics).

TFMA also provides built-in support for coverting binary classification metrics for use with multi-class/multi-label problems:

  • Binarization based on class ID, top K, etc.
  • Aggregated metrics based on micro averaging, macro averaging, etc.

TFMA also provides built-in support for query/ranking based metrics where the examples are grouped by a query key automatically in the pipeline.

Combined there are over 50+ standard metrics and plots available for a variety of problems including regression, binary classification, multi-class/multi-label classification, ranking, etc.

Configuration

There are two ways to configure metrics in TFMA: (1) using the MetricsSpec proto or (2) by creating instances of tf.keras.metrics.* and/or tfma.metrics.* classes in python and using tfma.metrics.specs_from_metrics to convert them to MetricsSpecs.

The following sections describe example configurations for different types of machine learning problems.

Regression Metrics

The following is an example configuration setup for a regression problem. Consult the tf.keras.metrics.* and tfma.metrics.* modules for possible additional metrics supported.

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    metrics { class_name: "ExampleCount" }
    metrics { class_name: "WeightedExampleCount" }
    metrics { class_name: "MeanSquaredError" }
    metrics { class_name: "Accuracy" }
    metrics { class_name: "MeanLabel" }
    metrics { class_name: "MeanPrediction" }
    metrics { class_name: "Calibration" }
    metrics {
      class_name: "CalibrationPlot"
      config: '"min_value": 0, "max_value": 10'
    }
  }
""", tfma.EvalConfig()).metrics_specs

This same setup can be created using the following python code:

metrics = [
    tfma.metrics.ExampleCount(name='example_count'),
    tfma.metrics.WeightedExampleCount(name='weighted_example_count'),
    tf.keras.metrics.MeanSquaredError(name='mse'),
    tf.keras.metrics.Accuracy(name='accuracy'),
    tfma.metrics.MeanLabel(name='mean_label'),
    tfma.metrics.MeanPrediction(name='mean_prediction'),
    tfma.metrics.Calibration(name='calibration'),
    tfma.metrics.CalibrationPlot(
        name='calibration', min_value=0, max_value=10)
]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)

Note that this setup is also avaliable by calling tfma.metrics.default_regression_specs.

Binary Classification Metrics

The following is an example configuration setup for a binary classification problem. Consult the tf.keras.metrics.* and tfma.metrics.* modules for possible additional metrics supported.

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    metrics { class_name: "ExampleCount" }
    metrics { class_name: "WeightedExampleCount" }
    metrics { class_name: "BinaryCrossentropy" }
    metrics { class_name: "BinaryAccuracy" }
    metrics { class_name: "AUC" }
    metrics { class_name: "AUCPrecisionRecall" }
    metrics { class_name: "MeanLabel" }
    metrics { class_name: "MeanPrediction" }
    metrics { class_name: "Calibration" }
    metrics { class_name: "ConfusionMatrixPlot" }
    metrics { class_name: "CalibrationPlot" }
  }
""", tfma.EvalConfig()).metrics_specs

This same setup can be created using the following python code:

metrics = [
    tfma.metrics.ExampleCount(name='example_count'),
    tfma.metrics.WeightedExampleCount(name='weighted_example_count'),
    tf.keras.metrics.BinaryCrossentropy(name='binary_crossentropy'),
    tf.keras.metrics.BinaryAccuracy(name='accuracy'),
    tf.keras.metrics.AUC(name='auc', num_thresholds=10000),
    tf.keras.metrics.AUC(
        name='auc_precision_recall', curve='PR', num_thresholds=10000),
    tf.keras.metrics.Precision(name='precision'),
    tf.keras.metrics.Recall(name='recall'),
    tfma.metrics.MeanLabel(name='mean_label'),
    tfma.metrics.MeanPrediction(name='mean_prediction'),
    tfma.metrics.Calibration(name='calibration'),
    tfma.metrics.ConfusionMatrixPlot(name='confusion_matrix_plot'),
    tfma.metrics.CalibrationPlot(name='calibration_plot')
]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)

Note that this setup is also avaliable by calling tfma.metrics.default_binary_classification_specs.

Multi-class/Multi-label Classification Metrics

The following is an example configuration setup for a multi-class classification problem. Consult the tf.keras.metrics.* and tfma.metrics.* modules for possible additional metrics supported.

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    metrics { class_name: "ExampleCount" }
    metrics { class_name: "WeightedExampleCount" }
    metrics { class_name: "SparseCategoricalCrossentropy" }
    metrics { class_name: "SparseCategoricalAccuracy" }
    metrics { class_name: "Precision" config: '"top_k": 1' }
    metrics { class_name: "Precision" config: '"top_k": 3' }
    metrics { class_name: "Recall" config: '"top_k": 1' }
    metrics { class_name: "Recall" config: '"top_k": 3' }

  }
""", tfma.EvalConfig()).metrics_specs

This same setup can be created using the following python code:

metrics = [
    tfma.metrics.ExampleCount(name='example_count'),
    tfma.metrics.WeightedExampleCount(name='weighted_example_count'),
    tf.keras.metrics.SparseCategoricalCrossentropy(
        name='sparse_categorical_crossentropy'),
    tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
    tf.keras.metrics.Precision(name='precision', top_k=1),
    tf.keras.metrics.Precision(name='precision', top_k=3),
    tf.keras.metrics.Recall(name='recall', top_k=1),
    tf.keras.metrics.Recall(name='recall', top_k=3),
]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)

Note that this setup is also avaliable by calling tfma.metrics.default_multi_class_classification_specs.

Multi-class/Multi-label Binarized Metrics

Multi-class/multi-label metrics can be binarized to produce metrics per class, per top_k, etc using the tfma.BinarizationOptions. For example:

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    binarize: { class_ids: { values: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] } }
    // Metrics to binarize
    metrics { class_name: "AUC" }
    ...
  }
""", tfma.EvalConfig()).metrics_specs

This same setup can be created using the following python code:

metrics = [
    // Metrics to binarize
    tf.keras.metrics.AUC(name='auc', num_thresholds=10000),
    ...
]
metrics_specs = tfma.metrics.specs_from_metrics(
    metrics, binarize=tfma.BinarizationOptions(
        class_ids={'values': [0,1,2,3,4,5,6,7,8,9]}))

Multi-class/Multi-label Aggregate Metrics

Multi-class/multi-label metrics can be aggregated to produce a single aggregated value for a binary classification metric.

Micro Average

Micro averaging can be performed either independently or as part of a binarization of metrics by using the micro_average option within tfma.AggregationOptions. For example:

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    aggregate: { micro_average: true }
    // Metrics to aggregate
    metrics { class_name: "AUC" }
    ...
  }
""", tfma.EvalConfig()).metrics_specs

This same setup can be created using the following python code:

metrics = [
    // Metrics to aggregate
    tf.keras.metrics.AUC(name='auc', num_thresholds=10000),
    ...
]
metrics_specs = tfma.metrics.specs_from_metrics(
    metrics, aggregate=tfma.AggregationOptions(micro_average=True))

Macro / Weighted Macro Average

Macro averaging must be performed as part of a binarization of metrics in conjunctiopn with the maro_average or weighted_macro_average options within tfma.AggregationOptions. For example:

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    binarize: { class_ids: { values: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] } }
    aggregate: { macro_average: true }
    // Metrics to both binarize and aggregate
    metrics { class_name: "AUC" }
    ...
  }
""", tfma.EvalConfig()).metrics_specs

This same setup can be created using the following python code:

metrics = [
    // Metrics to both binarize and aggregate
    tf.keras.metrics.AUC(name='auc', num_thresholds=10000),
    ...
]
metrics_specs = tfma.metrics.specs_from_metrics(
    metrics,
    binarize=tfma.BinarizationOptions(
        class_ids={'values': [0,1,2,3,4,5,6,7,8,9]}),
    aggregate=tfma.AggregationOptions(macro_average=True))

Query / Ranking Based Metrics

Query/ranking based metrics are enabled by specifying the query_key option in the metrics specs. For example:

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    query_key: "doc_id"
    binarize { top_k_list: { values: [1, 2] } }
    metrics { class_name: "NDCG" config: '"gain_key": "gain"' }
  }
  metrics_specs {
    query_key: "doc_id"
    metrics { class_name: "MinLabelPosition" }
  }
""", tfma.EvalConfig()).metrics_specs

This same setup can be created using the following python code:

metrics = [
    tfma.metrics.NDCG(name='ndcg', gain_key='gain'),
]
metrics_specs = tfma.metrics.specs_from_metrics(
    metrics, query_key='doc_id', binarize=tfma.BinarizationOptions(
        top_k_list={'values': [1,2]}))

metrics = [
    tfma.metrics.MinLabelPosition(name='min_label_position')
]
metrics_specs.extend(
    tfma.metrics.specs_from_metrics(metrics, query_key='doc_id'))

Multi-model Evaluation Metrics

TFMA supports evaluating multiple models at the same time. When multi-model evaluation is performed, the names of the models associated with a set of metrics must be specified in the model_names section of the MetricsSpec. For example:

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    model_names: ["my-model1", "my-model2"]
    ...
  }
""", tfma.EvalConfig()).metrics_specs

The specs_from_metrics API also supports passing model names:

metrics = [
    ...
]
metrics_specs = tfma.metrics.specs_from_metrics(
    metrics, model_names=['my-model1', 'my-model2'])

Multi-output Model Metrics

TFMA supports evaluating metrics on models that have different outputs. Multi-output models store their output predictions in the form of a dict keyed by output name. When multi-output model's are used, the names of the outputs associated with a set of metrics must be specified in the output_names section of the MetricsSpec. For example:

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    output_names: ["my-output"]
    ...
  }
""", tfma.EvalConfig()).metrics_specs

The specs_from_metrics API also supports passing output names:

metrics = [
    ...
]
metrics_specs = tfma.metrics.specs_from_metrics(
    metrics, output_names=['my-output'])

Customizing Metric Settings

TFMA allows customizing of the settings that are used with different metrics. For example you might want to change the name, set thresholds, etc. This is done by adding a config section to the metric config. The config is specified using the JSON string version of the parameters that would be passed to the metrics __init__ method (for ease of use the leading and trailing '{' and '}' brackets may be omitted). For example:

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    metrics {
      class_name: "ConfusionMatrixAtThresholds"
      config: '"thresholds": [0.3, 0.5, 0.8]'
    }
  }
""", tfma.MetricsSpec()).metrics_specs

This customization is of course also supported directly:

metrics = [
   tfma.metrics.ConfusionMatrixAtThresholds(thresholds=[0.3, 0.5, 0.8]),
]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)

NOTE: It is advisable to set the default number of thresholds used with AUC, etc to 10000 because this is the default value used by the underlying histogram calcuation which is shared between multiple metric implementations.

Outputs

The output of a metric evaluation is a series of metric keys/values and/or plot keys/values based on the configuration used.

Metric Keys

MetricKeys are defined using a structured key type. This key uniquely identifies each of the following aspects of a metric:

  • Metric name (auc, mean_label, etc)
  • Model name (only used if multi-model evaluation)
  • Output name (only used if multi-output models are evaluated)
  • Sub key (e.g. class ID if multi-class model is binarized)

Metric Value

MetricValues are defined using a proto that encapulates the different value types supported by the different metrics (e.g. double, ConfusionMatrixAtThresholds, etc).

Plot Keys

PlotKeys are similar to metric keys except that for historical reasons all the plots values are stored in a single proto so the plot key does not have a name.

Plot Values

All the supported plots are stored in a single proto called PlotData.

EvalResult

The return from an evaluation run is an EvalResult. This record contains slicing_metrics that encode the metric key as a multi-level dict where the levels correspond to output name, class ID, metric name, and metric value respectively. This is intended to be used for UI display in a Jupiter notebook. If access to the underlying data is needed the metrics result file should be used instead (see metrics_for_slice.proto).

Customization

In addition to custom metrics that are added as part of a saved keras (or legacy EvalSavedModel). There are two ways to customize metrics in TFMA post saving: (1) by defining a custom keras metric class and (2) by defining a custom TFMA metrics class backed by a beam combiner.

In both cases, the metrics are configured by specifying the name of the metric class and associated module. For example:

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    metrics { class_name: "MyMetric" module: "my.module"}
  }
""", tfma.EvalConfig()).metrics_specs

NOTE: When customizing metrics you must ensure that the module is available to beam.

Custom Keras Metrics

To create a custom keras metric, users need to extend tf.keras.metrics.Metric with their implementation and then make sure the metric's module is available at evaluation time.

Note that for metrics added post model save, TFMA only supports metrics that take label (i.e. y_true), prediction (y_pred), and example weight (sample_weight) as parameters to the update_state method.

Keras Metric Example

The following is an example of a custom keras metric:

class MyMetric(tf.keras.metrics.Mean):

  def __init__(self, name='my_metric', dtype=None):
    super(MyMetric, self).__init__(name=name, dtype=dtype)

  def update_state(self, y_true, y_pred, sample_weight=None):
    return super(MyMetric, self).update_state(
        y_pred, sample_weight=sample_weight)

Custom TFMA Metrics

To create a custom TFMA metric, users need to extend tfma.metrics.Metric with their implementation and then make sure the metric's module is available at evaluation time.

Metric

A tfma.metrics.Metric implementation is made up of a set of kwargs that define the metrics configuration along with a function for creating the computations (possibly multiple) needed to calcuate the metrics value. There are two main computation types that can be used: tfma.metrics.MetricComputation and tfma.metrics.DerivedMetricComputation that are described in the sections below. The function that creates these computations will be passed the following parameters as input:

  • eval_config: tfam.EvalConfig
    • The eval config passed to the evaluator (useful for looking up model spec settings such as prediction key to use, etc).
  • model_names: List[Text]
    • List of model names to compute metrics for (None if single-model)
  • output_names: List[Text].
    • List of output names to compute metrics for (None if single-model)
  • sub_keys: List[tfma.SubKey].
    • List of sub keys (class ID, top K, etc) to compute metrics for (or None)
  • class_weights: Dict[int, float].
    • Class weights to use if computing an aggregation metric.
  • query_key: Text
    • Query key used if computing a query/ranking based metric.

If a metric is not associated with one or more of these settings then it may leave those parameters out of its signature definition.

If a metric is computed the same way for each model, output, and sub key, then the utility tfma.metrics.merge_per_key_computations can be used to perform the same computations for each of these inputs separately.

MetricComputation

A MetricComputation is made up of a combination of a preprocessor and a combiner. The preprocessor is a beam.DoFn that takes extracts as its input and outputs the initial state that will be used by the combiner (see architecture for more info on what are extracts). If a preprocessor is not defined, then the combiner will be passed StandardMetricInputs (standard metric inputs contains labels, predictions, and example_weights). The combiner is a beam.CombineFn that takes a tuple of (slice key, preprocessor output) as its input and outputs a tuple of (slice_key, metric results dict) as its result.

Note that slicing happens between the preprocessor and combiner.

Note that if a metric computation wants to make use of both the standard metric inputs, but augment it with a few of the features from the features extracts, then the special FeaturePreprocessor can be used which will merge the requested features from multiple combiners into a single shared StandardMetricsInputs value that is passed to all the combiners (the combiners are responsible for reading the features they are interested in and ignoring the rest).

Example

The following is a very simple example of TFMA metric definition for computing the ExampleCount:

class ExampleCount(tfma.metrics.Metric):

  def __init__(self, name: Text = 'example_count'):
    super(ExampleCount, self).__init__(_example_count, name=name)


def _example_count(
    name: Text = 'example_count') -> tfma.metrics.MetricComputations:
  key = tfma.metrics.MetricKey(name=name)
  return [
      tfma.metrics.MetricComputation(
          keys=[key],
          preprocessor=_ExampleCountPreprocessor(),
          combiner=_ExampleCountCombiner(key))
  ]


class _ExampleCountPreprocessor(beam.DoFn):

  def process(self, extracts: tfma.Extracts) -> Iterable[int]:
    yield 1


class _ExampleCountCombiner(beam.CombineFn):

  def __init__(self, metric_key: tfma.metrics.MetricKey):
    self._metric_key = metric_key

  def create_accumulator(self) -> int:
    return 0

  def add_input(self, accumulator: int, state: int) -> int:
    return accumulator + state

  def merge_accumulators(self, accumulators: List[int]) -> int:
    result = 0
    for accumulator in accumulators:
      result += accumulator
    return result

  def extract_output(self,
                     accumulator: int) -> Dict[tfma.metrics.MetricKey, int]:
    return {self._metric_key: accumulator}

DerivedMetricComputation

A DerivedMetricComputation is made up of a result function that is used to calculate metric values based on the output of other metric computations. The result function takes a dict of computed values as its input and outputs a dict of additional metric results.

Note that it is acceptable (recommended) to include the computations that a derived computation depends on in the list of computations created by a metric. This avoid having to pre-create and pass computations that are shared between multiple metrics. The evaluator will automatically de-dup computations that have the same definition so ony one computation is actually run.

Example

The TJUR metrics provides a good example of derived metrics.