-
Notifications
You must be signed in to change notification settings - Fork 2.9k
/
Copy pathobject_detector.cc
592 lines (536 loc) · 22.5 KB
/
object_detector.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sstream>
// for setprecision
#include <chrono>
#include <iomanip>
#include "include/object_detector.h"
namespace PaddleDetection {
// Load Model and create model predictor
void ObjectDetector::LoadModel(const std::string &model_dir,
const int batch_size,
const std::string &run_mode) {
paddle_infer::Config config;
std::string prog_file = model_dir + OS_PATH_SEP + "model.pdmodel";
std::string params_file = model_dir + OS_PATH_SEP + "model.pdiparams";
config.SetModel(prog_file, params_file);
if (this->device_ == "GPU") {
config.EnableUseGpu(200, this->gpu_id_);
config.SwitchIrOptim(true);
// use tensorrt
if (run_mode != "paddle") {
auto precision = paddle_infer::Config::Precision::kFloat32;
if (run_mode == "trt_fp32") {
precision = paddle_infer::Config::Precision::kFloat32;
} else if (run_mode == "trt_fp16") {
precision = paddle_infer::Config::Precision::kHalf;
} else if (run_mode == "trt_int8") {
precision = paddle_infer::Config::Precision::kInt8;
} else {
printf("run_mode should be 'paddle', 'trt_fp32', 'trt_fp16' or "
"'trt_int8'");
}
// set tensorrt
config.EnableTensorRtEngine(1 << 30, batch_size, this->min_subgraph_size_,
precision, false, this->trt_calib_mode_);
// set use dynamic shape
if (this->use_dynamic_shape_) {
// set DynamicShape for image tensor
const std::vector<int> min_input_shape = {
batch_size, 3, this->trt_min_shape_, this->trt_min_shape_};
const std::vector<int> max_input_shape = {
batch_size, 3, this->trt_max_shape_, this->trt_max_shape_};
const std::vector<int> opt_input_shape = {
batch_size, 3, this->trt_opt_shape_, this->trt_opt_shape_};
const std::map<std::string, std::vector<int>> map_min_input_shape = {
{"image", min_input_shape}};
const std::map<std::string, std::vector<int>> map_max_input_shape = {
{"image", max_input_shape}};
const std::map<std::string, std::vector<int>> map_opt_input_shape = {
{"image", opt_input_shape}};
config.SetTRTDynamicShapeInfo(map_min_input_shape, map_max_input_shape,
map_opt_input_shape);
std::cout << "TensorRT dynamic shape enabled" << std::endl;
}
}
} else if (this->device_ == "XPU") {
config.EnableXpu(10 * 1024 * 1024);
} else {
config.DisableGpu();
if (this->use_mkldnn_) {
config.EnableMKLDNN();
// cache 10 different shapes for mkldnn to avoid memory leak
config.SetMkldnnCacheCapacity(10);
}
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
}
config.SwitchUseFeedFetchOps(false);
config.SwitchIrOptim(true);
config.DisableGlogInfo();
// Memory optimization
config.EnableMemoryOptim();
predictor_ = std::move(CreatePredictor(config));
}
// Visualiztion MaskDetector results
cv::Mat
VisualizeResult(const cv::Mat &img,
const std::vector<PaddleDetection::ObjectResult> &results,
const std::vector<std::string> &lables,
const std::vector<int> &colormap, const bool is_rbox = false) {
cv::Mat vis_img = img.clone();
int img_h = vis_img.rows;
int img_w = vis_img.cols;
for (int i = 0; i < results.size(); ++i) {
// Configure color and text size
std::ostringstream oss;
oss << std::setiosflags(std::ios::fixed) << std::setprecision(4);
oss << lables[results[i].class_id] << " ";
oss << results[i].confidence;
std::string text = oss.str();
int c1 = colormap[3 * results[i].class_id + 0];
int c2 = colormap[3 * results[i].class_id + 1];
int c3 = colormap[3 * results[i].class_id + 2];
cv::Scalar roi_color = cv::Scalar(c1, c2, c3);
int font_face = cv::FONT_HERSHEY_COMPLEX_SMALL;
double font_scale = 0.5f;
float thickness = 0.5;
cv::Size text_size =
cv::getTextSize(text, font_face, font_scale, thickness, nullptr);
cv::Point origin;
if (is_rbox) {
// Draw object, text, and background
for (int k = 0; k < 4; k++) {
cv::Point pt1 = cv::Point(results[i].rect[(k * 2) % 8],
results[i].rect[(k * 2 + 1) % 8]);
cv::Point pt2 = cv::Point(results[i].rect[(k * 2 + 2) % 8],
results[i].rect[(k * 2 + 3) % 8]);
cv::line(vis_img, pt1, pt2, roi_color, 2);
}
} else {
int w = results[i].rect[2] - results[i].rect[0];
int h = results[i].rect[3] - results[i].rect[1];
cv::Rect roi = cv::Rect(results[i].rect[0], results[i].rect[1], w, h);
// Draw roi object, text, and background
cv::rectangle(vis_img, roi, roi_color, 2);
// Draw mask
std::vector<int> mask_v = results[i].mask;
if (mask_v.size() > 0) {
cv::Mat mask = cv::Mat(img_h, img_w, CV_32S);
std::memcpy(mask.data, mask_v.data(), mask_v.size() * sizeof(int));
cv::Mat colored_img = vis_img.clone();
std::vector<cv::Mat> contours;
cv::Mat hierarchy;
mask.convertTo(mask, CV_8U);
cv::findContours(mask, contours, hierarchy, cv::RETR_CCOMP,
cv::CHAIN_APPROX_SIMPLE);
cv::drawContours(colored_img, contours, -1, roi_color, -1, cv::LINE_8,
hierarchy, 100);
cv::Mat debug_roi = vis_img;
colored_img = 0.4 * colored_img + 0.6 * vis_img;
colored_img.copyTo(vis_img, mask);
}
}
origin.x = results[i].rect[0];
origin.y = results[i].rect[1];
// Configure text background
cv::Rect text_back =
cv::Rect(results[i].rect[0], results[i].rect[1] - text_size.height,
text_size.width, text_size.height);
// Draw text, and background
cv::rectangle(vis_img, text_back, roi_color, -1);
cv::putText(vis_img, text, origin, font_face, font_scale,
cv::Scalar(255, 255, 255), thickness);
}
return vis_img;
}
void ObjectDetector::Preprocess(const cv::Mat &ori_im) {
// Clone the image : keep the original mat for postprocess
cv::Mat im = ori_im.clone();
cv::cvtColor(im, im, cv::COLOR_BGR2RGB);
preprocessor_.Run(&im, &inputs_);
}
void ObjectDetector::Postprocess(
const std::vector<cv::Mat> mats,
std::vector<PaddleDetection::ObjectResult> *result,
std::vector<int> bbox_num, std::vector<float> output_data_,
std::vector<int> output_mask_data_, bool is_rbox = false) {
result->clear();
int start_idx = 0;
int total_num = std::accumulate(bbox_num.begin(), bbox_num.end(), 0);
int out_mask_dim = -1;
if (config_.mask_) {
out_mask_dim = output_mask_data_.size() / total_num;
}
for (int im_id = 0; im_id < mats.size(); im_id++) {
cv::Mat raw_mat = mats[im_id];
int rh = 1;
int rw = 1;
for (int j = start_idx; j < start_idx + bbox_num[im_id]; j++) {
if (is_rbox) {
// Class id
int class_id = static_cast<int>(round(output_data_[0 + j * 10]));
// Confidence score
float score = output_data_[1 + j * 10];
int x1 = (output_data_[2 + j * 10] * rw);
int y1 = (output_data_[3 + j * 10] * rh);
int x2 = (output_data_[4 + j * 10] * rw);
int y2 = (output_data_[5 + j * 10] * rh);
int x3 = (output_data_[6 + j * 10] * rw);
int y3 = (output_data_[7 + j * 10] * rh);
int x4 = (output_data_[8 + j * 10] * rw);
int y4 = (output_data_[9 + j * 10] * rh);
PaddleDetection::ObjectResult result_item;
result_item.rect = {x1, y1, x2, y2, x3, y3, x4, y4};
result_item.class_id = class_id;
result_item.confidence = score;
result->push_back(result_item);
} else {
// Class id
int class_id = static_cast<int>(round(output_data_[0 + j * 6]));
// Confidence score
float score = output_data_[1 + j * 6];
int xmin = (output_data_[2 + j * 6] * rw);
int ymin = (output_data_[3 + j * 6] * rh);
int xmax = (output_data_[4 + j * 6] * rw);
int ymax = (output_data_[5 + j * 6] * rh);
int wd = xmax - xmin;
int hd = ymax - ymin;
PaddleDetection::ObjectResult result_item;
result_item.rect = {xmin, ymin, xmax, ymax};
result_item.class_id = class_id;
result_item.confidence = score;
if (config_.mask_) {
std::vector<int> mask;
for (int k = 0; k < out_mask_dim; ++k) {
if (output_mask_data_[k + j * out_mask_dim] > -1) {
mask.push_back(output_mask_data_[k + j * out_mask_dim]);
}
}
result_item.mask = mask;
}
result->push_back(result_item);
}
}
start_idx += bbox_num[im_id];
}
}
// This function is to convert output result from SOLOv2 to class ObjectResult
void ObjectDetector::SOLOv2Postprocess(
const std::vector<cv::Mat> mats, std::vector<ObjectResult> *result,
std::vector<int> *bbox_num, std::vector<int> out_bbox_num_data_,
std::vector<int64_t> out_label_data_, std::vector<float> out_score_data_,
std::vector<uint8_t> out_global_mask_data_, float threshold) {
for (int im_id = 0; im_id < mats.size(); im_id++) {
cv::Mat mat = mats[im_id];
int valid_bbox_count = 0;
for (int bbox_id = 0; bbox_id < out_bbox_num_data_[im_id]; ++bbox_id) {
if (out_score_data_[bbox_id] >= threshold) {
ObjectResult result_item;
result_item.class_id = out_label_data_[bbox_id];
result_item.confidence = out_score_data_[bbox_id];
std::vector<int> global_mask;
for (int k = 0; k < mat.rows * mat.cols; ++k) {
global_mask.push_back(static_cast<int>(
out_global_mask_data_[k + bbox_id * mat.rows * mat.cols]));
}
// find minimize bounding box from mask
cv::Mat mask(mat.rows, mat.cols, CV_32SC1);
std::memcpy(mask.data, global_mask.data(),
global_mask.size() * sizeof(int));
cv::Mat mask_fp;
cv::Mat rowSum;
cv::Mat colSum;
std::vector<float> sum_of_row(mat.rows);
std::vector<float> sum_of_col(mat.cols);
mask.convertTo(mask_fp, CV_32FC1);
cv::reduce(mask_fp, colSum, 0, CV_REDUCE_SUM, CV_32FC1);
cv::reduce(mask_fp, rowSum, 1, CV_REDUCE_SUM, CV_32FC1);
for (int row_id = 0; row_id < mat.rows; ++row_id) {
sum_of_row[row_id] = rowSum.at<float>(row_id, 0);
}
for (int col_id = 0; col_id < mat.cols; ++col_id) {
sum_of_col[col_id] = colSum.at<float>(0, col_id);
}
auto it = std::find_if(sum_of_row.begin(), sum_of_row.end(),
[](int x) { return x > 0.5; });
int y1 = std::distance(sum_of_row.begin(), it);
auto it2 = std::find_if(sum_of_col.begin(), sum_of_col.end(),
[](int x) { return x > 0.5; });
int x1 = std::distance(sum_of_col.begin(), it2);
auto rit = std::find_if(sum_of_row.rbegin(), sum_of_row.rend(),
[](int x) { return x > 0.5; });
int y2 = std::distance(rit, sum_of_row.rend());
auto rit2 = std::find_if(sum_of_col.rbegin(), sum_of_col.rend(),
[](int x) { return x > 0.5; });
int x2 = std::distance(rit2, sum_of_col.rend());
result_item.rect = {x1, y1, x2, y2};
result_item.mask = global_mask;
result->push_back(result_item);
valid_bbox_count++;
}
}
bbox_num->push_back(valid_bbox_count);
}
}
void ObjectDetector::Predict(const std::vector<cv::Mat> imgs,
const double threshold, const int warmup,
const int repeats,
std::vector<PaddleDetection::ObjectResult> *result,
std::vector<int> *bbox_num,
std::vector<double> *times) {
auto preprocess_start = std::chrono::steady_clock::now();
int batch_size = imgs.size();
// in_data_batch
std::vector<float> in_data_all;
std::vector<float> im_shape_all(batch_size * 2);
std::vector<float> scale_factor_all(batch_size * 2);
std::vector<const float *> output_data_list_;
std::vector<int> out_bbox_num_data_;
std::vector<int> out_mask_data_;
// these parameters are for SOLOv2 output
std::vector<float> out_score_data_;
std::vector<uint8_t> out_global_mask_data_;
std::vector<int64_t> out_label_data_;
// in_net img for each batch
std::vector<cv::Mat> in_net_img_all(batch_size);
// Preprocess image
for (int bs_idx = 0; bs_idx < batch_size; bs_idx++) {
cv::Mat im = imgs.at(bs_idx);
Preprocess(im);
im_shape_all[bs_idx * 2] = inputs_.im_shape_[0];
im_shape_all[bs_idx * 2 + 1] = inputs_.im_shape_[1];
scale_factor_all[bs_idx * 2] = inputs_.scale_factor_[0];
scale_factor_all[bs_idx * 2 + 1] = inputs_.scale_factor_[1];
in_data_all.insert(in_data_all.end(), inputs_.im_data_.begin(),
inputs_.im_data_.end());
// collect in_net img
in_net_img_all[bs_idx] = inputs_.in_net_im_;
}
// Pad Batch if batch size > 1
if (batch_size > 1 && CheckDynamicInput(in_net_img_all)) {
in_data_all.clear();
std::vector<cv::Mat> pad_img_all = PadBatch(in_net_img_all);
int rh = pad_img_all[0].rows;
int rw = pad_img_all[0].cols;
int rc = pad_img_all[0].channels();
for (int bs_idx = 0; bs_idx < batch_size; bs_idx++) {
cv::Mat pad_img = pad_img_all[bs_idx];
pad_img.convertTo(pad_img, CV_32FC3);
std::vector<float> pad_data;
pad_data.resize(rc * rh * rw);
float *base = pad_data.data();
for (int i = 0; i < rc; ++i) {
cv::extractChannel(pad_img,
cv::Mat(rh, rw, CV_32FC1, base + i * rh * rw), i);
}
in_data_all.insert(in_data_all.end(), pad_data.begin(), pad_data.end());
}
// update in_net_shape
inputs_.in_net_shape_ = {static_cast<float>(rh), static_cast<float>(rw)};
}
auto preprocess_end = std::chrono::steady_clock::now();
// Prepare input tensor
auto input_names = predictor_->GetInputNames();
for (const auto &tensor_name : input_names) {
auto in_tensor = predictor_->GetInputHandle(tensor_name);
if (tensor_name == "image") {
int rh = inputs_.in_net_shape_[0];
int rw = inputs_.in_net_shape_[1];
in_tensor->Reshape({batch_size, 3, rh, rw});
in_tensor->CopyFromCpu(in_data_all.data());
} else if (tensor_name == "im_shape") {
in_tensor->Reshape({batch_size, 2});
in_tensor->CopyFromCpu(im_shape_all.data());
} else if (tensor_name == "scale_factor") {
in_tensor->Reshape({batch_size, 2});
in_tensor->CopyFromCpu(scale_factor_all.data());
}
}
// Run predictor
std::vector<std::vector<float>> out_tensor_list;
std::vector<std::vector<int>> output_shape_list;
bool is_rbox = false;
int reg_max = 7;
int num_class = 80;
auto inference_start = std::chrono::steady_clock::now();
if (config_.arch_ == "SOLOv2") {
// warmup
for (int i = 0; i < warmup; i++) {
predictor_->Run();
// Get output tensor
auto output_names = predictor_->GetOutputNames();
for (int j = 0; j < output_names.size(); j++) {
auto output_tensor = predictor_->GetOutputHandle(output_names[j]);
std::vector<int> output_shape = output_tensor->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(),
1, std::multiplies<int>());
if (j == 0) {
out_bbox_num_data_.resize(out_num);
output_tensor->CopyToCpu(out_bbox_num_data_.data());
} else if (j == 1) {
out_label_data_.resize(out_num);
output_tensor->CopyToCpu(out_label_data_.data());
} else if (j == 2) {
out_score_data_.resize(out_num);
output_tensor->CopyToCpu(out_score_data_.data());
} else if (config_.mask_ && (j == 3)) {
out_global_mask_data_.resize(out_num);
output_tensor->CopyToCpu(out_global_mask_data_.data());
}
}
}
inference_start = std::chrono::steady_clock::now();
for (int i = 0; i < repeats; i++) {
predictor_->Run();
// Get output tensor
out_tensor_list.clear();
output_shape_list.clear();
auto output_names = predictor_->GetOutputNames();
for (int j = 0; j < output_names.size(); j++) {
auto output_tensor = predictor_->GetOutputHandle(output_names[j]);
std::vector<int> output_shape = output_tensor->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(),
1, std::multiplies<int>());
output_shape_list.push_back(output_shape);
if (j == 0) {
out_bbox_num_data_.resize(out_num);
output_tensor->CopyToCpu(out_bbox_num_data_.data());
} else if (j == 1) {
out_label_data_.resize(out_num);
output_tensor->CopyToCpu(out_label_data_.data());
} else if (j == 2) {
out_score_data_.resize(out_num);
output_tensor->CopyToCpu(out_score_data_.data());
} else if (config_.mask_ && (j == 3)) {
out_global_mask_data_.resize(out_num);
output_tensor->CopyToCpu(out_global_mask_data_.data());
}
}
}
} else {
// warmup
for (int i = 0; i < warmup; i++) {
predictor_->Run();
// Get output tensor
auto output_names = predictor_->GetOutputNames();
for (int j = 0; j < output_names.size(); j++) {
auto output_tensor = predictor_->GetOutputHandle(output_names[j]);
std::vector<int> output_shape = output_tensor->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(),
1, std::multiplies<int>());
if (config_.mask_ && (j == 2)) {
out_mask_data_.resize(out_num);
output_tensor->CopyToCpu(out_mask_data_.data());
} else if (output_tensor->type() == paddle_infer::DataType::INT32) {
out_bbox_num_data_.resize(out_num);
output_tensor->CopyToCpu(out_bbox_num_data_.data());
} else {
std::vector<float> out_data;
out_data.resize(out_num);
output_tensor->CopyToCpu(out_data.data());
out_tensor_list.push_back(out_data);
}
}
}
inference_start = std::chrono::steady_clock::now();
for (int i = 0; i < repeats; i++) {
predictor_->Run();
// Get output tensor
out_tensor_list.clear();
output_shape_list.clear();
auto output_names = predictor_->GetOutputNames();
for (int j = 0; j < output_names.size(); j++) {
auto output_tensor = predictor_->GetOutputHandle(output_names[j]);
std::vector<int> output_shape = output_tensor->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(),
1, std::multiplies<int>());
output_shape_list.push_back(output_shape);
if (config_.mask_ && (j == 2)) {
out_mask_data_.resize(out_num);
output_tensor->CopyToCpu(out_mask_data_.data());
} else if (output_tensor->type() == paddle_infer::DataType::INT32) {
out_bbox_num_data_.resize(out_num);
output_tensor->CopyToCpu(out_bbox_num_data_.data());
} else {
std::vector<float> out_data;
out_data.resize(out_num);
output_tensor->CopyToCpu(out_data.data());
out_tensor_list.push_back(out_data);
}
}
}
}
auto inference_end = std::chrono::steady_clock::now();
auto postprocess_start = std::chrono::steady_clock::now();
// Postprocessing result
result->clear();
bbox_num->clear();
if (config_.arch_ == "PicoDet") {
for (int i = 0; i < out_tensor_list.size(); i++) {
if (i == 0) {
num_class = output_shape_list[i][2];
}
if (i == config_.fpn_stride_.size()) {
reg_max = output_shape_list[i][2] / 4 - 1;
}
float *buffer = new float[out_tensor_list[i].size()];
memcpy(buffer, &out_tensor_list[i][0],
out_tensor_list[i].size() * sizeof(float));
output_data_list_.push_back(buffer);
}
PaddleDetection::PicoDetPostProcess(
result, output_data_list_, config_.fpn_stride_, inputs_.im_shape_,
inputs_.scale_factor_, config_.nms_info_["score_threshold"].as<float>(),
config_.nms_info_["nms_threshold"].as<float>(), num_class, reg_max);
bbox_num->push_back(result->size());
} else if (config_.arch_ == "SOLOv2") {
SOLOv2Postprocess(imgs, result, bbox_num, out_bbox_num_data_,
out_label_data_, out_score_data_, out_global_mask_data_,
threshold);
} else {
is_rbox = output_shape_list[0][output_shape_list[0].size() - 1] % 10 == 0;
Postprocess(imgs, result, out_bbox_num_data_, out_tensor_list[0],
out_mask_data_, is_rbox);
for (int k = 0; k < out_bbox_num_data_.size(); k++) {
int tmp = out_bbox_num_data_[k];
bbox_num->push_back(tmp);
}
}
auto postprocess_end = std::chrono::steady_clock::now();
std::chrono::duration<float> preprocess_diff =
preprocess_end - preprocess_start;
times->push_back(static_cast<double>(preprocess_diff.count() * 1000));
std::chrono::duration<float> inference_diff = inference_end - inference_start;
times->push_back(
static_cast<double>(inference_diff.count() / repeats * 1000));
std::chrono::duration<float> postprocess_diff =
postprocess_end - postprocess_start;
times->push_back(static_cast<double>(postprocess_diff.count() * 1000));
}
std::vector<int> GenerateColorMap(int num_class) {
auto colormap = std::vector<int>(3 * num_class, 0);
for (int i = 0; i < num_class; ++i) {
int j = 0;
int lab = i;
while (lab) {
colormap[i * 3] |= (((lab >> 0) & 1) << (7 - j));
colormap[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j));
colormap[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j));
++j;
lab >>= 3;
}
}
return colormap;
}
} // namespace PaddleDetection