Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit max number of scores associated with max boxes. #226

Merged
merged 9 commits into from
May 24, 2021
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions argoverse/evaluation/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,10 @@ def accumulate(
logger.info(f"{dt_filtered.shape[0]} detections")
logger.info(f"{gt_filtered.shape[0]} ground truth")
if dt_filtered.shape[0] > 0:
ranked_detections, scores = rank(dt_filtered)
metrics = assign(ranked_detections, gt_filtered, cfg)
ranked_label_records = rank(dt_filtered)
benjaminrwilson marked this conversation as resolved.
Show resolved Hide resolved
metrics = assign(ranked_label_records, gt_filtered, cfg)

scores = [[record.score] for record in ranked_label_records.tolist()]
cls_to_accum[class_name] = np.hstack((metrics, scores))

cls_to_ninst[class_name] = gt_filtered.shape[0]
Expand Down Expand Up @@ -212,10 +214,6 @@ def assign(dts: np.ndarray, gts: np.ndarray, cfg: DetectionCfg) -> np.ndarray:
of true positive thresholds used for AP computation and S is the number of true positive errors.
"""

# Ensure the number of boxes considered per class is at most `MAX_NUM_BOXES`.
benjaminrwilson marked this conversation as resolved.
Show resolved Hide resolved
if dts.shape[0] > MAX_NUM_BOXES:
dts = dts[:MAX_NUM_BOXES]

n_threshs = len(cfg.affinity_threshs)
metrics = np.zeros((dts.shape[0], n_threshs + N_TP_ERRORS))

Expand Down Expand Up @@ -323,20 +321,24 @@ def filter_instances(
return filtered_instances


def rank(dts: List[ObjectLabelRecord]) -> Tuple[np.ndarray, np.ndarray]:
"""Get the rankings for the detections, according to detector confidence.
def rank(label_records: np.ndarray) -> np.ndarray:
"""Rank the `ObjectLabelRecord` objects in descending order by score.
benjaminrwilson marked this conversation as resolved.
Show resolved Hide resolved

Args:
dts: Detections (N,).
label_records: Array of `ObjectLabelRecord` objects. (N,).
benjaminrwilson marked this conversation as resolved.
Show resolved Hide resolved

Returns:
ranks: Ranking for the detections (N,).
scores: Detection scores (N,).
ranked_label_records: Array of `ObjectLabelRecord` objects ranked by score. (N,).
"""
scores = np.array([dt.score for dt in dts])
scores = np.array([dt.score for dt in label_records.tolist()])
ranks = scores.argsort()[::-1]
ranked_detections = dts[ranks]
return ranked_detections, scores[:, np.newaxis]

ranked_label_records = label_records[ranks]

# Ensure the number of boxes considered per class is at most `MAX_NUM_BOXES`.
if ranked_label_records.shape[0] > MAX_NUM_BOXES:
ranked_label_records = ranked_label_records[:MAX_NUM_BOXES]
return ranked_label_records


def interp(prec: np.ndarray, method: InterpType = InterpType.ALL) -> np.ndarray:
Expand Down