Skip to content

Commit

Permalink
Merge pull request #226 from benjaminrwilson/detection-eval-max-boxes…
Browse files Browse the repository at this point in the history
…-fix

Limit max number of scores associated with max boxes.
  • Loading branch information
jagjeet-singh authored May 24, 2021
2 parents 27134b9 + cbd1e06 commit 778da7e
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions argoverse/evaluation/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,10 @@ def accumulate(
logger.info(f"{dt_filtered.shape[0]} detections")
logger.info(f"{gt_filtered.shape[0]} ground truth")
if dt_filtered.shape[0] > 0:
ranked_detections, scores = rank(dt_filtered)
metrics = assign(ranked_detections, gt_filtered, cfg)
cls_to_accum[class_name] = np.hstack((metrics, scores))
ranked_dts, ranked_scores = rank(dt_filtered)

metrics = assign(ranked_dts, gt_filtered, cfg)
cls_to_accum[class_name] = np.hstack((metrics, ranked_scores[:, None]))

cls_to_ninst[class_name] = gt_filtered.shape[0]
return cls_to_accum, cls_to_ninst
Expand Down Expand Up @@ -212,10 +213,6 @@ def assign(dts: np.ndarray, gts: np.ndarray, cfg: DetectionCfg) -> np.ndarray:
of true positive thresholds used for AP computation and S is the number of true positive errors.
"""

# Ensure the number of boxes considered per class is at most `MAX_NUM_BOXES`.
if dts.shape[0] > MAX_NUM_BOXES:
dts = dts[:MAX_NUM_BOXES]

n_threshs = len(cfg.affinity_threshs)
metrics = np.zeros((dts.shape[0], n_threshs + N_TP_ERRORS))

Expand Down Expand Up @@ -323,20 +320,28 @@ def filter_instances(
return filtered_instances


def rank(dts: List[ObjectLabelRecord]) -> Tuple[np.ndarray, np.ndarray]:
"""Get the rankings for the detections, according to detector confidence.
def rank(dts: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Rank the detections in descending order according to score (detector confidence).
Args:
dts: Detections (N,).
dts: Array of `ObjectLabelRecord` objects. (N,).
Returns:
ranks: Ranking for the detections (N,).
scores: Detection scores (N,).
ranked_dts: Array of `ObjectLabelRecord` objects ranked by score (N,) where N <= MAX_NUM_BOXES.
ranked_scores: Array of floats sorted in descending order (N,) where N <= MAX_NUM_BOXES.
"""
scores = np.array([dt.score for dt in dts])
scores = np.array([dt.score for dt in dts.tolist()])
ranks = scores.argsort()[::-1]
ranked_detections = dts[ranks]
return ranked_detections, scores[:, np.newaxis]

ranked_dts = dts[ranks]
ranked_scores = scores[ranks]

# Ensure the number of boxes considered per class is at most `MAX_NUM_BOXES`.
if ranked_dts.shape[0] > MAX_NUM_BOXES:
ranked_dts = ranked_dts[:MAX_NUM_BOXES]
ranked_scores = ranked_scores[:MAX_NUM_BOXES]
return ranked_dts, ranked_scores


def interp(prec: np.ndarray, method: InterpType = InterpType.ALL) -> np.ndarray:
Expand Down

0 comments on commit 778da7e

Please sign in to comment.