-
Notifications
You must be signed in to change notification settings - Fork 9
/
run_detection.py
148 lines (117 loc) · 4.28 KB
/
run_detection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import argparse
import sys
from pathlib import Path
import cv2
import numpy as np
import openvino.runtime as ov
import openvino_xai as xai
from openvino_xai.common.utils import logger
from openvino_xai.explainer.explainer import ExplainMode
from openvino_xai.methods.black_box.base import Preset
def get_argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument("model_path")
parser.add_argument("image_path")
parser.add_argument("--output", default=None, type=str)
return parser
def preprocess_fn(x: np.ndarray) -> np.ndarray:
# TODO: make sure it is correct
# x = cv2.resize(src=x, dsize=(416, 416)) # OTX YOLOX
x = cv2.resize(src=x, dsize=(992, 736)) # OTX ATSS
x = x.transpose((2, 0, 1))
x = np.expand_dims(x, 0)
return x
def postprocess_fn(x) -> np.ndarray:
"""Returns boxes, scores, labels."""
return x["boxes"][:, :, :4], x["boxes"][:, :, 4], x["labels"]
def explain_white_box(args):
"""
White-box scenario.
Per-class saliency map generation for single-stage detection models (using DetClassProbabilityMap).
Insertion of the XAI branch into the model, thus model has additional 'saliency_map' output.
"""
# Create ov.Model
model: ov.Model
model = ov.Core().read_model(args.model_path)
# # OTX YOLOX
# cls_head_output_node_names = [
# "/bbox_head/multi_level_conv_cls.0/Conv/WithoutBiases",
# "/bbox_head/multi_level_conv_cls.1/Conv/WithoutBiases",
# "/bbox_head/multi_level_conv_cls.2/Conv/WithoutBiases",
# ]
# OTX ATSS
cls_head_output_node_names = [
"/bbox_head/atss_cls_1/Conv/WithoutBiases",
"/bbox_head/atss_cls_2/Conv/WithoutBiases",
"/bbox_head/atss_cls_3/Conv/WithoutBiases",
"/bbox_head/atss_cls_4/Conv/WithoutBiases",
]
# Create explainer object
explainer = xai.Explainer(
model=model,
task=xai.Task.DETECTION,
preprocess_fn=preprocess_fn,
explain_mode=ExplainMode.WHITEBOX, # defaults to AUTO
target_layer=cls_head_output_node_names,
saliency_map_size=(23, 23), # Optional
)
# Prepare input image and explanation parameters, can be different for each explain call
image = cv2.imread(args.image_path)
# Generate explanation
explanation = explainer(
image,
targets=[0, 1, 2], # target classes to explain
overlay=True,
)
logger.info(
f"Generated {len(explanation.saliency_map)} detection "
f"saliency maps of layout {explanation.layout} with shape {explanation.shape}."
)
# Save saliency maps for visual inspection
if args.output is not None:
output = Path(args.output) / "detection_white_box"
ori_image_name = Path(args.image_path).stem
explanation.save(output, f"{ori_image_name}_")
def explain_black_box(args):
"""
Black-box scenario.
Per-box saliency map generation for all detection models (using AISEDetection).
"""
# Create ov.Model
model: ov.Model
model = ov.Core().read_model(args.model_path)
# Create explainer object
explainer = xai.Explainer(
model=model,
task=xai.Task.DETECTION,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
explain_mode=ExplainMode.BLACKBOX, # defaults to AUTO
)
# Prepare input image and explanation parameters, can be different for each explain call
image = cv2.imread(args.image_path)
# Generate explanation
explanation = explainer(
image,
targets=[0], # target boxes to explain
overlay=True,
preset=Preset.SPEED,
)
logger.info(
f"Generated {len(explanation.saliency_map)} detection "
f"saliency maps of layout {explanation.layout} with shape {explanation.shape}."
)
# Save saliency maps for visual inspection
if args.output is not None:
output = Path(args.output) / "detection_black_box"
ori_image_name = Path(args.image_path).stem
explanation.save(output, f"{ori_image_name}_")
def main(argv):
parser = get_argument_parser()
args = parser.parse_args(argv)
explain_white_box(args)
explain_black_box(args)
if __name__ == "__main__":
main(sys.argv[1:])