Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automotive reference implementation #1954

Merged
merged 17 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions automotive/3d-object-detection/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
## Reference implementation fo automotive 3D detection benchmark

## TODO: Instructions for dataset download after it is uploaded somewhere appropriate

## TODO: Instructions for checkpoints downloads after it is uploaded somewhere appropriate

## Running with docker
```
docker build -t auto_inference -f dockerfile.gpu .

docker run --gpus=all -it -v <directory to inference repo>/inference/:/inference -v <directory to waymo dataset>/waymo:/waymo --rm auto_inference

cd /inference/automotive/3d-object-detection
python main.py --dataset waymo --dataset-path /waymo/kitti_format/ --lidar-path <checkpoint_path>/pp_ep36.pth --segmentor-path <checkpoint_path>/best_deeplabv3plus_resnet50_waymo_os16.pth --mlperf_conf /inference/mlperf.conf
128 changes: 128 additions & 0 deletions automotive/3d-object-detection/accuracy_waymo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""
Tool to calculate accuracy for loadgen accuracy output found in mlperf_log_accuracy.json
We assume that loadgen's query index is in the same order as
the images in coco's annotations/instances_val2017.json.
"""

from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import argparse
import json
import os

import numpy as np
from waymo import Waymo
from tools.evaluate import do_eval
# pylint: disable=missing-docstring
CLASSES = Waymo.CLASSES
LABEL2CLASSES = {v: k for k, v in CLASSES.items()}


def get_args():
"""Parse commandline."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--mlperf-accuracy-file",
required=True,
help="path to mlperf_log_accuracy.json")
parser.add_argument(
"--waymo-dir",
required=True,
help="waymo dataset directory")
parser.add_argument(
"--verbose",
action="store_true",
help="verbose messages")
parser.add_argument(
"--output-file",
default="openimages-results.json",
help="path to output file")
parser.add_argument(
"--use-inv-map",
action="store_true",
help="use inverse label map")
args = parser.parse_args()
return args


def main():
args = get_args()

with open(args.mlperf_accuracy_file, "r") as f:
results = json.load(f)

detections = {}
image_ids = set()
seen = set()
no_results = 0

val_dataset = Waymo(
data_root=args.waymo_dir,
split='val',
painted=True,
cam_sync=False)

for j in results:
idx = j['qsl_idx']
# de-dupe in case loadgen sends the same image multiple times
if idx in seen:
continue
seen.add(idx)

# reconstruct from mlperf accuracy log
# what is written by the benchmark is an array of float32's:
# id, box[0], box[1], box[2], box[3], score, detection_class
# note that id is a index into instances_val2017.json, not the actual
# image_id
data = np.frombuffer(bytes.fromhex(j['data']), np.float32)

for i in range(0, len(data), 14):
dimension = [float(x) for x in data[i:i + 3]]
location = [float(x) for x in data[i + 3:i + 6]]
rotation_y = float(data[i + 6])
bbox = [float(x) for x in data[i + 7:i + 11]]
label = int(data[i + 11])
score = float(data[i + 12])
image_idx = int(data[i + 13])
if image_idx not in detections:
detections[image_idx] = {
'name': [],
'dimensions': [],
'location': [],
'rotation_y': [],
'bbox': [],
'score': []
}

detections[image_idx]['name'].append(LABEL2CLASSES[label])
detections[image_idx]['dimensions'].append(dimension)
detections[image_idx]['location'].append(location)
detections[image_idx]['rotation_y'].append(rotation_y)
detections[image_idx]['bbox'].append(bbox)
detections[image_idx]['score'].append(score)
image_ids.add(image_idx)

with open(args.output_file, "w") as fp:
json.dump(detections, fp, sort_keys=True, indent=4)
format_results = {}
for key in detections.keys():
format_results[key] = {k: np.array(v)
for k, v in detections[key].items()}
map_stats = do_eval(
format_results,
val_dataset.data_infos,
CLASSES,
cam_sync=False)

print(map_stats)
if args.verbose:
print("found {} results".format(len(results)))
print("found {} images".format(len(image_ids)))
print("found {} images with no results".format(no_results))
print("ignored {} dupes".format(len(results) - len(seen)))


if __name__ == "__main__":
main()
21 changes: 21 additions & 0 deletions automotive/3d-object-detection/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
abstract backend class
"""


class Backend:
def __init__(self):
self.inputs = []
self.outputs = []

def version(self):
raise NotImplementedError("Backend:version")

def name(self):
raise NotImplementedError("Backend:name")

def load(self, model_path, inputs=None, outputs=None):
raise NotImplementedError("Backend:load")

def predict(self, feed):
raise NotImplementedError("Backend:predict")
24 changes: 24 additions & 0 deletions automotive/3d-object-detection/backend_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
import backend


class BackendDebug(backend.Backend):
def __init__(self, image_size=[3, 1024, 1024], **kwargs):
super(BackendDebug, self).__init__()
self.image_size = image_size

def version(self):
return torch.__version__

def name(self):
return "debug-SUT"

def image_format(self):
return "NCHW"

def load(self):
return self

def predict(self, prompts):
images = []
return images
155 changes: 155 additions & 0 deletions automotive/3d-object-detection/backend_deploy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from typing import Optional, List, Union
import os
import torch
import logging
import backend
from collections import namedtuple
from model.painter import Painter
from model.pointpillars import PointPillars
import numpy as np
from tools.process import keep_bbox_from_image_range
from waymo import Waymo


logging.basicConfig(level=logging.INFO)
log = logging.getLogger("backend-pytorch")


def change_calib_device(calib, cuda):
result = {}
if cuda:
device = 'cuda'
else:
device = 'cpu'
result['R0_rect'] = calib['R0_rect'].to(device=device, dtype=torch.float)
for i in range(5):
result['P' + str(i)] = calib['P' + str(i)
].to(device=device, dtype=torch.float)
result['Tr_velo_to_cam_' +
str(i)] = calib['Tr_velo_to_cam_' +
str(i)].to(device=device, dtype=torch.float)
return result


class BackendDeploy(backend.Backend):
def __init__(
self,
segmentor_path,
lidar_detector_path,
data_path
):
super(BackendDeploy, self).__init__()
self.segmentor_path = segmentor_path
self.lidar_detector_path = lidar_detector_path
# self.segmentation_classes = 18
self.detection_classes = 3
self.data_root = data_path
CLASSES = Waymo.CLASSES
self.LABEL2CLASSES = {v: k for k, v in CLASSES.items()}

def version(self):
return torch.__version__

def name(self):
return "python-SUT"

def load(self):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
PaintArgs = namedtuple(
'PaintArgs', [
'training_path', 'model_path', 'cam_sync'])
painting_args = PaintArgs(
os.path.join(
self.data_root,
'training'),
self.segmentor_path,
False)
self.painter = Painter(painting_args)
self.segmentor = self.painter.model
model = PointPillars(
nclasses=self.detection_classes,
painted=True).to(
device=device)
model.eval()
checkpoint = torch.load(self.lidar_detector_path)
model.load_state_dict(checkpoint["model_state_dict"])
self.lidar_detector = model

return self

def predict(self, inputs):
# TODO: implement predict
dimensions, locations, rotation_y, box2d, class_labels, class_scores, ids = [
], [], [], [], [], [], []
with torch.inference_mode():
device = torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu")
format_results = {}
model_input = inputs[0]
batched_pts = model_input['pts']
scores_from_cam = []
for i in range(len(model_input['images'])):
segmentation_score = self.segmentor(
model_input['images'][i].to(device))[0]
scores_from_cam.append(
self.painter.get_score(segmentation_score).cpu())
points = self.painter.augment_lidar_class_scores_both(
scores_from_cam, batched_pts, model_input['calib_info'])
batch_results = self.lidar_detector(
batched_pts=[points.to(device=device)], mode='val')
for j, result in enumerate(batch_results):
format_result = {
'class': [],
'truncated': [],
'occluded': [],
'alpha': [],
'bbox': [],
'dimensions': [],
'location': [],
'rotation_y': [],
'score': [],
'idx': -1
}

calib_info = model_input['calib_info']
image_info = model_input['image_info']
idx = model_input['image_info']['image_idx']

calib_info = change_calib_device(calib_info, False)
result_filter = keep_bbox_from_image_range(
result, calib_info, 5, image_info, False)

lidar_bboxes = result_filter['lidar_bboxes']
labels, scores = result_filter['labels'], result_filter['scores']
bboxes2d, camera_bboxes = result_filter['bboxes2d'], result_filter['camera_bboxes']
for lidar_bbox, label, score, bbox2d, camera_bbox in \
zip(lidar_bboxes, labels, scores, bboxes2d, camera_bboxes):
format_result['class'].append(label.item())
format_result['truncated'].append(0.0)
format_result['occluded'].append(0)
alpha = camera_bbox[6] - \
np.arctan2(camera_bbox[0], camera_bbox[2])
format_result['alpha'].append(alpha.item())
format_result['bbox'].append(bbox2d.tolist())
format_result['dimensions'].append(camera_bbox[3:6])
format_result['location'].append(camera_bbox[:3])
format_result['rotation_y'].append(camera_bbox[6].item())
format_result['score'].append(score.item())
format_results['idx'] = idx

# write_label(format_result, os.path.join(saved_submit_path, f'{idx:06d}.txt'))

if len(format_result['dimensions']) > 0:
format_result['dimensions'] = torch.stack(
format_result['dimensions'])
format_result['location'] = torch.stack(
format_result['location'])
dimensions.append(format_result['dimensions'])
locations.append(format_result['location'])
rotation_y.append(format_result['rotation_y'])
class_labels.append(format_result['class'])
class_scores.append(format_result['score'])
box2d.append(format_result['bbox'])
ids.append(format_results['idx'])
# return Boxes, Classes, Scores # Change to desired output
return dimensions, locations, rotation_y, box2d, class_labels, class_scores, ids
Loading
Loading