-
Notifications
You must be signed in to change notification settings - Fork 0
/
helpers.py
82 lines (71 loc) · 3.48 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def get_remote_file_names(data_directory):
import subprocess
list_command = "gsutil ls " + data_directory
file_names_string_output = subprocess.check_output(list_command, shell=True).decode("utf-8")
file_names_list = file_names_string_output.split("\n")[:-1]
print("There are",len(file_names_list),"files to read")
start_index = file_names_list.index("gs://waymo_open_dataset_v_1_2_0_individual_files/testing/segment-11987368976578218644_1340_000_1360_000_with_camera_labels.tfrecord")
print(start_index)
return file_names_list[start_index:]
def get_remote_file(remote_file, target_folder):
import subprocess
import os
cp_command = "gsutil cp " + remote_file + " " + target_folder
file_to_copy = target_folder + "/" + os.path.basename(remote_file)
if not os.path.exists(file_to_copy):
print("Copying", file_to_copy)
subprocess.call(cp_command, shell=True)
return file_to_copy
def create_prediction(output_file, context, image, bboxes, probs,ratio):
from waymo_open_dataset import dataset_pb2
from waymo_open_dataset import label_pb2
from waymo_open_dataset.protos import metrics_pb2
objects = metrics_pb2.Objects()
file = open(output_file, "rb")
objects.ParseFromString(file.read())
file.close()
for object_type, boxes in bboxes.items():
for i, bbox in enumerate(boxes):
o = metrics_pb2.Object()
# The following 3 fields are used to uniquely identify a frame a prediction
# is predicted at. Make sure you set them to values exactly the same as what
# we provided in the raw data. Otherwise your prediction is considered as a
# false negative.
o.context_name = context.name
# The frame timestamp for the prediction. See Frame::timestamp_micros in
# dataset.proto.
o.frame_timestamp_micros = int(image.pose_timestamp)
# This is only needed for 2D detection or tracking tasks.
# Set it to the camera name the prediction is for.
o.camera_name = image.name
# Populating box and score.
box = label_pb2.Label.Box()
box.center_x = int((bbox[0] + bbox[2]/2)/ratio)
box.center_y = int((bbox[1] + bbox[3]/2)/ratio)
box.center_z = 0
box.length = int(bbox[3]/ratio)
box.width = int(bbox[2]/ratio)
box.height = 0
box.heading = 0
o.object.box.CopyFrom(box)
# This must be within [0.0, 1.0]. It is better to filter those boxes with
# small scores to speed up metrics computation.
o.score = probs[object_type][i]
# Use correct type.
if object_type == "TYPE_VEHICLE":
o.object.type = label_pb2.Label.TYPE_VEHICLE
elif object_type == "TYPE_CYCLIST":
o.object.type = label_pb2.Label.TYPE_CYCLIST
elif object_type == "TYPE_PEDESTRIAN":
o.object.type = label_pb2.Label.TYPE_PEDESTRIAN
else:
raise "Invalid Object Type"
objects.objects.append(o)
# Add more objects. Note that a reasonable detector should limit its maximum
# number of boxes predicted per frame. A reasonable value is around 400. A
# huge number of boxes can slow down metrics computation.
# Write objects to a file.
file = open(output_file, 'wb')
file.write(objects.SerializeToString())
file.close()
return