diff --git a/lite/tests/object_detection/test_accuracy.py b/lite/tests/object_detection/test_accuracy.py new file mode 100644 index 000000000..d2e6f36c4 --- /dev/null +++ b/lite/tests/object_detection/test_accuracy.py @@ -0,0 +1,492 @@ +import numpy as np +from valor_lite.object_detection import DataLoader, Detection, MetricType +from valor_lite.object_detection.computation import compute_metrics + + +def test__compute_average_precision(): + + sorted_pairs = np.array( + [ + # dt, gt, pd, iou, gl, pl, score, + [0.0, 0.0, 2.0, 0.25, 0.0, 0.0, 0.95], + [0.0, 0.0, 3.0, 0.33333, 0.0, 0.0, 0.9], + [0.0, 0.0, 4.0, 0.66667, 0.0, 0.0, 0.65], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.1], + [0.0, 0.0, 1.0, 0.5, 0.0, 0.0, 0.01], + ] + ) + + label_metadata = np.array([[1, 5, 0]]) + iou_thresholds = np.array([0.1, 0.6]) + score_thresholds = np.array([0.0]) + + (_, _, accuracy, _, _) = compute_metrics( + sorted_pairs, + label_metadata=label_metadata, + iou_thresholds=iou_thresholds, + score_thresholds=score_thresholds, + ) + + expected = np.array( + [ + [0.2], # iou = 0.1 + [0.2], # iou = 0.6 + ] + ) + assert (accuracy == expected).all() + + +def test_ap_using_torch_metrics_example( + torchmetrics_detections: list[Detection], +): + """ + cf with torch metrics/pycocotools results listed here: + https://github.com/Lightning-AI/metrics/blob/107dbfd5fb158b7ae6d76281df44bd94c836bfce/tests/unittests/detection/test_map.py#L231 + """ + + loader = DataLoader() + loader.add_bounding_boxes(torchmetrics_detections) + evaluator = loader.finalize() + + assert evaluator.ignored_prediction_labels == ["3"] + assert evaluator.missing_prediction_labels == [] + assert evaluator.n_datums == 4 + assert evaluator.n_labels == 6 + assert evaluator.n_groundtruths == 20 + assert evaluator.n_predictions == 19 + + metrics = evaluator.evaluate( + iou_thresholds=[0.5, 0.75], + as_dict=True, + ) + + # test Accuracy + actual_metrics = [m for m in metrics[MetricType.Accuracy]] + expected_metrics = [ + { + "type": "Accuracy", + "value": 9 / 19, + "parameters": { + "iou_threshold": 0.5, + "score_threshold": 0.5, + }, + }, + { + "type": "Accuracy", + "value": 8 / 19, + "parameters": { + "iou_threshold": 0.75, + "score_threshold": 0.5, + }, + }, + ] + for m in actual_metrics: + assert m in expected_metrics + for m in expected_metrics: + assert m in actual_metrics + + +def test_accuracy_metrics_first_class( + basic_detections_first_class: list[Detection], + basic_rotated_detections_first_class: list[Detection], +): + """ + Basic object detection test. + + groundtruths + datum uid1 + box 1 - label v1 - tp + box 3 - label v2 - fn missing prediction + datum uid2 + box 2 - label v1 - fn missing prediction + + predictions + datum uid1 + box 1 - label v1 - score 0.3 - tp + datum uid2 + box 2 - label v2 - score 0.98 - fp + """ + for input_, method in [ + (basic_detections_first_class, DataLoader.add_bounding_boxes), + (basic_rotated_detections_first_class, DataLoader.add_polygons), + ]: + loader = DataLoader() + method(loader, input_) + evaluator = loader.finalize() + + metrics = evaluator.evaluate( + iou_thresholds=[0.1, 0.6], + score_thresholds=[0.0, 0.5], + as_dict=True, + ) + + assert evaluator.ignored_prediction_labels == [] + assert evaluator.missing_prediction_labels == [] + assert evaluator.n_datums == 2 + assert evaluator.n_labels == 1 + assert evaluator.n_groundtruths == 2 + assert evaluator.n_predictions == 1 + + # test Accuracy + actual_metrics = [m for m in metrics[MetricType.Accuracy]] + expected_metrics = [ + { + "type": "Accuracy", + "value": 1.0, + "parameters": { + "iou_threshold": 0.1, + "score_threshold": 0.0, + }, + }, + { + "type": "Accuracy", + "value": 1.0, + "parameters": { + "iou_threshold": 0.6, + "score_threshold": 0.0, + }, + }, + { + "type": "Accuracy", + "value": 0.0, + "parameters": { + "iou_threshold": 0.1, + "score_threshold": 0.5, + }, + }, + { + "type": "Accuracy", + "value": 0.0, + "parameters": { + "iou_threshold": 0.6, + "score_threshold": 0.5, + }, + }, + ] + for m in actual_metrics: + assert m in expected_metrics + for m in expected_metrics: + assert m in actual_metrics + + +def test_accuracy_metrics_second_class( + basic_detections_second_class: list[Detection], + basic_rotated_detections_second_class: list[Detection], +): + """ + Basic object detection test. + + groundtruths + datum uid1 + box 3 - label v2 - fn missing prediction + datum uid2 + none + predictions + datum uid1 + none + datum uid2 + box 2 - label v2 - score 0.98 - fp + """ + for input_, method in [ + (basic_detections_second_class, DataLoader.add_bounding_boxes), + (basic_rotated_detections_second_class, DataLoader.add_polygons), + ]: + loader = DataLoader() + method(loader, input_) + evaluator = loader.finalize() + + metrics = evaluator.evaluate( + iou_thresholds=[0.1, 0.6], + score_thresholds=[0.0, 0.5], + as_dict=True, + ) + + assert evaluator.ignored_prediction_labels == [] + assert evaluator.missing_prediction_labels == [] + assert evaluator.n_datums == 2 + assert evaluator.n_labels == 1 + assert evaluator.n_groundtruths == 1 + assert evaluator.n_predictions == 1 + + # test Accuracy + actual_metrics = [m for m in metrics[MetricType.Accuracy]] + expected_metrics = [ + { + "type": "Accuracy", + "value": 0.0, + "parameters": { + "iou_threshold": 0.1, + "score_threshold": 0.0, + }, + }, + { + "type": "Accuracy", + "value": 0.0, + "parameters": { + "iou_threshold": 0.6, + "score_threshold": 0.0, + }, + }, + { + "type": "Accuracy", + "value": 0.0, + "parameters": { + "iou_threshold": 0.1, + "score_threshold": 0.5, + }, + }, + { + "type": "Accuracy", + "value": 0.0, + "parameters": { + "iou_threshold": 0.6, + "score_threshold": 0.5, + }, + }, + ] + for m in actual_metrics: + assert m in expected_metrics + for m in expected_metrics: + assert m in actual_metrics + + +def test_accuracy_false_negatives_single_datum_baseline( + false_negatives_single_datum_baseline_detections: list[Detection], +): + """This is the baseline for the below test. In this case there are two predictions and + one groundtruth, but the highest confident prediction overlaps sufficiently with the groundtruth + so there is not a penalty for the false negative so the Accuracy is 1 + """ + + loader = DataLoader() + loader.add_bounding_boxes(false_negatives_single_datum_baseline_detections) + evaluator = loader.finalize() + + metrics = evaluator.evaluate( + iou_thresholds=[0.5], + score_thresholds=[0.0, 0.9], + as_dict=True, + ) + + actual_metrics = [m for m in metrics[MetricType.Accuracy]] + expected_metrics = [ + { + "type": "Accuracy", + "value": 0.5, + "parameters": { + "iou_threshold": 0.5, + "score_threshold": 0.0, + }, + }, + { + "type": "Accuracy", + "value": 0.0, + "parameters": { + "iou_threshold": 0.5, + "score_threshold": 0.9, + }, + }, + ] + for m in actual_metrics: + assert m in expected_metrics + for m in expected_metrics: + assert m in actual_metrics + + +def test_accuracy_false_negatives_single_datum( + false_negatives_single_datum_detections: list[Detection], +): + """Tests where high confidence false negative was not being penalized. The + difference between this test and the above is that here the prediction with higher confidence + does not sufficiently overlap the groundtruth and so is penalized and we get an Accuracy of 0.5 + """ + + loader = DataLoader() + loader.add_bounding_boxes(false_negatives_single_datum_detections) + evaluator = loader.finalize() + metrics = evaluator.evaluate( + iou_thresholds=[0.5], + score_thresholds=[0.0], + as_dict=True, + ) + + actual_metrics = [m for m in metrics[MetricType.Accuracy]] + expected_metrics = [ + { + "type": "Accuracy", + "value": 0.5, + "parameters": { + "iou_threshold": 0.5, + "score_threshold": 0.0, + }, + } + ] + for m in actual_metrics: + assert m in expected_metrics + for m in expected_metrics: + assert m in actual_metrics + + +def test_accuracy_false_negatives_two_datums_one_empty_low_confidence_of_fp( + false_negatives_two_datums_one_empty_low_confidence_of_fp_detections: list[ + Detection + ], +): + """In this test we have + 1. An image with a matching groundtruth and prediction (same class and high IOU) + 2. A second image with empty groundtruth annotation but a prediction with lower confidence + then the prediction on the first image. + + In this case, the Accuracy should be 1.0 since the false positive has lower confidence than the true positive + + """ + + loader = DataLoader() + loader.add_bounding_boxes( + false_negatives_two_datums_one_empty_low_confidence_of_fp_detections + ) + evaluator = loader.finalize() + metrics = evaluator.evaluate( + iou_thresholds=[0.5], + score_thresholds=[0.0], + as_dict=True, + ) + + actual_metrics = [m for m in metrics[MetricType.Accuracy]] + expected_metrics = [ + { + "type": "Accuracy", + "value": 0.5, + "parameters": { + "iou_threshold": 0.5, + "score_threshold": 0.0, + }, + } + ] + for m in actual_metrics: + assert m in expected_metrics + for m in expected_metrics: + assert m in actual_metrics + + +def test_accuracy_false_negatives_two_datums_one_empty_high_confidence_of_fp( + false_negatives_two_datums_one_empty_high_confidence_of_fp_detections: list[ + Detection + ], +): + """In this test we have + 1. An image with a matching groundtruth and prediction (same class and high IOU) + 2. A second image with empty groundtruth annotation and a prediction with higher confidence + then the prediction on the first image. + + In this case, the Accuracy should be 0.5 since the false positive has higher confidence than the true positive + """ + + loader = DataLoader() + loader.add_bounding_boxes( + false_negatives_two_datums_one_empty_high_confidence_of_fp_detections + ) + evaluator = loader.finalize() + metrics = evaluator.evaluate( + iou_thresholds=[0.5], + score_thresholds=[0.0], + as_dict=True, + ) + + actual_metrics = [m for m in metrics[MetricType.Accuracy]] + expected_metrics = [ + { + "type": "Accuracy", + "value": 0.5, + "parameters": { + "iou_threshold": 0.5, + "score_threshold": 0.0, + }, + } + ] + for m in actual_metrics: + assert m in expected_metrics + for m in expected_metrics: + assert m in actual_metrics + + +def test_accuracy_false_negatives_two_datums_one_only_with_different_class_low_confidence_of_fp( + false_negatives_two_datums_one_only_with_different_class_low_confidence_of_fp_detections: list[ + Detection + ], +): + """In this test we have + 1. An image with a matching groundtruth and prediction (same class, `"value"`, and high IOU) + 2. A second image with a groundtruth annotation with class `"other value"` and a prediction with lower confidence + then the prediction on the first image. + + In this case, the Accuracy for class `"value"` should be 1 since the false positive has lower confidence than the true positive. + Accuracy for class `"other value"` should be 0 since there is no prediction for the `"other value"` groundtruth + """ + loader = DataLoader() + loader.add_bounding_boxes( + false_negatives_two_datums_one_only_with_different_class_low_confidence_of_fp_detections + ) + evaluator = loader.finalize() + metrics = evaluator.evaluate( + iou_thresholds=[0.5], + score_thresholds=[0.0], + as_dict=True, + ) + + actual_metrics = [m for m in metrics[MetricType.Accuracy]] + expected_metrics = [ + { + "type": "Accuracy", + "value": 0.5, + "parameters": { + "iou_threshold": 0.5, + "score_threshold": 0.0, + }, + }, + ] + for m in actual_metrics: + assert m in expected_metrics + for m in expected_metrics: + assert m in actual_metrics + + +def test_accuracy_false_negatives_two_datums_one_only_with_different_class_high_confidence_of_fp( + false_negatives_two_images_one_only_with_different_class_high_confidence_of_fp_detections: list[ + Detection + ], +): + """In this test we have + 1. An image with a matching groundtruth and prediction (same class, `"value"`, and high IOU) + 2. A second image with a groundtruth annotation with class `"other value"` and a prediction with higher confidence + then the prediction on the first image. + + In this case, the Accuracy for class `"value"` should be 0.5 since the false positive has higher confidence than the true positive. + Accuracy for class `"other value"` should be 0 since there is no prediction for the `"other value"` groundtruth + """ + loader = DataLoader() + loader.add_bounding_boxes( + false_negatives_two_images_one_only_with_different_class_high_confidence_of_fp_detections + ) + evaluator = loader.finalize() + metrics = evaluator.evaluate( + iou_thresholds=[0.5], + score_thresholds=[0.0], + as_dict=True, + ) + + actual_metrics = [m for m in metrics[MetricType.Accuracy]] + expected_metrics = [ + { + "type": "Accuracy", + "value": 0.5, + "parameters": { + "iou_threshold": 0.5, + "score_threshold": 0.0, + }, + }, + ] + for m in actual_metrics: + assert m in expected_metrics + for m in expected_metrics: + assert m in actual_metrics diff --git a/lite/tests/object_detection/test_average_precision.py b/lite/tests/object_detection/test_average_precision.py index f2b9a6e4f..8697be30f 100644 --- a/lite/tests/object_detection/test_average_precision.py +++ b/lite/tests/object_detection/test_average_precision.py @@ -24,7 +24,7 @@ def test__compute_average_precision(): iou_thresholds = np.array([0.1, 0.6]) score_thresholds = np.array([0.0]) - (results, _, _, _,) = compute_metrics( + (results, _, _, _, _) = compute_metrics( sorted_pairs, label_metadata=label_metadata, iou_thresholds=iou_thresholds, diff --git a/lite/tests/object_detection/test_average_recall.py b/lite/tests/object_detection/test_average_recall.py index 8350ce607..45686c753 100644 --- a/lite/tests/object_detection/test_average_recall.py +++ b/lite/tests/object_detection/test_average_recall.py @@ -25,7 +25,7 @@ def test__compute_average_recall(): iou_thresholds = np.array([0.1, 0.6]) score_thresholds = np.array([0.5, 0.93, 0.98]) - (_, results, _, _,) = compute_metrics( + (_, results, _, _, _,) = compute_metrics( sorted_pairs, label_metadata=label_metadata, iou_thresholds=iou_thresholds, diff --git a/lite/tests/object_detection/test_pr_curve.py b/lite/tests/object_detection/test_pr_curve.py index 8a661cff4..acb997315 100644 --- a/lite/tests/object_detection/test_pr_curve.py +++ b/lite/tests/object_detection/test_pr_curve.py @@ -24,7 +24,7 @@ def test_pr_curve_simple(): iou_thresholds = np.array([0.1, 0.6]) score_thresholds = np.array([0.0]) - (_, _, _, pr_curve) = compute_metrics( + (_, _, _, _, pr_curve) = compute_metrics( sorted_pairs, label_metadata=label_metadata, iou_thresholds=iou_thresholds, diff --git a/lite/valor_lite/object_detection/computation.py b/lite/valor_lite/object_detection/computation.py index 71b5193e0..dfd4a1076 100644 --- a/lite/valor_lite/object_detection/computation.py +++ b/lite/valor_lite/object_detection/computation.py @@ -282,6 +282,7 @@ def compute_metrics( ], NDArray[np.float64], NDArray[np.float64], + NDArray[np.float64], ]: """ Computes Object Detection metrics. @@ -309,13 +310,15 @@ def compute_metrics( Returns ------- - tuple[NDArray, NDArray, NDArray, float] + tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64], float] Average Precision results. - tuple[NDArray, NDArray, NDArray, float] + tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64], float] Average Recall results. - np.ndarray - Precision, Recall, TP, FP, FN, F1 Score, Accuracy. - np.ndarray + NDArray[np.float64] + Accuracy. + NDArray[np.float64] + Precision, Recall, TP, FP, FN, F1 Score. + NDArray[np.float64] Interpolated Precision-Recall Curves. """ @@ -329,9 +332,10 @@ def compute_metrics( elif n_scores == 0: raise ValueError("At least one score threshold must be passed.") - average_precision = np.zeros((n_ious, n_labels)) - average_recall = np.zeros((n_scores, n_labels)) - counts = np.zeros((n_ious, n_scores, n_labels, 7)) + average_precision = np.zeros((n_ious, n_labels), dtype=np.float64) + average_recall = np.zeros((n_scores, n_labels), dtype=np.float64) + accuracy = np.zeros((n_ious, n_scores), dtype=np.float64) + counts = np.zeros((n_ious, n_scores, n_labels, 6), dtype=np.float64) pd_labels = data[:, 5].astype(np.int32) scores = data[:, 6] @@ -417,14 +421,6 @@ def compute_metrics( out=f1_score, ) - accuracy = np.zeros_like(tp_count) - np.divide( - tp_count, - (gt_count + pd_count), - where=(gt_count + pd_count) > 1e-9, - out=accuracy, - ) - counts[iou_idx][score_idx] = np.concatenate( ( tp_count[:, np.newaxis], @@ -433,11 +429,18 @@ def compute_metrics( precision[:, np.newaxis], recall[:, np.newaxis], f1_score[:, np.newaxis], - accuracy[:, np.newaxis], ), axis=1, ) + # caluculate accuracy + total_pd_count = label_metadata[:, 1].sum() + accuracy[iou_idx, score_idx] = ( + (tp_count.sum() / total_pd_count) + if total_pd_count > 1e-9 + else 0.0 + ) + # calculate recall for AR average_recall[score_idx] += recall @@ -552,6 +555,7 @@ def compute_metrics( return ( ap_results, ar_results, + accuracy, counts, pr_curve, ) diff --git a/lite/valor_lite/object_detection/manager.py b/lite/valor_lite/object_detection/manager.py index db3274945..acd11b764 100644 --- a/lite/valor_lite/object_detection/manager.py +++ b/lite/valor_lite/object_detection/manager.py @@ -506,6 +506,7 @@ def compute_precision_recall( average_recall_averaged_over_scores, mean_average_recall_averaged_over_scores, ), + accuracy, precision_recall, pr_curves, ) = compute_metrics( @@ -593,6 +594,16 @@ def compute_precision_recall( ) ] + metrics[MetricType.Accuracy] = [ + Accuracy( + value=float(accuracy[iou_idx, score_idx]), + iou_threshold=iou_thresholds[iou_idx], + score_threshold=score_thresholds[score_idx], + ) + for iou_idx in range(accuracy.shape[0]) + for score_idx in range(accuracy.shape[1]) + ] + metrics[MetricType.PrecisionRecallCurve] = [ PrecisionRecallCurve( precisions=pr_curves[iou_idx, label_idx, :, 0] @@ -650,12 +661,6 @@ def compute_precision_recall( **kwargs, ) ) - metrics[MetricType.Accuracy].append( - Accuracy( - value=float(row[6]), - **kwargs, - ) - ) if as_dict: return { diff --git a/lite/valor_lite/object_detection/metric.py b/lite/valor_lite/object_detection/metric.py index a2dd0d33d..4e5b7c52b 100644 --- a/lite/valor_lite/object_detection/metric.py +++ b/lite/valor_lite/object_detection/metric.py @@ -160,9 +160,9 @@ class Recall(_ClassMetric): pass -class Accuracy(_ClassMetric): +class F1(_ClassMetric): """ - Accuracy metric for a specific class label in object detection. + F1 score for a specific class label in object detection. This class encapsulates a metric value for a particular class label, along with the associated Intersection over Union (IoU) threshold and @@ -190,20 +190,18 @@ class Accuracy(_ClassMetric): pass -class F1(_ClassMetric): +@dataclass +class Accuracy: """ - F1 score for a specific class label in object detection. + Accuracy metric for the object detection task type. - This class encapsulates a metric value for a particular class label, - along with the associated Intersection over Union (IoU) threshold and - confidence score threshold. + This class encapsulates a metric value at a specific Intersection + over Union (IoU) threshold and confidence score threshold. Attributes ---------- value : float The metric value. - label : str - The class label for which the metric is calculated. iou_threshold : float The IoU threshold used to determine matches between predicted and ground truth boxes. score_threshold : float @@ -217,7 +215,22 @@ class F1(_ClassMetric): Converts the instance to a dictionary representation. """ - pass + value: float + iou_threshold: float + score_threshold: float + + def to_metric(self) -> Metric: + return Metric( + type=type(self).__name__, + value=self.value, + parameters={ + "iou_threshold": self.iou_threshold, + "score_threshold": self.score_threshold, + }, + ) + + def to_dict(self) -> dict: + return self.to_metric().to_dict() @dataclass