-
Notifications
You must be signed in to change notification settings - Fork 87
/
demo.py
executable file
·206 lines (176 loc) · 7.62 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
#!/usr/bin/env python
import datetime
import logging
import pathlib
from typing import Optional
import cv2
import numpy as np
import yacs.config
from gaze_estimation import GazeEstimationMethod, GazeEstimator
from gaze_estimation.gaze_estimator.common import (Face, FacePartsName,
Visualizer)
from gaze_estimation.utils import load_config
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Demo:
QUIT_KEYS = {27, ord('q')}
def __init__(self, config: yacs.config.CfgNode):
self.config = config
self.gaze_estimator = GazeEstimator(config)
self.visualizer = Visualizer(self.gaze_estimator.camera)
self.cap = self._create_capture()
self.output_dir = self._create_output_dir()
self.writer = self._create_video_writer()
self.stop = False
self.show_bbox = self.config.demo.show_bbox
self.show_head_pose = self.config.demo.show_head_pose
self.show_landmarks = self.config.demo.show_landmarks
self.show_normalized_image = self.config.demo.show_normalized_image
self.show_template_model = self.config.demo.show_template_model
def run(self) -> None:
while True:
if self.config.demo.display_on_screen:
self._wait_key()
if self.stop:
break
ok, frame = self.cap.read()
if not ok:
break
undistorted = cv2.undistort(
frame, self.gaze_estimator.camera.camera_matrix,
self.gaze_estimator.camera.dist_coefficients)
self.visualizer.set_image(frame.copy())
faces = self.gaze_estimator.detect_faces(undistorted)
for face in faces:
self.gaze_estimator.estimate_gaze(undistorted, face)
self._draw_face_bbox(face)
self._draw_head_pose(face)
self._draw_landmarks(face)
self._draw_face_template_model(face)
self._draw_gaze_vector(face)
self._display_normalized_image(face)
if self.config.demo.use_camera:
self.visualizer.image = self.visualizer.image[:, ::-1]
if self.writer:
self.writer.write(self.visualizer.image)
if self.config.demo.display_on_screen:
cv2.imshow('frame', self.visualizer.image)
self.cap.release()
if self.writer:
self.writer.release()
def _create_capture(self) -> cv2.VideoCapture:
if self.config.demo.use_camera:
cap = cv2.VideoCapture(0)
elif self.config.demo.video_path:
cap = cv2.VideoCapture(self.config.demo.video_path)
else:
raise ValueError
cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.gaze_estimator.camera.width)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.gaze_estimator.camera.height)
return cap
def _create_output_dir(self) -> Optional[pathlib.Path]:
if not self.config.demo.output_dir:
return
output_dir = pathlib.Path(self.config.demo.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
return output_dir
@staticmethod
def _create_timestamp() -> str:
dt = datetime.datetime.now()
return dt.strftime('%Y%m%d_%H%M%S')
def _create_video_writer(self) -> Optional[cv2.VideoWriter]:
if not self.output_dir:
return None
ext = self.config.demo.output_file_extension
if ext == 'mp4':
fourcc = cv2.VideoWriter_fourcc(*'H264')
elif ext == 'avi':
fourcc = cv2.VideoWriter_fourcc(*'PIM1')
else:
raise ValueError
output_path = self.output_dir / f'{self._create_timestamp()}.{ext}'
writer = cv2.VideoWriter(output_path.as_posix(), fourcc, 30,
(self.gaze_estimator.camera.width,
self.gaze_estimator.camera.height))
if writer is None:
raise RuntimeError
return writer
def _wait_key(self) -> None:
key = cv2.waitKey(self.config.demo.wait_time) & 0xff
if key in self.QUIT_KEYS:
self.stop = True
elif key == ord('b'):
self.show_bbox = not self.show_bbox
elif key == ord('l'):
self.show_landmarks = not self.show_landmarks
elif key == ord('h'):
self.show_head_pose = not self.show_head_pose
elif key == ord('n'):
self.show_normalized_image = not self.show_normalized_image
elif key == ord('t'):
self.show_template_model = not self.show_template_model
def _draw_face_bbox(self, face: Face) -> None:
if not self.show_bbox:
return
self.visualizer.draw_bbox(face.bbox)
def _draw_head_pose(self, face: Face) -> None:
if not self.show_head_pose:
return
# Draw the axes of the model coordinate system
length = self.config.demo.head_pose_axis_length
self.visualizer.draw_model_axes(face, length, lw=2)
euler_angles = face.head_pose_rot.as_euler('XYZ', degrees=True)
pitch, yaw, roll = face.change_coordinate_system(euler_angles)
logger.info(f'[head] pitch: {pitch:.2f}, yaw: {yaw:.2f}, '
f'roll: {roll:.2f}, distance: {face.distance:.2f}')
def _draw_landmarks(self, face: Face) -> None:
if not self.show_landmarks:
return
self.visualizer.draw_points(face.landmarks,
color=(0, 255, 255),
size=1)
def _draw_face_template_model(self, face: Face) -> None:
if not self.show_template_model:
return
self.visualizer.draw_3d_points(face.model3d,
color=(255, 0, 525),
size=1)
def _display_normalized_image(self, face: Face) -> None:
if not self.config.demo.display_on_screen:
return
if not self.show_normalized_image:
return
if self.config.mode == GazeEstimationMethod.MPIIGaze.name:
reye = face.reye.normalized_image
leye = face.leye.normalized_image
normalized = np.hstack([reye, leye])
elif self.config.mode == GazeEstimationMethod.MPIIFaceGaze.name:
normalized = face.normalized_image
else:
raise ValueError
if self.config.demo.use_camera:
normalized = normalized[:, ::-1]
cv2.imshow('normalized', normalized)
def _draw_gaze_vector(self, face: Face) -> None:
length = self.config.demo.gaze_visualization_length
if self.config.mode == GazeEstimationMethod.MPIIGaze.name:
for key in [FacePartsName.REYE, FacePartsName.LEYE]:
eye = getattr(face, key.name.lower())
self.visualizer.draw_3d_line(
eye.center, eye.center + length * eye.gaze_vector)
pitch, yaw = np.rad2deg(eye.vector_to_angle(eye.gaze_vector))
logger.info(
f'[{key.name.lower()}] pitch: {pitch:.2f}, yaw: {yaw:.2f}')
elif self.config.mode == GazeEstimationMethod.MPIIFaceGaze.name:
self.visualizer.draw_3d_line(
face.center, face.center + length * face.gaze_vector)
pitch, yaw = np.rad2deg(face.vector_to_angle(face.gaze_vector))
logger.info(f'[face] pitch: {pitch:.2f}, yaw: {yaw:.2f}')
else:
raise ValueError
def main():
config = load_config()
demo = Demo(config)
demo.run()
if __name__ == '__main__':
main()