Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add units to variables #170

Merged
merged 53 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
5fa2b67
Distance Threshold and ROI Pruning
Apr 3, 2023
1427ee5
Distance Threshold and ROI Pruning
Apr 3, 2023
fa93ea1
Hierarchical AP
Apr 3, 2023
38398da
Update eval.py
neeharperi Apr 4, 2023
c412ab2
Fix partial lint + typing.
benjaminrwilson Apr 5, 2023
36f89d7
Reformat.
benjaminrwilson Apr 5, 2023
448b2c2
fix some linting and pytest issues
Redrew Apr 5, 2023
bd32596
PR Changes
Apr 7, 2023
a4d82f5
Fixed Linting
Apr 7, 2023
8cf9c37
Black Autoformat
Apr 7, 2023
a69acbe
Update src/av2/evaluation/detection/constants.py
neeharperi Apr 8, 2023
6251fd8
Update src/av2/evaluation/detection/eval.py
neeharperi Apr 8, 2023
2edd244
PR Refinements
Apr 8, 2023
ebc1315
Minor Fix
Apr 8, 2023
7f7cdb7
tidy code, fix forecasting evaluation
Redrew Apr 9, 2023
1f6ebaa
fix imports, return tuned metric values from tracking evaluation
Redrew Apr 9, 2023
662614b
fix linting
Redrew Apr 9, 2023
582716a
fix typing
Redrew Apr 9, 2023
74a538a
fix ruff
Redrew Apr 10, 2023
2cab177
Test revert reformatting.
benjaminrwilson Apr 11, 2023
7ca2e76
Merge remote-tracking branch 'upstream/main'
benjaminrwilson Apr 11, 2023
78623cd
Undo formatting.
benjaminrwilson Apr 11, 2023
bc7acc1
Undo additional formatting.
benjaminrwilson Apr 11, 2023
0ecb2d5
Revert a few more formatting changes.
benjaminrwilson Apr 11, 2023
b89949f
Simplify expressions.
benjaminrwilson Apr 11, 2023
bb89ddb
Refactor eval.
benjaminrwilson Apr 11, 2023
6c8fcb0
Fix imports.
benjaminrwilson Apr 11, 2023
19c4346
Fix lint.
benjaminrwilson Apr 11, 2023
5bc2afe
Fix lint.
benjaminrwilson Apr 11, 2023
b79f420
Consolidate constants.
benjaminrwilson Apr 11, 2023
fb577f6
Clean up.
benjaminrwilson Apr 11, 2023
99ae976
Add unit test stubs.
benjaminrwilson Apr 11, 2023
67c0e4f
Update typing.
benjaminrwilson Apr 11, 2023
99649a6
Fix typing + fix lint in detection eval.
benjaminrwilson Apr 11, 2023
d3722b5
Change detection args names.
benjaminrwilson Apr 11, 2023
a6688b2
Fix typing.
benjaminrwilson Apr 11, 2023
449b681
Fix lint.
benjaminrwilson Apr 13, 2023
89672e0
Fix mypy.
benjaminrwilson Apr 13, 2023
7b0c5f0
Reduce number of conversions.
benjaminrwilson Apr 13, 2023
45061f8
Make lca columns a constant.
benjaminrwilson Apr 13, 2023
7d281c5
fix group_frames bug
Redrew Apr 13, 2023
f0d6e37
Add Units
Apr 15, 2023
d1c2085
Minor Fix
Apr 16, 2023
be24157
Merge with upstream/main.
benjaminrwilson Apr 18, 2023
6c2bd4e
Fix duplicate imports.
benjaminrwilson Apr 18, 2023
16bc154
Remove additional duplicate lines.
benjaminrwilson Apr 18, 2023
88ba805
Merge branch 'argoverse:main' into main
neeharperi Apr 21, 2023
028d08b
Add files via upload
neeharperi Apr 21, 2023
5ea4099
Merge branch 'argoverse:main' into main
neeharperi Apr 25, 2023
7670c97
Remove PYC
Apr 25, 2023
96d53d6
Remove __pycache__
Apr 25, 2023
3bd61cb
Fix mypy
Apr 25, 2023
77e5513
Run black.
benjaminrwilson Apr 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
*.pyc

# PyInstaller
# Usually these files are written by a python script from a template
Expand Down
16 changes: 9 additions & 7 deletions src/av2/evaluation/detection/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@

import numpy as np
import pandas as pd
from tqdm import tqdm
from av2.evaluation.detection.constants import (
HIERARCHY,
LCA,
Expand Down Expand Up @@ -442,14 +443,14 @@ def evaluate_hierarchy(
gts_categories_list.append(sweep_gts_categories)

num_dts = len(sweep_dts)
num_gts = len(sweep_dts)
num_gts = len(sweep_gts)
dts_uuids_list.extend(num_dts * [uuid])
gts_uuids_list.extend(num_gts * [uuid])

dts_npy = np.concatenate(dts).astype(np.float64)
gts_npy = np.concatenate(gts).astype(np.float64)
dts_categories_npy = np.concatenate(dts_categories).astype(np.object_)
gts_categories_npy = np.concatenate(gts_categories).astype(np.object_)
dts_npy = np.concatenate(dts_list).astype(np.float64)
gts_npy = np.concatenate(gts_list).astype(np.float64)
dts_categories_npy = np.concatenate(dts_categories_list).astype(np.object_)
gts_categories_npy = np.concatenate(gts_categories_list).astype(np.object_)
dts_uuids_npy = np.array(dts_uuids_list)
gts_uuids_npy = np.array(gts_uuids_list)

Expand Down Expand Up @@ -487,8 +488,9 @@ def evaluate_hierarchy(
)

logger.info("Starting evaluation ...")
with mp.get_context("spawn").Pool(processes=n_jobs) as p:
accumulate_outputs: Any = p.starmap(accumulate_hierarchy, accumulate_hierarchy_args_list)
accumulate_outputs = []
for accumulate_args in tqdm(accumulate_hierarchy_args_list):
accumulate_outputs.append(accumulate_hierarchy(*accumulate_args))

super_categories = list(HIERARCHY.keys())
metrics = np.zeros((len(cfg.categories), len(HIERARCHY.keys())))
Expand Down
14 changes: 9 additions & 5 deletions src/av2/evaluation/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def accumulate_hierarchy(
fp: Dict[int, Any] = {}
gt_name: Dict[int, List[Any]] = {}
pred_name: Dict[int, List[Any]] = {}
taken: Dict[int, Set[Tuple[Any, Any]]] = {}
taken: Dict[int, Set[Tuple[Any, Any, Any]]] = {}
for i in range(len(cfg.affinity_thresholds_m)):
tp[i] = []
fp[i] = []
Expand All @@ -292,7 +292,11 @@ def accumulate_hierarchy(
min_dist = len(cfg.affinity_thresholds_m) * [np.inf]
match_gt_idx = len(cfg.affinity_thresholds_m) * [None]

keep_sweep = gts_uuids == np.array([gts.shape[0] * [pred_uuid]]).squeeze()
if len(gts_uuids) > 0:
keep_sweep = np.all(gts_uuids == np.array([gts.shape[0] * [pred_uuid]]).squeeze(), axis=1)
else:
keep_sweep = []

gt_ind_sweep = np.arange(gts.shape[0])[keep_sweep]
gts_sweep = gts[keep_sweep]
gts_cats_sweep = gts_cats[keep_sweep]
Expand All @@ -303,7 +307,7 @@ def accumulate_hierarchy(

# Find closest match among ground truth boxes
for i in range(len(cfg.affinity_thresholds_m)):
if gt_cat == cat and not (pred_uuid, gt_idx) in taken[i]:
if gt_cat == cat and not (pred_uuid[0], pred_uuid[1], gt_idx) in taken[i]:
this_distance = dist_mat[pred_idx][gt_idx]
if this_distance < min_dist[i]:
min_dist[i] = this_distance
Expand All @@ -316,7 +320,7 @@ def accumulate_hierarchy(
# Find closest match among ground truth boxes

for i in range(len(cfg.affinity_thresholds_m)):
if not is_match[i] and not (pred_uuid, gt_idx) in taken[i]:
if not is_match[i] and not (pred_uuid[0], pred_uuid[1], gt_idx) in taken[i]:
this_distance = dist_mat[pred_idx][gt_idx]
if this_distance < min_dist[i]:
min_dist[i] = this_distance
Expand All @@ -330,7 +334,7 @@ def accumulate_hierarchy(

for i in range(len(cfg.affinity_thresholds_m)):
if is_match[i]:
taken[i].add((pred_uuid, gt_idx))
taken[i].add((pred_uuid[0], pred_uuid[1], gt_idx))
tp[i].append(1)
fp[i].append(0)

Expand Down
94 changes: 0 additions & 94 deletions src/av2/evaluation/forecasting/SUBMISSION_FORMAT.md

This file was deleted.

76 changes: 38 additions & 38 deletions src/av2/evaluation/forecasting/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def evaluate(
ground_truth: ForecastSequences = convert_forecast_labels(raw_ground_truth)
ground_truth = filter_max_dist(ground_truth, max_range_m)

utils.annotate_frame_metadata(predictions, ground_truth, ["ego_translation"])
utils.annotate_frame_metadata(predictions, ground_truth, ["ego_translation_m"])
predictions = filter_max_dist(predictions, max_range_m)

if dataset_dir is not None:
Expand All @@ -77,20 +77,20 @@ def evaluate(
pred = []

for agent in gt:
if agent["future_translation"].shape[0] < 1:
if agent["future_translation_m"].shape[0] < 1:
continue

agent["seq_id"] = seq_id
agent["timestamp"] = timestamp_ns
agent["velocity"] = utils.agent_velocity(agent)
agent["timestamp_ns"] = timestamp_ns
agent["velocity_m_per_s"] = utils.agent_velocity_m_per_s(agent)
agent["trajectory_type"] = utils.trajectory_type(agent, CATEGORY_TO_VELOCITY_M_PER_S)

gt_agents.append(agent)

for agent in pred:
agent["seq_id"] = seq_id
agent["timestamp"] = timestamp_ns
agent["velocity"] = utils.agent_velocity(agent)
agent["timestamp_ns"] = timestamp_ns
agent["velocity_m_per_s"] = utils.agent_velocity_m_per_s(agent)
agent["trajectory_type"] = utils.trajectory_type(agent, CATEGORY_TO_VELOCITY_M_PER_S)

pred_agents.append(agent)
Expand Down Expand Up @@ -138,8 +138,8 @@ def accumulate(
"""Perform matching between predicted and ground truth trajectories.

Args:
pred_agents: List of predicted trajectories for a given log_id and timestamp.
gt_agents: List of ground truth trajectories for a given log_id and timestamp.
pred_agents: List of predicted trajectories for a given log_id and timestamp_ns.
gt_agents: List of ground truth trajectories for a given log_id and timestamp_ns.
top_k: Number of future trajectories to consider when evaluating Forecastin AP, ADE and FDE (K=5 by default).
class_name: Match class name (e.g. car, pedestrian, bicycle) to determine if a trajectory is included
in evaluation.
Expand All @@ -165,7 +165,7 @@ def match(gt: str, pred: str, profile: str) -> bool:
sortind = [i for (v, i) in sorted((v, i) for (i, v) in enumerate(conf))][::-1]
gt_agents_by_frame = defaultdict(list)
for agent in gt:
gt_agents_by_frame[(agent["seq_id"], agent["timestamp"])].append(agent)
gt_agents_by_frame[(agent["seq_id"], agent["timestamp_ns"])].append(agent)

npos = len(gt)
# ---------------------------------------------
Expand All @@ -179,12 +179,12 @@ def match(gt: str, pred: str, profile: str) -> bool:
min_dist = np.inf
match_gt_idx = None

gt_agents_in_frame = gt_agents_by_frame[(pred_agent["seq_id"], pred_agent["timestamp"])]
gt_agents_in_frame = gt_agents_by_frame[(pred_agent["seq_id"], pred_agent["timestamp_ns"])]
for gt_idx, gt_agent in enumerate(gt_agents_in_frame):
if not (pred_agent["seq_id"], pred_agent["timestamp"], gt_idx) in taken:
if not (pred_agent["seq_id"], pred_agent["timestamp_ns"], gt_idx) in taken:
# Find closest match among ground truth boxes
this_distance = utils.center_distance(
gt_agent["current_translation"], pred_agent["current_translation"]
gt_agent["current_translation_m"], pred_agent["current_translation_m"]
)
if this_distance < min_dist:
min_dist = this_distance
Expand All @@ -194,18 +194,18 @@ def match(gt: str, pred: str, profile: str) -> bool:
is_match = min_dist < threshold

if is_match and match_gt_idx is not None:
taken.add((pred_agent["seq_id"], pred_agent["timestamp"], match_gt_idx))
taken.add((pred_agent["seq_id"], pred_agent["timestamp_ns"], match_gt_idx))
gt_match_agent = gt_agents_in_frame[match_gt_idx]

gt_len = gt_match_agent["future_translation"].shape[0]
gt_len = gt_match_agent["future_translation_m"].shape[0]
forecast_match_th = [threshold + constants.FORECAST_SCALAR[i] * velocity for i in range(gt_len + 1)]

if top_k == 1:
ind = cast(int, np.argmax(pred_agent["score"]))
forecast_dist = [
utils.center_distance(
gt_match_agent["future_translation"][i],
pred_agent["prediction"][ind][i],
gt_match_agent["future_translation_m"][i],
pred_agent["prediction_m"][ind][i],
)
for i in range(gt_len)
]
Expand All @@ -221,8 +221,8 @@ def match(gt: str, pred: str, profile: str) -> bool:
for ind in range(top_k):
curr_forecast_dist = [
utils.center_distance(
gt_match_agent["future_translation"][i],
pred_agent["prediction"][ind][i],
gt_match_agent["future_translation_m"][i],
pred_agent["prediction_m"][ind][i],
)
for i in range(gt_len)
]
Expand Down Expand Up @@ -303,27 +303,27 @@ def convert_forecast_labels(labels: Any) -> Any:
frame_dict = {}
for frame_idx, frame in enumerate(frames):
forecast_instances = []
for instance in utils.array_dict_iterator(frame, len(frame["translation"])):
for instance in utils.array_dict_iterator(frame, len(frame["translation_m"])):
future_translations: Any = []
for future_frame in frames[frame_idx + 1 : frame_idx + 1 + constants.NUM_TIMESTEPS]:
if instance["track_id"] not in future_frame["track_id"]:
break
future_translations.append(
future_frame["translation"][future_frame["track_id"] == instance["track_id"]][0]
future_frame["translation_m"][future_frame["track_id"] == instance["track_id"]][0]
)

if len(future_translations) == 0:
continue

forecast_instances.append(
{
"current_translation": instance["translation"][:2],
"ego_translation": instance["ego_translation"][:2],
"future_translation": np.array(future_translations)[:, :2],
"current_translation_m": instance["translation_m"][:2],
"ego_translation_m": instance["ego_translation_m"][:2],
"future_translation_m": np.array(future_translations)[:, :2],
"name": instance["name"],
"size": instance["size"],
"yaw": instance["yaw"],
"velocity": instance["velocity"][:2],
"velocity_m_per_s": instance["velocity_m_per_s"][:2],
"label": instance["label"],
}
)
Expand All @@ -346,14 +346,14 @@ def filter_max_dist(forecasts: ForecastSequences, max_range_m: int) -> ForecastS
Dictionary of tracks.
"""
for seq_id in forecasts.keys():
for timestamp in forecasts[seq_id].keys():
for timestamp_ns in forecasts[seq_id].keys():
keep_forecasts = [
agent
for agent in forecasts[seq_id][timestamp]
if "ego_translation" in agent
and np.linalg.norm(agent["current_translation"] - agent["ego_translation"]) < max_range_m
for agent in forecasts[seq_id][timestamp_ns]
if "ego_translation_m" in agent
and np.linalg.norm(agent["current_translation_m"] - agent["ego_translation_m"]) < max_range_m
]
forecasts[seq_id][timestamp] = keep_forecasts
forecasts[seq_id][timestamp_ns] = keep_forecasts

return forecasts

Expand Down Expand Up @@ -386,23 +386,23 @@ def filter_drivable_area(forecasts: ForecastSequences, dataset_dir: str) -> Fore
for log_id in log_ids:
avm = log_id_to_avm[log_id]

for timestamp in forecasts[log_id]:
city_SE3_ego = log_id_to_timestamped_poses[log_id][int(timestamp)]
for timestamp_ns in forecasts[log_id]:
city_SE3_ego = log_id_to_timestamped_poses[log_id][int(timestamp_ns)]

translation, size, quat = [], [], []
translation_m, size, quat = [], [], []

if len(forecasts[log_id][timestamp]) == 0:
if len(forecasts[log_id][timestamp_ns]) == 0:
continue

for box in forecasts[log_id][timestamp]:
translation.append(box["current_translation"] - box["ego_translation"])
for box in forecasts[log_id][timestamp_ns]:
translation_m.append(box["current_translation_m"] - box["ego_translation_m"])
size.append(box["size"])
quat.append(yaw_to_quaternion3d(box["yaw"]))

score = np.ones((len(translation), 1))
score = np.ones((len(translation_m), 1))
boxes = np.concatenate(
[
np.array(translation),
np.array(translation_m),
np.array(size),
np.array(quat),
np.array(score),
Expand All @@ -411,7 +411,7 @@ def filter_drivable_area(forecasts: ForecastSequences, dataset_dir: str) -> Fore
)

is_evaluated = compute_objects_in_roi_mask(boxes, city_SE3_ego, avm)
forecasts[log_id][timestamp] = list(np.array(forecasts[log_id][timestamp])[is_evaluated])
forecasts[log_id][timestamp_ns] = list(np.array(forecasts[log_id][timestamp_ns])[is_evaluated])

return forecasts

Expand Down
Loading