diff --git a/gdl/datasets/FaceAlignmentTools.py b/gdl/datasets/FaceAlignmentTools.py new file mode 100644 index 0000000..b92e010 --- /dev/null +++ b/gdl/datasets/FaceAlignmentTools.py @@ -0,0 +1,108 @@ +import numpy as np +from pathlib import Path +from gdl.datasets.ImageDatasetHelpers import bbox2point, bbpoint_warp +import skvideo +import types + + +def align_face(image, landmarks, landmark_type, scale_adjustment, target_size_height, target_size_width=None,): + """ + Returns an image with the face aligned to the center of the image. + :param image: The full resolution image in which to align the face. + :param landmarks: The landmarks of the face in the image (in the original image coordinates). + :param landmark_type: The type of landmarks. Such as 'kpt68' or 'bbox' or 'mediapipe'. + :param scale_adjustment: The scale adjustment to apply to the image. + :param target_size_height: The height of the output image. + :param target_size_width: The width of the output image. If not provided, it is assumed to be the same as target_size_height. + :return: The aligned face image. The image will be in range [0,1]. + """ + # landmarks_for_alignment = "mediapipe" + left = landmarks[:,0].min() + top = landmarks[:,1].min() + right = landmarks[:,0].max() + bottom = landmarks[:,1].max() + + old_size, center = bbox2point(left, right, top, bottom, type=landmark_type) + size = (old_size * scale_adjustment).astype(np.int32) + + img_warped, lmk_warped = bbpoint_warp(image, center, size, target_size_height, target_size_width, landmarks=landmarks) + + return img_warped + + +def align_video(video, centers, sizes, landmarks, target_size_height, target_size_width=None, ): + """ + Returns a video with the face aligned to the center of the image. + :param video: The full resolution video in which to align the face. + :param landmarks: The landmarks of the face in the video (in the original video coordinates). + :param target_size_height: The height of the output video. + :param target_size_width: The width of the output video. If not provided, it is assumed to be the same as target_size_height. + :return: The aligned face video. The video will be in range [0,1]. + """ + if isinstance(video, (str, Path)): + video = skvideo.io.vread(video) + elif isinstance(video, (np.ndarray, types.GeneratorType)): + pass + else: + raise ValueError("video must be a string, Path, or numpy array") + + aligned_video = [] + warped_landmarks = [] + if isinstance(video, np.ndarray): + for i in range(len(centers)): + img_warped, lmk_warped = bbpoint_warp(video[i], centers[i], sizes[i], + target_size_height=target_size_height, target_size_width=target_size_width, + landmarks=landmarks[i]) + aligned_video.append(img_warped) + warped_landmarks += [lmk_warped] + + elif isinstance(video, types.GeneratorType): + for i, frame in enumerate(video): + img_warped, lmk_warped = bbpoint_warp(frame, centers[i], sizes[i], + target_size_height=target_size_height, target_size_width=target_size_width, + landmarks=landmarks[i]) + aligned_video.append(img_warped) + warped_landmarks += [lmk_warped] + + aligned_video = np.stack(aligned_video, axis=0) + return aligned_video, warped_landmarks + + +def align_and_save_video(video, out_video_path, centers, sizes, landmarks, target_size_height, target_size_width=None, output_dict=None): + """ + Returns a video with the face aligned to the center of the image. + :param video: The full resolution video in which to align the face. + :param landmarks: The landmarks of the face in the video (in the original video coordinates). + :param target_size_height: The height of the output video. + :param target_size_width: The width of the output video. If not provided, it is assumed to be the same as target_size_height. + :return: The aligned face video. The video will be in range [0,1]. + """ + if isinstance(video, (str, Path)): + video = skvideo.io.vread(video) + elif isinstance(video, (np.ndarray, types.GeneratorType)): + pass + else: + raise ValueError("video must be a string, Path, or numpy array") + + writer = skvideo.io.FFmpegWriter(str(out_video_path), outputdict=output_dict) + warped_landmarks = [] + if isinstance(video, np.ndarray): + for i in range(len(centers)): + img_warped, lmk_warped = bbpoint_warp(video[i], centers[i], sizes[i], + target_size_height=target_size_height, target_size_width=target_size_width, + landmarks=landmarks[i]) + img_warped = (img_warped * 255).astype(np.uint8) + writer.writeFrame(img_warped) + warped_landmarks += [lmk_warped] + + elif isinstance(video, types.GeneratorType): + for i, frame in enumerate(video): + img_warped, lmk_warped = bbpoint_warp(frame, centers[i], sizes[i], + target_size_height=target_size_height, target_size_width=target_size_width, + landmarks=landmarks[i]) + img_warped = (img_warped * 255).astype(np.uint8) + writer.writeFrame(img_warped) + warped_landmarks += [lmk_warped] + writer.close() + + return warped_landmarks \ No newline at end of file diff --git a/gdl/datasets/ImageDatasetHelpers.py b/gdl/datasets/ImageDatasetHelpers.py index 1dfea76..4fbed4d 100644 --- a/gdl/datasets/ImageDatasetHelpers.py +++ b/gdl/datasets/ImageDatasetHelpers.py @@ -27,12 +27,25 @@ def bbox2point(left, right, top, bottom, type='bbox'): ''' if type == 'kpt68': old_size = (right - left + bottom - top) / 2 * 1.1 - center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0]) + center_x = right - (right - left) / 2.0 + center_y = bottom - (bottom - top) / 2.0 + # center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0]) elif type == 'bbox': old_size = (right - left + bottom - top) / 2 - center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 + old_size * 0.12]) + center_x = right - (right - left) / 2.0 + center_y = bottom - (bottom - top) / 2.0 + old_size * 0.12 + # center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 + old_size * 0.12]) + elif type == "mediapipe": + old_size = (right - left + bottom - top) / 2 * 1.1 + center_x = right - (right - left) / 2.0 + center_y = bottom - (bottom - top) / 2.0 + # center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0]) else: - raise NotImplementedError + raise NotImplementedError(f" bbox2point not implemented for {type} ") + if isinstance(center_x, np.ndarray): + center = np.stack([center_x, center_y], axis=1) + else: + center = np.array([center_x, center_y]) return old_size, center @@ -53,15 +66,31 @@ def point2transform(center, size, target_size_height, target_size_width): return tform -def bbpoint_warp(image, center, size, target_size_height, target_size_width=None, output_shape=None, inv=True, landmarks=None): +def bbpoint_warp(image, center, size, target_size_height, target_size_width=None, output_shape=None, inv=True, landmarks=None, + order=3 # order of interpolation, bicubic by default + ): target_size_width = target_size_width or target_size_height tform = point2transform(center, size, target_size_height, target_size_width) tf = tform.inverse if inv else tform output_shape = output_shape or (target_size_height, target_size_width) - dst_image = warp(image, tf, output_shape=output_shape, order=3) + dst_image = warp(image, tf, output_shape=output_shape, order=order) if landmarks is None: return dst_image # points need the matrix - tf_lmk = tform if inv else tform.inverse - dst_landmarks = tf_lmk(landmarks) + if isinstance(landmarks, np.ndarray): + assert isinstance(landmarks, np.ndarray) + tf_lmk = tform if inv else tform.inverse + dst_landmarks = tf_lmk(landmarks[:, :2]) + elif isinstance(landmarks, list): + tf_lmk = tform if inv else tform.inverse + dst_landmarks = [] + for i in range(len(landmarks)): + dst_landmarks += [tf_lmk(landmarks[i][:, :2])] + elif isinstance(landmarks, dict): + tf_lmk = tform if inv else tform.inverse + dst_landmarks = {} + for key, value in landmarks.items(): + dst_landmarks[key] = tf_lmk(landmarks[key][:, :2]) + else: + raise ValueError("landmarks must be np.ndarray, list or dict") return dst_image, dst_landmarks \ No newline at end of file diff --git a/gdl/datasets/VideoFaceDetectionDataset.py b/gdl/datasets/VideoFaceDetectionDataset.py new file mode 100644 index 0000000..2d2b606 --- /dev/null +++ b/gdl/datasets/VideoFaceDetectionDataset.py @@ -0,0 +1,144 @@ +""" +Author: Radek Danecek +Copyright (c) 2022, Radek Danecek +All rights reserved. + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# Using this computer program means that you agree to the terms +# in the LICENSE file included with this software distribution. +# Any use not explicitly granted by the LICENSE is prohibited. +# +# Copyright©2022 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# For comments or questions, please email us at emoca@tue.mpg.de +# For commercial licensing contact, please contact ps-license@tuebingen.mpg.de +""" + + +import numpy as np +import torch +from PIL import Image +from skimage.io import imread +from torchvision.transforms import ToTensor + +from gdl.utils.FaceDetector import load_landmark +from gdl.datasets.FaceAlignmentTools import align_face + +from skvideo.io import vread, vreader +from types import GeneratorType +import pickle as pkl + +class VideoFaceDetectionDataset(torch.utils.data.Dataset): + + def __init__(self, video_name, landmark_path, image_transforms=None, + align_landmarks=False, vid_read=None, output_im_range=None, + scale_adjustment=1.25, + target_size_height=256, + target_size_width=256, + ): + super().__init__() + self.video_name = video_name + self.landmark_path = landmark_path / "landmarks_original.pkl" + # if landmark_list is not None and len(lanmark_file_name) != len(image_list): + # raise RuntimeError("There must be a landmark for every image") + self.image_transforms = image_transforms + self.vid_read = vid_read or 'skvreader' # 'skvread' + self.prev_index = -1 + + self.scale_adjustment=scale_adjustment + self.target_size_height=target_size_height + self.target_size_width=target_size_width + + self.video_frames = None + if self.vid_read == "skvread": + self.video_frames = vread(str(self.video_name)) + elif self.vid_read == "skvreader": + self.video_frames = vreader(str(self.video_name)) + + with open(self.landmark_path, "rb") as f: + self.landmark_list = pkl.load(f) + + with open(landmark_path / "landmark_types.pkl", "rb") as f: + self.landmark_types = pkl.load(f) + + self.total_len = 0 + self.frame_map = {} # detection index to frame map + self.index_for_frame_map = {} # detection index to frame map + for i in range(len(self.landmark_list)): + for j in range(len(self.landmark_list[i])): + self.frame_map[self.total_len + j] = i + self.index_for_frame_map[self.total_len + j] = j + self.total_len += len(self.landmark_list[i]) + + self.output_im_range = output_im_range + + + def __getitem__(self, index): + # if index < len(self.image_list): + # x = self.mnist_data[index] + # raise IndexError("Out of bounds") + if index != self.prev_index+1 and self.vid_read != 'skvread': + raise RuntimeError("This dataset is meant to be accessed in ordered way only (and with 0 or 1 workers)") + + frame_index = self.frame_map[index] + detection_in_frame_index = self.index_for_frame_map[index] + landmark = self.landmark_list[frame_index][detection_in_frame_index] + landmark_type = self.landmark_types[frame_index][detection_in_frame_index] + + if isinstance(self.video_frames, np.ndarray): + img = self.video_frames[frame_index, ...] + elif isinstance(self.video_frames, GeneratorType): + img = next(self.video_frames) + else: + raise NotImplementedError() + + # try: + # if self.vid_read == 'skvread': + # img = vread(self.image_list[index]) + # img = img.transpose([2, 0, 1]).astype(np.float32) + # img_torch = torch.from_numpy(img) + # path = str(self.image_list[index]) + # elif self.vid_read == 'pil': + # img = Image.open(self.image_list[index]) + # img_torch = ToTensor()(img) + # path = str(self.image_list[index]) + # # path = f"{index:05d}" + # else: + # raise ValueError(f"Invalid image reading method {self.im_read}") + # except Exception as e: + # print(f"Failed to read '{self.image_list[index]}'. File is probably corrupted. Rerun data processing") + # raise e + + # crop out the face + img = align_face(img, landmark, landmark_type, scale_adjustment=1.25, target_size_height=256, target_size_width=256,) + if self.output_im_range == 255: + img = img * 255.0 + img = img.astype(np.float32) + img_torch = ToTensor()(img) + + # # plot img with pyplot + # import matplotlib.pyplot as plt + # plt.figure() + # plt.imshow(img) + # plt.show() + # # plot image with plotly + # import plotly.graph_objects as go + # fig = go.Figure(data=go.Image(z=img*255.,)) + # fig.show() + + + if self.image_transforms is not None: + img_torch = self.image_transforms(img_torch) + + batch = {"image" : img_torch, + # "path" : path + } + + self.prev_index += 1 + return batch + + def __len__(self): + return self.total_len \ No newline at end of file