Skip to content

Commit

Permalink
Automotive reference implementation (#1954)
Browse files Browse the repository at this point in the history
* Automotive reference implementation sketch

* WIP: automotive reference implementation

* WIP add segmentation, dataset, and util functions to reference

* WIP: reference implementation with issues during post processing

* WIP update dockerfile remove pdb breaks

* WIP: reference implementation that runs samples

* Update README.md with initial docker runs

* WIP: add accuracy checker to reference

* Fix: set lidar detector to evaluation mode

* [Automated Commit] Format Codebase

* Update README.md

---------

Co-authored-by: Pablo Gonzalez <pablo.gonzalez@factored.ai>
Co-authored-by: Radoyeh Shojaei <radoyeh@gmail.com>
Co-authored-by: arjunsuresh <arjunsuresh@users.noreply.github.com>
Co-authored-by: Miro <mirhodak@amd.com>
  • Loading branch information
5 people authored Dec 21, 2024
1 parent be6ff52 commit 2fdb814
Show file tree
Hide file tree
Showing 33 changed files with 4,629 additions and 0 deletions.
14 changes: 14 additions & 0 deletions automotive/3d-object-detection/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
## Reference implementation fo automotive 3D detection benchmark

## TODO: Instructions for dataset download after it is uploaded somewhere appropriate

## TODO: Instructions for checkpoints downloads after it is uploaded somewhere appropriate

## Running with docker
```
docker build -t auto_inference -f dockerfile.gpu .
docker run --gpus=all -it -v <directory to inference repo>/inference/:/inference -v <directory to waymo dataset>/waymo:/waymo --rm auto_inference
cd /inference/automotive/3d-object-detection
python main.py --dataset waymo --dataset-path /waymo/kitti_format/ --lidar-path <checkpoint_path>/pp_ep36.pth --segmentor-path <checkpoint_path>/best_deeplabv3plus_resnet50_waymo_os16.pth --mlperf_conf /inference/mlperf.conf
128 changes: 128 additions & 0 deletions automotive/3d-object-detection/accuracy_waymo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""
Tool to calculate accuracy for loadgen accuracy output found in mlperf_log_accuracy.json
We assume that loadgen's query index is in the same order as
the images in coco's annotations/instances_val2017.json.
"""

from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import argparse
import json
import os

import numpy as np
from waymo import Waymo
from tools.evaluate import do_eval
# pylint: disable=missing-docstring
CLASSES = Waymo.CLASSES
LABEL2CLASSES = {v: k for k, v in CLASSES.items()}


def get_args():
"""Parse commandline."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--mlperf-accuracy-file",
required=True,
help="path to mlperf_log_accuracy.json")
parser.add_argument(
"--waymo-dir",
required=True,
help="waymo dataset directory")
parser.add_argument(
"--verbose",
action="store_true",
help="verbose messages")
parser.add_argument(
"--output-file",
default="openimages-results.json",
help="path to output file")
parser.add_argument(
"--use-inv-map",
action="store_true",
help="use inverse label map")
args = parser.parse_args()
return args


def main():
args = get_args()

with open(args.mlperf_accuracy_file, "r") as f:
results = json.load(f)

detections = {}
image_ids = set()
seen = set()
no_results = 0

val_dataset = Waymo(
data_root=args.waymo_dir,
split='val',
painted=True,
cam_sync=False)

for j in results:
idx = j['qsl_idx']
# de-dupe in case loadgen sends the same image multiple times
if idx in seen:
continue
seen.add(idx)

# reconstruct from mlperf accuracy log
# what is written by the benchmark is an array of float32's:
# id, box[0], box[1], box[2], box[3], score, detection_class
# note that id is a index into instances_val2017.json, not the actual
# image_id
data = np.frombuffer(bytes.fromhex(j['data']), np.float32)

for i in range(0, len(data), 14):
dimension = [float(x) for x in data[i:i + 3]]
location = [float(x) for x in data[i + 3:i + 6]]
rotation_y = float(data[i + 6])
bbox = [float(x) for x in data[i + 7:i + 11]]
label = int(data[i + 11])
score = float(data[i + 12])
image_idx = int(data[i + 13])
if image_idx not in detections:
detections[image_idx] = {
'name': [],
'dimensions': [],
'location': [],
'rotation_y': [],
'bbox': [],
'score': []
}

detections[image_idx]['name'].append(LABEL2CLASSES[label])
detections[image_idx]['dimensions'].append(dimension)
detections[image_idx]['location'].append(location)
detections[image_idx]['rotation_y'].append(rotation_y)
detections[image_idx]['bbox'].append(bbox)
detections[image_idx]['score'].append(score)
image_ids.add(image_idx)

with open(args.output_file, "w") as fp:
json.dump(detections, fp, sort_keys=True, indent=4)
format_results = {}
for key in detections.keys():
format_results[key] = {k: np.array(v)
for k, v in detections[key].items()}
map_stats = do_eval(
format_results,
val_dataset.data_infos,
CLASSES,
cam_sync=False)

print(map_stats)
if args.verbose:
print("found {} results".format(len(results)))
print("found {} images".format(len(image_ids)))
print("found {} images with no results".format(no_results))
print("ignored {} dupes".format(len(results) - len(seen)))


if __name__ == "__main__":
main()
21 changes: 21 additions & 0 deletions automotive/3d-object-detection/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
abstract backend class
"""


class Backend:
def __init__(self):
self.inputs = []
self.outputs = []

def version(self):
raise NotImplementedError("Backend:version")

def name(self):
raise NotImplementedError("Backend:name")

def load(self, model_path, inputs=None, outputs=None):
raise NotImplementedError("Backend:load")

def predict(self, feed):
raise NotImplementedError("Backend:predict")
24 changes: 24 additions & 0 deletions automotive/3d-object-detection/backend_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
import backend


class BackendDebug(backend.Backend):
def __init__(self, image_size=[3, 1024, 1024], **kwargs):
super(BackendDebug, self).__init__()
self.image_size = image_size

def version(self):
return torch.__version__

def name(self):
return "debug-SUT"

def image_format(self):
return "NCHW"

def load(self):
return self

def predict(self, prompts):
images = []
return images
155 changes: 155 additions & 0 deletions automotive/3d-object-detection/backend_deploy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from typing import Optional, List, Union
import os
import torch
import logging
import backend
from collections import namedtuple
from model.painter import Painter
from model.pointpillars import PointPillars
import numpy as np
from tools.process import keep_bbox_from_image_range
from waymo import Waymo


logging.basicConfig(level=logging.INFO)
log = logging.getLogger("backend-pytorch")


def change_calib_device(calib, cuda):
result = {}
if cuda:
device = 'cuda'
else:
device = 'cpu'
result['R0_rect'] = calib['R0_rect'].to(device=device, dtype=torch.float)
for i in range(5):
result['P' + str(i)] = calib['P' + str(i)
].to(device=device, dtype=torch.float)
result['Tr_velo_to_cam_' +
str(i)] = calib['Tr_velo_to_cam_' +
str(i)].to(device=device, dtype=torch.float)
return result


class BackendDeploy(backend.Backend):
def __init__(
self,
segmentor_path,
lidar_detector_path,
data_path
):
super(BackendDeploy, self).__init__()
self.segmentor_path = segmentor_path
self.lidar_detector_path = lidar_detector_path
# self.segmentation_classes = 18
self.detection_classes = 3
self.data_root = data_path
CLASSES = Waymo.CLASSES
self.LABEL2CLASSES = {v: k for k, v in CLASSES.items()}

def version(self):
return torch.__version__

def name(self):
return "python-SUT"

def load(self):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
PaintArgs = namedtuple(
'PaintArgs', [
'training_path', 'model_path', 'cam_sync'])
painting_args = PaintArgs(
os.path.join(
self.data_root,
'training'),
self.segmentor_path,
False)
self.painter = Painter(painting_args)
self.segmentor = self.painter.model
model = PointPillars(
nclasses=self.detection_classes,
painted=True).to(
device=device)
model.eval()
checkpoint = torch.load(self.lidar_detector_path)
model.load_state_dict(checkpoint["model_state_dict"])
self.lidar_detector = model

return self

def predict(self, inputs):
# TODO: implement predict
dimensions, locations, rotation_y, box2d, class_labels, class_scores, ids = [
], [], [], [], [], [], []
with torch.inference_mode():
device = torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu")
format_results = {}
model_input = inputs[0]
batched_pts = model_input['pts']
scores_from_cam = []
for i in range(len(model_input['images'])):
segmentation_score = self.segmentor(
model_input['images'][i].to(device))[0]
scores_from_cam.append(
self.painter.get_score(segmentation_score).cpu())
points = self.painter.augment_lidar_class_scores_both(
scores_from_cam, batched_pts, model_input['calib_info'])
batch_results = self.lidar_detector(
batched_pts=[points.to(device=device)], mode='val')
for j, result in enumerate(batch_results):
format_result = {
'class': [],
'truncated': [],
'occluded': [],
'alpha': [],
'bbox': [],
'dimensions': [],
'location': [],
'rotation_y': [],
'score': [],
'idx': -1
}

calib_info = model_input['calib_info']
image_info = model_input['image_info']
idx = model_input['image_info']['image_idx']

calib_info = change_calib_device(calib_info, False)
result_filter = keep_bbox_from_image_range(
result, calib_info, 5, image_info, False)

lidar_bboxes = result_filter['lidar_bboxes']
labels, scores = result_filter['labels'], result_filter['scores']
bboxes2d, camera_bboxes = result_filter['bboxes2d'], result_filter['camera_bboxes']
for lidar_bbox, label, score, bbox2d, camera_bbox in \
zip(lidar_bboxes, labels, scores, bboxes2d, camera_bboxes):
format_result['class'].append(label.item())
format_result['truncated'].append(0.0)
format_result['occluded'].append(0)
alpha = camera_bbox[6] - \
np.arctan2(camera_bbox[0], camera_bbox[2])
format_result['alpha'].append(alpha.item())
format_result['bbox'].append(bbox2d.tolist())
format_result['dimensions'].append(camera_bbox[3:6])
format_result['location'].append(camera_bbox[:3])
format_result['rotation_y'].append(camera_bbox[6].item())
format_result['score'].append(score.item())
format_results['idx'] = idx

# write_label(format_result, os.path.join(saved_submit_path, f'{idx:06d}.txt'))

if len(format_result['dimensions']) > 0:
format_result['dimensions'] = torch.stack(
format_result['dimensions'])
format_result['location'] = torch.stack(
format_result['location'])
dimensions.append(format_result['dimensions'])
locations.append(format_result['location'])
rotation_y.append(format_result['rotation_y'])
class_labels.append(format_result['class'])
class_scores.append(format_result['score'])
box2d.append(format_result['bbox'])
ids.append(format_results['idx'])
# return Boxes, Classes, Scores # Change to desired output
return dimensions, locations, rotation_y, box2d, class_labels, class_scores, ids
Loading

0 comments on commit 2fdb814

Please sign in to comment.