diff --git a/src/av2/evaluation/tracking/eval.py b/src/av2/evaluation/tracking/eval.py index 56d56b9f..98124789 100644 --- a/src/av2/evaluation/tracking/eval.py +++ b/src/av2/evaluation/tracking/eval.py @@ -83,7 +83,10 @@ def _load_raw_file( raw_data = { f"{source}_ids": [frame["track_id"] for frame in tracks], - f"{source}_classes": [frame["label"] for frame in tracks], + f"{source}_classes": [ + np.array([self.full_class_list.index(n) for n in frame["name"]]) + for frame in tracks + ], f"{source}_dets": [ np.concatenate((frame["translation_m"], frame["size"]), axis=-1) for frame in tracks