diff --git a/automotive/3d-object-detection/README.md b/automotive/3d-object-detection/README.md new file mode 100644 index 000000000..f4848647f --- /dev/null +++ b/automotive/3d-object-detection/README.md @@ -0,0 +1,14 @@ +## Reference implementation fo automotive 3D detection benchmark + +## TODO: Instructions for dataset download after it is uploaded somewhere appropriate + +## TODO: Instructions for checkpoints downloads after it is uploaded somewhere appropriate + +## Running with docker +``` +docker build -t auto_inference -f dockerfile.gpu . + +docker run --gpus=all -it -v /inference/:/inference -v /waymo:/waymo --rm auto_inference + +cd /inference/automotive/3d-object-detection +python main.py --dataset waymo --dataset-path /waymo/kitti_format/ --lidar-path /pp_ep36.pth --segmentor-path /best_deeplabv3plus_resnet50_waymo_os16.pth --mlperf_conf /inference/mlperf.conf diff --git a/automotive/3d-object-detection/accuracy_waymo.py b/automotive/3d-object-detection/accuracy_waymo.py new file mode 100644 index 000000000..c8b5cb72c --- /dev/null +++ b/automotive/3d-object-detection/accuracy_waymo.py @@ -0,0 +1,128 @@ +""" +Tool to calculate accuracy for loadgen accuracy output found in mlperf_log_accuracy.json +We assume that loadgen's query index is in the same order as +the images in coco's annotations/instances_val2017.json. +""" + +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import argparse +import json +import os + +import numpy as np +from waymo import Waymo +from tools.evaluate import do_eval +# pylint: disable=missing-docstring +CLASSES = Waymo.CLASSES +LABEL2CLASSES = {v: k for k, v in CLASSES.items()} + + +def get_args(): + """Parse commandline.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--mlperf-accuracy-file", + required=True, + help="path to mlperf_log_accuracy.json") + parser.add_argument( + "--waymo-dir", + required=True, + help="waymo dataset directory") + parser.add_argument( + "--verbose", + action="store_true", + help="verbose messages") + parser.add_argument( + "--output-file", + default="openimages-results.json", + help="path to output file") + parser.add_argument( + "--use-inv-map", + action="store_true", + help="use inverse label map") + args = parser.parse_args() + return args + + +def main(): + args = get_args() + + with open(args.mlperf_accuracy_file, "r") as f: + results = json.load(f) + + detections = {} + image_ids = set() + seen = set() + no_results = 0 + + val_dataset = Waymo( + data_root=args.waymo_dir, + split='val', + painted=True, + cam_sync=False) + + for j in results: + idx = j['qsl_idx'] + # de-dupe in case loadgen sends the same image multiple times + if idx in seen: + continue + seen.add(idx) + + # reconstruct from mlperf accuracy log + # what is written by the benchmark is an array of float32's: + # id, box[0], box[1], box[2], box[3], score, detection_class + # note that id is a index into instances_val2017.json, not the actual + # image_id + data = np.frombuffer(bytes.fromhex(j['data']), np.float32) + + for i in range(0, len(data), 14): + dimension = [float(x) for x in data[i:i + 3]] + location = [float(x) for x in data[i + 3:i + 6]] + rotation_y = float(data[i + 6]) + bbox = [float(x) for x in data[i + 7:i + 11]] + label = int(data[i + 11]) + score = float(data[i + 12]) + image_idx = int(data[i + 13]) + if image_idx not in detections: + detections[image_idx] = { + 'name': [], + 'dimensions': [], + 'location': [], + 'rotation_y': [], + 'bbox': [], + 'score': [] + } + + detections[image_idx]['name'].append(LABEL2CLASSES[label]) + detections[image_idx]['dimensions'].append(dimension) + detections[image_idx]['location'].append(location) + detections[image_idx]['rotation_y'].append(rotation_y) + detections[image_idx]['bbox'].append(bbox) + detections[image_idx]['score'].append(score) + image_ids.add(image_idx) + + with open(args.output_file, "w") as fp: + json.dump(detections, fp, sort_keys=True, indent=4) + format_results = {} + for key in detections.keys(): + format_results[key] = {k: np.array(v) + for k, v in detections[key].items()} + map_stats = do_eval( + format_results, + val_dataset.data_infos, + CLASSES, + cam_sync=False) + + print(map_stats) + if args.verbose: + print("found {} results".format(len(results))) + print("found {} images".format(len(image_ids))) + print("found {} images with no results".format(no_results)) + print("ignored {} dupes".format(len(results) - len(seen))) + + +if __name__ == "__main__": + main() diff --git a/automotive/3d-object-detection/backend.py b/automotive/3d-object-detection/backend.py new file mode 100644 index 000000000..58e4f9fa6 --- /dev/null +++ b/automotive/3d-object-detection/backend.py @@ -0,0 +1,21 @@ +""" +abstract backend class +""" + + +class Backend: + def __init__(self): + self.inputs = [] + self.outputs = [] + + def version(self): + raise NotImplementedError("Backend:version") + + def name(self): + raise NotImplementedError("Backend:name") + + def load(self, model_path, inputs=None, outputs=None): + raise NotImplementedError("Backend:load") + + def predict(self, feed): + raise NotImplementedError("Backend:predict") diff --git a/automotive/3d-object-detection/backend_debug.py b/automotive/3d-object-detection/backend_debug.py new file mode 100644 index 000000000..086a84a7a --- /dev/null +++ b/automotive/3d-object-detection/backend_debug.py @@ -0,0 +1,24 @@ +import torch +import backend + + +class BackendDebug(backend.Backend): + def __init__(self, image_size=[3, 1024, 1024], **kwargs): + super(BackendDebug, self).__init__() + self.image_size = image_size + + def version(self): + return torch.__version__ + + def name(self): + return "debug-SUT" + + def image_format(self): + return "NCHW" + + def load(self): + return self + + def predict(self, prompts): + images = [] + return images diff --git a/automotive/3d-object-detection/backend_deploy.py b/automotive/3d-object-detection/backend_deploy.py new file mode 100644 index 000000000..1a2f3dee4 --- /dev/null +++ b/automotive/3d-object-detection/backend_deploy.py @@ -0,0 +1,155 @@ +from typing import Optional, List, Union +import os +import torch +import logging +import backend +from collections import namedtuple +from model.painter import Painter +from model.pointpillars import PointPillars +import numpy as np +from tools.process import keep_bbox_from_image_range +from waymo import Waymo + + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("backend-pytorch") + + +def change_calib_device(calib, cuda): + result = {} + if cuda: + device = 'cuda' + else: + device = 'cpu' + result['R0_rect'] = calib['R0_rect'].to(device=device, dtype=torch.float) + for i in range(5): + result['P' + str(i)] = calib['P' + str(i) + ].to(device=device, dtype=torch.float) + result['Tr_velo_to_cam_' + + str(i)] = calib['Tr_velo_to_cam_' + + str(i)].to(device=device, dtype=torch.float) + return result + + +class BackendDeploy(backend.Backend): + def __init__( + self, + segmentor_path, + lidar_detector_path, + data_path + ): + super(BackendDeploy, self).__init__() + self.segmentor_path = segmentor_path + self.lidar_detector_path = lidar_detector_path + # self.segmentation_classes = 18 + self.detection_classes = 3 + self.data_root = data_path + CLASSES = Waymo.CLASSES + self.LABEL2CLASSES = {v: k for k, v in CLASSES.items()} + + def version(self): + return torch.__version__ + + def name(self): + return "python-SUT" + + def load(self): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + PaintArgs = namedtuple( + 'PaintArgs', [ + 'training_path', 'model_path', 'cam_sync']) + painting_args = PaintArgs( + os.path.join( + self.data_root, + 'training'), + self.segmentor_path, + False) + self.painter = Painter(painting_args) + self.segmentor = self.painter.model + model = PointPillars( + nclasses=self.detection_classes, + painted=True).to( + device=device) + model.eval() + checkpoint = torch.load(self.lidar_detector_path) + model.load_state_dict(checkpoint["model_state_dict"]) + self.lidar_detector = model + + return self + + def predict(self, inputs): + # TODO: implement predict + dimensions, locations, rotation_y, box2d, class_labels, class_scores, ids = [ + ], [], [], [], [], [], [] + with torch.inference_mode(): + device = torch.device( + "cuda:0" if torch.cuda.is_available() else "cpu") + format_results = {} + model_input = inputs[0] + batched_pts = model_input['pts'] + scores_from_cam = [] + for i in range(len(model_input['images'])): + segmentation_score = self.segmentor( + model_input['images'][i].to(device))[0] + scores_from_cam.append( + self.painter.get_score(segmentation_score).cpu()) + points = self.painter.augment_lidar_class_scores_both( + scores_from_cam, batched_pts, model_input['calib_info']) + batch_results = self.lidar_detector( + batched_pts=[points.to(device=device)], mode='val') + for j, result in enumerate(batch_results): + format_result = { + 'class': [], + 'truncated': [], + 'occluded': [], + 'alpha': [], + 'bbox': [], + 'dimensions': [], + 'location': [], + 'rotation_y': [], + 'score': [], + 'idx': -1 + } + + calib_info = model_input['calib_info'] + image_info = model_input['image_info'] + idx = model_input['image_info']['image_idx'] + + calib_info = change_calib_device(calib_info, False) + result_filter = keep_bbox_from_image_range( + result, calib_info, 5, image_info, False) + + lidar_bboxes = result_filter['lidar_bboxes'] + labels, scores = result_filter['labels'], result_filter['scores'] + bboxes2d, camera_bboxes = result_filter['bboxes2d'], result_filter['camera_bboxes'] + for lidar_bbox, label, score, bbox2d, camera_bbox in \ + zip(lidar_bboxes, labels, scores, bboxes2d, camera_bboxes): + format_result['class'].append(label.item()) + format_result['truncated'].append(0.0) + format_result['occluded'].append(0) + alpha = camera_bbox[6] - \ + np.arctan2(camera_bbox[0], camera_bbox[2]) + format_result['alpha'].append(alpha.item()) + format_result['bbox'].append(bbox2d.tolist()) + format_result['dimensions'].append(camera_bbox[3:6]) + format_result['location'].append(camera_bbox[:3]) + format_result['rotation_y'].append(camera_bbox[6].item()) + format_result['score'].append(score.item()) + format_results['idx'] = idx + + # write_label(format_result, os.path.join(saved_submit_path, f'{idx:06d}.txt')) + + if len(format_result['dimensions']) > 0: + format_result['dimensions'] = torch.stack( + format_result['dimensions']) + format_result['location'] = torch.stack( + format_result['location']) + dimensions.append(format_result['dimensions']) + locations.append(format_result['location']) + rotation_y.append(format_result['rotation_y']) + class_labels.append(format_result['class']) + class_scores.append(format_result['score']) + box2d.append(format_result['bbox']) + ids.append(format_results['idx']) + # return Boxes, Classes, Scores # Change to desired output + return dimensions, locations, rotation_y, box2d, class_labels, class_scores, ids diff --git a/automotive/3d-object-detection/dataset.py b/automotive/3d-object-detection/dataset.py new file mode 100644 index 000000000..04748dd94 --- /dev/null +++ b/automotive/3d-object-detection/dataset.py @@ -0,0 +1,76 @@ +""" +dataset related classes and methods +""" + +# pylint: disable=unused-argument,missing-docstring + +import logging +import sys +import time + +import numpy as np +import torch + + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("dataset") + + +class Dataset: + def __init__(self): + self.items_inmemory = {} + + def preprocess(self, use_cache=True): + raise NotImplementedError("Dataset:preprocess") + + def get_item_count(self): + raise NotImplementedError("Dataset:get_item_count") + + def get_list(self): + raise NotImplementedError("Dataset:get_list") + + def load_query_samples(self, sample_list): + raise NotImplementedError("Dataset:load_query_samples") + + def unload_query_samples(self, sample_list): + raise NotImplementedError("Dataset:unload_query_samples") + + def get_samples(self, id_list): + raise NotImplementedError("Dataset:get_samples") + + def get_item(self, id): + raise NotImplementedError("Dataset:get_item") + + +def preprocess(list_data): + batched_pts_list, batched_gt_bboxes_list = [], [] + batched_labels_list, batched_names_list = [], [] + batched_difficulty_list = [] + batched_img_list, batched_calib_list = [], [] + batched_images = [] + for data_dict in list_data: + pts, gt_bboxes_3d = data_dict['pts'], data_dict['gt_bboxes_3d'] + gt_labels, gt_names = data_dict['gt_labels'], data_dict['gt_names'] + difficulty = data_dict['difficulty'] + image_info, calib_info = data_dict['image_info'], data_dict['calib_info'] + + batched_pts_list.append(torch.from_numpy(pts)) + batched_gt_bboxes_list.append(torch.from_numpy(gt_bboxes_3d)) + batched_labels_list.append(torch.from_numpy(gt_labels)) + batched_names_list.append(gt_names) # List(str) + batched_difficulty_list.append(torch.from_numpy(difficulty)) + batched_img_list.append(image_info) + batched_calib_list.append(calib_info) + batched_images.append(data_dict['images']) + rt_data_dict = dict( + batched_pts=batched_pts_list, + batched_gt_bboxes=batched_gt_bboxes_list, + batched_labels=batched_labels_list, + batched_names=batched_names_list, + batched_difficulty=batched_difficulty_list, + batched_img_info=batched_img_list, + batched_calib_info=batched_calib_list, + batched_images=batched_images + ) + + return rt_data_dict diff --git a/automotive/3d-object-detection/dockerfile cpu b/automotive/3d-object-detection/dockerfile cpu new file mode 100644 index 000000000..875cdf76c --- /dev/null +++ b/automotive/3d-object-detection/dockerfile cpu @@ -0,0 +1,21 @@ +ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:23.08-py3 +FROM ${FROM_IMAGE_NAME} + +ENV DEBIAN_FRONTEND=noninteractive + +# apt dependencies +RUN apt-get update +RUN apt-get install -y ffmpeg libsm6 libxext6 + +# install LDM +COPY . /diffusion +RUN cd /diffusion && \ + pip install --no-cache-dir -r requirements.txt + +# install loadgen +RUN cd /tmp && \ + git clone --recursive https://github.com/mlcommons/inference && \ + cd inference/loadgen && \ + pip install pybind11 && \ + CFLAGS="-std=c++14" python setup.py install && \ + rm -rf mlperf \ No newline at end of file diff --git a/automotive/3d-object-detection/dockerfile.gpu b/automotive/3d-object-detection/dockerfile.gpu new file mode 100644 index 000000000..02acca7a3 --- /dev/null +++ b/automotive/3d-object-detection/dockerfile.gpu @@ -0,0 +1,31 @@ +ARG FROM_IMAGE_NAME=pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel +FROM ${FROM_IMAGE_NAME} + +ENV DEBIAN_FRONTEND=noninteractive + +# apt dependencies +RUN apt-get update +RUN apt-get install -y ffmpeg libsm6 libxext6 git + +# install LDM +COPY . /diffusion +RUN cd /diffusion && \ + pip install --no-cache-dir -r requirements.txt + +# install loadgen +RUN cd /tmp && \ + git clone --recursive https://github.com/mlcommons/inference && \ + cd inference/loadgen && \ + pip install pybind11 && \ + CFLAGS="-std=c++14" python setup.py install && \ + rm -rf mlperf + +RUN pip install tqdm +RUN pip install numba +RUN pip install opencv-python +RUN pip install open3d +RUN pip install tensorboard +RUN pip install scikit-image +RUN pip install ninja +RUN pip install visdom +RUN pip install shapely \ No newline at end of file diff --git a/automotive/3d-object-detection/main.py b/automotive/3d-object-detection/main.py new file mode 100644 index 000000000..04269e428 --- /dev/null +++ b/automotive/3d-object-detection/main.py @@ -0,0 +1,469 @@ +""" +mlperf inference benchmarking tool +""" + +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import argparse +import array +import collections +import json +import logging +import os +import sys +import threading +import time +from queue import Queue + +import mlperf_loadgen as lg +import numpy as np +import torch + +import dataset +import waymo + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("main") + +NANO_SEC = 1e9 +MILLI_SEC = 1000 + +SUPPORTED_DATASETS = { + "waymo": ( + waymo.Waymo, + dataset.preprocess, + waymo.PostProcessWaymo(), + {} # "image_size": [3, 1024, 1024]}, + ) +} + + +SUPPORTED_PROFILES = { + "defaults": { + "dataset": "waymo", + "backend": "pytorch", + "model-name": "pointpainting", + }, +} + +SCENARIO_MAP = { + "SingleStream": lg.TestScenario.SingleStream, + "MultiStream": lg.TestScenario.MultiStream, + "Server": lg.TestScenario.Server, + "Offline": lg.TestScenario.Offline, +} + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset", + choices=SUPPORTED_DATASETS.keys(), + help="dataset") + parser.add_argument( + "--dataset-path", + required=True, + help="path to the dataset") + parser.add_argument( + "--profile", choices=SUPPORTED_PROFILES.keys(), help="standard profiles" + ) + parser.add_argument( + "--scenario", + default="SingleStream", + help="mlperf benchmark scenario, one of " + + str(list(SCENARIO_MAP.keys())), + ) + parser.add_argument( + "--max-batchsize", + type=int, + default=1, + help="max batch size in a single inference", + ) + parser.add_argument("--threads", default=1, type=int, help="threads") + parser.add_argument( + "--accuracy", + action="store_true", + help="enable accuracy pass") + parser.add_argument( + "--find-peak-performance", + action="store_true", + help="enable finding peak performance pass", + ) + parser.add_argument("--backend", help="Name of the backend") + parser.add_argument("--model-name", help="Name of the model") + parser.add_argument("--output", default="output", help="test results") + parser.add_argument("--qps", type=int, help="target qps") + parser.add_argument("--lidar-path", help="Path to model weights") + parser.add_argument("--segmentor-path", help="Path to model weights") + + parser.add_argument( + "--dtype", + default="fp32", + choices=["fp32", "fp16", "bf16"], + help="dtype of the model", + ) + parser.add_argument( + "--device", + default="cuda", + choices=["cuda", "cpu"], + help="device to run the benchmark", + ) + + # file to use mlperf rules compliant parameters + parser.add_argument( + "--mlperf_conf", default="mlperf.conf", help="mlperf rules config" + ) + # file for user LoadGen settings such as target QPS + parser.add_argument( + "--user_conf", + default="user.conf", + help="user config for user LoadGen settings such as target QPS", + ) + # file for LoadGen audit settings + parser.add_argument( + "--audit_conf", default="audit.config", help="config for LoadGen audit settings" + ) + + # below will override mlperf rules compliant settings - don't use for + # official submission + parser.add_argument("--time", type=int, help="time to scan in seconds") + parser.add_argument("--count", type=int, help="dataset items to use") + parser.add_argument("--debug", action="store_true", help="debug") + parser.add_argument( + "--performance-sample-count", type=int, help="performance sample count", default=5000 + ) + parser.add_argument( + "--max-latency", type=float, help="mlperf max latency in pct tile" + ) + parser.add_argument( + "--samples-per-query", + default=8, + type=int, + help="mlperf multi-stream samples per query", + ) + args = parser.parse_args() + + # don't use defaults in argparser. Instead we default to a dict, override that with a profile + # and take this as default unless command line give + defaults = SUPPORTED_PROFILES["defaults"] + + if args.profile: + profile = SUPPORTED_PROFILES[args.profile] + defaults.update(profile) + for k, v in defaults.items(): + kc = k.replace("-", "_") + if getattr(args, kc) is None: + setattr(args, kc, v) + + if args.scenario not in SCENARIO_MAP: + parser.error("valid scanarios:" + str(list(SCENARIO_MAP.keys()))) + return args + + +def get_backend(backend, **kwargs): + if backend == "pytorch": + from backend_deploy import BackendDeploy + + backend = BackendDeploy(**kwargs) + + elif backend == "debug": + from backend_debug import BackendDebug + + backend = BackendDebug() + else: + raise ValueError("unknown backend: " + backend) + return backend + + +class Item: + """An item that we queue for processing by the thread pool.""" + + def __init__(self, query_id, content_id, inputs, img=None): + self.query_id = query_id + self.content_id = content_id + self.img = img + self.inputs = inputs + self.start = time.time() + + +class RunnerBase: + def __init__(self, model, ds, threads, post_proc=None, max_batchsize=128): + self.take_accuracy = False + self.ds = ds + self.model = model + self.post_process = post_proc + self.threads = threads + self.take_accuracy = False + self.max_batchsize = max_batchsize + self.result_timing = [] + + def handle_tasks(self, tasks_queue): + pass + + def start_run(self, result_dict, take_accuracy): + self.result_dict = result_dict + self.result_timing = [] + self.take_accuracy = take_accuracy + self.post_process.start() + + def run_one_item(self, qitem: Item): + # run the prediction + processed_results = [] + try: + results = self.model.predict(qitem.inputs) + processed_results = self.post_process( + results, qitem.content_id, qitem.inputs, self.result_dict) + + if self.take_accuracy: + self.post_process.add_results(processed_results) + self.result_timing.append(time.time() - qitem.start) + except Exception as ex: # pylint: disable=broad-except + src = [self.ds.get_item_loc(i) for i in qitem.content_id] + log.error("thread: failed on contentid=%s, %s", src, ex) + # since post_process will not run, fake empty responses + processed_results = [[]] * len(qitem.query_id) + finally: + response_array_refs = [] + response = [] + for idx, query_id in enumerate(qitem.query_id): + response_array = array.array("B", np.array( + processed_results[idx], np.float32).tobytes()) + + response_array_refs.append(response_array) + bi = response_array.buffer_info() + response.append(lg.QuerySampleResponse(query_id, bi[0], bi[1])) + lg.QuerySamplesComplete(response) + + def enqueue(self, query_samples): + idx = [q.index for q in query_samples] + query_id = [q.id for q in query_samples] + if len(query_samples) < self.max_batchsize: + data, label = self.ds.get_samples(idx) + self.run_one_item(Item(query_id, idx, data, label)) + else: + bs = self.max_batchsize + for i in range(0, len(idx), bs): + data, label = self.ds.get_samples(idx[i: i + bs]) + self.run_one_item( + Item(query_id[i: i + bs], idx[i: i + bs], data, label) + ) + + def finish(self): + pass + + +class QueueRunner(RunnerBase): + def __init__(self, model, ds, threads, post_proc=None, max_batchsize=128): + super().__init__(model, ds, threads, post_proc, max_batchsize) + self.tasks = Queue(maxsize=threads * 4) + self.workers = [] + self.result_dict = {} + + for _ in range(self.threads): + worker = threading.Thread( + target=self.handle_tasks, args=( + self.tasks,)) + worker.daemon = True + self.workers.append(worker) + worker.start() + + def handle_tasks(self, tasks_queue): + """Worker thread.""" + while True: + qitem = tasks_queue.get() + if qitem is None: + # None in the queue indicates the parent want us to exit + tasks_queue.task_done() + break + self.run_one_item(qitem) + tasks_queue.task_done() + + def enqueue(self, query_samples): + idx = [q.index for q in query_samples] + query_id = [q.id for q in query_samples] + if len(query_samples) < self.max_batchsize: + data, label = self.ds.get_samples(idx) + self.tasks.put(Item(query_id, idx, data, label)) + else: + bs = self.max_batchsize + for i in range(0, len(idx), bs): + ie = i + bs + data, label = self.ds.get_samples(idx[i:ie]) + self.tasks.put(Item(query_id[i:ie], idx[i:ie], data, label)) + + def finish(self): + # exit all threads + for _ in self.workers: + self.tasks.put(None) + for worker in self.workers: + worker.join() + + +def main(): + args = get_args() + + log.info(args) + + # find backend + backend = get_backend( + # TODO: pass model, inference and backend arguments + args.backend, + lidar_detector_path=args.lidar_path, + segmentor_path=args.segmentor_path, + data_path=args.dataset_path + + ) + if args.dtype == "fp16": + dtype = torch.float16 + elif args.dtype == "bf16": + dtype = torch.bfloat16 + else: + dtype = torch.float32 + + # --count applies to accuracy mode only and can be used to limit the number of images + # for testing. + count_override = False + count = args.count + if count: + count_override = True + + # load model to backend + model = backend.load() + + # dataset to use + dataset_class, pre_proc, post_proc, kwargs = SUPPORTED_DATASETS[args.dataset] + ds = dataset_class( + data_root=args.dataset_path, + split='val', + painted=True, + cam_sync=False) + + final_results = { + "runtime": model.name(), + "version": model.version(), + "time": int(time.time()), + "args": vars(args), + "cmdline": str(args), + } + + mlperf_conf = os.path.abspath(args.mlperf_conf) + if not os.path.exists(mlperf_conf): + log.error("{} not found".format(mlperf_conf)) + sys.exit(1) + + user_conf = os.path.abspath(args.user_conf) + if not os.path.exists(user_conf): + log.error("{} not found".format(user_conf)) + sys.exit(1) + + audit_config = os.path.abspath(args.audit_conf) + + if args.output: + output_dir = os.path.abspath(args.output) + os.makedirs(output_dir, exist_ok=True) + os.chdir(output_dir) + + # + # make one pass over the dataset to validate accuracy + # + count = ds.get_item_count() + + # warmup + # TODO: Load warmup samples, the following code is a general + # way of doing this, but might need some fixing + ds.load_query_samples([0]) + for i in range(5): + input = ds.get_samples([0]) + _ = backend.predict(input[0]) + + scenario = SCENARIO_MAP[args.scenario] + runner_map = { + lg.TestScenario.SingleStream: RunnerBase, + lg.TestScenario.MultiStream: QueueRunner, + lg.TestScenario.Server: QueueRunner, + lg.TestScenario.Offline: QueueRunner, + } + runner = runner_map[scenario]( + model, ds, args.threads, post_proc=post_proc, max_batchsize=args.max_batchsize + ) + + def issue_queries(query_samples): + runner.enqueue(query_samples) + + def flush_queries(): + pass + + log_output_settings = lg.LogOutputSettings() + log_output_settings.outdir = output_dir + log_output_settings.copy_summary_to_stdout = False + log_settings = lg.LogSettings() + log_settings.enable_trace = args.debug + log_settings.log_output = log_output_settings + + settings = lg.TestSettings() + settings.FromConfig(mlperf_conf, args.model_name, args.scenario) + settings.FromConfig(user_conf, args.model_name, args.scenario) + settings.scenario = scenario + settings.mode = lg.TestMode.PerformanceOnly + if args.accuracy: + settings.mode = lg.TestMode.AccuracyOnly + if args.find_peak_performance: + settings.mode = lg.TestMode.FindPeakPerformance + + if args.time: + # override the time we want to run + settings.min_duration_ms = args.time * MILLI_SEC + settings.max_duration_ms = args.time * MILLI_SEC + + if args.qps: + qps = float(args.qps) + settings.server_target_qps = qps + settings.offline_expected_qps = qps + + if count_override: + settings.min_query_count = count + settings.max_query_count = count + + if args.samples_per_query: + settings.multi_stream_samples_per_query = args.samples_per_query + if args.max_latency: + settings.server_target_latency_ns = int(args.max_latency * NANO_SEC) + settings.multi_stream_expected_latency_ns = int( + args.max_latency * NANO_SEC) + + performance_sample_count = ( + args.performance_sample_count + if args.performance_sample_count + else min(count, 500) + ) + sut = lg.ConstructSUT(issue_queries, flush_queries) + qsl = lg.ConstructQSL( + count, performance_sample_count, ds.load_query_samples, ds.unload_query_samples + ) + + log.info("starting {}".format(scenario)) + result_dict = {"scenario": str(scenario)} + runner.start_run(result_dict, args.accuracy) + + lg.StartTestWithLogSettings(sut, qsl, settings, log_settings, audit_config) + + if args.accuracy: + post_proc.finalize(result_dict, ds) + final_results["accuracy_results"] = result_dict + + runner.finish() + lg.DestroyQSL(qsl) + lg.DestroySUT(sut) + + # + # write final results + # + if args.output: + with open("results.json", "w") as f: + json.dump(final_results, f, sort_keys=True, indent=4) + + +if __name__ == "__main__": + main() diff --git a/automotive/3d-object-detection/model/__init__.py b/automotive/3d-object-detection/model/__init__.py new file mode 100644 index 000000000..340ddce01 --- /dev/null +++ b/automotive/3d-object-detection/model/__init__.py @@ -0,0 +1,2 @@ +from .anchors import Anchors, anchors2bboxes, bboxes2deltas +from .pointpillars import PointPillars diff --git a/automotive/3d-object-detection/model/anchors.py b/automotive/3d-object-detection/model/anchors.py new file mode 100644 index 000000000..7e104ece3 --- /dev/null +++ b/automotive/3d-object-detection/model/anchors.py @@ -0,0 +1,290 @@ +import torch +import math +from tools.process import limit_period, iou2d_nearest + + +class Anchors(): + def __init__(self, ranges, sizes, rotations): + assert len(ranges) == len(sizes) + self.ranges = ranges + self.sizes = sizes + self.rotations = rotations + + def get_anchors(self, feature_map_size, anchor_range, + anchor_size, rotations): + ''' + feature_map_size: (y_l, x_l) + anchor_range: [x1, y1, z1, x2, y2, z2] + anchor_size: [w, l, h] + rotations: [0, 1.57] + return: shape=(y_l, x_l, 2, 7) + ''' + device = feature_map_size.device + x_centers = torch.linspace( + anchor_range[0], + anchor_range[3], + feature_map_size[1] + 1, + device=device) + y_centers = torch.linspace( + anchor_range[1], + anchor_range[4], + feature_map_size[0] + 1, + device=device) + z_centers = torch.linspace( + anchor_range[2], + anchor_range[5], + 1 + 1, + device=device) + + x_shift = (x_centers[1] - x_centers[0]) / 2 + y_shift = (y_centers[1] - y_centers[0]) / 2 + z_shift = (z_centers[1] - z_centers[0]) / 2 + x_centers = x_centers[:feature_map_size[1]] + \ + x_shift # (feature_map_size[1], ) + y_centers = y_centers[:feature_map_size[0]] + \ + y_shift # (feature_map_size[0], ) + z_centers = z_centers[:1] + z_shift # (1, ) + + # [feature_map_size[1], feature_map_size[0], 1, 2] * 4 + meshgrids = torch.meshgrid(x_centers, y_centers, z_centers, rotations) + meshgrids = list(meshgrids) + for i in range(len(meshgrids)): + # [feature_map_size[1], feature_map_size[0], 1, 2, 1] + meshgrids[i] = meshgrids[i][..., None] + + anchor_size = anchor_size[None, None, None, None, :] + repeat_shape = [ + feature_map_size[1], + feature_map_size[0], + 1, + len(rotations), + 1] + # [feature_map_size[1], feature_map_size[0], 1, 2, 3] + anchor_size = anchor_size.repeat(repeat_shape) + meshgrids.insert(3, anchor_size) + # [1, feature_map_size[0], feature_map_size[1], 2, 7] + anchors = torch.cat( + meshgrids, + dim=- + 1).permute( + 2, + 1, + 0, + 3, + 4).contiguous() + return anchors.squeeze(0) + + def get_multi_anchors(self, feature_map_size): + ''' + feature_map_size: (y_l, x_l) + ranges: [[x1, y1, z1, x2, y2, z2], [x1, y1, z1, x2, y2, z2], [x1, y1, z1, x2, y2, z2]] + sizes: [[w, l, h], [w, l, h], [w, l, h]] + rotations: [0, 1.57] + return: shape=(y_l, x_l, 3, 2, 7) + ''' + device = feature_map_size.device + ranges = torch.tensor(self.ranges, device=device) + sizes = torch.tensor(self.sizes, device=device) + rotations = torch.tensor(self.rotations, device=device) + multi_anchors = [] + for i in range(len(ranges)): + anchors = self.get_anchors(feature_map_size=feature_map_size, + anchor_range=ranges[i], + anchor_size=sizes[i], + rotations=rotations) + multi_anchors.append(anchors[:, :, None, :, :]) + multi_anchors = torch.cat(multi_anchors, dim=2) + + return multi_anchors + + +def anchors2bboxes(anchors, deltas): + ''' + anchors: (M, 7), (x, y, z, w, l, h, theta) + deltas: (M, 7) + return: (M, 7) + ''' + da = torch.sqrt(anchors[:, 3] ** 2 + anchors[:, 4] ** 2) + x = deltas[:, 0] * da + anchors[:, 0] + y = deltas[:, 1] * da + anchors[:, 1] + z = deltas[:, 2] * anchors[:, 5] + anchors[:, 2] + anchors[:, 5] / 2 + + w = anchors[:, 3] * torch.exp(deltas[:, 3]) + l = anchors[:, 4] * torch.exp(deltas[:, 4]) + h = anchors[:, 5] * torch.exp(deltas[:, 5]) + + z = z - h / 2 + + theta = anchors[:, 6] + deltas[:, 6] + + bboxes = torch.stack([x, y, z, w, l, h, theta], dim=1) + return bboxes + + +def bboxes2deltas(bboxes, anchors): + ''' + bboxes: (M, 7), (x, y, z, w, l, h, theta) + anchors: (M, 7) + return: (M, 7) + ''' + da = torch.sqrt(anchors[:, 3] ** 2 + anchors[:, 4] ** 2) + + dx = (bboxes[:, 0] - anchors[:, 0]) / da + dy = (bboxes[:, 1] - anchors[:, 1]) / da + + zb = bboxes[:, 2] + bboxes[:, 5] / 2 # bottom center + za = anchors[:, 2] + anchors[:, 5] / 2 # bottom center + dz = (zb - za) / anchors[:, 5] # bottom center + + dw = torch.log(bboxes[:, 3] / anchors[:, 3]) + dl = torch.log(bboxes[:, 4] / anchors[:, 4]) + dh = torch.log(bboxes[:, 5] / anchors[:, 5]) + dtheta = bboxes[:, 6] - anchors[:, 6] + + deltas = torch.stack([dx, dy, dz, dw, dl, dh, dtheta], dim=1) + return deltas + + +def anchor_target(batched_anchors, batched_gt_bboxes, + batched_gt_labels, assigners, nclasses): + ''' + batched_anchors: [(y_l, x_l, 3, 2, 7), (y_l, x_l, 3, 2, 7), ... ] + batched_gt_bboxes: [(n1, 7), (n2, 7), ...] + batched_gt_labels: [(n1, ), (n2, ), ...] + return: + dict = {batched_anchors_labels: (bs, n_anchors), + batched_labels_weights: (bs, n_anchors), + batched_anchors_reg: (bs, n_anchors, 7), + batched_reg_weights: (bs, n_anchors), + batched_anchors_dir: (bs, n_anchors), + batched_dir_weights: (bs, n_anchors)} + ''' + assert len(batched_anchors) == len( + batched_gt_bboxes) == len(batched_gt_labels) + batch_size = len(batched_anchors) + n_assigners = len(assigners) + batched_labels, batched_label_weights = [], [] + batched_bbox_reg, batched_bbox_reg_weights = [], [] + batched_dir_labels, batched_dir_labels_weights = [], [] + for i in range(batch_size): + anchors = batched_anchors[i] + gt_bboxes, gt_labels = batched_gt_bboxes[i], batched_gt_labels[i] + # what we want to get next ? + # 1. identify positive anchors and negative anchors -> cls + # 2. identify the regresstion values -> reg + # 3. indentify the direction -> dir_cls + multi_labels, multi_label_weights = [], [] + multi_bbox_reg, multi_bbox_reg_weights = [], [] + multi_dir_labels, multi_dir_labels_weights = [], [] + d1, d2, d3, d4, d5 = anchors.size() + for j in range(n_assigners): # multi anchors + assigner = assigners[j] + pos_iou_thr, neg_iou_thr, min_iou_thr = \ + assigner['pos_iou_thr'], assigner['neg_iou_thr'], assigner['min_iou_thr'] + cur_anchors = anchors[:, :, j, :, :].reshape(-1, 7) + overlaps = iou2d_nearest(gt_bboxes, cur_anchors) + if overlaps.shape[0] == 0: + max_overlaps = torch.zeros_like( + cur_anchors[:, 0], dtype=cur_anchors.dtype) + max_overlaps_idx = torch.zeros_like( + cur_anchors[:, 0], dtype=torch.long) + else: + max_overlaps, max_overlaps_idx = torch.max(overlaps, dim=0) + gt_max_overlaps, _ = torch.max(overlaps, dim=1) + + assigned_gt_inds = - \ + torch.ones_like(cur_anchors[:, 0], dtype=torch.long) + # a. negative anchors + assigned_gt_inds[max_overlaps < neg_iou_thr] = 0 + + # b. positive anchors + # rule 1 + assigned_gt_inds[max_overlaps >= + pos_iou_thr] = max_overlaps_idx[max_overlaps >= pos_iou_thr] + 1 + + # rule 2 + # support one bbox to multi anchors, only if the anchors are with the highest iou. + # rule2 may modify the labels generated by rule 1 + for i in range(len(gt_bboxes)): + if gt_max_overlaps[i] >= min_iou_thr: + assigned_gt_inds[overlaps[i] == gt_max_overlaps[i]] = i + 1 + + pos_flag = assigned_gt_inds > 0 + neg_flag = assigned_gt_inds == 0 + # 1. anchor labels + # -1 is not optimal, for some bboxes are with labels -1 + assigned_gt_labels = torch.zeros_like( + cur_anchors[:, 0], dtype=torch.long) + nclasses + assigned_gt_labels[pos_flag] = gt_labels[assigned_gt_inds[pos_flag] - 1].long() + assigned_gt_labels_weights = torch.zeros_like(cur_anchors[:, 0]) + assigned_gt_labels_weights[pos_flag] = 1 + assigned_gt_labels_weights[neg_flag] = 1 + + # 2. anchor regression + assigned_gt_reg_weights = torch.zeros_like(cur_anchors[:, 0]) + assigned_gt_reg_weights[pos_flag] = 1 + + assigned_gt_reg = torch.zeros_like(cur_anchors) + positive_anchors = cur_anchors[pos_flag] + corr_gt_bboxes = gt_bboxes[assigned_gt_inds[pos_flag] - 1] + assigned_gt_reg[pos_flag] = bboxes2deltas( + corr_gt_bboxes, positive_anchors) + + # 3. anchor direction + assigned_gt_dir_weights = torch.zeros_like(cur_anchors[:, 0]) + assigned_gt_dir_weights[pos_flag] = 1 + + assigned_gt_dir = torch.zeros_like( + cur_anchors[:, 0], dtype=torch.long) + dir_cls_targets = limit_period( + corr_gt_bboxes[:, 6].cpu(), 0, 2 * math.pi).to(corr_gt_bboxes) + dir_cls_targets = torch.floor(dir_cls_targets / math.pi).long() + assigned_gt_dir[pos_flag] = torch.clamp( + dir_cls_targets, min=0, max=1) + + multi_labels.append(assigned_gt_labels.reshape(d1, d2, 1, d4)) + multi_label_weights.append( + assigned_gt_labels_weights.reshape( + d1, d2, 1, d4)) + multi_bbox_reg.append(assigned_gt_reg.reshape(d1, d2, 1, d4, -1)) + multi_bbox_reg_weights.append( + assigned_gt_reg_weights.reshape( + d1, d2, 1, d4)) + multi_dir_labels.append(assigned_gt_dir.reshape(d1, d2, 1, d4)) + multi_dir_labels_weights.append( + assigned_gt_dir_weights.reshape( + d1, d2, 1, d4)) + + multi_labels = torch.cat(multi_labels, dim=-2).reshape(-1) + multi_label_weights = torch.cat( + multi_label_weights, dim=-2).reshape(-1) + multi_bbox_reg = torch.cat(multi_bbox_reg, dim=-3).reshape(-1, d5) + multi_bbox_reg_weights = torch.cat( + multi_bbox_reg_weights, dim=-2).reshape(-1) + multi_dir_labels = torch.cat(multi_dir_labels, dim=-2).reshape(-1) + multi_dir_labels_weights = torch.cat( + multi_dir_labels_weights, dim=-2).reshape(-1) + + batched_labels.append(multi_labels) + batched_label_weights.append(multi_label_weights) + batched_bbox_reg.append(multi_bbox_reg) + batched_bbox_reg_weights.append(multi_bbox_reg_weights) + batched_dir_labels.append(multi_dir_labels) + batched_dir_labels_weights.append(multi_dir_labels_weights) + + rt_dict = dict( + batched_labels=torch.stack( + batched_labels, 0), # (bs, y_l * x_l * 3 * 2) + batched_label_weights=torch.stack( + batched_label_weights, 0), # (bs, y_l * x_l * 3 * 2) + batched_bbox_reg=torch.stack( + batched_bbox_reg, 0), # (bs, y_l * x_l * 3 * 2, 7) + batched_bbox_reg_weights=torch.stack( + batched_bbox_reg_weights, 0), # (bs, y_l * x_l * 3 * 2) + batched_dir_labels=torch.stack( + batched_dir_labels, 0), # (bs, y_l * x_l * 3 * 2) + batched_dir_labels_weights=torch.stack( + batched_dir_labels_weights, 0) # (bs, y_l * x_l * 3 * 2) + ) + + return rt_dict diff --git a/automotive/3d-object-detection/model/painter.py b/automotive/3d-object-detection/model/painter.py new file mode 100644 index 000000000..f0680931f --- /dev/null +++ b/automotive/3d-object-detection/model/painter.py @@ -0,0 +1,267 @@ +import argparse +import model.segmentation as network +import os +import numpy as np +import torch +from torchvision import transforms +from PIL import Image +import copy +import sys +from tqdm import tqdm +sys.path.append('..') + + +def get_calib_from_file(calib_file): + """Read in a calibration file and parse into a dictionary.""" + data = {} + + with open(calib_file, 'r') as f: + lines = [line for line in f.readlines() if line.strip()] + for line in lines: + key, value = line.split(':', 1) + # The only non-float values in these files are dates, which + # we don't care about anyway + try: + if key == 'R0_rect': + data['R0'] = torch.tensor([float(x) + for x in value.split()]).reshape(3, 3) + else: + data[key] = torch.tensor([float(x) + for x in value.split()]).reshape(3, 4) + except ValueError: + pass + + return data + + +class Painter: + def __init__(self, args): + self.root_split_path = args.training_path + self.save_path = os.path.join(args.training_path, "painted_lidar/") + if not os.path.exists(self.save_path): + os.mkdir(self.save_path) + + self.seg_net_index = 0 + self.model = None + print(f'Using Segmentation Network -- deeplabv3plus') + checkpoint_file = args.model_path + model = network.modeling.__dict__['deeplabv3plus_resnet50']( + num_classes=19, output_stride=16) + checkpoint = torch.load(checkpoint_file) + model.load_state_dict(checkpoint["model_state"]) + model.eval() + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model.to(device) + self.model = model + self.cam_sync = args.cam_sync + + def get_lidar(self, idx): + lidar_file = os.path.join( + self.root_split_path, 'velodyne/' + ('%s.bin' % idx)) + return torch.from_numpy(np.fromfile( + str(lidar_file), dtype=np.float32).reshape(-1, 6)) + + def get_image(self, idx, camera): + filename = os.path.join(self.root_split_path, + camera + ('%s.jpg' % idx)) + input_image = Image.open(filename) + preprocess = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize( + mean=[ + 0.485, 0.456, 0.406], std=[ + 0.229, 0.224, 0.225]), + ]) + + input_tensor = preprocess(input_image) + # create a mini-batch as expected by the model + input_batch = input_tensor.unsqueeze(0) + if torch.cuda.is_available(): + input_batch = input_batch.to('cuda') + # move the input and model to GPU for speed if available + if torch.cuda.is_available(): + input_batch = input_batch.to('cuda') + return input_batch + + def get_model_output(self, input_batch): + with torch.no_grad(): + output = self.model(input_batch)[0] + return output + + def get_score(self, model_output): + sf = torch.nn.Softmax(dim=2) + output_permute = model_output.permute(1, 2, 0) + output_permute = sf(output_permute) + output_reassign = torch.zeros( + output_permute.size(0), output_permute.size(1), 6).to( + device=model_output.device) + output_reassign[:, :, 0] = torch.sum( + output_permute[:, :, :11], dim=2) # background + output_reassign[:, :, 1] = output_permute[:, :, 18] # bicycle + output_reassign[:, :, 2] = torch.sum( + output_permute[:, :, [13, 14, 15, 16]], dim=2) # vehicles + output_reassign[:, :, 3] = output_permute[:, :, 11] # person + output_reassign[:, :, 4] = output_permute[:, :, 12] # rider + output_reassign[:, :, 5] = output_permute[:, :, 17] # motorcycle + + return output_reassign + + def get_calib_fromfile(self, idx, device): + calib_file = os.path.join( + self.root_split_path, 'calib/' + ('%s.txt' % idx)) + calib = get_calib_from_file(calib_file) + calib['P0'] = torch.cat([calib['P0'], torch.tensor( + [[0., 0., 0., 1.]])], axis=0).to(device=device) + calib['P1'] = torch.cat([calib['P1'], torch.tensor( + [[0., 0., 0., 1.]])], axis=0).to(device=device) + calib['P2'] = torch.cat([calib['P2'], torch.tensor( + [[0., 0., 0., 1.]])], axis=0).to(device=device) + calib['P3'] = torch.cat([calib['P3'], torch.tensor( + [[0., 0., 0., 1.]])], axis=0).to(device=device) + calib['P4'] = torch.cat([calib['P4'], torch.tensor( + [[0., 0., 0., 1.]])], axis=0).to(device=device) + calib['R0_rect'] = torch.zeros( + [4, 4], dtype=calib['R0'].dtype, device=device) + calib['R0_rect'][3, 3] = 1. + calib['R0_rect'][:3, :3] = calib['R0'].to(device=device) + calib['Tr_velo_to_cam_0'] = torch.cat([calib['Tr_velo_to_cam_0'], torch.tensor( + [[0., 0., 0., 1.]], )], axis=0).to(device=device) + calib['Tr_velo_to_cam_1'] = torch.cat([calib['Tr_velo_to_cam_1'], torch.tensor( + [[0., 0., 0., 1.]], )], axis=0).to(device=device) + calib['Tr_velo_to_cam_2'] = torch.cat([calib['Tr_velo_to_cam_2'], torch.tensor( + [[0., 0., 0., 1.]], )], axis=0).to(device=device) + calib['Tr_velo_to_cam_3'] = torch.cat([calib['Tr_velo_to_cam_3'], torch.tensor( + [[0., 0., 0., 1.]], )], axis=0).to(device=device) + calib['Tr_velo_to_cam_4'] = torch.cat([calib['Tr_velo_to_cam_4'], torch.tensor( + [[0., 0., 0., 1.]], )], axis=0).to(device=device) + return calib + + def cam_to_lidar(self, pointcloud, projection_mats, camera_num): + """ + Takes in lidar in velo coords, returns lidar points in camera coords + + :param pointcloud: (n_points, 4) np.array (x,y,z,r) in velodyne coordinates + :return lidar_cam_coords: (n_points, 4) np.array (x,y,z,r) in camera coordinates + """ + + lidar_velo_coords = copy.deepcopy(pointcloud) + # copy reflectances column + reflectances = copy.deepcopy(lidar_velo_coords[:, -1]) + lidar_velo_coords[:, -1] = 1 # for multiplying with homogeneous matrix + lidar_cam_coords = projection_mats['Tr_velo_to_cam_' + + str(camera_num)].matmul(lidar_velo_coords.transpose(0, 1)) + lidar_cam_coords = lidar_cam_coords.transpose(0, 1) + lidar_cam_coords[:, -1] = reflectances + + return lidar_cam_coords + + def project_points_mask(self, lidar_cam_points, + projection_mats, class_scores, camera_num): + points_projected_on_mask = projection_mats['P' + str(camera_num)].matmul( + projection_mats['R0_rect'].matmul(lidar_cam_points.transpose(0, 1))) + points_projected_on_mask = points_projected_on_mask.transpose(0, 1) + points_projected_on_mask = points_projected_on_mask / \ + (points_projected_on_mask[:, 2].reshape(-1, 1)) + + true_where_x_on_img = (0 < points_projected_on_mask[:, 0]) & ( + points_projected_on_mask[:, 0] < class_scores[camera_num].shape[1]) # x in img coords is cols of img + true_where_y_on_img = (0 < points_projected_on_mask[:, 1]) & ( + points_projected_on_mask[:, 1] < class_scores[camera_num].shape[0]) + true_where_point_on_img = true_where_x_on_img & true_where_y_on_img & ( + lidar_cam_points[:, 2] > 0) + + # filter out points that don't project to image + points_projected_on_mask = points_projected_on_mask[true_where_point_on_img] + # using floor so you don't end up indexing num_rows+1th row or col + points_projected_on_mask = torch.floor(points_projected_on_mask).int() + # drops homogenous coord 1 from every point, giving (N_pts, 2) int + # array + points_projected_on_mask = points_projected_on_mask[:, :2] + return (points_projected_on_mask, true_where_point_on_img) + + def augment_lidar_class_scores_both( + self, class_scores, lidar_raw, projection_mats): + """ + Projects lidar points onto segmentation map, appends class score each point projects onto. + """ + # lidar_cam_coords = self.cam_to_lidar(lidar_raw, projection_mats) + + ################################ + lidar_cam_coords = self.cam_to_lidar( + lidar_raw[:, :4], projection_mats, 0) + + lidar_cam_coords[:, -1] = 1 # homogenous coords for projection + + points_projected_on_mask_0, true_where_point_on_img_0 = self.project_points_mask( + lidar_cam_coords, projection_mats, class_scores, 0) + + lidar_cam_coords = self.cam_to_lidar( + lidar_raw[:, :4], projection_mats, 1) + lidar_cam_coords[:, -1] = 1 # homogenous coords for projection + + points_projected_on_mask_1, true_where_point_on_img_1 = self.project_points_mask( + lidar_cam_coords, projection_mats, class_scores, 1) + + lidar_cam_coords = self.cam_to_lidar( + lidar_raw[:, :4], projection_mats, 2) + lidar_cam_coords[:, -1] = 1 + points_projected_on_mask_2, true_where_point_on_img_2 = self.project_points_mask( + lidar_cam_coords, projection_mats, class_scores, 2) + + lidar_cam_coords = self.cam_to_lidar( + lidar_raw[:, :4], projection_mats, 3) + lidar_cam_coords[:, -1] = 1 + points_projected_on_mask_3, true_where_point_on_img_3 = self.project_points_mask( + lidar_cam_coords, projection_mats, class_scores, 3) + + lidar_cam_coords = self.cam_to_lidar( + lidar_raw[:, :4], projection_mats, 4) + lidar_cam_coords[:, -1] = 1 + points_projected_on_mask_4, true_where_point_on_img_4 = self.project_points_mask( + lidar_cam_coords, projection_mats, class_scores, 4) + + true_where_point_on_both_0_1 = true_where_point_on_img_0 & true_where_point_on_img_1 + true_where_point_on_both_0_2 = true_where_point_on_img_0 & true_where_point_on_img_2 + true_where_point_on_both_1_3 = true_where_point_on_img_1 & true_where_point_on_img_3 + true_where_point_on_both_2_4 = true_where_point_on_img_2 & true_where_point_on_img_4 + true_where_point_on_img = true_where_point_on_img_1 | true_where_point_on_img_0 | true_where_point_on_img_2 | true_where_point_on_img_3 | true_where_point_on_img_4 + + point_scores_0 = class_scores[0][points_projected_on_mask_0[:, 1], + points_projected_on_mask_0[:, 0]].reshape(-1, class_scores[0].shape[2]) + point_scores_1 = class_scores[1][points_projected_on_mask_1[:, 1], + points_projected_on_mask_1[:, 0]].reshape(-1, class_scores[1].shape[2]) + point_scores_2 = class_scores[2][points_projected_on_mask_2[:, 1], + points_projected_on_mask_2[:, 0]].reshape(-1, class_scores[2].shape[2]) + point_scores_3 = class_scores[3][points_projected_on_mask_3[:, 1], + points_projected_on_mask_3[:, 0]].reshape(-1, class_scores[3].shape[2]) + point_scores_4 = class_scores[4][points_projected_on_mask_4[:, 1], + points_projected_on_mask_4[:, 0]].reshape(-1, class_scores[4].shape[2]) + + augmented_lidar = torch.cat((lidar_raw[:, :5], torch.zeros( + (lidar_raw.shape[0], class_scores[1].shape[2])).to(device=lidar_raw.device)), axis=1) + augmented_lidar[true_where_point_on_img_0, - + class_scores[0].shape[2]:] += point_scores_0 + augmented_lidar[true_where_point_on_img_1, - + class_scores[1].shape[2]:] += point_scores_1 + augmented_lidar[true_where_point_on_img_2, - + class_scores[2].shape[2]:] += point_scores_2 + augmented_lidar[true_where_point_on_img_3, - + class_scores[3].shape[2]:] += point_scores_3 + augmented_lidar[true_where_point_on_img_4, - + class_scores[4].shape[2]:] += point_scores_4 + augmented_lidar[true_where_point_on_both_0_1, -class_scores[0].shape[2]:] = 0.5 * \ + augmented_lidar[true_where_point_on_both_0_1, - + class_scores[0].shape[2]:] + augmented_lidar[true_where_point_on_both_0_2, -class_scores[0].shape[2]:] = 0.5 * \ + augmented_lidar[true_where_point_on_both_0_2, - + class_scores[0].shape[2]:] + augmented_lidar[true_where_point_on_both_1_3, -class_scores[1].shape[2]:] = 0.5 * \ + augmented_lidar[true_where_point_on_both_1_3, - + class_scores[1].shape[2]:] + augmented_lidar[true_where_point_on_both_2_4, -class_scores[2].shape[2]:] = 0.5 * \ + augmented_lidar[true_where_point_on_both_2_4, - + class_scores[2].shape[2]:] + if self.cam_sync: + augmented_lidar = augmented_lidar[true_where_point_on_img] + + return augmented_lidar diff --git a/automotive/3d-object-detection/model/pointpillars.py b/automotive/3d-object-detection/model/pointpillars.py new file mode 100644 index 000000000..49257fe5a --- /dev/null +++ b/automotive/3d-object-detection/model/pointpillars.py @@ -0,0 +1,515 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from model.anchors import Anchors, anchor_target, anchors2bboxes +from ops import Voxelization +import open3d.ml.torch as ml3d +from tools.process import limit_period +import math + + +class PillarLayer(nn.Module): + def __init__(self, voxel_size, point_cloud_range, + max_num_points, max_voxels): + super().__init__() + self.voxel_layer = Voxelization(voxel_size=voxel_size, + point_cloud_range=point_cloud_range, + max_num_points=max_num_points, + max_voxels=max_voxels) + + @torch.no_grad() + def forward(self, batched_pts): + ''' + batched_pts: list[tensor], len(batched_pts) = bs + return: + pillars: (p1 + p2 + ... + pb, num_points, c), + coors_batch: (p1 + p2 + ... + pb, 1 + 3), + num_points_per_pillar: (p1 + p2 + ... + pb, ), (b: batch size) + ''' + pillars, coors, npoints_per_pillar = [], [], [] + for i, pts in enumerate(batched_pts): + voxels_out, coors_out, num_points_per_voxel_out = self.voxel_layer( + pts) + # voxels_out: (max_voxel, num_points, c), coors_out: (max_voxel, 3) + # num_points_per_voxel_out: (max_voxel, ) + pillars.append(voxels_out) + coors.append(coors_out.long()) + npoints_per_pillar.append(num_points_per_voxel_out) + + # (p1 + p2 + ... + pb, num_points, c) + pillars = torch.cat(pillars, dim=0) + npoints_per_pillar = torch.cat( + npoints_per_pillar, + dim=0) # (p1 + p2 + ... + pb, ) + coors_batch = [] + for i, cur_coors in enumerate(coors): + coors_batch.append(F.pad(cur_coors, (1, 0), value=i)) + # (p1 + p2 + ... + pb, 1 + 3) + coors_batch = torch.cat(coors_batch, dim=0) + + return pillars, coors_batch, npoints_per_pillar + + +class PillarEncoder(nn.Module): + def __init__(self, voxel_size, point_cloud_range, in_channel, out_channel): + super().__init__() + self.out_channel = out_channel + self.vx, self.vy = voxel_size[0], voxel_size[1] + self.x_offset = voxel_size[0] / 2 + point_cloud_range[0] + self.y_offset = voxel_size[1] / 2 + point_cloud_range[1] + self.x_l = math.ceil( + (point_cloud_range[3] - + point_cloud_range[0]) / + voxel_size[0]) + self.y_l = math.ceil( + (point_cloud_range[4] - + point_cloud_range[1]) / + voxel_size[1]) + + self.conv = nn.Conv1d(in_channel, out_channel, 1, bias=False) + self.bn = nn.BatchNorm1d(out_channel, eps=1e-3, momentum=0.01) + + def forward(self, pillars, coors_batch, npoints_per_pillar): + ''' + pillars: (p1 + p2 + ... + pb, num_points, c), c = 4 + coors_batch: (p1 + p2 + ... + pb, 1 + 3) + npoints_per_pillar: (p1 + p2 + ... + pb, ) + return: (bs, out_channel, y_l, x_l) + ''' + device = pillars.device + # 1. calculate offset to the points center (in each pillar) + offset_pt_center = pillars[:, + :, + :3] - torch.sum(pillars[:, + :, + :3], + dim=1, + keepdim=True) / npoints_per_pillar[:, + None, + None] # (p1 + p2 + ... + pb, num_points, 3) + + # 2. calculate offset to the pillar center + # (p1 + p2 + ... + pb, num_points, 1) + x_offset_pi_center = pillars[:, :, :1] - \ + (coors_batch[:, None, 1:2] * self.vx + self.x_offset) + # (p1 + p2 + ... + pb, num_points, 1) + y_offset_pi_center = pillars[:, :, 1:2] - \ + (coors_batch[:, None, 2:3] * self.vy + self.y_offset) + + # 3. encoder + features = torch.cat([pillars, + offset_pt_center, + x_offset_pi_center, + y_offset_pi_center], + dim=-1) # (p1 + p2 + ... + pb, num_points, 9) + features[:, :, 0:1] = x_offset_pi_center # tmp + features[:, :, 1:2] = y_offset_pi_center # tmp + # In consitent with mmdet3d. + # The reason can be referenced to + # https://github.com/open-mmlab/mmdetection3d/issues/1150 + + # 4. find mask for (0, 0, 0) and update the encoded features + # a very beautiful implementation + voxel_ids = torch.arange( + 0, pillars.size(1)).to(device) # (num_points, ) + # (num_points, p1 + p2 + ... + pb) + mask = voxel_ids[:, None] < npoints_per_pillar[None, :] + # (p1 + p2 + ... + pb, num_points) + mask = mask.permute(1, 0).contiguous() + features *= mask[:, :, None] + + # 5. embedding + # (p1 + p2 + ... + pb, 9, num_points) + features = features.permute(0, 2, 1).contiguous() + # (p1 + p2 + ... + pb, out_channels, num_points) + features = F.relu(self.bn(self.conv(features))) + # (p1 + p2 + ... + pb, out_channels) + pooling_features = torch.max(features, dim=-1)[0] + + # 6. pillar scatter + batched_canvas = [] + bs = coors_batch[-1, 0] + 1 + for i in range(bs): + cur_coors_idx = coors_batch[:, 0] == i + cur_coors = coors_batch[cur_coors_idx, :] + cur_features = pooling_features[cur_coors_idx] + + canvas = torch.zeros( + (self.x_l, + self.y_l, + self.out_channel), + dtype=torch.float32, + device=device) + canvas[cur_coors[:, 1], cur_coors[:, 2]] = cur_features + canvas = canvas.permute(2, 1, 0).contiguous() + batched_canvas.append(canvas) + # (bs, in_channel, self.y_l, self.x_l) + batched_canvas = torch.stack(batched_canvas, dim=0) + return batched_canvas + + +class Backbone(nn.Module): + def __init__(self, in_channel, out_channels, + layer_nums, layer_strides=[2, 2, 2]): + super().__init__() + assert len(out_channels) == len(layer_nums) + assert len(out_channels) == len(layer_strides) + + self.multi_blocks = nn.ModuleList() + for i in range(len(layer_strides)): + blocks = [] + blocks.append( + nn.Conv2d( + in_channel, + out_channels[i], + 3, + stride=layer_strides[i], + bias=False, + padding=1)) + blocks.append( + nn.BatchNorm2d( + out_channels[i], + eps=1e-3, + momentum=0.01)) + blocks.append(nn.ReLU(inplace=True)) + + for _ in range(layer_nums[i]): + blocks.append( + nn.Conv2d( + out_channels[i], + out_channels[i], + 3, + bias=False, + padding=1)) + blocks.append( + nn.BatchNorm2d( + out_channels[i], + eps=1e-3, + momentum=0.01)) + blocks.append(nn.ReLU(inplace=True)) + + in_channel = out_channels[i] + self.multi_blocks.append(nn.Sequential(*blocks)) + + # in consitent with mmdet3d + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + + def forward(self, x): + ''' + x: (b, c, y_l, x_l). Default: (6, 64, 496, 432) + return: list[]. Default: [(6, 64, 248, 216), (6, 128, 124, 108), (6, 256, 62, 54)] + ''' + outs = [] + for i in range(len(self.multi_blocks)): + x = self.multi_blocks[i](x) + outs.append(x) + return outs + + +class Neck(nn.Module): + def __init__(self, in_channels, upsample_strides, out_channels): + super().__init__() + assert len(in_channels) == len(upsample_strides) + assert len(upsample_strides) == len(out_channels) + + self.decoder_blocks = nn.ModuleList() + for i in range(len(in_channels)): + decoder_block = [] + decoder_block.append(nn.ConvTranspose2d(in_channels[i], + out_channels[i], + upsample_strides[i], + stride=upsample_strides[i], + bias=False)) + decoder_block.append( + nn.BatchNorm2d( + out_channels[i], + eps=1e-3, + momentum=0.01)) + decoder_block.append(nn.ReLU(inplace=True)) + + self.decoder_blocks.append(nn.Sequential(*decoder_block)) + + # in consitent with mmdet3d + for m in self.modules(): + if isinstance(m, nn.ConvTranspose2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + + def forward(self, x): + ''' + x: [(bs, 64, 248, 216), (bs, 128, 124, 108), (bs, 256, 62, 54)] + return: (bs, 384, 248, 216) + ''' + outs = [] + for i in range(len(self.decoder_blocks)): + xi = self.decoder_blocks[i](x[i]) # (bs, 128, 248, 216) + outs.append(xi) + out = torch.cat(outs, dim=1) + return out + + +class Head(nn.Module): + def __init__(self, in_channel, n_anchors, n_classes): + super().__init__() + + self.conv_cls = nn.Conv2d(in_channel, n_anchors * n_classes, 1) + self.conv_reg = nn.Conv2d(in_channel, n_anchors * 7, 1) + self.conv_dir_cls = nn.Conv2d(in_channel, n_anchors * 2, 1) + + # in consitent with mmdet3d + conv_layer_id = 0 + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, mean=0, std=0.01) + if conv_layer_id == 0: + prior_prob = 0.01 + bias_init = float(-math.log((1 - prior_prob) / prior_prob)) + nn.init.constant_(m.bias, bias_init) + else: + nn.init.constant_(m.bias, 0) + conv_layer_id += 1 + + def forward(self, x): + ''' + x: (bs, 384, 248, 216) + return: + bbox_cls_pred: (bs, n_anchors*3, 248, 216) + bbox_pred: (bs, n_anchors*7, 248, 216) + bbox_dir_cls_pred: (bs, n_anchors*2, 248, 216) + ''' + bbox_cls_pred = self.conv_cls(x) + bbox_pred = self.conv_reg(x) + bbox_dir_cls_pred = self.conv_dir_cls(x) + return bbox_cls_pred, bbox_pred, bbox_dir_cls_pred + + +class PointPillars(nn.Module): + def __init__(self, + nclasses=3, + voxel_size=[0.32, 0.32, 6], + point_cloud_range=[-74.88, -74.88, -2, 74.88, 74.88, 4], + max_num_points=20, + max_voxels=(32000, 32000), + painted=False): + super().__init__() + self.nclasses = nclasses + self.pillar_layer = PillarLayer(voxel_size=voxel_size, + point_cloud_range=point_cloud_range, + max_num_points=max_num_points, + max_voxels=max_voxels) + if painted: + pillar_channel = 16 + else: + pillar_channel = 10 + self.pillar_encoder = PillarEncoder(voxel_size=voxel_size, + point_cloud_range=point_cloud_range, + in_channel=pillar_channel, + out_channel=64) + self.backbone = Backbone(in_channel=64, + out_channels=[64, 128, 256], + layer_nums=[3, 5, 5], + layer_strides=[1, 2, 2]) + self.neck = Neck(in_channels=[64, 128, 256], + upsample_strides=[1, 2, 4], + out_channels=[128, 128, 128]) + self.head = Head( + in_channel=384, + n_anchors=2 * nclasses, + n_classes=nclasses) + + # anchors + ranges = [[-74.88, -74.88, -0.0345, 74.88, 74.88, -0.0345], + [-74.88, -74.88, 0, 74.88, 74.88, 0], + [-74.88, -74.88, -0.1188, 74.88, 74.88, -0.1188]] + sizes = [[0.84, .91, 1.74], [.84, 1.81, 1.77], [2.08, 4.73, 1.77]] + rotations = [0, 1.57] + self.anchors_generator = Anchors(ranges=ranges, + sizes=sizes, + rotations=rotations) + + # train + self.assigners = [ + {'pos_iou_thr': 0.5, 'neg_iou_thr': 0.3, 'min_iou_thr': 0.3}, + {'pos_iou_thr': 0.5, 'neg_iou_thr': 0.3, 'min_iou_thr': 0.3}, + {'pos_iou_thr': 0.55, 'neg_iou_thr': 0.4, 'min_iou_thr': 0.4}, + ] + + # val and test + self.nms_pre = 4096 + self.nms_thr = 0.25 + self.score_thr = 0.1 + self.max_num = 500 + + def get_predicted_bboxes_single( + self, bbox_cls_pred, bbox_pred, bbox_dir_cls_pred, anchors): + ''' + bbox_cls_pred: (n_anchors*3, 248, 216) + bbox_pred: (n_anchors*7, 248, 216) + bbox_dir_cls_pred: (n_anchors*2, 248, 216) + anchors: (y_l, x_l, 3, 2, 7) + return: + bboxes: (k, 7) + labels: (k, ) + scores: (k, ) + ''' + # 0. pre-process + bbox_cls_pred = bbox_cls_pred.permute( + 1, 2, 0).reshape(-1, self.nclasses) + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 7) + bbox_dir_cls_pred = bbox_dir_cls_pred.permute(1, 2, 0).reshape(-1, 2) + anchors = anchors.reshape(-1, 7) + + bbox_cls_pred = torch.sigmoid(bbox_cls_pred) + bbox_dir_cls_pred = torch.max(bbox_dir_cls_pred, dim=1)[1] + + # 1. obtain self.nms_pre bboxes based on scores + inds = bbox_cls_pred.max(1)[0].topk(self.nms_pre)[1] + bbox_cls_pred = bbox_cls_pred[inds] + bbox_pred = bbox_pred[inds] + bbox_dir_cls_pred = bbox_dir_cls_pred[inds] + anchors = anchors[inds] + + # 2. decode predicted offsets to bboxes + bbox_pred = anchors2bboxes(anchors, bbox_pred) + + # 3. nms + bbox_pred2d_xy = bbox_pred[:, [0, 1]] + bbox_pred2d_lw = bbox_pred[:, [3, 4]] + bbox_pred2d = torch.cat([bbox_pred2d_xy - bbox_pred2d_lw / 2, + bbox_pred2d_xy + bbox_pred2d_lw / 2, + bbox_pred[:, 6:]], dim=-1) # (n_anchors, 5) + + ret_bboxes, ret_labels, ret_scores = [], [], [] + for i in range(self.nclasses): + # 3.1 filter bboxes with scores below self.score_thr + cur_bbox_cls_pred = bbox_cls_pred[:, i] + score_inds = cur_bbox_cls_pred > self.score_thr + if score_inds.sum() == 0: + continue + + cur_bbox_cls_pred = cur_bbox_cls_pred[score_inds] + cur_bbox_pred2d = bbox_pred2d[score_inds] + cur_bbox_pred = bbox_pred[score_inds] + cur_bbox_dir_cls_pred = bbox_dir_cls_pred[score_inds] + + # 3.2 nms core + keep_inds = ml3d.ops.nms( + cur_bbox_pred2d, cur_bbox_cls_pred, self.nms_thr) + + cur_bbox_cls_pred = cur_bbox_cls_pred[keep_inds] + cur_bbox_pred = cur_bbox_pred[keep_inds] + cur_bbox_dir_cls_pred = cur_bbox_dir_cls_pred[keep_inds] + cur_bbox_pred[:, - + 1] = limit_period(cur_bbox_pred[:, - + 1].detach().cpu(), 1, math.pi).to(cur_bbox_pred) # [-pi, 0] + cur_bbox_pred[:, -1] += (1 - cur_bbox_dir_cls_pred) * math.pi + + ret_bboxes.append(cur_bbox_pred) + ret_labels.append(torch.zeros_like( + cur_bbox_pred[:, 0], dtype=torch.long) + i) + ret_scores.append(cur_bbox_cls_pred) + + # 4. filter some bboxes if bboxes number is above self.max_num + if len(ret_bboxes) == 0: + return { + 'lidar_bboxes': torch.empty((0, 7)).detach().cpu(), + 'labels': torch.empty(0).detach().cpu(), + 'scores': torch.empty(0).detach().cpu() + } + ret_bboxes = torch.cat(ret_bboxes, 0) + ret_labels = torch.cat(ret_labels, 0) + ret_scores = torch.cat(ret_scores, 0) + if ret_bboxes.size(0) > self.max_num: + final_inds = ret_scores.topk(self.max_num)[1] + ret_bboxes = ret_bboxes[final_inds] + ret_labels = ret_labels[final_inds] + ret_scores = ret_scores[final_inds] + result = { + 'lidar_bboxes': ret_bboxes.detach().cpu(), + 'labels': ret_labels.detach().cpu(), + 'scores': ret_scores.detach().cpu() + } + return result + + def get_predicted_bboxes( + self, bbox_cls_pred, bbox_pred, bbox_dir_cls_pred, batched_anchors): + ''' + bbox_cls_pred: (bs, n_anchors*3, 248, 216) + bbox_pred: (bs, n_anchors*7, 248, 216) + bbox_dir_cls_pred: (bs, n_anchors*2, 248, 216) + batched_anchors: (bs, y_l, x_l, 3, 2, 7) + return: + bboxes: [(k1, 7), (k2, 7), ... ] + labels: [(k1, ), (k2, ), ... ] + scores: [(k1, ), (k2, ), ... ] + ''' + results = [] + bs = bbox_cls_pred.size(0) + for i in range(bs): + result = self.get_predicted_bboxes_single(bbox_cls_pred=bbox_cls_pred[i], + bbox_pred=bbox_pred[i], + bbox_dir_cls_pred=bbox_dir_cls_pred[i], + anchors=batched_anchors[i]) + results.append(result) + return results + + def forward(self, batched_pts, mode='test', + batched_gt_bboxes=None, batched_gt_labels=None): + batch_size = len(batched_pts) + # batched_pts: list[tensor] -> pillars: (p1 + p2 + ... + pb, num_points, c), + # coors_batch: (p1 + p2 + ... + pb, 1 + 3), + # num_points_per_pillar: (p1 + p2 + ... + pb, ), (b: batch size) + pillars, coors_batch, npoints_per_pillar = self.pillar_layer( + batched_pts) + + # pillars: (p1 + p2 + ... + pb, num_points, c), c = 4 + # coors_batch: (p1 + p2 + ... + pb, 1 + 3) + # npoints_per_pillar: (p1 + p2 + ... + pb, ) + # -> pillar_features: (bs, out_channel, y_l, x_l) + pillar_features = self.pillar_encoder( + pillars, coors_batch, npoints_per_pillar) + + # xs: [(bs, 64, 248, 216), (bs, 128, 124, 108), (bs, 256, 62, 54)] + xs = self.backbone(pillar_features) + + # x: (bs, 384, 248, 216) + x = self.neck(xs) + + # bbox_cls_pred: (bs, n_anchors*3, 248, 216) + # bbox_pred: (bs, n_anchors*7, 248, 216) + # bbox_dir_cls_pred: (bs, n_anchors*2, 248, 216) + bbox_cls_pred, bbox_pred, bbox_dir_cls_pred = self.head(x) + + # anchors + device = bbox_cls_pred.device + feature_map_size = torch.tensor( + list(bbox_cls_pred.size()[-2:]), device=device) + anchors = self.anchors_generator.get_multi_anchors(feature_map_size) + batched_anchors = [anchors for _ in range(batch_size)] + + if mode == 'train': + anchor_target_dict = anchor_target(batched_anchors=batched_anchors, + batched_gt_bboxes=batched_gt_bboxes, + batched_gt_labels=batched_gt_labels, + assigners=self.assigners, + nclasses=self.nclasses) + + return bbox_cls_pred, bbox_pred, bbox_dir_cls_pred, anchor_target_dict + elif mode == 'val': + results = self.get_predicted_bboxes(bbox_cls_pred=bbox_cls_pred, + bbox_pred=bbox_pred, + bbox_dir_cls_pred=bbox_dir_cls_pred, + batched_anchors=batched_anchors) + return results + + elif mode == 'test': + results = self.get_predicted_bboxes(bbox_cls_pred=bbox_cls_pred, + bbox_pred=bbox_pred, + bbox_dir_cls_pred=bbox_dir_cls_pred, + batched_anchors=batched_anchors) + return results + else: + raise ValueError diff --git a/automotive/3d-object-detection/model/segmentation/__init__.py b/automotive/3d-object-detection/model/segmentation/__init__.py new file mode 100644 index 000000000..f6d9c2307 --- /dev/null +++ b/automotive/3d-object-detection/model/segmentation/__init__.py @@ -0,0 +1,2 @@ +from .modeling import * +from ._deeplab import convert_to_separable_conv diff --git a/automotive/3d-object-detection/model/segmentation/_deeplab.py b/automotive/3d-object-detection/model/segmentation/_deeplab.py new file mode 100644 index 000000000..4c01a651c --- /dev/null +++ b/automotive/3d-object-detection/model/segmentation/_deeplab.py @@ -0,0 +1,210 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from .utils import _SimpleSegmentationModel + + +__all__ = ["DeepLabV3"] + + +class DeepLabV3(_SimpleSegmentationModel): + """ + Implements DeepLabV3 model from + `"Rethinking Atrous Convolution for Semantic Image Segmentation" + `_. + + Arguments: + backbone (nn.Module): the network used to compute the features for the model. + The backbone should return an OrderedDict[Tensor], with the key being + "out" for the last feature map used, and "aux" if an auxiliary classifier + is used. + classifier (nn.Module): module that takes the "out" element returned from + the backbone and returns a dense prediction. + aux_classifier (nn.Module, optional): auxiliary classifier used during training + """ + pass + + +class DeepLabHeadV3Plus(nn.Module): + def __init__(self, in_channels, low_level_channels, + num_classes, aspp_dilate=[12, 24, 36]): + super(DeepLabHeadV3Plus, self).__init__() + self.project = nn.Sequential( + nn.Conv2d(low_level_channels, 48, 1, bias=False), + nn.BatchNorm2d(48), + nn.ReLU(inplace=True), + ) + + self.aspp = ASPP(in_channels, aspp_dilate) + + self.classifier = nn.Sequential( + nn.Conv2d(304, 256, 3, padding=1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Conv2d(256, num_classes, 1) + ) + self._init_weight() + + def forward(self, feature): + low_level_feature = self.project(feature['low_level']) + output_feature = self.aspp(feature['out']) + output_feature = F.interpolate(output_feature, + size=low_level_feature.shape[2:], + mode='bilinear', + align_corners=False) + return self.classifier( + torch.cat([low_level_feature, output_feature], dim=1)) + + def _init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + +class DeepLabHead(nn.Module): + def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36]): + super(DeepLabHead, self).__init__() + + self.classifier = nn.Sequential( + ASPP(in_channels, aspp_dilate), + nn.Conv2d(256, 256, 3, padding=1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Conv2d(256, num_classes, 1) + ) + self._init_weight() + + def forward(self, feature): + return self.classifier(feature['out']) + + def _init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + +class AtrousSeparableConvolution(nn.Module): + """ Atrous Separable Convolution + """ + + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, bias=True): + super(AtrousSeparableConvolution, self).__init__() + self.body = nn.Sequential( + # Separable Conv + nn.Conv2d( + in_channels, + in_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + groups=in_channels), + # PointWise Conv + nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias), + ) + + self._init_weight() + + def forward(self, x): + return self.body(x) + + def _init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + +class ASPPConv(nn.Sequential): + def __init__(self, in_channels, out_channels, dilation): + modules = [ + nn.Conv2d( + in_channels, + out_channels, + 3, + padding=dilation, + dilation=dilation, + bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ] + super(ASPPConv, self).__init__(*modules) + + +class ASPPPooling(nn.Sequential): + def __init__(self, in_channels, out_channels): + super(ASPPPooling, self).__init__( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True)) + + def forward(self, x): + size = x.shape[-2:] + x = super(ASPPPooling, self).forward(x) + return F.interpolate( + x, size=size, mode='bilinear', align_corners=False) + + +class ASPP(nn.Module): + def __init__(self, in_channels, atrous_rates): + super(ASPP, self).__init__() + out_channels = 256 + modules = [] + modules.append(nn.Sequential( + nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True))) + + rate1, rate2, rate3 = tuple(atrous_rates) + modules.append(ASPPConv(in_channels, out_channels, rate1)) + modules.append(ASPPConv(in_channels, out_channels, rate2)) + modules.append(ASPPConv(in_channels, out_channels, rate3)) + modules.append(ASPPPooling(in_channels, out_channels)) + + self.convs = nn.ModuleList(modules) + + self.project = nn.Sequential( + nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Dropout(0.1),) + + def forward(self, x): + res = [] + for conv in self.convs: + res.append(conv(x)) + res = torch.cat(res, dim=1) + return self.project(res) + + +def convert_to_separable_conv(module): + new_module = module + if isinstance(module, nn.Conv2d) and module.kernel_size[0] > 1: + new_module = AtrousSeparableConvolution(module.in_channels, + module.out_channels, + module.kernel_size, + module.stride, + module.padding, + module.dilation, + module.bias) + for name, child in module.named_children(): + new_module.add_module(name, convert_to_separable_conv(child)) + return new_module diff --git a/automotive/3d-object-detection/model/segmentation/backbone/__init__.py b/automotive/3d-object-detection/model/segmentation/backbone/__init__.py new file mode 100644 index 000000000..c2cbcf4d2 --- /dev/null +++ b/automotive/3d-object-detection/model/segmentation/backbone/__init__.py @@ -0,0 +1 @@ +from . import resnet diff --git a/automotive/3d-object-detection/model/segmentation/backbone/resnet.py b/automotive/3d-object-detection/model/segmentation/backbone/resnet.py new file mode 100644 index 000000000..70809b3ba --- /dev/null +++ b/automotive/3d-object-detection/model/segmentation/backbone/resnet.py @@ -0,0 +1,353 @@ +import torch +import torch.nn as nn +try: # for torchvision<0.4 + from torchvision.models.utils import load_state_dict_from_url +except BaseException: # for torchvision>=0.4 + from torch.hub import load_state_dict_from_url + + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, + stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError( + 'BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError( + "Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when + # stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when + # stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to + # https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained=False, progress=True, **kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained=False, progress=True, **kwargs): + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained=False, progress=True, **kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet101(pretrained=False, progress=True, **kwargs): + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def resnet152(pretrained=False, progress=True, **kwargs): + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +def resnext50_32x4d(pretrained=False, progress=True, **kwargs): + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def resnext101_32x8d(pretrained=False, progress=True, **kwargs): + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_2(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_ + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_ + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) diff --git a/automotive/3d-object-detection/model/segmentation/modeling.py b/automotive/3d-object-detection/model/segmentation/modeling.py new file mode 100644 index 000000000..06c3757b6 --- /dev/null +++ b/automotive/3d-object-detection/model/segmentation/modeling.py @@ -0,0 +1,103 @@ +from .utils import IntermediateLayerGetter +from ._deeplab import DeepLabHead, DeepLabHeadV3Plus, DeepLabV3 +from .backbone import (resnet) + + +def _segm_resnet(name, backbone_name, num_classes, + output_stride, pretrained_backbone): + + if output_stride == 8: + replace_stride_with_dilation = [False, True, True] + aspp_dilate = [12, 24, 36] + else: + replace_stride_with_dilation = [False, False, True] + aspp_dilate = [6, 12, 18] + + backbone = resnet.__dict__[backbone_name]( + pretrained=pretrained_backbone, + replace_stride_with_dilation=replace_stride_with_dilation) + + inplanes = 2048 + low_level_planes = 256 + + if name == 'deeplabv3plus': + return_layers = {'layer4': 'out', 'layer1': 'low_level'} + classifier = DeepLabHeadV3Plus( + inplanes, low_level_planes, num_classes, aspp_dilate) + elif name == 'deeplabv3': + return_layers = {'layer4': 'out'} + classifier = DeepLabHead(inplanes, num_classes, aspp_dilate) + backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) + + model = DeepLabV3(backbone, classifier) + return model + + +def _load_model(arch_type, backbone, num_classes, + output_stride, pretrained_backbone): + + if backbone.startswith('resnet'): + model = _segm_resnet( + arch_type, + backbone, + num_classes, + output_stride=output_stride, + pretrained_backbone=pretrained_backbone) + else: + raise NotImplementedError + return model + + +# Deeplab v3 +def deeplabv3_resnet50(num_classes=21, output_stride=8, + pretrained_backbone=True): + """Constructs a DeepLabV3 model with a ResNet-50 backbone. + + Args: + num_classes (int): number of classes. + output_stride (int): output stride for deeplab. + pretrained_backbone (bool): If True, use the pretrained backbone. + """ + return _load_model('deeplabv3', 'resnet50', num_classes, + output_stride=output_stride, pretrained_backbone=pretrained_backbone) + + +def deeplabv3_resnet101(num_classes=21, output_stride=8, + pretrained_backbone=True): + """Constructs a DeepLabV3 model with a ResNet-101 backbone. + + Args: + num_classes (int): number of classes. + output_stride (int): output stride for deeplab. + pretrained_backbone (bool): If True, use the pretrained backbone. + """ + return _load_model('deeplabv3', 'resnet101', num_classes, + output_stride=output_stride, pretrained_backbone=pretrained_backbone) + +# Deeplab v3+ + + +def deeplabv3plus_resnet50( + num_classes=21, output_stride=8, pretrained_backbone=True): + """Constructs a DeepLabV3 model with a ResNet-50 backbone. + + Args: + num_classes (int): number of classes. + output_stride (int): output stride for deeplab. + pretrained_backbone (bool): If True, use the pretrained backbone. + """ + return _load_model('deeplabv3plus', 'resnet50', num_classes, + output_stride=output_stride, pretrained_backbone=pretrained_backbone) + + +def deeplabv3plus_resnet101( + num_classes=21, output_stride=8, pretrained_backbone=True): + """Constructs a DeepLabV3+ model with a ResNet-101 backbone. + + Args: + num_classes (int): number of classes. + output_stride (int): output stride for deeplab. + pretrained_backbone (bool): If True, use the pretrained backbone. + """ + return _load_model('deeplabv3plus', 'resnet101', num_classes, + output_stride=output_stride, pretrained_backbone=pretrained_backbone) diff --git a/automotive/3d-object-detection/model/segmentation/utils.py b/automotive/3d-object-detection/model/segmentation/utils.py new file mode 100644 index 000000000..bfc7e90e3 --- /dev/null +++ b/automotive/3d-object-detection/model/segmentation/utils.py @@ -0,0 +1,121 @@ +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +from collections import OrderedDict + + +class _SimpleSegmentationModel(nn.Module): + def __init__(self, backbone, classifier): + super(_SimpleSegmentationModel, self).__init__() + self.backbone = backbone + self.classifier = classifier + + def forward(self, x): + input_shape = x.shape[-2:] + features = self.backbone(x) + x = self.classifier(features) + x = F.interpolate( + x, + size=input_shape, + mode='bilinear', + align_corners=False) + return x + + +class IntermediateLayerGetter(nn.ModuleDict): + """ + Module wrapper that returns intermediate layers from a model + + It has a strong assumption that the modules have been registered + into the model in the same order as they are used. + This means that one should **not** reuse the same nn.Module + twice in the forward if you want this to work. + + Additionally, it is only able to query submodules that are directly + assigned to the model. So if `model` is passed, `model.feature1` can + be returned, but not `model.feature1.layer2`. + + Arguments: + model (nn.Module): model on which we will extract the features + return_layers (Dict[name, new_name]): a dict containing the names + of the modules for which the activations will be returned as + the key of the dict, and the value of the dict is the name + of the returned activation (which the user can specify). + + Examples:: + + >>> m = torchvision.models.resnet18(pretrained=True) + >>> # extract layer1 and layer3, giving as names `feat1` and feat2` + >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, + >>> {'layer1': 'feat1', 'layer3': 'feat2'}) + >>> out = new_m(torch.rand(1, 3, 224, 224)) + >>> print([(k, v.shape) for k, v in out.items()]) + >>> [('feat1', torch.Size([1, 64, 56, 56])), + >>> ('feat2', torch.Size([1, 256, 14, 14]))] + """ + + def __init__(self, model, return_layers, hrnet_flag=False): + if not set(return_layers).issubset( + [name for name, _ in model.named_children()]): + raise ValueError("return_layers are not present in model") + + self.hrnet_flag = hrnet_flag + + orig_return_layers = return_layers + return_layers = {k: v for k, v in return_layers.items()} + layers = OrderedDict() + for name, module in model.named_children(): + layers[name] = module + if name in return_layers: + del return_layers[name] + if not return_layers: + break + + super(IntermediateLayerGetter, self).__init__(layers) + self.return_layers = orig_return_layers + + def forward(self, x): + out = OrderedDict() + for name, module in self.named_children(): + if self.hrnet_flag and name.startswith( + 'transition'): # if using hrnet, you need to take care of transition + if name == 'transition1': # in transition1, you need to split the module to two streams first + x = [trans(x) for trans in module] + else: # all other transition is just an extra one stream split + x.append(module(x[-1])) + # other models (ex:resnet,mobilenet) are convolutions in series. + else: + x = module(x) + + if name in self.return_layers: + out_name = self.return_layers[name] + if name == 'stage4' and self.hrnet_flag: # In HRNetV2, we upsample and concat all outputs streams together + # Upsample to size of highest resolution stream + output_h, output_w = x[0].size(2), x[0].size(3) + x1 = F.interpolate( + x[1], + size=( + output_h, + output_w), + mode='bilinear', + align_corners=False) + x2 = F.interpolate( + x[2], + size=( + output_h, + output_w), + mode='bilinear', + align_corners=False) + x3 = F.interpolate( + x[3], + size=( + output_h, + output_w), + mode='bilinear', + align_corners=False) + x = torch.cat([x[0], x1, x2, x3], dim=1) + out[out_name] = x + else: + out[out_name] = x + return out diff --git a/automotive/3d-object-detection/ops/__init__.py b/automotive/3d-object-detection/ops/__init__.py new file mode 100644 index 000000000..66b739937 --- /dev/null +++ b/automotive/3d-object-detection/ops/__init__.py @@ -0,0 +1 @@ +from .voxel_module import Voxelization diff --git a/automotive/3d-object-detection/ops/iou3d/iou3d.cpp b/automotive/3d-object-detection/ops/iou3d/iou3d.cpp new file mode 100644 index 000000000..a337a93c7 --- /dev/null +++ b/automotive/3d-object-detection/ops/iou3d/iou3d.cpp @@ -0,0 +1,213 @@ +// Modified from +// https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/iou3d_nms/src/iou3d_nms.cpp + +/* +3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others) +Written by Shaoshuai Shi +All Rights Reserved 2019-2020. +*/ + +#include +#include +#include +#include + +#include +#include + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDAtensor ") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + +#define CHECK_ERROR(ans) \ + { gpuAssert((ans), __FILE__, __LINE__); } +inline void gpuAssert(cudaError_t code, const char *file, int line, + bool abort = true) { + if (code != cudaSuccess) { + fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, + line); + if (abort) + exit(code); + } +} + +const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8; + +void boxesoverlapLauncher(const int num_a, const float *boxes_a, + const int num_b, const float *boxes_b, + float *ans_overlap); +void boxesioubevLauncher(const int num_a, const float *boxes_a, const int num_b, + const float *boxes_b, float *ans_iou); +void nmsLauncher(const float *boxes, unsigned long long *mask, int boxes_num, + float nms_overlap_thresh); +void nmsNormalLauncher(const float *boxes, unsigned long long *mask, + int boxes_num, float nms_overlap_thresh); + +int boxes_overlap_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b, + at::Tensor ans_overlap) { + // params boxes_a: (N, 5) [x1, y1, x2, y2, ry] + // params boxes_b: (M, 5) + // params ans_overlap: (N, M) + + CHECK_INPUT(boxes_a); + CHECK_INPUT(boxes_b); + CHECK_INPUT(ans_overlap); + + int num_a = boxes_a.size(0); + int num_b = boxes_b.size(0); + + const float *boxes_a_data = boxes_a.data_ptr(); + const float *boxes_b_data = boxes_b.data_ptr(); + float *ans_overlap_data = ans_overlap.data_ptr(); + + boxesoverlapLauncher(num_a, boxes_a_data, num_b, boxes_b_data, + ans_overlap_data); + + return 1; +} + +int boxes_iou_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b, + at::Tensor ans_iou) { + // params boxes_a: (N, 5) [x1, y1, x2, y2, ry] + // params boxes_b: (M, 5) + // params ans_overlap: (N, M) + + CHECK_INPUT(boxes_a); + CHECK_INPUT(boxes_b); + CHECK_INPUT(ans_iou); + + int num_a = boxes_a.size(0); + int num_b = boxes_b.size(0); + + const float *boxes_a_data = boxes_a.data_ptr(); + const float *boxes_b_data = boxes_b.data_ptr(); + float *ans_iou_data = ans_iou.data_ptr(); + + boxesioubevLauncher(num_a, boxes_a_data, num_b, boxes_b_data, ans_iou_data); + + return 1; +} + +int nms_gpu(at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh, + int device_id) { + // params boxes: (N, 5) [x1, y1, x2, y2, ry] + // params keep: (N) + + CHECK_INPUT(boxes); + CHECK_CONTIGUOUS(keep); + cudaSetDevice(device_id); + + int boxes_num = boxes.size(0); + const float *boxes_data = boxes.data_ptr(); + int64_t *keep_data = keep.data_ptr(); + + const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); + + unsigned long long *mask_data = NULL; + CHECK_ERROR(cudaMalloc((void **)&mask_data, + boxes_num * col_blocks * sizeof(unsigned long long))); + nmsLauncher(boxes_data, mask_data, boxes_num, nms_overlap_thresh); + + // unsigned long long mask_cpu[boxes_num * col_blocks]; + // unsigned long long *mask_cpu = new unsigned long long [boxes_num * + // col_blocks]; + std::vector mask_cpu(boxes_num * col_blocks); + + // printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks); + CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data, + boxes_num * col_blocks * sizeof(unsigned long long), + cudaMemcpyDeviceToHost)); + + cudaFree(mask_data); + + unsigned long long *remv_cpu = new unsigned long long[col_blocks](); + + int num_to_keep = 0; + + for (int i = 0; i < boxes_num; i++) { + int nblock = i / THREADS_PER_BLOCK_NMS; + int inblock = i % THREADS_PER_BLOCK_NMS; + + if (!(remv_cpu[nblock] & (1ULL << inblock))) { + keep_data[num_to_keep++] = i; + unsigned long long *p = &mask_cpu[0] + i * col_blocks; + for (int j = nblock; j < col_blocks; j++) { + remv_cpu[j] |= p[j]; + } + } + } + delete[] remv_cpu; + if (cudaSuccess != cudaGetLastError()) + printf("Error!\n"); + + return num_to_keep; +} + +int nms_normal_gpu(at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh, + int device_id) { + // params boxes: (N, 5) [x1, y1, x2, y2, ry] + // params keep: (N) + + CHECK_INPUT(boxes); + CHECK_CONTIGUOUS(keep); + cudaSetDevice(device_id); + + int boxes_num = boxes.size(0); + const float *boxes_data = boxes.data_ptr(); + int64_t *keep_data = keep.data_ptr(); + + const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); + + unsigned long long *mask_data = NULL; + CHECK_ERROR(cudaMalloc((void **)&mask_data, + boxes_num * col_blocks * sizeof(unsigned long long))); + nmsNormalLauncher(boxes_data, mask_data, boxes_num, nms_overlap_thresh); + + // unsigned long long mask_cpu[boxes_num * col_blocks]; + // unsigned long long *mask_cpu = new unsigned long long [boxes_num * + // col_blocks]; + std::vector mask_cpu(boxes_num * col_blocks); + + // printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks); + CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data, + boxes_num * col_blocks * sizeof(unsigned long long), + cudaMemcpyDeviceToHost)); + + cudaFree(mask_data); + + unsigned long long *remv_cpu = new unsigned long long[col_blocks](); + + int num_to_keep = 0; + + for (int i = 0; i < boxes_num; i++) { + int nblock = i / THREADS_PER_BLOCK_NMS; + int inblock = i % THREADS_PER_BLOCK_NMS; + + if (!(remv_cpu[nblock] & (1ULL << inblock))) { + keep_data[num_to_keep++] = i; + unsigned long long *p = &mask_cpu[0] + i * col_blocks; + for (int j = nblock; j < col_blocks; j++) { + remv_cpu[j] |= p[j]; + } + } + } + delete[] remv_cpu; + if (cudaSuccess != cudaGetLastError()) + printf("Error!\n"); + + return num_to_keep; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("boxes_overlap_bev_gpu", &boxes_overlap_bev_gpu, + "oriented boxes overlap"); + m.def("boxes_iou_bev_gpu", &boxes_iou_bev_gpu, "oriented boxes iou"); + m.def("nms_gpu", &nms_gpu, "oriented nms gpu"); + m.def("nms_normal_gpu", &nms_normal_gpu, "nms gpu"); +} diff --git a/automotive/3d-object-detection/ops/iou3d/iou3d_kernel.cu b/automotive/3d-object-detection/ops/iou3d/iou3d_kernel.cu new file mode 100644 index 000000000..861aea3c5 --- /dev/null +++ b/automotive/3d-object-detection/ops/iou3d/iou3d_kernel.cu @@ -0,0 +1,439 @@ +// Modified from +// https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/iou3d_nms/src/iou3d_nms_kernel.cu + +/* +3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others) +Written by Shaoshuai Shi +All Rights Reserved 2019-2020. +*/ + +#include +#define THREADS_PER_BLOCK 16 +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + +//#define DEBUG +const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8; +__device__ const float EPS = 1e-8; +struct Point { + float x, y; + __device__ Point() {} + __device__ Point(double _x, double _y) { x = _x, y = _y; } + + __device__ void set(float _x, float _y) { + x = _x; + y = _y; + } + + __device__ Point operator+(const Point &b) const { + return Point(x + b.x, y + b.y); + } + + __device__ Point operator-(const Point &b) const { + return Point(x - b.x, y - b.y); + } +}; + +__device__ inline float cross(const Point &a, const Point &b) { + return a.x * b.y - a.y * b.x; +} + +__device__ inline float cross(const Point &p1, const Point &p2, + const Point &p0) { + return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y); +} + +__device__ int check_rect_cross(const Point &p1, const Point &p2, + const Point &q1, const Point &q2) { + int ret = min(p1.x, p2.x) <= max(q1.x, q2.x) && + min(q1.x, q2.x) <= max(p1.x, p2.x) && + min(p1.y, p2.y) <= max(q1.y, q2.y) && + min(q1.y, q2.y) <= max(p1.y, p2.y); + return ret; +} + +__device__ inline int check_in_box2d(const float *box, const Point &p) { + // params: box (5) [x1, y1, x2, y2, angle] + const float MARGIN = 1e-5; + + float center_x = (box[0] + box[2]) / 2; + float center_y = (box[1] + box[3]) / 2; + float angle_cos = cos(-box[4]), + angle_sin = + sin(-box[4]); // rotate the point in the opposite direction of box + float rot_x = + (p.x - center_x) * angle_cos + (p.y - center_y) * angle_sin + center_x; + float rot_y = + -(p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos + center_y; +#ifdef DEBUG + printf("box: (%.3f, %.3f, %.3f, %.3f, %.3f)\n", box[0], box[1], box[2], + box[3], box[4]); + printf( + "center: (%.3f, %.3f), cossin(%.3f, %.3f), src(%.3f, %.3f), rot(%.3f, " + "%.3f)\n", + center_x, center_y, angle_cos, angle_sin, p.x, p.y, rot_x, rot_y); +#endif + return (rot_x > box[0] - MARGIN && rot_x < box[2] + MARGIN && + rot_y > box[1] - MARGIN && rot_y < box[3] + MARGIN); +} + +__device__ inline int intersection(const Point &p1, const Point &p0, + const Point &q1, const Point &q0, + Point &ans) { + // fast exclusion + if (check_rect_cross(p0, p1, q0, q1) == 0) return 0; + + // check cross standing + float s1 = cross(q0, p1, p0); + float s2 = cross(p1, q1, p0); + float s3 = cross(p0, q1, q0); + float s4 = cross(q1, p1, q0); + + if (!(s1 * s2 > 0 && s3 * s4 > 0)) return 0; + + // calculate intersection of two lines + float s5 = cross(q1, p1, p0); + if (fabs(s5 - s1) > EPS) { + ans.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1); + ans.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1); + + } else { + float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y; + float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y; + float D = a0 * b1 - a1 * b0; + + ans.x = (b0 * c1 - b1 * c0) / D; + ans.y = (a1 * c0 - a0 * c1) / D; + } + + return 1; +} + +__device__ inline void rotate_around_center(const Point ¢er, + const float angle_cos, + const float angle_sin, Point &p) { + float new_x = + (p.x - center.x) * angle_cos + (p.y - center.y) * angle_sin + center.x; + float new_y = + -(p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y; + p.set(new_x, new_y); +} + +__device__ inline int point_cmp(const Point &a, const Point &b, + const Point ¢er) { + return atan2(a.y - center.y, a.x - center.x) > + atan2(b.y - center.y, b.x - center.x); +} + +__device__ inline float box_overlap(const float *box_a, const float *box_b) { + // params: box_a (5) [x1, y1, x2, y2, angle] + // params: box_b (5) [x1, y1, x2, y2, angle] + + float a_x1 = box_a[0], a_y1 = box_a[1], a_x2 = box_a[2], a_y2 = box_a[3], + a_angle = box_a[4]; + float b_x1 = box_b[0], b_y1 = box_b[1], b_x2 = box_b[2], b_y2 = box_b[3], + b_angle = box_b[4]; + + Point center_a((a_x1 + a_x2) / 2, (a_y1 + a_y2) / 2); + Point center_b((b_x1 + b_x2) / 2, (b_y1 + b_y2) / 2); +#ifdef DEBUG + printf( + "a: (%.3f, %.3f, %.3f, %.3f, %.3f), b: (%.3f, %.3f, %.3f, %.3f, %.3f)\n", + a_x1, a_y1, a_x2, a_y2, a_angle, b_x1, b_y1, b_x2, b_y2, b_angle); + printf("center a: (%.3f, %.3f), b: (%.3f, %.3f)\n", center_a.x, center_a.y, + center_b.x, center_b.y); +#endif + + Point box_a_corners[5]; + box_a_corners[0].set(a_x1, a_y1); + box_a_corners[1].set(a_x2, a_y1); + box_a_corners[2].set(a_x2, a_y2); + box_a_corners[3].set(a_x1, a_y2); + + Point box_b_corners[5]; + box_b_corners[0].set(b_x1, b_y1); + box_b_corners[1].set(b_x2, b_y1); + box_b_corners[2].set(b_x2, b_y2); + box_b_corners[3].set(b_x1, b_y2); + + // get oriented corners + float a_angle_cos = cos(a_angle), a_angle_sin = sin(a_angle); + float b_angle_cos = cos(b_angle), b_angle_sin = sin(b_angle); + + for (int k = 0; k < 4; k++) { +#ifdef DEBUG + printf("before corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \n", k, + box_a_corners[k].x, box_a_corners[k].y, box_b_corners[k].x, + box_b_corners[k].y); +#endif + rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners[k]); + rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners[k]); +#ifdef DEBUG + printf("corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \n", k, box_a_corners[k].x, + box_a_corners[k].y, box_b_corners[k].x, box_b_corners[k].y); +#endif + } + + box_a_corners[4] = box_a_corners[0]; + box_b_corners[4] = box_b_corners[0]; + + // get intersection of lines + Point cross_points[16]; + Point poly_center; + int cnt = 0, flag = 0; + + poly_center.set(0, 0); + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + flag = intersection(box_a_corners[i + 1], box_a_corners[i], + box_b_corners[j + 1], box_b_corners[j], + cross_points[cnt]); + if (flag) { + poly_center = poly_center + cross_points[cnt]; + cnt++; + } + } + } + + // check corners + for (int k = 0; k < 4; k++) { + if (check_in_box2d(box_a, box_b_corners[k])) { + poly_center = poly_center + box_b_corners[k]; + cross_points[cnt] = box_b_corners[k]; + cnt++; + } + if (check_in_box2d(box_b, box_a_corners[k])) { + poly_center = poly_center + box_a_corners[k]; + cross_points[cnt] = box_a_corners[k]; + cnt++; + } + } + + poly_center.x /= cnt; + poly_center.y /= cnt; + + // sort the points of polygon + Point temp; + for (int j = 0; j < cnt - 1; j++) { + for (int i = 0; i < cnt - j - 1; i++) { + if (point_cmp(cross_points[i], cross_points[i + 1], poly_center)) { + temp = cross_points[i]; + cross_points[i] = cross_points[i + 1]; + cross_points[i + 1] = temp; + } + } + } + +#ifdef DEBUG + printf("cnt=%d\n", cnt); + for (int i = 0; i < cnt; i++) { + printf("All cross point %d: (%.3f, %.3f)\n", i, cross_points[i].x, + cross_points[i].y); + } +#endif + + // get the overlap areas + float area = 0; + for (int k = 0; k < cnt - 1; k++) { + area += cross(cross_points[k] - cross_points[0], + cross_points[k + 1] - cross_points[0]); + } + + return fabs(area) / 2.0; +} + +__device__ inline float iou_bev(const float *box_a, const float *box_b) { + // params: box_a (5) [x1, y1, x2, y2, angle] + // params: box_b (5) [x1, y1, x2, y2, angle] + float sa = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]); + float sb = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]); + float s_overlap = box_overlap(box_a, box_b); + return s_overlap / fmaxf(sa + sb - s_overlap, EPS); +} + +__global__ void boxes_overlap_kernel(const int num_a, const float *boxes_a, + const int num_b, const float *boxes_b, + float *ans_overlap) { + const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y; + const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x; + + if (a_idx >= num_a || b_idx >= num_b) { + return; + } + const float *cur_box_a = boxes_a + a_idx * 5; + const float *cur_box_b = boxes_b + b_idx * 5; + float s_overlap = box_overlap(cur_box_a, cur_box_b); + ans_overlap[a_idx * num_b + b_idx] = s_overlap; +} + +__global__ void boxes_iou_bev_kernel(const int num_a, const float *boxes_a, + const int num_b, const float *boxes_b, + float *ans_iou) { + const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y; + const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x; + + if (a_idx >= num_a || b_idx >= num_b) { + return; + } + + const float *cur_box_a = boxes_a + a_idx * 5; + const float *cur_box_b = boxes_b + b_idx * 5; + float cur_iou_bev = iou_bev(cur_box_a, cur_box_b); + ans_iou[a_idx * num_b + b_idx] = cur_iou_bev; +} + +__global__ void nms_kernel(const int boxes_num, const float nms_overlap_thresh, + const float *boxes, unsigned long long *mask) { + // params: boxes (N, 5) [x1, y1, x2, y2, ry] + // params: mask (N, N/THREADS_PER_BLOCK_NMS) + + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + // if (row_start > col_start) return; + + const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS, + THREADS_PER_BLOCK_NMS); + const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS, + THREADS_PER_BLOCK_NMS); + + __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5]; + + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 5 + 0] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0]; + block_boxes[threadIdx.x * 5 + 1] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1]; + block_boxes[threadIdx.x * 5 + 2] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2]; + block_boxes[threadIdx.x * 5 + 3] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3]; + block_boxes[threadIdx.x * 5 + 4] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x; + const float *cur_box = boxes + cur_box_idx * 5; + + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (iou_bev(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { + t |= 1ULL << i; + } + } + const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); + mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +__device__ inline float iou_normal(float const *const a, float const *const b) { + float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]); + float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]); + float width = fmaxf(right - left, 0.f), height = fmaxf(bottom - top, 0.f); + float interS = width * height; + float Sa = (a[2] - a[0]) * (a[3] - a[1]); + float Sb = (b[2] - b[0]) * (b[3] - b[1]); + return interS / fmaxf(Sa + Sb - interS, EPS); +} + +__global__ void nms_normal_kernel(const int boxes_num, + const float nms_overlap_thresh, + const float *boxes, + unsigned long long *mask) { + // params: boxes (N, 5) [x1, y1, x2, y2, ry] + // params: mask (N, N/THREADS_PER_BLOCK_NMS) + + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + // if (row_start > col_start) return; + + const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS, + THREADS_PER_BLOCK_NMS); + const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS, + THREADS_PER_BLOCK_NMS); + + __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5]; + + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 5 + 0] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0]; + block_boxes[threadIdx.x * 5 + 1] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1]; + block_boxes[threadIdx.x * 5 + 2] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2]; + block_boxes[threadIdx.x * 5 + 3] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3]; + block_boxes[threadIdx.x * 5 + 4] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x; + const float *cur_box = boxes + cur_box_idx * 5; + + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (iou_normal(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { + t |= 1ULL << i; + } + } + const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); + mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +void boxesoverlapLauncher(const int num_a, const float *boxes_a, + const int num_b, const float *boxes_b, + float *ans_overlap) { + dim3 blocks( + DIVUP(num_b, THREADS_PER_BLOCK), + DIVUP(num_a, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK); + + boxes_overlap_kernel<<>>(num_a, boxes_a, num_b, boxes_b, + ans_overlap); +#ifdef DEBUG + cudaDeviceSynchronize(); // for using printf in kernel function +#endif +} + +void boxesioubevLauncher(const int num_a, const float *boxes_a, const int num_b, + const float *boxes_b, float *ans_iou) { + dim3 blocks( + DIVUP(num_b, THREADS_PER_BLOCK), + DIVUP(num_a, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK); + + boxes_iou_bev_kernel<<>>(num_a, boxes_a, num_b, boxes_b, + ans_iou); +} + +void nmsLauncher(const float *boxes, unsigned long long *mask, int boxes_num, + float nms_overlap_thresh) { + dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS), + DIVUP(boxes_num, THREADS_PER_BLOCK_NMS)); + dim3 threads(THREADS_PER_BLOCK_NMS); + nms_kernel<<>>(boxes_num, nms_overlap_thresh, boxes, mask); +} + +void nmsNormalLauncher(const float *boxes, unsigned long long *mask, + int boxes_num, float nms_overlap_thresh) { + dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS), + DIVUP(boxes_num, THREADS_PER_BLOCK_NMS)); + dim3 threads(THREADS_PER_BLOCK_NMS); + nms_normal_kernel<<>>(boxes_num, nms_overlap_thresh, boxes, + mask); +} diff --git a/automotive/3d-object-detection/ops/voxel_module.py b/automotive/3d-object-detection/ops/voxel_module.py new file mode 100644 index 000000000..9850dd289 --- /dev/null +++ b/automotive/3d-object-detection/ops/voxel_module.py @@ -0,0 +1,253 @@ +# This file is modified from +# https://github.com/open-mmlab/mmdetection3d/blob/master/mmdet3d/ops/voxel/voxelize.py + +import torch +import torch.nn as nn +import numpy as np +import numba + + +@numba.jit(nopython=True) +def _points_to_voxel_reverse_kernel(points, + voxel_size, + coors_range, + num_points_per_voxel, + coor_to_voxelidx, + voxels, + coors, + max_points=35, + max_voxels=20000): + # put all computations to one loop. + # we shouldn't create large array in main jit code, otherwise + # reduce performance + N = points.shape[0] + # ndim = points.shape[1] - 1 + ndim = 3 + ndim_minus_1 = ndim - 1 + grid_size = (coors_range[3:] - coors_range[:3]) / voxel_size + # np.round(grid_size) + # grid_size = np.round(grid_size).astype(np.int64)(np.int32) + grid_size = np.round(grid_size, 0, grid_size).astype(np.int32) + coor = np.zeros(shape=(3, ), dtype=np.int32) + voxel_num = 0 + failed = False + for i in range(N): + failed = False + for j in range(ndim): + c = np.floor((points[i, j] - coors_range[j]) / voxel_size[j]) + if c < 0 or c >= grid_size[j]: + failed = True + break + coor[ndim_minus_1 - j] = c + if failed: + continue + voxelidx = coor_to_voxelidx[coor[0], coor[1], coor[2]] + if voxelidx == -1: + voxelidx = voxel_num + if voxel_num >= max_voxels: + break + voxel_num += 1 + coor_to_voxelidx[coor[0], coor[1], coor[2]] = voxelidx + coors[voxelidx] = coor + num = num_points_per_voxel[voxelidx] + if num < max_points: + voxels[voxelidx, num] = points[i] + num_points_per_voxel[voxelidx] += 1 + return voxel_num + + +@numba.jit(nopython=True) +def _points_to_voxel_kernel(points, + voxel_size, + coors_range, + num_points_per_voxel, + coor_to_voxelidx, + voxels, + coors, + max_points=35, + max_voxels=20000): + # need mutex if write in cuda, but numba.cuda don't support mutex. + # in addition, pytorch don't support cuda in dataloader(tensorflow support this). + # put all computations to one loop. + # we shouldn't create large array in main jit code, otherwise + # decrease performance + N = points.shape[0] + # ndim = points.shape[1] - 1 + ndim = 3 + grid_size = (coors_range[3:] - coors_range[:3]) / voxel_size + # grid_size = np.round(grid_size).astype(np.int64)(np.int32) + grid_size = np.round(grid_size, 0, grid_size).astype(np.int32) + + lower_bound = coors_range[:3] + upper_bound = coors_range[3:] + coor = np.zeros(shape=(3, ), dtype=np.int32) + voxel_num = 0 + failed = False + for i in range(N): + failed = False + for j in range(ndim): + c = np.floor((points[i, j] - coors_range[j]) / voxel_size[j]) + if c < 0 or c >= grid_size[j]: + failed = True + break + coor[j] = c + if failed: + continue + voxelidx = coor_to_voxelidx[coor[0], coor[1], coor[2]] + if voxelidx == -1: + voxelidx = voxel_num + if voxel_num >= max_voxels: + break + voxel_num += 1 + coor_to_voxelidx[coor[0], coor[1], coor[2]] = voxelidx + coors[voxelidx] = coor + num = num_points_per_voxel[voxelidx] + if num < max_points: + voxels[voxelidx, num] = points[i] + num_points_per_voxel[voxelidx] += 1 + return voxel_num + + +def points_to_voxel(points, + voxel_size, + coors_range, + max_points=35, + reverse_index=True, + max_voxels=20000): + """convert kitti points(N, >=3) to voxels. This version calculate + everything in one loop. now it takes only 4.2ms(complete point cloud) + with jit and 3.2ghz cpu.(don't calculate other features) + Note: this function in ubuntu seems faster than windows 10. + + Args: + points: [N, ndim] float tensor. points[:, :3] contain xyz points and + points[:, 3:] contain other information such as reflectivity. + voxel_size: [3] list/tuple or array, float. xyz, indicate voxel size + coors_range: [6] list/tuple or array, float. indicate voxel range. + format: xyzxyz, minmax + max_points: int. indicate maximum points contained in a voxel. + reverse_index: boolean. indicate whether return reversed coordinates. + if points has xyz format and reverse_index is True, output + coordinates will be zyx format, but points in features always + xyz format. + max_voxels: int. indicate maximum voxels this function create. + for second, 20000 is a good choice. you should shuffle points + before call this function because max_voxels may drop some points. + + Returns: + voxels: [M, max_points, ndim] float tensor. only contain points. + coordinates: [M, 3] int32 tensor. + num_points_per_voxel: [M] int32 tensor. + """ + if not isinstance(voxel_size, np.ndarray): + voxel_size = np.array(voxel_size, dtype=points.dtype) + if not isinstance(coors_range, np.ndarray): + coors_range = np.array(coors_range, dtype=points.dtype) + voxelmap_shape = (coors_range[3:] - coors_range[:3]) / voxel_size + voxelmap_shape = tuple(np.round(voxelmap_shape).astype(np.int32).tolist()) + if reverse_index: + voxelmap_shape = voxelmap_shape[::-1] + # don't create large array in jit(nopython=True) code. + num_points_per_voxel = np.zeros(shape=(max_voxels, ), dtype=np.int32) + coor_to_voxelidx = -np.ones(shape=voxelmap_shape, dtype=np.int32) + voxels = np.zeros( + shape=(max_voxels, max_points, points.shape[-1]), dtype=points.dtype) + coors = np.zeros(shape=(max_voxels, 3), dtype=np.int32) + if reverse_index: + voxel_num = _points_to_voxel_reverse_kernel( + points, voxel_size, coors_range, num_points_per_voxel, + coor_to_voxelidx, voxels, coors, max_points, max_voxels) + + else: + voxel_num = _points_to_voxel_kernel( + points, voxel_size, coors_range, num_points_per_voxel, + coor_to_voxelidx, voxels, coors, max_points, max_voxels) + + coors = coors[:voxel_num] + voxels = voxels[:voxel_num] + num_points_per_voxel = num_points_per_voxel[:voxel_num] + # voxels[:, :, -3:] = voxels[:, :, :3] - \ + # voxels[:, :, :3].sum(axis=1, keepdims=True)/num_points_per_voxel.reshape(-1, 1, 1) + return voxels, coors, num_points_per_voxel + + +class Voxelization(nn.Module): + + def __init__(self, + voxel_size, + point_cloud_range, + max_num_points, + max_voxels, + deterministic=True): + super(Voxelization, self).__init__() + """ + Args: + voxel_size (list): list [x, y, z] size of three dimension + point_cloud_range (list): + [x_min, y_min, z_min, x_max, y_max, z_max] + max_num_points (int): max number of points per voxel + max_voxels (tuple): max number of voxels in + (training, testing) time + deterministic: bool. whether to invoke the non-deterministic + version of hard-voxelization implementations. non-deterministic + version is considerablly fast but is not deterministic. only + affects hard voxelization. default True. for more information + of this argument and the implementation insights, please refer + to the following links: + https://github.com/open-mmlab/mmdetection3d/issues/894 + https://github.com/open-mmlab/mmdetection3d/pull/904 + it is an experimental feature and we will appreciate it if + you could share with us the failing cases. + """ + self.voxel_size = voxel_size + self.point_cloud_range = point_cloud_range + self.max_num_points = max_num_points + self.max_voxels = max_voxels + self.deterministic = deterministic + + point_cloud_range = torch.tensor( + point_cloud_range, dtype=torch.float32) + + voxel_size = torch.tensor(voxel_size, dtype=torch.float32) + grid_size = (point_cloud_range[3:] - + point_cloud_range[:3]) / voxel_size + grid_size = torch.round(grid_size).long() + input_feat_shape = grid_size[:2] + self.grid_size = grid_size + # the origin shape is as [x-len, y-len, z-len] + # [w, h, d] -> [d, h, w] + self.pcd_shape = [*input_feat_shape, 1][::-1] + + def forward(self, input): + """ + input: shape=(N, c) + """ + if self.training: + max_voxels = self.max_voxels[0] + else: + max_voxels = self.max_voxels[1] + voxel_parts = points_to_voxel(input.detach().cpu().numpy(), + self.voxel_size, + self.point_cloud_range, + self.max_num_points, + reverse_index=False, + max_voxels=max_voxels) + # return _Voxelization.apply(input, self.voxel_size, self.point_cloud_range, + # self.max_num_points, max_voxels, + # self.deterministic) + voxels = torch.from_numpy(voxel_parts[0]).to(device=input.device) + coors = torch.from_numpy(voxel_parts[1]).to(device=input.device) + num_points_per_voxel = torch.from_numpy( + voxel_parts[2]).to( + device=input.device) + return voxels, coors, num_points_per_voxel + + def __repr__(self): + tmpstr = self.__class__.__name__ + '(' + tmpstr += 'voxel_size=' + str(self.voxel_size) + tmpstr += ', point_cloud_range=' + str(self.point_cloud_range) + tmpstr += ', max_num_points=' + str(self.max_num_points) + tmpstr += ', max_voxels=' + str(self.max_voxels) + tmpstr += ', deterministic=' + str(self.deterministic) + tmpstr += ')' + return tmpstr diff --git a/automotive/3d-object-detection/output/mlperf_log_accuracy.json b/automotive/3d-object-detection/output/mlperf_log_accuracy.json new file mode 100644 index 000000000..8e2f0bef1 --- /dev/null +++ b/automotive/3d-object-detection/output/mlperf_log_accuracy.json @@ -0,0 +1 @@ +[ \ No newline at end of file diff --git a/automotive/3d-object-detection/output/mlperf_log_detail.txt b/automotive/3d-object-detection/output/mlperf_log_detail.txt new file mode 100644 index 000000000..bfe9a7812 --- /dev/null +++ b/automotive/3d-object-detection/output/mlperf_log_detail.txt @@ -0,0 +1,64 @@ +:::MLLOG {"key": "error_invalid_config", "value": "Multiple conf files are used. This is not valid for official submission.", "time_ms": 18446744073709.254158, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": true, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 538, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "loadgen_version", "value": "4.1 @ f74d16f541", "time_ms": 0.002363, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "version.cc", "line_no": 53, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "loadgen_build_date_local", "value": "2024-10-28T22:57:16.905751", "time_ms": 0.002363, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "version.cc", "line_no": 55, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "loadgen_build_date_utc", "value": "2024-10-28T22:57:16.905756", "time_ms": 0.002363, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "version.cc", "line_no": 56, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "loadgen_git_commit_date", "value": "2024-10-22T15:38:02-05:00", "time_ms": 0.002363, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "version.cc", "line_no": 57, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "loadgen_git_log_message", "value": "f74d16f54131d9080b9e45f234cc23e0ebaaf20c Apply min_new_tokens=2 to mixtral-8x7b, address #1777 (#1884)\necb880167756cb4b36ad70766b8d3254bfb06d26 [Postmortem 4.1] Make mlperf.conf static in loadgen, enable automatic Pypi release (#1882)\nf5c8f1758374aeaba26b2e84d31690111cfdf054 Fix bug: Loadgen ignoring token latency targets in user conf (#1874)\n976bb1ad9c7946be79507f3ff67955c27426af52 Set correct remote repo (#1871)\n41fa8aadd1ba0ecc97f6a519d8b42b04278e5f24 Add format files github action (#1682)\n518b454fd8647bfbd23a074e875e87353f33393e Tflite tpu (#1449)\ne0fdec1c7a75c98cfc194f13d62ac4388d419c8a Fix link in GettingStarted.ipynb (#1512)\n92bd8198d15411d7fb7d7c27f8904bc5a0bcfe7a Fix warning in the submission checker (#1808)\n224cfbf5c0e82cae6d48620025b7e1258ae3666a Fix typo in reference datatype (#1851)\n3ef1249b7f50a250c02c568342e0aea6638fc5a7 Fix docs (#1853)\na0874c100c54cbc54fb743ac8bf9fb5fadc64135 Update build_wheels.yml (#1758)\n6eff09986e337ccf03f675c9f244d8ee93644e16 Extend the final report generation script to output a json file of results (#1825)\n54f3f93a73cc8ca5e3319ad87fb325e510574f56 Add binding for server_num_issue_query_threads parameter (#1862)\nc4d0b3ea98e6fe7252e50cb573f0d523da7979df Update docs: SCC24, fix broken redirect (#1843)\n7d2f0c41e5cd79c9178702867392e38f57953338 Update DLRM readme (#1811)\ncf5fddc5d0746bf3820eb0ab7294bbf709d788ab Enable systems to be marked as power only (#1850)", "time_ms": 0.002363, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "version.cc", "line_no": 58, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "loadgen_git_status_message", "value": "", "time_ms": 0.002363, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "version.cc", "line_no": 60, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "loadgen_file_sha1", "value": {"/.clang-format":"012aad77e5206c89d50718c46c119d1f3cb056b2","/CMakeLists.txt":"b73434348f7860471606aaa395b570e81113cb6d","/MANIFEST.in":"8d3c4ac6c325b7b9a0fd4cf4a4108cbeff8d5025","/README.md":"20a55bb946c2c0bbb564ced2af1e48efd096b3a8","/README_BUILD.md":"5f6c6a784e9cd6995db47f9b9f70b1769909c9d8","/README_FAQ.md":"01f9ae9887f50bc030dc6107e740f40c43ca388f","/VERSION.txt":"cb67dcc41adcbb7849a0a808a501ee9ccd951d92","/__init__.py":"da39a3ee5e6b4b0d3255bfef95601890afd80709","/bindings/c_api.cc":"32181da9e161c285f8fe46ddaa49e6cba2f9f918","/bindings/c_api.h":"91f58bd79b83b278f3240174a9af747fc38aff74","/bindings/python_api.cc":"9f538d2a5390c77ae0bc3f8a351bcdb2587bc66c","/diagram_network_submission.png":"53dba8ad4272190ceb6335c12fd25e53dc02a8cb","/diagram_submission.png":"84c2f79309b237cef652aef6a187ba8e875a3952","/early_stopping.cc":"0cd7b546a389deac73f7955cd39255ed76557d62","/early_stopping.h":"158fcae6a5f47e82150d6416fa1f7bcef37e77fe","/issue_query_controller.cc":"126e952d00f4ea9efd12405fb209aa3ed585e4b2","/issue_query_controller.h":"923d9d5cdf598e3ec33d7a1110a31f7e11527ec7","/loadgen.cc":"6650091ba7a918f343b06eb7a5aa540eae87275f","/loadgen.h":"e00fdc6dbc85a8c9a8485dbcbfe2944f81251c4e","/loadgen_integration_diagram.svg":"47f748307536f80cfc606947b440dd732afc2637","/logging.cc":"197efc96d178e5d33a750d07fa7b2966417506ea","/logging.h":"ddb961df7bcc145bcd7cce8c21f7cf075350dcbe","/mlperf.conf":"0a4daef277bb3151139980e484dd5e644bf36e18","/pyproject.toml":"712fab87b72ba67ef2a068d0f9f47da65130342f","/query_dispatch_library.h":"13ad6d842200cb161d6927eb74a3fafd79c46c75","/query_sample.h":"e9187c8612bbdc972305b789feb6e15c26e96cfe","/query_sample_library.h":"8323a2225be1dff31f08ecc86b76eb3de06568bc","/requirements.txt":"a5ff7e77caa6e9e22ada90f0de0c865c987bf167","/results.cc":"34e2d2a44324cb07c884f92146ecbb8ef9d704e2","/results.h":"fce22d5a588d91fd968a6b25c27896dba87bc276","/setup.py":"a722046e05858c6d9f38f0e2b3fe425334beef28","/system_under_test.h":"18d4809589dae33317d88d9beeb5491a6e1ccdec","/test_settings.h":"476ecd4032f3bafe6f201df25d68aca4e177f659","/test_settings_internal.cc":"ce4322c849d24ffafc28a37b5e528a4cb4df227d","/test_settings_internal.h":"f1d5335b53ca610c30e0edc5d07999a27b5b4b9a","/utils.cc":"3df8fdabf6eaea4697cf25d1dcb89cae88e36efd","/utils.h":"40775e32d619ea6356826ae5ea4174c7911f6894","/version.cc":"cbec2a5f98f9786c8c3d8b06b3d12df0b6550fa0","/version.h":"9d574baa64424e9c708fcfedd3dbb0b518a65fcc","/version_generator.py":"eea9b9cb1a06cd1abe1bbdaee82f9af31527fedb"}, "time_ms": 0.002363, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "version.cc", "line_no": 67, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "test_datetime", "value": "2024-10-29T02:08:01Z", "time_ms": 0.002578, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "loadgen.cc", "line_no": 1198, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "sut_name", "value": "PySUT", "time_ms": 0.002578, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "loadgen.cc", "line_no": 1199, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "get_sut_name_duration_ns", "value": 182, "time_ms": 0.002578, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "loadgen.cc", "line_no": 1200, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "qsl_name", "value": "PyQSL", "time_ms": 0.002578, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "loadgen.cc", "line_no": 1201, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "qsl_reported_total_count", "value": 198, "time_ms": 0.002578, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "loadgen.cc", "line_no": 1202, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "qsl_reported_performance_count", "value": 5000, "time_ms": 0.002578, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "loadgen.cc", "line_no": 1203, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "requested_scenario", "value": "SingleStream", "time_ms": 0.005484, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 271, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "requested_test_mode", "value": "PerformanceOnly", "time_ms": 0.005484, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 272, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "requested_single_stream_expected_latency_ns", "value": 1e+07, "time_ms": 0.005484, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 277, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "requested_single_stream_target_latency_percentile", "value": 0.9, "time_ms": 0.005484, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 279, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "requested_min_duration_ms", "value": 600000, "time_ms": 0.005484, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 315, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "requested_max_duration_ms", "value": 0, "time_ms": 0.005484, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 316, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "requested_min_query_count", "value": 100, "time_ms": 0.005484, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 317, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "requested_max_query_count", "value": 0, "time_ms": 0.005484, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 318, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "requested_qsl_rng_seed", "value": 3066443479025735752, "time_ms": 0.005484, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 319, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "requested_sample_index_rng_seed", "value": 10688027786191513374, "time_ms": 0.005484, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 320, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "requested_schedule_rng_seed", "value": 14962580496156340209, "time_ms": 0.005484, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 322, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "requested_accuracy_log_rng_seed", "value": 0, "time_ms": 0.005484, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 323, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "requested_accuracy_log_probability", "value": 0, "time_ms": 0.005484, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 325, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "requested_accuracy_log_sampling_target", "value": 0, "time_ms": 0.005484, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 327, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "requested_print_timestamps", "value": false, "time_ms": 0.005484, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 329, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "requested_performance_issue_unique", "value": false, "time_ms": 0.005484, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 330, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "requested_performance_issue_same", "value": false, "time_ms": 0.005484, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 332, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "requested_performance_issue_same_index", "value": 0, "time_ms": 0.005484, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 334, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "requested_performance_sample_count_override", "value": 0, "time_ms": 0.005484, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 336, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "requested_sample_concatenate_permutation", "value": false, "time_ms": 0.005484, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 338, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_scenario", "value": "SingleStream", "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 417, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_test_mode", "value": "PerformanceOnly", "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 418, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_samples_per_query", "value": 1, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 420, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_target_qps", "value": 100, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 421, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_target_latency_ns", "value": 0, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 422, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_target_latency_percentile", "value": 0.9, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 423, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_max_async_queries", "value": 1, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 425, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_target_duration_ms", "value": 600000, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 426, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_min_duration_ms", "value": 600000, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 428, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_max_duration_ms", "value": 0, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 429, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_min_query_count", "value": 100, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 430, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_max_query_count", "value": 0, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 431, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_min_sample_count", "value": 100, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 432, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_qsl_rng_seed", "value": 3066443479025735752, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 433, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_sample_index_rng_seed", "value": 10688027786191513374, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 434, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_schedule_rng_seed", "value": 14962580496156340209, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 436, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_accuracy_log_rng_seed", "value": 0, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 437, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_accuracy_log_probability", "value": 0, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 439, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_accuracy_log_sampling_target", "value": 0, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 441, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_print_timestamps", "value": false, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 443, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_performance_issue_unique", "value": false, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 444, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_performance_issue_same", "value": false, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 446, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_performance_issue_same_index", "value": 0, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 448, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_performance_sample_count", "value": 5000, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 450, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "effective_sample_concatenate_permutation", "value": false, "time_ms": 0.005668, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "test_settings_internal.cc", "line_no": 452, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "generic_message", "value": "Starting performance mode", "time_ms": 0.007752, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "loadgen.cc", "line_no": 841, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "loaded_qsl_set", "value": [166,119,189,146,100,142,64,87,84,70,135,40,195,77,168,185,85,163,193,29,190,30,53,50,99,164,157,167,156,165,49,27,187,79,125,170,123,182,88,9,101,6,128,62,151,124,92,38,42,56,66,90,51,181,48,12,148,35,178,15,82,106,67,176,143,153,147,97,102,61,83,109,158,118,52,107,131,114,113,69,0,39,55,127,173,155,160,75,162,5,21,24,121,186,18,59,145,7,25,26,74,14,91,2,76,141,86,63,196,134,110,137,20,180,23,34,105,188,117,68,133,108,129,154,194,116,44,197,57,126,41,80,58,111,43,32,78,10,98,28,89,152,22,179,47,73,159,149,130,144,11,60,17,3,112,115,103,174,171,36,140,37,72,96,31,104,138,8,120,81,19,175,13,184,16,71,172,65,136,169,93,1,132,191,45,192,94,54,95,46,183,177,122,139,161,33,150,4], "time_ms": 0.015220, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "loadgen.cc", "line_no": 613, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "generated_query_count", "value": 120001, "time_ms": 25.346542, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "loadgen.cc", "line_no": 428, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "generated_samples_per_query", "value": 1, "time_ms": 25.346542, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "loadgen.cc", "line_no": 429, "pid": 3279, "tid": 3279}} +:::MLLOG {"key": "generated_query_duration", "value": 1200010000000, "time_ms": 25.346542, "namespace": "mlperf::logging", "event_type": "POINT_IN_TIME", "metadata": {"is_error": false, "is_warning": false, "file": "loadgen.cc", "line_no": 430, "pid": 3279, "tid": 3279}} diff --git a/automotive/3d-object-detection/output/mlperf_log_summary.txt b/automotive/3d-object-detection/output/mlperf_log_summary.txt new file mode 100644 index 000000000..e69de29bb diff --git a/automotive/3d-object-detection/output/mlperf_log_trace.json b/automotive/3d-object-detection/output/mlperf_log_trace.json new file mode 100644 index 000000000..e69de29bb diff --git a/automotive/3d-object-detection/requirements.txt b/automotive/3d-object-detection/requirements.txt new file mode 100644 index 000000000..a5bb20b23 --- /dev/null +++ b/automotive/3d-object-detection/requirements.txt @@ -0,0 +1 @@ +# TODO: Add requirements diff --git a/automotive/3d-object-detection/tools/download_dataset.py b/automotive/3d-object-detection/tools/download_dataset.py new file mode 100644 index 000000000..8a567ccf5 --- /dev/null +++ b/automotive/3d-object-detection/tools/download_dataset.py @@ -0,0 +1 @@ +# TODO: script to download dataset diff --git a/automotive/3d-object-detection/tools/evaluate.py b/automotive/3d-object-detection/tools/evaluate.py new file mode 100644 index 000000000..3bc56ee4f --- /dev/null +++ b/automotive/3d-object-detection/tools/evaluate.py @@ -0,0 +1,261 @@ +import argparse +import numpy as np +import os +import torch +from tqdm import tqdm + +from tools.process import iou2d, iou3d_camera + + +def get_score_thresholds(tp_scores, total_num_valid_gt, num_sample_pts=41): + score_thresholds = [] + tp_scores = sorted(tp_scores)[::-1] + cur_recall, pts_ind = 0, 0 + for i, score in enumerate(tp_scores): + lrecall = (i + 1) / total_num_valid_gt + rrecall = (i + 2) / total_num_valid_gt + + if i == len(tp_scores) - 1: + score_thresholds.append(score) + break + + if (lrecall + rrecall) / 2 < cur_recall: + continue + + score_thresholds.append(score) + pts_ind += 1 + cur_recall = pts_ind / (num_sample_pts - 1) + return score_thresholds + + +def convert_calib(calib, cuda): + result = {} + if cuda: + device = 'cuda' + else: + device = 'cpu' + result['R0_rect'] = torch.from_numpy( + calib['R0_rect']).to( + device=device, + dtype=torch.float) + for i in range(5): + result['P' + str(i)] = torch.from_numpy(calib['P' + str(i)] + ).to(device=device, dtype=torch.float) + result['Tr_velo_to_cam_' + + str(i)] = torch.from_numpy(calib['Tr_velo_to_cam_' + + str(i)]).to(device=device, dtype=torch.float) + return result + + +def do_eval(det_results, gt_results, CLASSES, cam_sync=False): + ''' + det_results: list, + gt_results: dict(id -> det_results) + CLASSES: dict + ''' + assert len(det_results) == len(gt_results) + + # 1. calculate iou + ious = { + 'bbox_3d': [] + } + # ids = list(sorted([g['image']['image_idx'] for g in gt_results])) + if cam_sync: + annos_label = 'cam_sync_annos' + else: + annos_label = 'annos' + for id in range(len(gt_results)): + gt_result = gt_results[id][annos_label] + det_result = det_results[gt_results[id]['image']['image_idx']] + + # 1.2, bev iou + gt_location = gt_result['location'].astype(np.float32) + gt_dimensions = gt_result['dimensions'].astype(np.float32) + gt_rotation_y = gt_result['rotation_y'].astype(np.float32) + det_location = det_result['location'].astype(np.float32).reshape(-1, 3) + det_dimensions = det_result['dimensions'].astype( + np.float32).reshape(-1, 3) + det_rotation_y = det_result['rotation_y'].astype(np.float32) + + # 1.3, 3dbboxes iou + gt_bboxes3d = np.concatenate( + [gt_location, gt_dimensions, gt_rotation_y[:, None]], axis=-1) + det_bboxes3d = np.concatenate( + [det_location, det_dimensions, det_rotation_y[:, None]], axis=-1) + iou3d_v = iou3d_camera( + torch.from_numpy(gt_bboxes3d).cuda(), + torch.from_numpy(det_bboxes3d).cuda()) + ious['bbox_3d'].append(iou3d_v.cpu().numpy()) + + MIN_IOUS = { + 'Pedestrian': [0.5], + 'Cyclist': [0.5], + 'Car': [0.7] + } + MIN_HEIGHT = [-1] + + overall_results = {} + for e_ind, eval_type in enumerate(['bbox_3d']): + eval_ious = ious[eval_type] + eval_ap_results, eval_aos_results = {}, {} + for cls in CLASSES: + eval_ap_results[cls] = [] + eval_aos_results[cls] = [] + CLS_MIN_IOU = MIN_IOUS[cls][e_ind] + for difficulty in [0]: + # 1. bbox property + total_gt_ignores, total_det_ignores, total_dc_bboxes, total_scores = [], [], [], [] + for id in range(len(gt_results)): + gt_result = gt_results[id][annos_label] + det_result = det_results[gt_results[id] + ['image']['image_idx']] + + # 1.1 gt bbox property + cur_gt_names = gt_result['name'] + cur_difficulty = gt_result['difficulty'] + gt_ignores, dc_bboxes = [], [] + for j, cur_gt_name in enumerate(cur_gt_names): + ignore = cur_difficulty[j] < 0 or cur_difficulty[j] > difficulty + if cur_gt_name == cls: + valid_class = 1 + elif cls == 'Pedestrian' and cur_gt_name == 'Person_sitting': + valid_class = 0 + elif cls == 'Car' and cur_gt_name == 'Van': + valid_class = 0 + else: + valid_class = -1 + + if valid_class == 1 and not ignore: + gt_ignores.append(0) + elif valid_class == 0 or (valid_class == 1 and ignore): + gt_ignores.append(1) + else: + gt_ignores.append(-1) + + if cur_gt_name == 'DontCare': + dc_bboxes.append(gt_result['bbox'][j]) + total_gt_ignores.append(gt_ignores) + total_dc_bboxes.append(np.array(dc_bboxes)) + + # 1.2 det bbox property + cur_det_names = det_result['name'] + if len(cur_det_names) == 0: + cur_det_heights = np.empty_like(det_result['bbox']) + else: + cur_det_heights = det_result['bbox'][:, + 3] - det_result['bbox'][:, 1] + det_ignores = [] + for j, cur_det_name in enumerate(cur_det_names): + if cur_det_heights[j] < MIN_HEIGHT[difficulty]: + det_ignores.append(1) + elif cur_det_name == cls: + det_ignores.append(0) + else: + det_ignores.append(-1) + total_det_ignores.append(det_ignores) + total_scores.append(det_result['score']) + + # 2. calculate scores thresholds for PR curve + tp_scores = [] + for i in range(len(gt_results)): + cur_eval_ious = eval_ious[i] + gt_ignores, det_ignores = total_gt_ignores[i], total_det_ignores[i] + scores = total_scores[i] + + nn, mm = cur_eval_ious.shape + assigned = np.zeros((mm, ), dtype=np.bool_) + for j in range(nn): + if gt_ignores[j] == -1: + continue + match_id, match_score = -1, -1 + for k in range(mm): + if not assigned[k] and det_ignores[k] >= 0 and cur_eval_ious[j, + k] > CLS_MIN_IOU and scores[k] > match_score: + match_id = k + match_score = scores[k] + if match_id != -1: + assigned[match_id] = True + if det_ignores[match_id] == 0 and gt_ignores[j] == 0: + tp_scores.append(match_score) + total_num_valid_gt = np.sum( + [np.sum(np.array(gt_ignores) == 0) for gt_ignores in total_gt_ignores]) + score_thresholds = get_score_thresholds( + tp_scores, total_num_valid_gt) + + # 3. draw PR curve and calculate mAP + tps, fns, fps, total_aos = [], [], [], [] + + for score_threshold in score_thresholds: + tp, fn, fp = 0, 0, 0 + aos = 0 + for i in range(len(gt_results)): + cur_eval_ious = eval_ious[i] + gt_ignores, det_ignores = total_gt_ignores[i], total_det_ignores[i] + scores = total_scores[i] + + nn, mm = cur_eval_ious.shape + assigned = np.zeros((mm, ), dtype=np.bool_) + for j in range(nn): + if gt_ignores[j] == -1: + continue + match_id, match_iou = -1, -1 + for k in range(mm): + if not assigned[k] and det_ignores[k] >= 0 and scores[ + k] >= score_threshold and cur_eval_ious[j, k] > CLS_MIN_IOU: + + if det_ignores[k] == 0 and cur_eval_ious[j, + k] > match_iou: + match_iou = cur_eval_ious[j, k] + match_id = k + elif det_ignores[k] == 1 and match_iou == -1: + match_id = k + + if match_id != -1: + assigned[match_id] = True + if det_ignores[match_id] == 0 and gt_ignores[j] == 0: + tp += 1 + else: + if gt_ignores[j] == 0: + fn += 1 + + for k in range(mm): + if det_ignores[k] == 0 and scores[k] >= score_threshold and not assigned[k]: + fp += 1 + + # In case 2d bbox evaluation, we should consider + # dontcare bboxes + if eval_type == 'bbox_2d': + dc_bboxes = total_dc_bboxes[i] + det_bboxes = det_results[gt_results[i] + ['image']['image_idx']]['bbox'] + if len(dc_bboxes) > 0: + ious_dc_det = iou2d( + torch.from_numpy(det_bboxes), + torch.from_numpy(dc_bboxes), + metric=1).numpy().T + for j in range(len(dc_bboxes)): + for k in range(len(det_bboxes)): + if det_ignores[k] == 0 and scores[k] >= score_threshold and not assigned[k]: + if ious_dc_det[j, k] > CLS_MIN_IOU: + fp -= 1 + assigned[k] = True + + tps.append(tp) + fns.append(fn) + fps.append(fp) + if eval_type == 'bbox_2d': + total_aos.append(aos) + + tps, fns, fps = np.array(tps), np.array(fns), np.array(fps) + + recalls = tps / (tps + fns) + precisions = tps / (tps + fps) + for i in range(len(score_thresholds)): + precisions[i] = np.max(precisions[i:]) + + sums_AP = 0 + for i in range(0, len(score_thresholds), 4): + sums_AP += precisions[i] + mAP = sums_AP / 11 * 100 + eval_ap_results[cls].append(mAP) + return eval_ap_results diff --git a/automotive/3d-object-detection/tools/process.py b/automotive/3d-object-detection/tools/process.py new file mode 100644 index 000000000..59705d533 --- /dev/null +++ b/automotive/3d-object-detection/tools/process.py @@ -0,0 +1,341 @@ +import shapely.geometry +import numpy as np +import torch +import copy + + +def bbox_camera2lidar(bboxes, tr_velo_to_cam, r0_rect): + ''' + bboxes: shape=(N, 7) + tr_velo_to_cam: shape=(4, 4) + r0_rect: shape=(4, 4) + return: shape=(N, 7) + ''' + x_size, y_size, z_size = bboxes[:, 3:4], bboxes[:, 4:5], bboxes[:, 5:6] + xyz_size = np.concatenate([z_size, x_size, y_size], axis=1) + extended_xyz = np.pad( + bboxes[:, :3], ((0, 0), (0, 1)), 'constant', constant_values=1.0) + rt_mat = np.linalg.inv(r0_rect @ tr_velo_to_cam) + xyz = extended_xyz @ rt_mat.T + bboxes_lidar = np.concatenate( + [xyz[:, :3], xyz_size, bboxes[:, 6:]], axis=1) + return np.array(bboxes_lidar, dtype=np.float32) + + +def bbox_lidar2camera(bboxes, tr_velo_to_cam, r0_rect): + ''' + bboxes: shape=(N, 7) + tr_velo_to_cam: shape=(4, 4) + r0_rect: shape=(4, 4) + return: shape=(N, 7) + ''' + x_size, y_size, z_size = bboxes[:, 3:4], bboxes[:, 4:5], bboxes[:, 5:6] + xyz_size = torch.cat([y_size, z_size, x_size], axis=1) + extended_xyz = torch.nn.functional.pad( + bboxes[:, :3], (0, 1), 'constant', value=1.0) + rt_mat = r0_rect @ tr_velo_to_cam + xyz = extended_xyz @ rt_mat.T + bboxes_camera = torch.cat([xyz[:, :3], xyz_size, bboxes[:, 6:]], axis=1) + return bboxes_camera + + +def bbox3d2corners_camera(bboxes): + ''' + bboxes: shape=(n, 7) + return: shape=(n, 8, 3) + z (front) 6 ------ 5 + / / | / | + / 2 -|---- 1 | + / | | | | + |o ------> x(right) | 7 -----| 4 + | |/ o |/ + | 3 ------ 0 + | + v y(down) + ''' + centers, dims, angles = bboxes[:, :3], bboxes[:, 3:6], bboxes[:, 6] + + # 1.generate bbox corner coordinates, clockwise from minimal point + bboxes_corners = torch.tensor([[0.5, 0.0, -0.5], [0.5, -1.0, -0.5], [-0.5, -1.0, -0.5], [-0.5, 0.0, -0.5], + [0.5, 0.0, 0.5], [0.5, -1.0, 0.5], [-0.5, -1.0, 0.5], [-0.5, 0.0, 0.5]]) + # (1, 8, 3) * (n, 1, 3) -> (n, 8, 3) + bboxes_corners = bboxes_corners[None, :, :] * dims[:, None, :] + + # 2. rotate around y axis + rot_sin, rot_cos = torch.sin(angles), torch.cos(angles) + # in fact, angle + rot_mat = torch.stack([torch.stack([rot_cos, torch.zeros_like(rot_cos), rot_sin]), + torch.stack([torch.zeros_like(rot_cos), torch.ones_like( + rot_cos), torch.zeros_like(rot_cos)]), + torch.stack([-rot_sin, torch.zeros_like(rot_cos), rot_cos])]) # (3, 3, n) + rot_mat = torch.permute(rot_mat, (2, 1, 0)) # (n, 3, 3) + bboxes_corners = bboxes_corners @ rot_mat # (n, 8, 3) + + # 3. translate to centers + bboxes_corners += centers[:, None, :] + return bboxes_corners.clone().detach() + + +def points_camera2image(points, P2): + ''' + points: shape=(N, 8, 3) + P2: shape=(4, 4) + return: shape=(N, 8, 2) + ''' + extended_points = torch.nn.functional.pad( + points, (0, 1), 'constant', value=1.0) # (n, 8, 4) + image_points = extended_points @ P2.T # (N, 8, 4) + image_points = image_points[:, :, :2] / image_points[:, :, 2:3] + return image_points.clone().detach() + + +def keep_bbox_from_image_range( + result, calib_info, num_images, image_info, cam_sync=False): + r0_rect = calib_info['R0_rect'] + lidar_bboxes = result['lidar_bboxes'] + labels = result['labels'] + scores = result['scores'] + total_keep_flag = torch.zeros(lidar_bboxes.size(dim=0)).bool() + for i in range(num_images): + h, w = image_info['camera'][i]['image_shape'] + tr_velo_to_cam = calib_info['Tr_velo_to_cam_' + str(i)] + P = calib_info['P' + str(i)] + camera_bboxes = bbox_lidar2camera( + lidar_bboxes, tr_velo_to_cam, r0_rect) # (n, 7) + if i == 0: + main_camera_bboxes = camera_bboxes.clone() + bboxes_points = bbox3d2corners_camera(camera_bboxes) # (n, 8, 3) + image_points = points_camera2image(bboxes_points, P) # (n, 8, 2) + image_x1y1 = torch.min(image_points, axis=1)[0] # (n, 2) + image_x1y1 = torch.maximum(image_x1y1, torch.tensor(0)) + image_x2y2 = torch.max(image_points, axis=1)[0] # (n, 2) + image_x2y2 = torch.minimum(image_x2y2, torch.tensor([w, h])) + bboxes2d = torch.cat([image_x1y1, image_x2y2], axis=-1) + + keep_flag = (image_x1y1[:, 0] < w) & (image_x1y1[:, 1] < h) & ( + image_x2y2[:, 0] > 0) & (image_x2y2[:, 1] > 0) & (camera_bboxes[:, 2] > 0) + total_keep_flag = total_keep_flag | keep_flag + if cam_sync: + result = { + 'lidar_bboxes': lidar_bboxes[total_keep_flag], + 'labels': labels[total_keep_flag], + 'scores': scores[total_keep_flag], + 'bboxes2d': bboxes2d[total_keep_flag], + 'camera_bboxes': main_camera_bboxes[total_keep_flag] + } + else: + result = { + 'lidar_bboxes': lidar_bboxes, + 'labels': labels, + 'scores': scores, + 'bboxes2d': bboxes2d, + 'camera_bboxes': main_camera_bboxes + } + return result + + +def limit_period(val, offset=0.5, period=np.pi): + """ + val: array or float + offset: float + period: float + return: Value in the range of [-offset * period, (1-offset) * period] + """ + limited_val = val - np.floor(val / period + offset) * period + return limited_val + + +def iou2d(bboxes1, bboxes2, metric=0): + ''' + bboxes1: (n, 4), (x1, y1, x2, y2) + bboxes2: (m, 4), (x1, y1, x2, y2) + return: (n, m) + ''' + rows = len(bboxes1) + cols = len(bboxes2) + if rows * cols == 0: + return torch.empty((rows, cols)) + bboxes_x1 = torch.maximum( + bboxes1[:, 0][:, None], bboxes2[:, 0][None, :]) # (n, m) + bboxes_y1 = torch.maximum( + bboxes1[:, 1][:, None], bboxes2[:, 1][None, :]) # (n, m) + bboxes_x2 = torch.minimum(bboxes1[:, 2][:, None], bboxes2[:, 2][None, :]) + bboxes_y2 = torch.minimum(bboxes1[:, 3][:, None], bboxes2[:, 3][None, :]) + + bboxes_w = torch.clamp(bboxes_x2 - bboxes_x1, min=0) + bboxes_h = torch.clamp(bboxes_y2 - bboxes_y1, min=0) + + iou_area = bboxes_w * bboxes_h # (n, m) + + bboxes1_wh = bboxes1[:, 2:] - bboxes1[:, :2] + area1 = bboxes1_wh[:, 0] * bboxes1_wh[:, 1] # (n, ) + bboxes2_wh = bboxes2[:, 2:] - bboxes2[:, :2] + area2 = bboxes2_wh[:, 0] * bboxes2_wh[:, 1] # (m, ) + if metric == 0: + iou = iou_area / (area1[:, None] + area2[None, :] - iou_area + 1e-8) + elif metric == 1: + iou = iou_area / (area1[:, None] + 1e-8) + return iou + + +def nearest_bev(bboxes): + ''' + bboxes: (n, 7), (x, y, z, w, l, h, theta) + return: (n, 4), (x1, y1, x2, y2) + ''' + bboxes_bev = copy.deepcopy(bboxes[:, [0, 1, 3, 4]]) + bboxes_angle = limit_period( + bboxes[:, 6].cpu(), offset=0.5, period=np.pi).to(bboxes_bev) + bboxes_bev = torch.where(torch.abs( + bboxes_angle[:, None]) > np.pi / 4, bboxes_bev[:, [0, 1, 3, 2]], bboxes_bev) + + bboxes_xy = bboxes_bev[:, :2] + bboxes_wl = bboxes_bev[:, 2:] + bboxes_bev_x1y1x2y2 = torch.cat( + [bboxes_xy - bboxes_wl / 2, bboxes_xy + bboxes_wl / 2], dim=-1) + return bboxes_bev_x1y1x2y2 + + +def iou2d_nearest(bboxes1, bboxes2): + ''' + bboxes1: (n, 7), (x, y, z, w, l, h, theta) + bboxes2: (m, 7), + return: (n, m) + ''' + bboxes1_bev = nearest_bev(bboxes1) + bboxes2_bev = nearest_bev(bboxes2) + iou = iou2d(bboxes1_bev, bboxes2_bev) + return iou + + +def limit_period(val, offset=0.5, period=np.pi): + """ + val: array or float + offset: float + period: float + return: Value in the range of [-offset * period, (1-offset) * period] + """ + limited_val = val - np.floor(val / period + offset) * period + return limited_val + + +def iou3d_camera(bboxes1, bboxes2): + ''' + bboxes1: (n, 7), (x, y, z, w, l, h, theta) + bboxes2: (m, 7) + return: (n, m) + ''' + rows = len(bboxes1) + cols = len(bboxes2) + if rows * cols == 0: + return torch.empty((rows, cols)) + # 1. height overlap + bboxes1_bottom, bboxes2_bottom = bboxes1[:, 1] - \ + bboxes1[:, 4], bboxes2[:, 1] - bboxes2[:, 4] # (n, ), (m, ) + bboxes1_top, bboxes2_top = bboxes1[:, 1], bboxes2[:, 1] # (n, ), (m, ) + bboxes_bottom = torch.maximum( + bboxes1_bottom[:, None], bboxes2_bottom[None, :]) # (n, m) + bboxes_top = torch.minimum(bboxes1_top[:, None], bboxes2_top[None, :]) + height_overlap = torch.clamp(bboxes_top - bboxes_bottom, min=0) + + # 2. bev overlap + bboxes1_x1y1 = bboxes1[:, [0, 2]] - bboxes1[:, [3, 5]] / 2 + bboxes1_x2y2 = bboxes1[:, [0, 2]] + bboxes1[:, [3, 5]] / 2 + bboxes2_x1y1 = bboxes2[:, [0, 2]] - bboxes2[:, [3, 5]] / 2 + bboxes2_x2y2 = bboxes2[:, [0, 2]] + bboxes2[:, [3, 5]] / 2 + bboxes1_bev = torch.cat( + [bboxes1_x1y1, bboxes1_x2y2, bboxes1[:, 6:]], dim=-1) + bboxes2_bev = torch.cat( + [bboxes2_x1y1, bboxes2_x2y2, bboxes2[:, 6:]], dim=-1) + bev_overlap = ( + rotated_box_iou( + bboxes1_bev, + bboxes2_bev)).to( + device=height_overlap.device) # (n, m) + + # 3. overlap and volume + overlap = height_overlap * bev_overlap + volume1 = bboxes1[:, 3] * bboxes1[:, 4] * bboxes1[:, 5] + volume2 = bboxes2[:, 3] * bboxes2[:, 4] * bboxes2[:, 5] + volume = volume1[:, None] + volume2[None, :] # (n, m) + + # 4. iou + iou = overlap / (volume - overlap + 1e-8) + + return iou + + +def boxes_overlap_bev(boxes_a, boxes_b): + """Calculate boxes Overlap in the bird view. + + Args: + boxes_a (torch.Tensor): Input boxes a with shape (M, 5). + boxes_b (torch.Tensor): Input boxes b with shape (N, 5). + + Returns: + ans_overlap (torch.Tensor): Overlap result with shape (M, N). + """ + ans_overlap = boxes_a.new_zeros( + torch.Size((boxes_a.shape[0], boxes_b.shape[0]))) + if ans_overlap.size(0) * ans_overlap.size(1) == 0: + return ans_overlap + boxes_overlap_bev_gpu( + boxes_a.contiguous(), + boxes_b.contiguous(), + ans_overlap) + + return ans_overlap + + +def rotated_box_iou(boxes1, boxes2): + """ + Calculates IoU for rotated bounding boxes. + + Args: + boxes1 (torch.Tensor): Tensor of shape (N, 5) representing rotated boxes in format (x_center, y_center, width, height, angle). + boxes2 (torch.Tensor): Tensor of shape (M, 5) representing rotated boxes in the same format. + + Returns: + torch.Tensor: IoU matrix of shape (N, M). + """ + + # Convert boxes to polygons + polygons1 = boxes_to_polygons(boxes1) + polygons2 = boxes_to_polygons(boxes2) + + # Calculate IoU for each pair of polygons + ious = torch.zeros((boxes1.shape[0], boxes2.shape[0])) + overlaps = torch.zeros((boxes1.shape[0], boxes2.shape[0])) + for i in range(boxes1.shape[0]): + for j in range(boxes2.shape[0]): + intersection = polygon_intersection(polygons1[i], polygons2[j]) + union = polygon_union(polygons1[i], polygons2[j]) + ious[i, j] = intersection / union + overlaps[i, j] = intersection + + return overlaps + + +def boxes_to_polygons(boxes): + # Implementation to convert boxes to polygons + polygons = [] + for box in boxes: + x_min = box[0] + y_min = box[1] + x_max = box[2] + y_max = box[3] + polygon = shapely.geometry.Polygon( + [(x_min, y_min), (x_max, y_min), (x_max, y_max), (x_min, y_max)]) + polygon = shapely.affinity.rotate( + polygon, -1 * box[4], use_radians=True) + polygons.append(polygon) + return polygons + + +def polygon_intersection(polygon1, polygon2): + return shapely.intersection(polygon1, polygon2).area + + +def polygon_union(polygon1, polygon2): + # Implementation to calculate union area of polygons + return shapely.union(polygon1, polygon2).area diff --git a/automotive/3d-object-detection/user.conf b/automotive/3d-object-detection/user.conf new file mode 100644 index 000000000..07a10bbe2 --- /dev/null +++ b/automotive/3d-object-detection/user.conf @@ -0,0 +1,4 @@ +# The format of this config file is 'key = value'. +# The key has the format 'model.scenario.key'. Value is mostly int64_t. +# Model maybe '*' as wildcard. In that case the value applies to all models. +# All times are in milli seconds diff --git a/automotive/3d-object-detection/waymo.py b/automotive/3d-object-detection/waymo.py new file mode 100644 index 000000000..5818ed01e --- /dev/null +++ b/automotive/3d-object-detection/waymo.py @@ -0,0 +1,247 @@ +""" +implementation of waymo dataset +""" + +# pylint: disable=unused-argument,missing-docstring + +import json +import logging +import os +import time + +from PIL import Image +import numpy as np +import pandas as pd +import dataset +import os +import pickle +import torch +from torchvision import transforms +import tools.process as process + + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("waymo") + + +def read_pickle(file_path, suffix='.pkl'): + assert os.path.splitext(file_path)[1] == suffix + with open(file_path, 'rb') as f: + data = pickle.load(f) + return data + + +def point_range_filter(data_dict, point_range): + ''' + data_dict: dict(pts, gt_bboxes_3d, gt_labels, gt_names, difficulty) + point_range: [x1, y1, z1, x2, y2, z2] + ''' + pts = data_dict['pts'] + flag_x_low = pts[:, 0] > point_range[0] + flag_y_low = pts[:, 1] > point_range[1] + flag_z_low = pts[:, 2] > point_range[2] + flag_x_high = pts[:, 0] < point_range[3] + flag_y_high = pts[:, 1] < point_range[4] + flag_z_high = pts[:, 2] < point_range[5] + keep_mask = flag_x_low & flag_y_low & flag_z_low & flag_x_high & flag_y_high & flag_z_high + pts = pts[keep_mask] + data_dict.update({'pts': pts}) + return data_dict + + +def read_points(file_path, dim=4): + suffix = os.path.splitext(file_path)[1] + assert suffix in ['.bin', '.ply', '.npy'] + if suffix == '.bin': + return np.fromfile(file_path, dtype=np.float32).reshape(-1, dim) + elif suffix == '.npy': + return np.load(file_path).astype(np.float32) + else: + raise NotImplementedError + + +class Waymo(dataset.Dataset): + CLASSES = { + 'Pedestrian': 0, + 'Cyclist': 1, + 'Car': 2 + } + + def __init__(self, data_root, split, + pts_prefix='velodyne_reduced', painted=True, cam_sync=False): + super().__init__() + assert split in ['train', 'val', 'trainval', 'test'] + self.data_root = data_root + self.split = split + self.pts_prefix = pts_prefix + info_file = f'waymo_infos_val.pkl' + self.painted = painted + self.cam_sync = cam_sync + self.point_range_filter = [-74.88, -74.88, -2, 74.88, 74.88, 4] + if painted or cam_sync: + info_file = f'painted_waymo_infos_{split}.pkl' + else: + info_file = f'waymo_infos_{split}.pkl' + self.data_infos = read_pickle(os.path.join(data_root, info_file)) + self.sorted_ids = range(len(self.data_infos)) + + def preprocess(self, input): + image_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize( + mean=[ + 0.485, 0.456, 0.406], std=[ + 0.229, 0.224, 0.225]), + ]) + input_images = input['images'] + for i in range(len(input_images)): + input['images'][i] = image_transform( + input['images'][i]).unsqueeze(0) + return input + + def get_list(self): + raise NotImplementedError("Dataset:get_list") + + def load_query_samples(self, sample_list): + # TODO: Load queries into memory, if needed + pass + + def unload_query_samples(self, sample_list): + # TODO: Unload queries from memory, if needed + pass + + def get_samples(self, id_list): + data = [] + labels = [] + for id in id_list: + item = self.get_item(id) + data.append({'pts': item['pts'], + 'images': item['images'], + 'calib_info': item['calib_info'], + 'image_info': item['image_info']}) + labels.append({'gt_labels': item['gt_labels'], + 'calib_info': item['calib_info'], + 'gt_names': item['gt_names'], + }) + return data, labels + + def get_item(self, id): + data_info = self.data_infos[self.sorted_ids[id]] + image_info, calib_info, annos_info = \ + data_info['image'], data_info['calib'].copy(), data_info['annos'] + # point cloud input + velodyne_path = data_info['point_cloud']['velodyne_path'] + pts_path = os.path.join(self.data_root, velodyne_path) + if self.cam_sync: + annos_info = data_info['cam_sync_annos'] + pts = read_points(pts_path, 6) + pts = pts[:, :5] + + # calib input: for bbox coordinates transformation between Camera and Lidar. + # because + tr_velo_to_cam = calib_info['Tr_velo_to_cam_0'].astype(np.float32) + r0_rect = calib_info['R0_rect'].astype(np.float32) + for key in calib_info.keys(): + calib_info[key] = torch.from_numpy( + calib_info[key]).type( + torch.float32) + + # annotations input + annos_info = self.remove_dont_care(annos_info) + annos_name = annos_info['name'] + annos_location = annos_info['location'] + annos_dimension = annos_info['dimensions'] + rotation_y = annos_info['rotation_y'] + gt_bboxes = np.concatenate( + [annos_location, annos_dimension, rotation_y[:, None]], axis=1).astype(np.float32) + gt_bboxes_3d = process.bbox_camera2lidar( + gt_bboxes, tr_velo_to_cam, r0_rect) + gt_labels = [self.CLASSES.get(name, -1) for name in annos_name] + data_dict = { + 'pts': torch.from_numpy(pts), + 'gt_bboxes_3d': torch.from_numpy(gt_bboxes_3d), + 'gt_labels': torch.from_numpy(np.array(gt_labels)), + 'gt_names': annos_name, + 'difficulty': annos_info['difficulty'], + 'image_info': image_info, + 'calib_info': calib_info + } + data_dict = point_range_filter( + data_dict, point_range=self.point_range_filter) + images = [] + for i in range(5): + image = self.get_image( + image_info['image_idx'], + 'image_' + str(i) + '/') + images.append(image) + data_dict['images'] = images + + return data_dict + + def remove_dont_care(self, annos_info): + keep_ids = [ + i for i, name in enumerate( + annos_info['name']) if name != 'DontCare'] + for k, v in annos_info.items(): + annos_info[k] = v[keep_ids] + return annos_info + + def get_image(self, idx, camera): + filename = os.path.join( + self.data_root, 'training', camera + ('%s.jpg' % idx)) + input_image = Image.open(filename) + preprocess = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize( + mean=[ + 0.485, 0.456, 0.406], std=[ + 0.229, 0.224, 0.225]), + ]) + + input_tensor = preprocess(input_image) + input_batch = input_tensor.unsqueeze(0) + return input_batch + + +class PostProcessWaymo: + def __init__( + self, # Postprocess parameters + ): + self.content_ids = [] + # TODO: Init Postprocess parameters + self.results = [] + + def add_results(self, results): + self.results.extend(results) + + def __call__(self, results, content_id, inputs, result_dict): + self.content_ids.extend(content_id) + processed_results = [] + for idx in range(len(content_id)): + processed_results.append([]) + detection_num = len(results[0][idx]) + for detection in range(0, detection_num): + processed_results[idx].append([ + results[0][idx][detection][0], + results[0][idx][detection][1], + results[0][idx][detection][2], + results[1][idx][detection][0], + results[1][idx][detection][1], + results[1][idx][detection][2], + results[2][idx][detection], + results[3][idx][detection][0], + results[3][idx][detection][1], + results[3][idx][detection][2], + results[3][idx][detection][3], + results[4][idx][detection], + results[5][idx][detection], + results[6][idx] + ]) + return processed_results + + def start(self): + self.results = [] + + def finalize(self, result_dict, ds=None): + + return result_dict