-
Notifications
You must be signed in to change notification settings - Fork 277
/
metrics_for_slice.proto
490 lines (425 loc) · 17.3 KB
/
metrics_for_slice.proto
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package tensorflow_model_analysis;
import "google/protobuf/wrappers.proto";
// Sync with PerformanceStatistics because of b/110954446.
// LINT.IfChange
//
// A sub key identifies specialized sub-types of metrics and plots.
message SubKey {
// Used with multi-class metrics to identify a specific class ID.
google.protobuf.Int32Value class_id = 1;
// Used with multi-class metrics to identify the kth predicted value.
google.protobuf.Int32Value k = 2;
// Used with multi-class and ranking metrics to identify top-k predicted
// values.
google.protobuf.Int32Value top_k = 3;
}
// LINT.ThenChange(
// ../metrics/metric_types.py,
// ../../../../intelligence/lantern/proto/stats/performance_statistics.proto)
// LINT.IfChange
//
// Aggregation types used with AggregationOptions.
message AggregationType {
oneof type {
// See AggregationOptions.micro_average.
bool micro_average = 1;
// See AggregationOptions.macro_average.
bool macro_average = 2;
// See AggregationOptions.weighted_macro_average.
bool weighted_macro_average = 3;
}
}
// LINT.ThenChange(
// ../metrics/metric_types.py,
// ../../../../intelligence/lantern/proto/stats/performance_statistics.proto)
// Sync with PerformanceStatistics because of b/110954446.
// LINT.IfChange
//
// A metric key uniquely identifies a metric.
message MetricKey {
// Name of the metric ('auc', etc).
string name = 1;
// Optional model name associated with metric (if multi-model evaluation).
string model_name = 4;
// Optional output name associated with metric (for multi-output models).
string output_name = 2;
// Optional sub key associated with metric.
SubKey sub_key = 3;
// Optional type of aggregation (if AggregationOptions used).
AggregationType aggregation_type = 6;
// If true, the metric is weighted by examples. If false, then the metric is
// not weighted by examples. If unset then it is unknown as to whether the
// metric was weighted by examples or not (i.e. the metrics were defined
// inside of a model). See MetricsSpecs.example_weighted for more information.
google.protobuf.BoolValue example_weighted = 7;
// If true, this metric is a diff metric based on a comparison with the
// baseline.
bool is_diff = 5;
}
// LINT.ThenChange(
// ../metrics/metric_types.py,
// ../../../../intelligence/lantern/proto/stats/performance_statistics.proto)
// Sync the MetricValue types with PerformanceStatistics because of b/110954446.
// LINT.IfChange
// The value will be converted into an error message if we do not know its type.
message UnknownType {
string error = 1;
bytes value = 2;
}
// Represents a real value which could be a pointwise estimate, optionally with
// approximate bounds of some sort. For instance, for AUC, these bounds could be
// the upper and lower Riemann sum of the integral.
message BoundedValue {
// The lower bound of the range.
google.protobuf.DoubleValue lower_bound = 1;
// The upper bound of the range.
google.protobuf.DoubleValue upper_bound = 2;
// Represents an exact value if the lower_bound and upper_bound are unset,
// else it's an approximate value. For the approximate value, it should be
// within the range [lower_bound, uppper_bound].
google.protobuf.DoubleValue value = 3;
enum Methodology {
UNKNOWN = 0;
// Used to calculate AUC, the upper and lower Riemann sum for an integral.
RIEMANN_SUM = 1;
// Used to calculate confidence intervals using Poisson bootstrapping.
// For more details, please see:
// http://www.unofficialgoogledatascience.com/2015/08/an-introduction-to-poisson-bootstrap26.html
POISSON_BOOTSTRAP = 2;
}
// Optionally describe the methodology that was used to calculate the bounds.
Methodology methodology = 4;
}
// Represents a t-distribution, which includes sample mean, sample standard
// deviation and degrees of freedom of samples. It's calculated when evaluation
// runs on multiple samples, which by default generated by the Poisson
// bootstrapping method:
// http://www.unofficialgoogledatascience.com/2015/08/an-introduction-to-poisson-bootstrap26.html
message TDistributionValue {
// Sample Mean.
google.protobuf.DoubleValue sample_mean = 1;
// Sample Standard Deviation.
google.protobuf.DoubleValue sample_standard_deviation = 2;
// Number of degrees of freedom.
google.protobuf.Int64Value sample_degrees_of_freedom = 3;
// Represents the value of data if calculated without bootstrapping.
// This field is deprecated as going forward we will remove the
// TDistributionValue from the oneof in the MetricValue and the unsampled
// value will be populated in MetricValue.double_value
google.protobuf.DoubleValue unsampled_value = 4 [deprecated = true];
}
// Value at cutoffs, e.g. for precision@K, recall@K
message ValueAtCutoffs {
message ValueCutoffPair {
int32 cutoff = 1;
double value = 2;
BoundedValue bounded_value = 3 [deprecated = true];
TDistributionValue t_distribution_value = 4 [deprecated = true];
}
repeated ValueCutoffPair values = 1;
}
// Confusion matrix at thresholds.
// Next tag: 24
message ConfusionMatrixAtThresholds {
message ConfusionMatrixAtThreshold {
double threshold = 1;
double false_negatives = 2;
double true_negatives = 3;
double false_positives = 4;
double true_positives = 5;
double precision = 6;
double recall = 7;
double false_positive_rate = 20;
double f1 = 21;
double accuracy = 22;
double false_omission_rate = 23;
BoundedValue bounded_false_negatives = 8 [deprecated = true];
BoundedValue bounded_true_negatives = 9 [deprecated = true];
BoundedValue bounded_false_positives = 10 [deprecated = true];
BoundedValue bounded_true_positives = 11 [deprecated = true];
BoundedValue bounded_precision = 12 [deprecated = true];
BoundedValue bounded_recall = 13 [deprecated = true];
TDistributionValue t_distribution_false_negatives = 14 [deprecated = true];
TDistributionValue t_distribution_true_negatives = 15 [deprecated = true];
TDistributionValue t_distribution_false_positives = 16 [deprecated = true];
TDistributionValue t_distribution_true_positives = 17 [deprecated = true];
TDistributionValue t_distribution_precision = 18 [deprecated = true];
TDistributionValue t_distribution_recall = 19 [deprecated = true];
}
// Matrices has different types of value representations: bounded,
// t-distribution and double.
// 1. Bounded values will be provided if the metrices are calculated using
// bootstrapping (Note: Confidence level is set to 95%).
// 2. T distribution values will be provided if metrices are calculated using
// bootstrapping and confidence level isn't set. Hence user will config
// the confidece levels through the frontend to get the final confidence
// intervals. We will support both TDistributionValue and BoundedValue now.
// But BoundedValue will be eventually deprecated.
// 3. Double values is being deprecated.
repeated ConfusionMatrixAtThreshold matrices = 1;
}
// For metrics which return an array of values.
message ArrayValue {
enum DataType {
UNKNOWN = 0;
BYTES = 1;
INT32 = 2;
INT64 = 3;
FLOAT32 = 4;
FLOAT64 = 5;
}
DataType data_type = 1;
repeated int32 shape = 2;
// Exactly one of these fields, corresponding to the data type, should be set.
repeated bytes bytes_values = 3;
repeated int32 int32_values = 4;
repeated int64 int64_values = 5;
repeated float float32_values = 6;
repeated double float64_values = 7;
}
// It stores metrics values in different types, so that the frontend will know
// how to visualize the values based on the types.
message MetricValue {
oneof type {
google.protobuf.DoubleValue double_value = 1;
// bounded_value is deprecated for use as a confidence interval container.
// Only use to encode non-CI bounds, such as approximation bounds.
BoundedValue bounded_value = 2;
TDistributionValue t_distribution_value = 9 [deprecated = true];
ValueAtCutoffs value_at_cutoffs = 4;
ConfusionMatrixAtThresholds confusion_matrix_at_thresholds = 5;
MultiClassConfusionMatrixAtThresholds
multi_class_confusion_matrix_at_thresholds = 11;
UnknownType unknown_type = 3;
bytes bytes_value = 6;
ArrayValue array_value = 7;
// This field will contain a generic message to be used to communicate any
// extra information, such as in a scenario when no data is aggregated for a
// small data slice due to privacy concerns.
string debug_message = 10;
}
reserved 8, 14;
// Next tag = 16;
}
// LINT.ThenChange(
// ../../../../intelligence/lantern/proto/stats/performance_statistics.proto)
// Sync the SliceKey types with PerformanceStatistics because of b/110954446.
// LINT.IfChange
// A single slice key.
message SingleSliceKey {
string column = 1;
oneof kind {
bytes bytes_value = 2;
float float_value = 3;
int64 int64_value = 4;
}
}
// A slice key, which may consist of multiple single slice keys.
message SliceKey {
repeated SingleSliceKey single_slice_keys = 1;
}
// LINT.ThenChange(
// ../../../../intelligence/lantern/proto/stats/performance_statistics.proto)
// CrossSliceKey contains two slices which are compared with each other.
message CrossSliceKey {
SliceKey baseline_slice_key = 1;
SliceKey comparison_slice_key = 2;
}
message ConfidenceInterval {
// Each MetricValue field within this message will be populated with the same
// value type as in MetricKeyAndValue.value. This has the effect of creating
// a set of parallel data structures which provide elementwise confidence
// intervals. For example, if the MetricKeyAndValue.value contains an
// ArrayValue, then each of these fields will also contain an ArrayValue in
// which the array element at a given index will represent the lower bound,
// upper bound, and standard error for the MetricKeyAndValue.value element at
// that same index.
MetricValue upper_bound = 1;
MetricValue lower_bound = 2;
MetricValue standard_error = 3;
google.protobuf.Int64Value degrees_of_freedom = 4;
}
message MetricsForSlice {
message MetricKeyAndValue {
MetricKey key = 1;
MetricValue value = 2;
// When the `confidence_interval` field is populated, the `value` field will
// contain the point estimate.
ConfidenceInterval confidence_interval = 3;
}
oneof slicing_spec_oneof {
// The slice key for the metrics.
SliceKey slice_key = 1;
// The cross slice key for the metrics.
CrossSliceKey cross_slice_key = 3;
}
// Metric keys and values.
repeated MetricKeyAndValue metric_keys_and_values = 51;
// DEPRECATED
// A map to store metrics. Currently we convert the post_export_metric
// provided by TFMA to its appropriate type for better visualization, and map
// all other metrics to DoubleValue type.
map<string, MetricValue> metrics = 2 [deprecated = true];
reserved 53; // slice_key used in prefromance_statistics.proto
}
// Sync the Plot types with PerformanceStatistics because of b/110954446.
// LINT.IfChange
message CalibrationHistogramBuckets {
message Bucket {
double lower_threshold_inclusive = 1;
double upper_threshold_exclusive = 2;
google.protobuf.DoubleValue num_weighted_examples = 3;
google.protobuf.DoubleValue total_weighted_label = 4;
google.protobuf.DoubleValue total_weighted_refined_prediction = 5;
}
repeated Bucket buckets = 1;
}
message MultiClassConfusionMatrixAtThresholds {
message MultiClassConfusionMatrixEntry {
int32 actual_class_id = 1;
int32 predicted_class_id = 2;
double num_weighted_examples = 3;
}
message MultiClassConfusionMatrix {
double threshold = 1;
// Only entries with non-zero num_weighted_examples are included. If the top
// prediction was less than the threshold, then the predict_class_id will be
// set to -1. Entries are sorted in order of actual_class_id followed by
// predicted_class_id.
repeated MultiClassConfusionMatrixEntry entries = 2;
}
// Entries are sorted in order of threshold.
repeated MultiClassConfusionMatrix matrices = 1;
}
message MultiLabelConfusionMatrixAtThresholds {
message MultiLabelConfusionMatrixEntry {
int32 actual_class_id = 1;
int32 predicted_class_id = 2;
double false_negatives = 3;
double true_negatives = 4;
double false_positives = 5;
double true_positives = 6;
}
message MultiLabelConfusionMatrix {
double threshold = 1;
// Only entries with no non-zero values are included. Entries are sorted in
// order of actual_class_id followed by predicted_class_id.
repeated MultiLabelConfusionMatrixEntry entries = 2;
}
// Entries are sorted in order of threshold.
repeated MultiLabelConfusionMatrix matrices = 1;
}
message PlotData {
// For calibration plot and prediction distribution.
CalibrationHistogramBuckets calibration_histogram_buckets = 1;
// For auc curve and auprc curve.
ConfusionMatrixAtThresholds confusion_matrix_at_thresholds = 2;
// For multi-class confusion matrix.
MultiClassConfusionMatrixAtThresholds
multi_class_confusion_matrix_at_thresholds = 4;
// For multi-label confusion matrix.
MultiLabelConfusionMatrixAtThresholds
multi_label_confusion_matrix_at_thresholds = 5;
// This field will contain a generic message to be used to communicate any
// extra information, such as in a scenario when no data is aggregated for a
// small data slice due to privacy concerns.
string debug_message = 3;
}
// LINT.ThenChange(
// ../../../../intelligence/lantern/proto/stats/performance_statistics.proto)
// Sync with PerformanceStatistics because of b/110954446.
// LINT.IfChange
//
// A plot key uniquely identifies a set of PlotData.
message PlotKey {
// Optional plot name associated with plot.
string name = 6;
// Optional model name associated with plot (if multi-model evaluation).
string model_name = 4;
// Optional output name associated with plot (for multi-output models).
string output_name = 2;
// Optional sub key associated with plot.
SubKey sub_key = 3;
// If true, the plot is weighted by examples. If false, then the plot is
// not weighted by examples. If unset then it is unknown as to whether the
// plot was weighted by examples or not. See MetricsSpecs.example_weighted for
// more information.
google.protobuf.BoolValue example_weighted = 5;
}
// LINT.ThenChange(
// ../metrics/metric_types.py,
// ../../../../intelligence/lantern/proto/stats/performance_statistics.proto)
message PlotsForSlice {
message PlotKeyAndValue {
PlotKey key = 1;
PlotData value = 2;
}
oneof slicing_spec_oneof {
// The slice key for the metrics.
SliceKey slice_key = 1;
// The cross slice key for the metrics.
CrossSliceKey cross_slice_key = 4;
}
// Plot keys and values.
repeated PlotKeyAndValue plot_keys_and_values = 8;
// DEPRECATED
// The plot data--deprecated please use 'plots' instead.
PlotData plot_data = 2 [deprecated = true];
// Use this field instead of tfma_plots to support multiple plot evaluations
// in a single evaluator run. Note that each entry of TFMAPlotData should
// contain all plots for the same grouping. eg: for the same head of a
// multihead model or for the same class in the case of multiclass. For
// example, the key can be of the form 'post_export_metrics/head_name' for a
// multihead model.
map<string, PlotData> plots = 3 [deprecated = true];
reserved 9; // slice_key used in prefromance_statistics.proto
}
// LINT.IfChange
// Attribution keys uniquely identify aggregated attribution values.
message AttributionsKey {
// Attribution metric name (e.g. 'mean', 'total', etc)
string name = 1;
// Optional model name (if multi-model evaluation).
string model_name = 2;
// Optional output name (for multi-output models).
string output_name = 3;
// Optional sub key associated with attribution (class_id, etc).
SubKey sub_key = 4;
// If true, the metric is weighted by examples. If false, then the metric is
// not weighted by examples. If unset then it is unknown as to whether the
// metric was weighted by examples or not (i.e. the metrics were defined
// inside of a model). See MetricsSpecs.example_weighted for more information.
google.protobuf.BoolValue example_weighted = 6;
// If true, this is a diff of attributions based on comparison with baseline.
bool is_diff = 5;
}
// LINT.ThenChange(../metrics/metric_types.py)
message AttributionsForSlice {
message AttributionsKeyAndValues {
AttributionsKey key = 1;
// Attribution values keyed by feature key (e.g. 'age', etc).
map<string, MetricValue> values = 2;
}
oneof slicing_spec_oneof {
// The slice key for the metrics.
SliceKey slice_key = 1;
// The cross slice key for the metrics.
CrossSliceKey cross_slice_key = 3;
}
// Attribution keys and values.
repeated AttributionsKeyAndValues attributions_keys_and_values = 2;
}